From 79e9b164f99e0d34673928845a3a95441d0ede01 Mon Sep 17 00:00:00 2001 From: Lukasz Raczylo Date: Mon, 13 Oct 2025 10:43:35 +0100 Subject: [PATCH] release 0.7.9 (#78) * Speed improvements. After introduction of introspection the plugin became significantly slower. This commit introduces several optimizations to bring the speed back up. * Add relevant documentation and tests. --- .traefik.yml | 62 +++++ README.md | 12 +- cache_manager.go | 8 + .../AUTH0_AUDIENCE_GUIDE.md | 0 .../TEST_EXECUTION_GUIDE.md | 0 main.go | 1 + token_manager.go | 230 ++++++++++-------- token_type_detection_bench_test.go | 173 +++++++++++++ token_type_detection_test.go | 211 ++++++++++++++++ types.go | 1 + universal_cache_singleton.go | 18 +- 11 files changed, 611 insertions(+), 105 deletions(-) rename AUTH0_AUDIENCE_GUIDE.md => docs/AUTH0_AUDIENCE_GUIDE.md (100%) rename TEST_EXECUTION_GUIDE.md => docs/TEST_EXECUTION_GUIDE.md (100%) create mode 100644 token_type_detection_bench_test.go create mode 100644 token_type_detection_test.go diff --git a/.traefik.yml b/.traefik.yml index 6db077f..f077f2b 100644 --- a/.traefik.yml +++ b/.traefik.yml @@ -121,6 +121,12 @@ testData: - "https://*.example.com" corsAllowCredentials: true + # Cross-origin policies + permissionsPolicy: "geolocation=(), camera=(), microphone=()" + crossOriginEmbedderPolicy: "require-corp" + crossOriginOpenerPolicy: "same-origin" + crossOriginResourcePolicy: "same-origin" + # Custom headers customHeaders: X-Custom-Header: "production" @@ -1031,3 +1037,59 @@ configuration: Remove the X-Powered-By header to hide technology stack information. Default: true required: false + + permissionsPolicy: + type: string + description: | + Permissions-Policy header to control browser feature permissions. + This header allows you to control which features and APIs can be used. + + Examples: + - "geolocation=(), camera=(), microphone=()" (deny all) + - "geolocation=(self), camera=()" (allow geolocation for same origin only) + + Common directives: accelerometer, camera, geolocation, gyroscope, + magnetometer, microphone, payment, usb + required: false + + crossOriginEmbedderPolicy: + type: string + description: | + Cross-Origin-Embedder-Policy (COEP) header to prevent untrusted + resources from being loaded. + + Options: + - "require-corp": Resources must explicitly grant permission + - "credentialless": Load without credentials for cross-origin resources + - "unsafe-none": No restrictions (default) + + Required for certain browser features like SharedArrayBuffer. + required: false + + crossOriginOpenerPolicy: + type: string + description: | + Cross-Origin-Opener-Policy (COOP) header to isolate browsing context + from cross-origin windows. + + Options: + - "same-origin": Isolate from cross-origin documents + - "same-origin-allow-popups": Allow popups that don't set COOP + - "unsafe-none": No isolation (default) + + Helps prevent cross-origin attacks and Spectre-like vulnerabilities. + required: false + + crossOriginResourcePolicy: + type: string + description: | + Cross-Origin-Resource-Policy (CORP) header to control which origins + can load this resource. + + Options: + - "same-origin": Only same-origin requests can load the resource + - "same-site": Only same-site requests can load the resource + - "cross-origin": Any origin can load the resource (default) + + Prevents your resources from being embedded on other sites. + required: false diff --git a/README.md b/README.md index 0275cba..6cb633d 100644 --- a/README.md +++ b/README.md @@ -76,7 +76,7 @@ experimental: plugins: traefikoidc: moduleName: github.com/lukaszraczylo/traefikoidc - version: v0.2.1 # Use the latest version + version: v0.7.8 # Use the latest version ``` 2. Configure the middleware in your dynamic configuration (see examples below). @@ -301,7 +301,7 @@ spec: strictAudienceValidation: true ``` -For detailed Auth0 configuration including all three scenarios, troubleshooting, and security best practices, see **[AUTH0_AUDIENCE_GUIDE.md](AUTH0_AUDIENCE_GUIDE.md)**. +For detailed Auth0 configuration including all three scenarios, troubleshooting, and security best practices, see **[AUTH0_AUDIENCE_GUIDE.md](docs/AUTH0_AUDIENCE_GUIDE.md)**. ## Security Headers Configuration @@ -421,6 +421,10 @@ securityHeaders: | `customHeaders` | Additional custom headers | `{}` | `{"X-Custom": "value"}` | | `disableServerHeader` | Remove Server header | `true` | `true`, `false` | | `disablePoweredByHeader` | Remove X-Powered-By header | `true` | `true`, `false` | +| `permissionsPolicy` | Permissions-Policy header | `` | `"geolocation=(), camera=(), microphone=()"` | +| `crossOriginEmbedderPolicy` | Cross-Origin-Embedder-Policy header | `` | `"require-corp"`, `"credentialless"`, `"unsafe-none"` | +| `crossOriginOpenerPolicy` | Cross-Origin-Opener-Policy header | `` | `"same-origin"`, `"same-origin-allow-popups"`, `"unsafe-none"` | +| `crossOriginResourcePolicy` | Cross-Origin-Resource-Policy header | `` | `"same-origin"`, `"same-site"`, `"cross-origin"` | ### CORS Wildcard Support @@ -855,7 +859,7 @@ spec: postLogoutRedirectURI: /logged-out-page # Must be in Auth0 Allowed Logout URLs ``` -**Note**: For detailed Auth0 audience configuration including opaque tokens and all security scenarios, see [AUTH0_AUDIENCE_GUIDE.md](AUTH0_AUDIENCE_GUIDE.md). +**Note**: For detailed Auth0 audience configuration including opaque tokens and all security scenarios, see [AUTH0_AUDIENCE_GUIDE.md](docs/AUTH0_AUDIENCE_GUIDE.md). ### Okta Configuration @@ -1029,7 +1033,7 @@ services: image: traefik:v3.2.1 command: - "--experimental.plugins.traefikoidc.modulename=github.com/lukaszraczylo/traefikoidc" - - "--experimental.plugins.traefikoidc.version=v0.2.1" + - "--experimental.plugins.traefikoidc.version=v0.7.8" volumes: - /var/run/docker.sock:/var/run/docker.sock - ./traefik-config/traefik.yml:/etc/traefik/traefik.yml diff --git a/cache_manager.go b/cache_manager.go index 01e9b11..7a3738d 100644 --- a/cache_manager.go +++ b/cache_manager.go @@ -69,6 +69,14 @@ func (cm *CacheManager) GetSharedIntrospectionCache() CacheInterface { return &CacheInterfaceWrapper{cache: cm.manager.GetIntrospectionCache()} } +// GetSharedTokenTypeCache returns the shared token type cache +// for caching token type detection results to improve performance +func (cm *CacheManager) GetSharedTokenTypeCache() CacheInterface { + cm.mu.RLock() + defer cm.mu.RUnlock() + return &CacheInterfaceWrapper{cache: cm.manager.GetTokenTypeCache()} +} + // Close gracefully shuts down all cache components func (cm *CacheManager) Close() error { cm.mu.Lock() diff --git a/AUTH0_AUDIENCE_GUIDE.md b/docs/AUTH0_AUDIENCE_GUIDE.md similarity index 100% rename from AUTH0_AUDIENCE_GUIDE.md rename to docs/AUTH0_AUDIENCE_GUIDE.md diff --git a/TEST_EXECUTION_GUIDE.md b/docs/TEST_EXECUTION_GUIDE.md similarity index 100% rename from TEST_EXECUTION_GUIDE.md rename to docs/TEST_EXECUTION_GUIDE.md diff --git a/main.go b/main.go index 7c918c0..956cce3 100644 --- a/main.go +++ b/main.go @@ -153,6 +153,7 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name return config.PostLogoutRedirectURI }(), tokenBlacklist: cacheManager.GetSharedTokenBlacklist(), + tokenTypeCache: cacheManager.GetSharedTokenTypeCache(), // Cache for token type detection jwkCache: cacheManager.GetSharedJWKCache(), metadataCache: cacheManager.GetSharedMetadataCache(), introspectionCache: cacheManager.GetSharedIntrospectionCache(), // Cache for introspection results diff --git a/token_manager.go b/token_manager.go index 5867fc9..b20f7c4 100644 --- a/token_manager.go +++ b/token_manager.go @@ -158,6 +158,134 @@ func (t *TraefikOidc) cacheVerifiedToken(token string, claims map[string]interfa t.tokenCache.Set(token, claims, duration) } +// detectTokenType efficiently detects whether a token is an ID token or access token. +// It uses caching to avoid re-detection and optimizes the detection order for performance. +// Parameters: +// - jwt: The parsed JWT structure containing header and claims. +// - token: The raw token string for cache key generation. +// +// Returns: +// - true if the token is an ID token, false if it's an access token. +func (t *TraefikOidc) detectTokenType(jwt *JWT, token string) bool { + // Use first 32 chars of token as cache key (sufficient for uniqueness) + cacheKey := token + if len(token) > 32 { + cacheKey = token[:32] + } + + // Check cache first + if t.tokenTypeCache != nil { + if cachedType, found := t.tokenTypeCache.Get(cacheKey); found { + if isIDToken, ok := cachedType.(bool); ok { + return isIDToken + } + } + } + + // Perform optimized detection + isIDToken := false + + // 1. Check 'nonce' claim first (most definitive for ID tokens - short circuit) + if nonce, ok := jwt.Claims["nonce"]; ok { + if _, ok := nonce.(string); ok { + isIDToken = true + if !t.suppressDiagnosticLogs { + t.safeLogDebugf("ID token detected via nonce claim") + } + // Cache and return immediately + if t.tokenTypeCache != nil { + t.tokenTypeCache.Set(cacheKey, true, 5*time.Minute) + } + return true + } + } + + // 2. Check 'typ' header for "at+jwt" (definitive for access tokens - short circuit) + if typ, ok := jwt.Header["typ"].(string); ok && typ == "at+jwt" { + // RFC 9068 compliant access token + if !t.suppressDiagnosticLogs { + t.safeLogDebugf("RFC 9068 access token detected (typ=at+jwt)") + } + // Cache and return immediately + if t.tokenTypeCache != nil { + t.tokenTypeCache.Set(cacheKey, false, 5*time.Minute) + } + return false + } + + // 3. Check 'token_use' claim (definitive if present - short circuit) + if tokenUse, ok := jwt.Claims["token_use"].(string); ok { + if tokenUse == "id" { + isIDToken = true + if !t.suppressDiagnosticLogs { + t.safeLogDebugf("ID token detected via token_use claim") + } + // Cache and return + if t.tokenTypeCache != nil { + t.tokenTypeCache.Set(cacheKey, true, 5*time.Minute) + } + return true + } else if tokenUse == "access" { + if !t.suppressDiagnosticLogs { + t.safeLogDebugf("Access token detected via token_use claim") + } + // Cache and return + if t.tokenTypeCache != nil { + t.tokenTypeCache.Set(cacheKey, false, 5*time.Minute) + } + return false + } + } + + // 4. Check 'scope' claim (strong indicator for access tokens) + if scope, ok := jwt.Claims["scope"]; ok { + if _, ok := scope.(string); ok { + if !t.suppressDiagnosticLogs { + t.safeLogDebugf("Access token detected via scope claim") + } + // Cache and return + if t.tokenTypeCache != nil { + t.tokenTypeCache.Set(cacheKey, false, 5*time.Minute) + } + return false + } + } + + // 5. Check if aud == clientID only (ID token pattern) + if aud, ok := jwt.Claims["aud"]; ok { + // Check string audience + if audStr, ok := aud.(string); ok && audStr == t.clientID { + isIDToken = true + } else if audArr, ok := aud.([]interface{}); ok { + // Check array audience - only treat as ID token if client_id is sole audience + if len(audArr) == 1 { + for _, v := range audArr { + if str, ok := v.(string); ok && str == t.clientID { + isIDToken = true + break + } + } + } + } + } + + // Cache the result + if t.tokenTypeCache != nil { + t.tokenTypeCache.Set(cacheKey, isIDToken, 5*time.Minute) + } + + // Log detection result in debug mode + if !t.suppressDiagnosticLogs { + if isIDToken { + t.safeLogDebugf("ID token detected via audience matching") + } else { + t.safeLogDebugf("Defaulting to access token") + } + } + + return isIDToken +} + // VerifyJWTSignatureAndClaims verifies JWT signature using provider's public keys and validates standard claims. // It retrieves the appropriate public key from the JWKS cache, verifies the token signature, // and validates standard OIDC claims like issuer, audience, and expiration. @@ -240,105 +368,8 @@ func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error t.safeLogDebugf("DIAGNOSTIC: Signature verification successful for kid=%s", kid) } - // Determine expected audience based on token type - // Per OIDC spec: ID tokens MUST have aud=client_id - // Access tokens can have custom audience values (e.g., Auth0 API identifiers) - - // Token type detection strategy (RFC 9068 + OIDC Core 1.0): - // 1. Check 'typ' header claim (RFC 9068) → "at+jwt" = ACCESS_TOKEN, "JWT" = could be either - // 2. Check explicit token type claims (token_use, token_type) if present - // 3. Check 'scope' claim → ACCESS_TOKEN (use configured audience) - // 4. Check 'nonce' claim → ID_TOKEN (use client_id, per OIDC spec) - // 5. Check if aud == client_id only → ID_TOKEN (use client_id) - // 6. Else → ACCESS_TOKEN with custom audience (use configured audience) - - isIDToken := false - isAccessToken := false - - // Step 1: Check typ header for explicit type (RFC 9068) - if typ, ok := jwt.Header["typ"].(string); ok { - if typ == "at+jwt" { - // RFC 9068 compliant access token - isAccessToken = true - if !t.suppressDiagnosticLogs { - t.safeLogDebugf("RFC 9068 access token detected (typ=at+jwt)") - } - } else if typ == "JWT" { - // Generic JWT, need further checks - if !t.suppressDiagnosticLogs { - t.safeLogDebugf("Generic JWT detected (typ=JWT), checking claims") - } - } - } - - // Step 2: Check explicit token type claims (if not already determined) - if !isAccessToken && !isIDToken { - // Check for token_use claim (used by some providers like AWS Cognito) - if tokenUse, ok := jwt.Claims["token_use"].(string); ok { - if tokenUse == "access" { - isAccessToken = true - } else if tokenUse == "id" { - isIDToken = true - } - } - - // Check for token_type claim - if !isAccessToken && !isIDToken { - if tokenType, ok := jwt.Claims["token_type"].(string); ok { - if tokenType == "access_token" || tokenType == "Bearer" { - isAccessToken = true - } else if tokenType == "id_token" { - isIDToken = true - } - } - } - } - - // Step 3: Check scope claim (access tokens have this) - if !isAccessToken && !isIDToken { - if scope, ok := jwt.Claims["scope"]; ok { - if _, ok := scope.(string); ok { - isAccessToken = true - } - } - } - - // Step 4: Check nonce claim (ID tokens have this per OIDC spec for replay protection) - if !isAccessToken && !isIDToken { - if nonce, ok := jwt.Claims["nonce"]; ok { - if _, ok := nonce.(string); ok { - isIDToken = true // Nonce indicates ID token - } - } - } - - // Step 5: If no scope and no nonce, check if aud matches client_id (indicates ID token) - if !isAccessToken && !isIDToken { - if aud, ok := jwt.Claims["aud"]; ok { - // Check string audience - if audStr, ok := aud.(string); ok && audStr == t.clientID { - isIDToken = true - } - // Check array audience - if audArr, ok := aud.([]interface{}); ok { - for _, v := range audArr { - if str, ok := v.(string); ok && str == t.clientID { - // Only treat as ID token if it's the sole audience - // Access tokens can also contain client_id in array - if len(audArr) == 1 { - isIDToken = true - } - break - } - } - } - } - } - - // Step 6: Default to access token if still undetermined - if !isIDToken { - isAccessToken = true - } + // Detect token type (cached for performance) + isIDToken := t.detectTokenType(jwt, token) // Determine expected audience expectedAudience := t.audience // Default to configured audience @@ -348,7 +379,6 @@ func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error t.safeLogDebugf("ID token detected, validating with client_id: %s", expectedAudience) } } else { - // Access token or ambiguous - use configured audience if !t.suppressDiagnosticLogs { t.safeLogDebugf("Access token detected, validating with audience: %s", expectedAudience) } diff --git a/token_type_detection_bench_test.go b/token_type_detection_bench_test.go new file mode 100644 index 0000000..d61feb1 --- /dev/null +++ b/token_type_detection_bench_test.go @@ -0,0 +1,173 @@ +package traefikoidc + +import ( + "testing" + "time" +) + +func BenchmarkDetectTokenType(b *testing.B) { + tr := &TraefikOidc{ + clientID: "test-client-id", + suppressDiagnosticLogs: true, + tokenTypeCache: NewTestCache(), + } + + // Create various JWT test cases + jwtWithNonce := &JWT{ + Header: map[string]interface{}{"alg": "RS256"}, + Claims: map[string]interface{}{ + "nonce": "test-nonce", + "aud": "test-client-id", + "exp": time.Now().Add(1 * time.Hour).Unix(), + }, + } + + jwtWithScope := &JWT{ + Header: map[string]interface{}{"alg": "RS256"}, + Claims: map[string]interface{}{ + "scope": "openid profile email", + "aud": "some-api", + "exp": time.Now().Add(1 * time.Hour).Unix(), + }, + } + + jwtComplexDetection := &JWT{ + Header: map[string]interface{}{"alg": "RS256", "typ": "JWT"}, + Claims: map[string]interface{}{ + "aud": []interface{}{"test-client-id", "another-aud"}, + "exp": time.Now().Add(1 * time.Hour).Unix(), + "sub": "user123", + "token_type": "Bearer", + "custom_claim": "value", + }, + } + + testCases := []struct { + name string + jwt *JWT + token string + }{ + {"WithNonce", jwtWithNonce, "token-with-nonce-for-benchmark-testing-12345678901234567890"}, + {"WithScope", jwtWithScope, "token-with-scope-for-benchmark-testing-12345678901234567890"}, + {"ComplexDetection", jwtComplexDetection, "token-complex-for-benchmark-testing-12345678901234567890"}, + } + + for _, tc := range testCases { + b.Run(tc.name+"_FirstCall", func(b *testing.B) { + // Benchmark first call (uncached) + for i := 0; i < b.N; i++ { + // Clear cache before each iteration + tr.tokenTypeCache.Clear() + _ = tr.detectTokenType(tc.jwt, tc.token) + } + }) + + b.Run(tc.name+"_Cached", func(b *testing.B) { + // Prime the cache + _ = tr.detectTokenType(tc.jwt, tc.token) + + // Benchmark cached calls + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = tr.detectTokenType(tc.jwt, tc.token) + } + }) + } +} + +// Benchmark comparison with the old implementation logic +func BenchmarkOldDetectionLogic(b *testing.B) { + clientID := "test-client-id" + + jwt := &JWT{ + Header: map[string]interface{}{"alg": "RS256", "typ": "JWT"}, + Claims: map[string]interface{}{ + "aud": []interface{}{"test-client-id", "another-aud"}, + "exp": time.Now().Add(1 * time.Hour).Unix(), + "sub": "user123", + "token_type": "Bearer", + "custom_claim": "value", + }, + } + + b.Run("OldLogic", func(b *testing.B) { + for i := 0; i < b.N; i++ { + // Simulate the old detection logic (all 6 sequential checks) + isIDToken := false + isAccessToken := false + + // Step 1: Check typ header + if typ, ok := jwt.Header["typ"].(string); ok { + if typ == "at+jwt" { + isAccessToken = true + } + } + + // Step 2: Check token_use claim + if !isAccessToken && !isIDToken { + if tokenUse, ok := jwt.Claims["token_use"].(string); ok { + if tokenUse == "access" { + isAccessToken = true + } else if tokenUse == "id" { + isIDToken = true + } + } + } + + // Step 3: Check token_type claim + if !isAccessToken && !isIDToken { + if tokenType, ok := jwt.Claims["token_type"].(string); ok { + if tokenType == "access_token" || tokenType == "Bearer" { + isAccessToken = true + } else if tokenType == "id_token" { + isIDToken = true + } + } + } + + // Step 4: Check scope claim + if !isAccessToken && !isIDToken { + if scope, ok := jwt.Claims["scope"]; ok { + if _, ok := scope.(string); ok { + isAccessToken = true + } + } + } + + // Step 5: Check nonce claim + if !isAccessToken && !isIDToken { + if nonce, ok := jwt.Claims["nonce"]; ok { + if _, ok := nonce.(string); ok { + isIDToken = true + } + } + } + + // Step 6: Check audience + if !isAccessToken && !isIDToken { + if aud, ok := jwt.Claims["aud"]; ok { + if audStr, ok := aud.(string); ok && audStr == clientID { + isIDToken = true + } + if audArr, ok := aud.([]interface{}); ok { + for _, v := range audArr { + if str, ok := v.(string); ok && str == clientID { + if len(audArr) == 1 { + isIDToken = true + } + break + } + } + } + } + } + + // Step 7: Default to access token + if !isIDToken { + isAccessToken = true + } + + _ = isAccessToken + } + }) +} diff --git a/token_type_detection_test.go b/token_type_detection_test.go new file mode 100644 index 0000000..4831707 --- /dev/null +++ b/token_type_detection_test.go @@ -0,0 +1,211 @@ +package traefikoidc + +import ( + "testing" + "time" +) + +func TestDetectTokenType(t *testing.T) { + // Create a test instance with mock cache + tr := &TraefikOidc{ + clientID: "test-client-id", + suppressDiagnosticLogs: true, + tokenTypeCache: NewTestCache(), + } + + testCases := []struct { + name string + jwt *JWT + token string + expectedID bool + description string + }{ + { + name: "ID token with nonce", + jwt: &JWT{ + Header: map[string]interface{}{"alg": "RS256"}, + Claims: map[string]interface{}{ + "nonce": "test-nonce", + "aud": "test-client-id", + }, + }, + token: "test-token-with-nonce", + expectedID: true, + description: "Should detect ID token via nonce claim", + }, + { + name: "RFC 9068 access token", + jwt: &JWT{ + Header: map[string]interface{}{ + "alg": "RS256", + "typ": "at+jwt", + }, + Claims: map[string]interface{}{ + "scope": "openid profile", + }, + }, + token: "test-access-token-rfc9068", + expectedID: false, + description: "Should detect access token via typ=at+jwt header", + }, + { + name: "Token with token_use=id", + jwt: &JWT{ + Header: map[string]interface{}{"alg": "RS256"}, + Claims: map[string]interface{}{ + "token_use": "id", + "aud": "test-client-id", + }, + }, + token: "test-token-use-id", + expectedID: true, + description: "Should detect ID token via token_use claim", + }, + { + name: "Token with token_use=access", + jwt: &JWT{ + Header: map[string]interface{}{"alg": "RS256"}, + Claims: map[string]interface{}{ + "token_use": "access", + "scope": "read write", + }, + }, + token: "test-token-use-access", + expectedID: false, + description: "Should detect access token via token_use claim", + }, + { + name: "Access token with scope", + jwt: &JWT{ + Header: map[string]interface{}{"alg": "RS256"}, + Claims: map[string]interface{}{ + "scope": "openid profile email", + "aud": "some-api-audience", + }, + }, + token: "test-access-token-with-scope", + expectedID: false, + description: "Should detect access token via scope claim", + }, + { + name: "ID token with client_id audience", + jwt: &JWT{ + Header: map[string]interface{}{"alg": "RS256"}, + Claims: map[string]interface{}{ + "aud": "test-client-id", + "sub": "user123", + }, + }, + token: "test-id-token-client-aud", + expectedID: true, + description: "Should detect ID token via audience matching client_id", + }, + { + name: "Default to access token", + jwt: &JWT{ + Header: map[string]interface{}{"alg": "RS256"}, + Claims: map[string]interface{}{ + "aud": "different-audience", + "sub": "user123", + }, + }, + token: "test-default-access-token", + expectedID: false, + description: "Should default to access token when no clear indicators", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // First call - should not be cached + result := tr.detectTokenType(tc.jwt, tc.token) + if result != tc.expectedID { + t.Errorf("%s: expected isIDToken=%v, got %v", tc.description, tc.expectedID, result) + } + + // Second call - should be cached + result2 := tr.detectTokenType(tc.jwt, tc.token) + if result2 != tc.expectedID { + t.Errorf("%s (cached): expected isIDToken=%v, got %v", tc.description, tc.expectedID, result2) + } + }) + } +} + +func TestDetectTokenTypeCaching(t *testing.T) { + cache := NewTestCache() + tr := &TraefikOidc{ + clientID: "test-client-id", + suppressDiagnosticLogs: true, + tokenTypeCache: cache, + } + + jwt := &JWT{ + Header: map[string]interface{}{"alg": "RS256"}, + Claims: map[string]interface{}{ + "nonce": "test-nonce", + }, + } + token := "test-token-for-caching-with-enough-characters-for-key" + cacheKey := token + if len(token) > 32 { + cacheKey = token[:32] // First 32 chars + } + + // First call - should cache + result := tr.detectTokenType(jwt, token) + if !result { + t.Error("Expected ID token detection via nonce") + } + + // Check cache was populated + if cached, found := cache.Get(cacheKey); !found { + t.Error("Expected token type to be cached") + } else if cachedBool, ok := cached.(bool); !ok || !cachedBool { + t.Error("Expected cached value to be true (ID token)") + } + + // Modify JWT to have different detection (but use same token for cache key) + jwt.Claims = map[string]interface{}{ + "scope": "openid profile", // This would normally make it an access token + } + + // Second call with modified JWT - should still return cached value + result2 := tr.detectTokenType(jwt, token) + if !result2 { + t.Error("Expected cached ID token result, ignoring modified JWT") + } +} + +// TestCache is a simple in-memory cache for testing +type TestCache struct { + data map[string]interface{} +} + +func NewTestCache() *TestCache { + return &TestCache{ + data: make(map[string]interface{}), + } +} + +func (c *TestCache) Set(key string, value interface{}, ttl time.Duration) { + c.data[key] = value +} + +func (c *TestCache) Get(key string) (interface{}, bool) { + val, ok := c.data[key] + return val, ok +} + +func (c *TestCache) Delete(key string) { + delete(c.data, key) +} + +func (c *TestCache) SetMaxSize(size int) {} +func (c *TestCache) Size() int { return len(c.data) } +func (c *TestCache) Clear() { c.data = make(map[string]interface{}) } +func (c *TestCache) Cleanup() {} +func (c *TestCache) Close() {} +func (c *TestCache) GetStats() map[string]interface{} { + return map[string]interface{}{"size": len(c.data)} +} diff --git a/types.go b/types.go index f270d6c..a78b7d4 100644 --- a/types.go +++ b/types.go @@ -73,6 +73,7 @@ type TraefikOidc struct { initComplete chan struct{} limiter *rate.Limiter tokenBlacklist CacheInterface + tokenTypeCache CacheInterface // Cache for token type detection results headerTemplates map[string]*template.Template sessionManager *SessionManager tokenCleanupStopChan chan struct{} diff --git a/universal_cache_singleton.go b/universal_cache_singleton.go index e2ce474..cf48cb4 100644 --- a/universal_cache_singleton.go +++ b/universal_cache_singleton.go @@ -13,6 +13,7 @@ type UniversalCacheManager struct { jwkCache *UniversalCache sessionCache *UniversalCache introspectionCache *UniversalCache // OAuth 2.0 Token Introspection cache (RFC 7662) + tokenTypeCache *UniversalCache // Cache for token type detection results mu sync.RWMutex logger *Logger } @@ -94,6 +95,14 @@ func GetUniversalCacheManager(logger *Logger) *UniversalCacheManager { DefaultTTL: 5 * time.Minute, // Short TTL for security (introspect frequently) Logger: logger, }) + + // Initialize token type cache for performance optimization + universalCacheManager.tokenTypeCache = NewUniversalCache(UniversalCacheConfig{ + Type: CacheTypeToken, // Use token cache type for token type detection + MaxSize: 2000, // Cache up to 2000 token type detections + DefaultTTL: 5 * time.Minute, // 5 minute TTL for token type detection + Logger: logger, + }) }) return universalCacheManager @@ -141,13 +150,20 @@ func (m *UniversalCacheManager) GetIntrospectionCache() *UniversalCache { return m.introspectionCache } +// GetTokenTypeCache returns the token type detection cache +func (m *UniversalCacheManager) GetTokenTypeCache() *UniversalCache { + m.mu.RLock() + defer m.mu.RUnlock() + return m.tokenTypeCache +} + // Close shuts down all caches func (m *UniversalCacheManager) Close() error { m.mu.Lock() defer m.mu.Unlock() for _, cache := range []*UniversalCache{ - m.tokenCache, m.blacklistCache, m.metadataCache, m.jwkCache, m.sessionCache, m.introspectionCache, + m.tokenCache, m.blacklistCache, m.metadataCache, m.jwkCache, m.sessionCache, m.introspectionCache, m.tokenTypeCache, } { if cache != nil { cache.Close()