Files
traefikoidc/token_consolidated_test.go
T
lukaszraczylo e64fc7f730 Add redis support for distributed caching (#83)
* Add redis support for distributed caching

* Move towards the self-provided Redis connection pool and RESP protocol implementation.
Official redis client library won't work with yaegi.

* fixup! Move towards the self-provided Redis connection pool and RESP protocol implementation. Official redis client library won't work with yaegi.

* fixup! fixup! Move towards the self-provided Redis connection pool and RESP protocol implementation. Official redis client library won't work with yaegi.

* fixup! fixup! fixup! Move towards the self-provided Redis connection pool and RESP protocol implementation. Official redis client library won't work with yaegi.

* fixup! fixup! fixup! fixup! Move towards the self-provided Redis connection pool and RESP protocol implementation. Official redis client library won't work with yaegi.

* fixup! fixup! fixup! fixup! fixup! Move towards the self-provided Redis connection pool and RESP protocol implementation. Official redis client library won't work with yaegi.

* ... and another all nighter.

* fixup! ... and another all nighter.

* fixup! fixup! ... and another all nighter.

* fixup! fixup! fixup! ... and another all nighter.

* Resolve issue #85 by adding ability to set custom claims in JWT tokens

* Remove redundant validation in auth middleware ( issue #89 )

* Add ability to set cookie prefix for session cookies ( #87 )

* fixup! Add ability to set cookie prefix for session cookies ( #87 )

* Add ability to set cookie max age - issue #91

* Potential fix for code scanning alert no. 10: Size computation for allocation may overflow

Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>

* fixup! Merge main into 0.8.0-redis: resolve conflicts

---------

Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
2025-11-30 02:18:46 +00:00

915 lines
26 KiB
Go

package traefikoidc
import (
"bytes"
"compress/gzip"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"sync"
"sync/atomic"
"testing"
"text/template"
"time"
"golang.org/x/time/rate"
)
// ============================================================================
// Test Constants
// ============================================================================
// Test tokens used across multiple test files
var (
ValidAccessToken = "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2V5LWlkIn0.eyJpc3MiOiJodHRwczovL3Rlc3QtaXNzdWVyLmNvbSIsImF1ZCI6InRlc3QtY2xpZW50LWlkIiwiZXhwIjozMDAwMDAwMDAwLCJzdWIiOiJ0ZXN0LXN1YmplY3QiLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.dGVzdC1zaWduYXR1cmU"
ValidIDToken = "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2V5LWlkIn0.eyJpc3MiOiJodHRwczovL3Rlc3QtaXNzdWVyLmNvbSIsImF1ZCI6InRlc3QtY2xpZW50LWlkIiwiZXhwIjozMDAwMDAwMDAwLCJzdWIiOiJ0ZXN0LXN1YmplY3QiLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.dGVzdC1zaWduYXR1cmU"
ValidRefreshToken = "refresh_token_abc123"
MinimalValidJWT = "eyJhbGciOiJub25lIn0.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIn0."
InvalidTokenOneDot = "invalid.token"
InvalidTokenNoDots = "invalidtoken"
InvalidTokenThreeDots = "invalid..token"
)
// ============================================================================
// Token Type Tests
// ============================================================================
func TestTokenTypes(t *testing.T) {
t.Run("TokenTypeDistinction", func(t *testing.T) {
type templateData struct {
Claims map[string]interface{}
AccessToken string
IDToken string
RefreshToken string
}
testData := templateData{
AccessToken: "test-access-token-abc123",
IDToken: "test-id-token-xyz789",
RefreshToken: "test-refresh-token",
Claims: map[string]interface{}{
"sub": "test-subject",
"email": "user@example.com",
},
}
tests := []struct {
name string
templateText string
expectedValue string
}{
{
name: "Access Token Only",
templateText: "Bearer {{.AccessToken}}",
expectedValue: "Bearer test-access-token-abc123",
},
{
name: "ID Token Only",
templateText: "ID: {{.IDToken}}",
expectedValue: "ID: test-id-token-xyz789",
},
{
name: "Both Tokens",
templateText: "Access: {{.AccessToken}} ID: {{.IDToken}}",
expectedValue: "Access: test-access-token-abc123 ID: test-id-token-xyz789",
},
{
name: "Both Tokens in Authorization Format",
templateText: "Bearer {{.AccessToken}} and Bearer {{.IDToken}}",
expectedValue: "Bearer test-access-token-abc123 and Bearer test-id-token-xyz789",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
tmpl, err := template.New("test").Parse(tc.templateText)
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
}
var buf bytes.Buffer
err = tmpl.Execute(&buf, testData)
if err != nil {
t.Fatalf("Failed to execute template: %v", err)
}
result := buf.String()
if result != tc.expectedValue {
t.Errorf("Expected template output %q, got %q", tc.expectedValue, result)
}
})
}
})
t.Run("TokenTypeIntegration", func(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
idToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": float64(3000000000),
"sub": "id-token-subject",
"email": "id@example.com",
"nonce": "test-nonce",
"token_type": "id",
})
if err != nil {
t.Fatalf("Failed to create ID token: %v", err)
}
accessToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": float64(3000000000),
"sub": "access-token-subject",
"email": "access@example.com",
"scope": "openid email profile",
"token_type": "access",
})
if err != nil {
t.Fatalf("Failed to create access token: %v", err)
}
// Test that tokens are correctly stored and retrieved
req := httptest.NewRequest("GET", "http://example.com", nil)
session, err := ts.sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
defer session.ReturnToPool()
session.SetIDToken(idToken)
session.SetAccessToken(accessToken)
retrievedID := session.GetIDToken()
retrievedAccess := session.GetAccessToken()
if retrievedID != idToken {
t.Errorf("ID token mismatch: expected %q, got %q", idToken, retrievedID)
}
if retrievedAccess != accessToken {
t.Errorf("Access token mismatch: expected %q, got %q", accessToken, retrievedAccess)
}
})
}
// ============================================================================
// Token Corruption Tests
// ============================================================================
func TestTokenCorruption(t *testing.T) {
t.Run("TokenCorruptionScenario", func(t *testing.T) {
logger := NewLogger("debug")
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, "", "", 0, logger)
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
testTokens := NewTestTokens()
validJWT := testTokens.CreateLargeValidJWT(100)
tests := []struct {
name string
tokenSize int
iterations int
expectConsistent bool
corruptionScenario func(*SessionData)
}{
{
name: "Small token - multiple retrievals",
tokenSize: len(validJWT),
iterations: 10,
expectConsistent: true,
},
{
name: "Large chunked token - multiple retrievals",
tokenSize: 5000,
iterations: 10,
expectConsistent: true,
},
{
name: "Compression corruption simulation",
tokenSize: 2000,
iterations: 5,
expectConsistent: false,
corruptionScenario: func(session *SessionData) {
if session.accessSession != nil {
session.accessSession.Values["token"] = "corrupted_base64_!@#$"
session.accessSession.Values["compressed"] = true
}
},
},
{
name: "Chunk reassembly corruption simulation",
tokenSize: 25000,
iterations: 5,
expectConsistent: false,
corruptionScenario: func(session *SessionData) {
if len(session.accessTokenChunks) > 0 {
if chunk, exists := session.accessTokenChunks[0]; exists {
chunk.Values["token_chunk"] = "invalid_base64_!@#$%"
}
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
defer session.ReturnToPool()
token := createTokenOfSize(validJWT, tt.tokenSize)
session.SetAccessToken(token)
var retrievedTokens []string
for i := 0; i < tt.iterations; i++ {
retrieved := session.GetAccessToken()
retrievedTokens = append(retrievedTokens, retrieved)
if tt.expectConsistent && retrieved != token {
t.Errorf("Iteration %d: Token changed unexpectedly", i)
}
}
if tt.corruptionScenario != nil {
tt.corruptionScenario(session)
retrieved := session.GetAccessToken()
if retrieved == token {
t.Error("Expected corrupted token to be different")
}
}
if tt.expectConsistent {
for i, retrievedToken := range retrievedTokens {
if retrievedToken != token {
t.Errorf("Iteration %d: Token mismatch", i)
}
}
}
})
}
})
t.Run("Base64CorruptionHandling", func(t *testing.T) {
tests := []struct {
name string
input string
expectError bool
}{
{"Valid base64", "eyJhbGciOiJSUzI1NiJ9", false},
{"Invalid characters", "eyJ!@#$%^&*()", true},
{"Missing padding", "eyJhbGc", false}, // base64url doesn't require padding
{"Empty string", "", false},
{"Spaces in base64", "eyJ hbG ciOi JSU zI1 NiJ9", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := base64.RawURLEncoding.DecodeString(strings.TrimSpace(tt.input))
hasError := err != nil
if hasError != tt.expectError {
t.Errorf("Expected error=%v, got error=%v (err: %v)", tt.expectError, hasError, err)
}
})
}
})
}
// ============================================================================
// Token Resilience Tests
// ============================================================================
func TestTokenResilience(t *testing.T) {
t.Run("ConcurrentTokenAccess", func(t *testing.T) {
logger := NewLogger("debug")
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, "", "", 0, logger)
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
req := httptest.NewRequest("GET", "http://example.com", nil)
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
defer session.ReturnToPool()
testToken := "test-token-" + generateRandomString(100)
session.SetAccessToken(testToken)
var wg sync.WaitGroup
errors := make(chan error, 100)
successCount := int32(0)
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
retrieved := session.GetAccessToken()
if retrieved == testToken {
atomic.AddInt32(&successCount, 1)
} else {
errors <- fmt.Errorf("token mismatch: expected %q, got %q", testToken, retrieved)
}
}()
}
wg.Wait()
close(errors)
for err := range errors {
t.Error(err)
}
if successCount != 100 {
t.Errorf("Expected 100 successful retrievals, got %d", successCount)
}
})
t.Run("TokenSizeHandling", func(t *testing.T) {
logger := NewLogger("debug")
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, "", "", 0, logger)
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
sizes := []int{
100, // Small token
1000, // Medium token
4000, // Just under chunk threshold
5000, // Just over chunk threshold
10000, // Large token requiring chunking
20000, // Very large token (but within 25 chunk limit)
}
for _, size := range sizes {
t.Run(fmt.Sprintf("Size_%d", size), func(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com", nil)
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
defer session.ReturnToPool()
// Create a valid JWT token of the desired size
token := createTokenOfSize(ValidAccessToken, size)
session.SetAccessToken(token)
retrieved := session.GetAccessToken()
// For very large tokens that exceed chunk limits, retrieval will fail
if size > 15000 && retrieved == "" {
// Expected failure for very large tokens
t.Logf("Token size %d exceeds chunk limits (expected)", size)
} else if retrieved != token {
t.Errorf("Token mismatch for size %d", size)
}
})
}
})
t.Run("RateLimitedTokenRefresh", func(t *testing.T) {
limiter := rate.NewLimiter(rate.Limit(10), 1) // 10 requests per second
var wg sync.WaitGroup
successCount := int32(0)
deniedCount := int32(0)
for i := 0; i < 50; i++ {
wg.Add(1)
go func() {
defer wg.Done()
if limiter.Allow() {
atomic.AddInt32(&successCount, 1)
} else {
atomic.AddInt32(&deniedCount, 1)
}
}()
time.Sleep(10 * time.Millisecond) // Spread requests over 500ms
}
wg.Wait()
t.Logf("Allowed: %d, Denied: %d", successCount, deniedCount)
if successCount == 0 {
t.Error("No requests were allowed")
}
if successCount == 50 {
t.Error("All requests were allowed, rate limiting not working")
}
})
}
// ============================================================================
// Token Validation Tests
// ============================================================================
func TestTokenValidation(t *testing.T) {
t.Run("JWTStructureValidation", func(t *testing.T) {
tests := []struct {
name string
token string
expectValid bool
}{
{
name: "Valid JWT structure",
token: "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiJ0ZXN0In0.signature",
expectValid: true,
},
{
name: "Missing signature",
token: "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiJ0ZXN0In0",
expectValid: false,
},
{
name: "Missing payload",
token: "eyJhbGciOiJSUzI1NiJ9..signature",
expectValid: true, // Empty payload is technically valid
},
{
name: "Only header",
token: "eyJhbGciOiJSUzI1NiJ9",
expectValid: false,
},
{
name: "Too many parts",
token: "header.payload.signature.extra",
expectValid: false,
},
{
name: "Empty token",
token: "",
expectValid: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parts := strings.Split(tt.token, ".")
isValid := len(parts) == 3
if isValid != tt.expectValid {
t.Errorf("Expected valid=%v, got %v", tt.expectValid, isValid)
}
})
}
})
t.Run("TokenExpiryValidation", func(t *testing.T) {
now := time.Now()
tests := []struct {
name string
exp time.Time
expectValid bool
}{
{"Future expiry", now.Add(time.Hour), true},
{"Just expired", now.Add(-time.Second), false},
{"Long expired", now.Add(-24 * time.Hour), false},
{"Far future", now.Add(365 * 24 * time.Hour), true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
isValid := tt.exp.After(now)
if isValid != tt.expectValid {
t.Errorf("Expected valid=%v, got %v", tt.expectValid, isValid)
}
})
}
})
}
// ============================================================================
// Token Chunking Tests
// ============================================================================
func TestTokenChunking(t *testing.T) {
t.Run("ChunkSplitting", func(t *testing.T) {
chunkSize := 4000
tests := []struct {
name string
tokenSize int
expectedChunks int
}{
{"Small token", 100, 1},
{"Just under chunk size", 3999, 1},
{"Exactly chunk size", 4000, 1},
{"Just over chunk size", 4100, 2},
{"Multiple chunks", 10000, 3},
{"Large token", 50000, 13},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
token := generateRandomString(tt.tokenSize)
chunks := (len(token) + chunkSize - 1) / chunkSize
if chunks != tt.expectedChunks {
t.Errorf("Expected %d chunks, got %d", tt.expectedChunks, chunks)
}
})
}
})
t.Run("ChunkReassembly", func(t *testing.T) {
originalToken := generateRandomString(10000)
chunkSize := 4000
// Split into chunks
var chunks []string
for i := 0; i < len(originalToken); i += chunkSize {
end := i + chunkSize
if end > len(originalToken) {
end = len(originalToken)
}
chunks = append(chunks, originalToken[i:end])
}
// Reassemble
var reassembled strings.Builder
for _, chunk := range chunks {
reassembled.WriteString(chunk)
}
if reassembled.String() != originalToken {
t.Error("Token reassembly failed")
}
})
}
// ============================================================================
// Token Compression Tests
// ============================================================================
func TestTokenCompression(t *testing.T) {
t.Run("CompressionEfficiency", func(t *testing.T) {
// Create a token with repetitive content (compresses well)
repetitiveToken := strings.Repeat("AAAA", 1000)
var compressed bytes.Buffer
gz := gzip.NewWriter(&compressed)
_, err := gz.Write([]byte(repetitiveToken))
if err != nil {
t.Fatalf("Compression failed: %v", err)
}
gz.Close()
compressionRatio := float64(len(repetitiveToken)) / float64(compressed.Len())
t.Logf("Compression ratio: %.2fx (original: %d, compressed: %d)",
compressionRatio, len(repetitiveToken), compressed.Len())
if compressionRatio < 10 {
t.Error("Expected better compression for repetitive data")
}
})
t.Run("CompressionDecompression", func(t *testing.T) {
tokens := []string{
generateRandomString(100),
generateRandomString(1000),
generateRandomString(10000),
strings.Repeat("A", 5000), // Highly compressible
}
for i, token := range tokens {
t.Run(fmt.Sprintf("Token_%d", i), func(t *testing.T) {
// Compress
var compressed bytes.Buffer
gz := gzip.NewWriter(&compressed)
_, err := gz.Write([]byte(token))
if err != nil {
t.Fatalf("Compression failed: %v", err)
}
gz.Close()
// Decompress
reader, err := gzip.NewReader(&compressed)
if err != nil {
t.Fatalf("Failed to create decompressor: %v", err)
}
var decompressed bytes.Buffer
_, err = decompressed.ReadFrom(reader)
if err != nil {
t.Fatalf("Decompression failed: %v", err)
}
reader.Close()
if decompressed.String() != token {
t.Error("Token changed after compression/decompression")
}
})
}
})
}
// ============================================================================
// Ajax Token Expiry Tests
// ============================================================================
func TestAjaxTokenExpiry(t *testing.T) {
t.Run("AjaxExpiryDetection", func(t *testing.T) {
tests := []struct {
name string
isAjax bool
tokenExpired bool
expectedStatus int
}{
{"Regular request, valid token", false, false, http.StatusOK},
{"Regular request, expired token", false, true, http.StatusFound},
{"Ajax request, valid token", true, false, http.StatusOK},
{"Ajax request, expired token", true, true, http.StatusUnauthorized},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com", nil)
if tt.isAjax {
req.Header.Set("X-Requested-With", "XMLHttpRequest")
}
w := httptest.NewRecorder()
// Simulate token validation
if tt.tokenExpired {
if tt.isAjax {
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte(`{"error": "token_expired", "message": "Your session has expired"}`))
} else {
w.WriteHeader(http.StatusFound)
w.Header().Set("Location", "/auth/login")
}
} else {
w.WriteHeader(http.StatusOK)
w.Write([]byte("Success"))
}
if w.Code != tt.expectedStatus {
t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code)
}
if tt.isAjax && tt.tokenExpired {
body := w.Body.String()
if !strings.Contains(body, "token_expired") {
t.Error("Expected token_expired error in response")
}
}
})
}
})
t.Run("AjaxRetryMechanism", func(t *testing.T) {
attemptCount := 0
maxRetries := 3
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attemptCount++
if attemptCount < maxRetries {
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte(`{"error": "token_expired"}`))
} else {
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"success": true}`))
}
})
server := httptest.NewServer(handler)
defer server.Close()
// Simulate client with retry logic
client := &http.Client{Timeout: 5 * time.Second}
var lastResponse *http.Response
for i := 0; i < maxRetries; i++ {
req, _ := http.NewRequest("GET", server.URL, nil)
req.Header.Set("X-Requested-With", "XMLHttpRequest")
resp, err := client.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
lastResponse = resp
if resp.StatusCode == http.StatusOK {
break
}
resp.Body.Close()
}
if lastResponse.StatusCode != http.StatusOK {
t.Errorf("Expected successful retry, got status %d", lastResponse.StatusCode)
}
lastResponse.Body.Close()
if attemptCount != maxRetries {
t.Errorf("Expected %d attempts, got %d", maxRetries, attemptCount)
}
})
}
// ============================================================================
// Test Token Creation Helper Tests
// ============================================================================
func TestTestTokens(t *testing.T) {
t.Run("CreateValidJWT", func(t *testing.T) {
tokens := NewTestTokens()
jwt := tokens.CreateValidJWT()
parts := strings.Split(jwt, ".")
if len(parts) != 3 {
t.Errorf("Expected 3 JWT parts, got %d", len(parts))
}
// Decode and verify header
headerJSON, err := base64.RawURLEncoding.DecodeString(parts[0])
if err != nil {
t.Fatalf("Failed to decode header: %v", err)
}
var header map[string]interface{}
if err := json.Unmarshal(headerJSON, &header); err != nil {
t.Fatalf("Failed to parse header: %v", err)
}
if header["alg"] != "RS256" {
t.Errorf("Expected RS256 algorithm, got %v", header["alg"])
}
})
t.Run("CreateLargeValidJWT", func(t *testing.T) {
tokens := NewTestTokens()
sizes := []int{10, 100, 1000}
for _, size := range sizes {
t.Run(fmt.Sprintf("Size_%d", size), func(t *testing.T) {
jwt := tokens.CreateLargeValidJWT(size)
// Verify it's a valid JWT structure
parts := strings.Split(jwt, ".")
if len(parts) != 3 {
t.Errorf("Expected 3 JWT parts, got %d", len(parts))
}
// Verify size is roughly as expected
// The JWT will be larger than the claim size due to base64 encoding and metadata
// Base64 encoding adds ~33% overhead, plus headers and structure
minExpectedSize := size + 200 // claim size + headers/structure overhead
if len(jwt) < minExpectedSize {
t.Errorf("JWT seems too small for requested claim size: got %d, expected at least %d", len(jwt), minExpectedSize)
}
})
}
})
t.Run("CreateExpiredJWT", func(t *testing.T) {
tokens := NewTestTokens()
jwt := tokens.CreateExpiredJWT()
parts := strings.Split(jwt, ".")
if len(parts) != 3 {
t.Errorf("Expected 3 JWT parts, got %d", len(parts))
}
// Decode payload to verify expiration
payloadJSON, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
t.Fatalf("Failed to decode payload: %v", err)
}
var payload map[string]interface{}
if err := json.Unmarshal(payloadJSON, &payload); err != nil {
t.Fatalf("Failed to parse payload: %v", err)
}
exp, ok := payload["exp"].(float64)
if !ok {
t.Fatal("Expected exp claim in payload")
}
if exp >= float64(time.Now().Unix()) {
t.Error("Token should be expired")
}
})
}
// ============================================================================
// Helper Functions
// ============================================================================
// Mock implementations for testing
type MockJWTVerifier struct {
valid bool
}
func (v *MockJWTVerifier) Verify(token string) error {
if !v.valid {
return fmt.Errorf("invalid token")
}
return nil
}
// equalSlices compares two string slices for equality
func equalSlices(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i, v := range a {
if v != b[i] {
return false
}
}
return true
}
func createTokenOfSize(baseToken string, targetSize int) string {
// For large tokens, use the CreateLargeValidJWT function which creates proper JWT format
if targetSize > 1000 {
testTokens := NewTestTokens()
// Calculate the claim size needed to reach approximately the target token size
// A rough estimate: header ~60 bytes, payload wrapper ~150 bytes, signature ~20 bytes
// So claim size = targetSize - 230
claimSize := targetSize - 230
if claimSize < 0 {
claimSize = 10
}
return testTokens.CreateLargeValidJWT(claimSize)
}
// For smaller tokens, just return the base token
return baseToken
}
// TestTokens provides test JWT tokens
type TestTokens struct {
validJWT string
expiredJWT string
}
func NewTestTokens() *TestTokens {
return &TestTokens{
validJWT: "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2V5LWlkIn0.eyJpc3MiOiJodHRwczovL3Rlc3QtaXNzdWVyLmNvbSIsImF1ZCI6InRlc3QtY2xpZW50LWlkIiwiZXhwIjozMDAwMDAwMDAwLCJzdWIiOiJ0ZXN0LXN1YmplY3QiLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.dGVzdC1zaWduYXR1cmU",
expiredJWT: "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2V5LWlkIn0.eyJpc3MiOiJodHRwczovL3Rlc3QtaXNzdWVyLmNvbSIsImF1ZCI6InRlc3QtY2xpZW50LWlkIiwiZXhwIjoxMDAwMDAwMDAwLCJzdWIiOiJ0ZXN0LXN1YmplY3QiLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.dGVzdC1zaWduYXR1cmU",
}
}
func (tt *TestTokens) CreateValidJWT() string {
return tt.validJWT
}
// TokenSet represents a complete set of tokens with proper field names
type TokenSet struct {
AccessToken string
IDToken string
RefreshToken string
}
func (tt *TestTokens) GetValidTokenSet() *TokenSet {
return &TokenSet{
AccessToken: tt.validJWT,
IDToken: tt.validJWT,
RefreshToken: ValidRefreshToken,
}
}
func (tt *TestTokens) CreateIncompressibleToken(size int) string {
// Create a token with random data that doesn't compress well
return "incompressible." + generateRandomString(size) + ".signature"
}
func (tt *TestTokens) CreateUniqueValidJWT(suffix string) string {
// Return a unique valid JWT for each call
return tt.validJWT + "_" + suffix
}
func (tt *TestTokens) GetLargeTokenSet() *TokenSet {
return &TokenSet{
AccessToken: tt.CreateIncompressibleToken(2000),
IDToken: tt.CreateIncompressibleToken(2000),
RefreshToken: ValidRefreshToken,
}
}
func (tt *TestTokens) CreateExpiredJWT() string {
return tt.expiredJWT
}
func (tt *TestTokens) CreateLargeValidJWT(claimSize int) string {
// Create a large claim
largeClaim := generateRandomString(claimSize)
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"RS256","kid":"test-key-id"}`))
payload := fmt.Sprintf(`{"iss":"https://test-issuer.com","aud":"test-client-id","exp":3000000000,"sub":"test-subject","email":"test@example.com","large_claim":"%s"}`, largeClaim)
encodedPayload := base64.RawURLEncoding.EncodeToString([]byte(payload))
signature := base64.RawURLEncoding.EncodeToString([]byte("test-signature"))
return fmt.Sprintf("%s.%s.%s", header, encodedPayload, signature)
}