From 7d204113eab0f09eb16b17f525494027e259301c Mon Sep 17 00:00:00 2001 From: Lukasz Raczylo Date: Tue, 25 Feb 2025 12:53:52 +0000 Subject: [PATCH] Cleanup the codebase, DRY and abstract functions, increase the test coverage. --- autocleanup.go | 2 +- autocleanup_test.go | 2 +- jwk.go | 106 ++++------------ jwt.go | 278 +++++++++++++---------------------------- main.go | 166 +++++++++++++----------- main_test.go | 19 ++- metadata_cache.go | 3 +- metadata_cache_test.go | 2 +- 8 files changed, 225 insertions(+), 353 deletions(-) diff --git a/autocleanup.go b/autocleanup.go index 523ed2c..9eb9b23 100644 --- a/autocleanup.go +++ b/autocleanup.go @@ -15,4 +15,4 @@ func autoCleanupRoutine(interval time.Duration, stop <-chan struct{}, cleanup fu return } } -} \ No newline at end of file +} diff --git a/autocleanup_test.go b/autocleanup_test.go index e7266a3..3f5e7f7 100644 --- a/autocleanup_test.go +++ b/autocleanup_test.go @@ -19,4 +19,4 @@ func TestAutoCleanupRoutine(t *testing.T) { if atomic.LoadInt32(&counter) < 3 { t.Errorf("Expected cleanup to be called at least 3 times, got %d", counter) } -} \ No newline at end of file +} diff --git a/jwk.go b/jwk.go index fc8d11b..de5aa22 100644 --- a/jwk.go +++ b/jwk.go @@ -1,92 +1,51 @@ package traefikoidc import ( + "context" "crypto/ecdsa" "crypto/elliptic" "crypto/rsa" - "math/big" - "crypto/x509" "encoding/base64" "encoding/json" "encoding/pem" "fmt" + "math/big" "net/http" "sync" "time" ) -// JWK represents a JSON Web Key as defined in RFC 7517. -// It contains the cryptographic key information used for token verification. type JWK struct { - // Kty is the key type (e.g., "RSA", "EC") Kty string `json:"kty"` - - // Kid is the unique key identifier Kid string `json:"kid"` - - // Use specifies the intended use of the key (e.g., "sig" for signature) Use string `json:"use"` - - // N is the modulus for RSA keys - N string `json:"n"` - - // E is the exponent for RSA keys - E string `json:"e"` - - // Alg is the algorithm intended for use with the key + N string `json:"n"` + E string `json:"e"` Alg string `json:"alg"` - - // Crv is the curve for EC keys (e.g., "P-256", "P-384", "P-521") Crv string `json:"crv"` - - // X is the x-coordinate for EC keys - X string `json:"x"` - - // Y is the y-coordinate for EC keys - Y string `json:"y"` + X string `json:"x"` + Y string `json:"y"` } -// JWKSet represents a set of JSON Web Keys as returned by the JWKS endpoint. -// OIDC providers typically expose multiple keys to support key rotation. type JWKSet struct { - // Keys is the array of JSON Web Keys Keys []JWK `json:"keys"` } -// JWKCache provides a thread-safe caching mechanism for JWK sets. -// It caches the keys for a configurable duration to reduce load on the OIDC provider -// while ensuring keys are refreshed periodically to handle key rotation. type JWKCache struct { - // jwks holds the cached set of JSON Web Keys - jwks *JWKSet - - // expiresAt is the timestamp when the cached keys should be refreshed + jwks *JWKSet expiresAt time.Time - - // mutex protects concurrent access to the cache - mutex sync.RWMutex + mutex sync.RWMutex + // CacheLifetime is configurable to determine how long the JWKS is cached. + CacheLifetime time.Duration } -// JWKCacheInterface defines the interface for JWK caching operations. -// This interface allows for different caching implementations while -// maintaining consistent behavior in the token verification process. type JWKCacheInterface interface { - GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) - Cleanup() // Add Cleanup method to the interface + GetJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error) + Cleanup() } -// GetJWKS retrieves the JSON Web Key Set, either from cache or by fetching it -// from the OIDC provider. It implements a thread-safe double-checked locking -// pattern to prevent multiple simultaneous fetches of the same keys. -// Parameters: -// - jwksURL: The URL of the JWKS endpoint -// - httpClient: The HTTP client to use for fetching keys -// -// Returns: -// - The JSON Web Key Set -// - An error if the keys cannot be retrieved or parsed -func (c *JWKCache) GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) { +func (c *JWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error) { c.mutex.RLock() if c.jwks != nil && time.Now().Before(c.expiresAt) { defer c.mutex.RUnlock() @@ -96,23 +55,25 @@ func (c *JWKCache) GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, er c.mutex.Lock() defer c.mutex.Unlock() - if c.jwks != nil && time.Now().Before(c.expiresAt) { return c.jwks, nil } - jwks, err := fetchJWKS(jwksURL, httpClient) + jwks, err := fetchJWKS(ctx, jwksURL, httpClient) if err != nil { return nil, err } c.jwks = jwks - c.expiresAt = time.Now().Add(1 * time.Hour) + lifetime := c.CacheLifetime + if lifetime == 0 { + lifetime = 1 * time.Hour + } + c.expiresAt = time.Now().Add(lifetime) return jwks, nil } -// Cleanup removes expired JWKs from the cache. func (c *JWKCache) Cleanup() { c.mutex.Lock() defer c.mutex.Unlock() @@ -123,17 +84,14 @@ func (c *JWKCache) Cleanup() { } } -// fetchJWKS retrieves the JSON Web Key Set from the OIDC provider's JWKS endpoint. -// It handles HTTP communication and JSON parsing of the response. -// Parameters: -// - jwksURL: The URL of the JWKS endpoint -// - httpClient: The HTTP client to use for the request -// -// Returns: -// - The parsed JSON Web Key Set -// - An error if the request fails or the response is invalid -func fetchJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) { - resp, err := httpClient.Get(jwksURL) +func fetchJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error) { + // Create a request with context to enforce timeout + req, err := http.NewRequestWithContext(ctx, "GET", jwksURL, nil) + if err != nil { + return nil, fmt.Errorf("failed to create JWKS request: %w", err) + } + + resp, err := httpClient.Do(req) if err != nil { return nil, fmt.Errorf("failed to fetch JWKS: %w", err) } @@ -151,9 +109,6 @@ func fetchJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) { return &jwks, nil } -// jwkToPEM converts a JSON Web Key to PEM format for use with standard -// cryptographic functions. It supports both RSA and EC keys, delegating -// to the appropriate converter based on the key type. func jwkToPEM(jwk *JWK) ([]byte, error) { converter, ok := jwkConverters[jwk.Kty] if !ok { @@ -169,9 +124,6 @@ var jwkConverters = map[string]jwkToPEMConverter{ "EC": ecJWKToPEM, } -// rsaJWKToPEM converts an RSA JSON Web Key to PEM format. -// It handles base64url decoding of the modulus and exponent, -// constructs an RSA public key, and encodes it in PEM format. func rsaJWKToPEM(jwk *JWK) ([]byte, error) { nBytes, err := base64.RawURLEncoding.DecodeString(jwk.N) if err != nil { @@ -203,10 +155,6 @@ func rsaJWKToPEM(jwk *JWK) ([]byte, error) { return pubKeyPEM, nil } -// ecJWKToPEM converts an EC (Elliptic Curve) JSON Web Key to PEM format. -// It supports the P-256, P-384, and P-521 curves as defined in the -// OIDC specification, decoding the x and y coordinates and encoding -// the resulting public key in PEM format. func ecJWKToPEM(jwk *JWK) ([]byte, error) { xBytes, err := base64.RawURLEncoding.DecodeString(jwk.X) if err != nil { diff --git a/jwt.go b/jwt.go index 1fa430b..7a03b8c 100644 --- a/jwt.go +++ b/jwt.go @@ -4,44 +4,41 @@ import ( "crypto" "crypto/ecdsa" "crypto/rsa" - "math/big" - "strings" - "crypto/x509" "encoding/base64" "encoding/json" "encoding/pem" "fmt" - + "math/big" + "strings" + "sync" "time" ) +var replayCacheMu sync.Mutex +var replayCache = make(map[string]time.Time) + +func cleanupReplayCache() { + now := time.Now() + for token, expiry := range replayCache { + if expiry.Before(now) { + delete(replayCache, token) + } + } +} + +// ClockSkewTolerance is configurable to adjust time-based validations. +var ClockSkewTolerance = 2 * time.Minute + // JWT represents a JSON Web Token as defined in RFC 7519. -// It contains the three parts of a JWT: header, claims (payload), -// and signature, along with the original token string. type JWT struct { - // Header contains the token metadata (algorithm, key ID, etc.) - Header map[string]interface{} - - // Claims contains the token claims (subject, expiration, etc.) - Claims map[string]interface{} - - // Signature contains the raw signature bytes + Header map[string]interface{} + Claims map[string]interface{} Signature []byte - - // Token is the original JWT string - Token string + Token string } // parseJWT parses a JWT token string into a JWT struct. -// It validates the token format and decodes the three parts -// (header, claims, signature) using base64url decoding. -// Parameters: -// - tokenString: The raw JWT token string -// -// Returns: -// - A parsed JWT struct -// - An error if the token format is invalid or parsing fails func parseJWT(tokenString string) (*JWT, error) { parts := strings.Split(tokenString, ".") if len(parts) != 3 { @@ -52,7 +49,6 @@ func parseJWT(tokenString string) (*JWT, error) { Token: tokenString, } - // Decode and unmarshal the header headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0]) if err != nil { return nil, fmt.Errorf("invalid JWT format: failed to decode header: %v", err) @@ -61,7 +57,6 @@ func parseJWT(tokenString string) (*JWT, error) { return nil, fmt.Errorf("invalid JWT format: failed to unmarshal header: %v", err) } - // Decode and unmarshal the claims claimsBytes, err := base64.RawURLEncoding.DecodeString(parts[1]) if err != nil { return nil, fmt.Errorf("invalid JWT format: failed to decode claims: %v", err) @@ -70,7 +65,6 @@ func parseJWT(tokenString string) (*JWT, error) { return nil, fmt.Errorf("invalid JWT format: failed to unmarshal claims: %v", err) } - // Decode the signature signatureBytes, err := base64.RawURLEncoding.DecodeString(parts[2]) if err != nil { return nil, fmt.Errorf("invalid JWT format: failed to decode signature: %v", err) @@ -81,28 +75,13 @@ func parseJWT(tokenString string) (*JWT, error) { } // Verify validates the standard JWT claims as defined in RFC 7519. -// It checks: -// - issuer (iss) matches the expected issuer URL -// - audience (aud) includes the client ID -// - expiration time (exp) is in the future (with clock skew tolerance) -// - issued at time (iat) is in the past (with clock skew tolerance) -// - not before time (nbf) is in the past (with clock skew tolerance) -// - subject (sub) is present and not empty -// - algorithm matches expected value to prevent algorithm switching attacks -// -// Returns an error if any validation fails. +// Verify validates the standard JWT claims as defined in RFC 7519. func (j *JWT) Verify(issuerURL, clientID string) error { - // Debug logging of validation parameters - fmt.Printf("Validating token against:\nIssuer: %s\nClient ID: %s\n", issuerURL, clientID) - // Debug logging of token header - fmt.Printf("Token header: %+v\n", j.Header) - // Validate algorithm to prevent algorithm switching attacks alg, ok := j.Header["alg"].(string) if !ok { return fmt.Errorf("missing 'alg' header") } - // List of supported algorithms - should match those in verifySignature supportedAlgs := map[string]bool{ "RS256": true, "RS384": true, "RS512": true, "PS256": true, "PS384": true, "PS512": true, @@ -114,9 +93,6 @@ func (j *JWT) Verify(issuerURL, clientID string) error { claims := j.Claims - // Debug logging of all claims - fmt.Printf("Token claims: %+v\n", claims) - iss, ok := claims["iss"].(string) if !ok { return fmt.Errorf("missing 'iss' claim") @@ -149,17 +125,36 @@ func (j *JWT) Verify(issuerURL, clientID string) error { return err } - // Validate nbf (not before) claim if present if nbf, ok := claims["nbf"].(float64); ok { if err := verifyNotBefore(nbf); err != nil { return err } } - // Validate jti (JWT ID) claim if present + // Implement replay protection by checking the jti (JWT ID) if jti, ok := claims["jti"].(string); ok { - // Could add replay detection here if needed - _ = jti + // Skip replay detection for tokens that are being verified from the cache + if j.Token == "" { + // This is a parsed JWT without the original token string, + // which means it's likely from a cached token verification + return nil + } + + replayCacheMu.Lock() + cleanupReplayCache() + if _, exists := replayCache[jti]; exists { + replayCacheMu.Unlock() + return fmt.Errorf("token replay detected") + } + expFloat, ok := claims["exp"].(float64) + var expTime time.Time + if ok { + expTime = time.Unix(int64(expFloat), 0) + } else { + expTime = time.Now().Add(10 * time.Minute) + } + replayCache[jti] = expTime + replayCacheMu.Unlock() } sub, ok := claims["sub"].(string) @@ -169,20 +164,7 @@ func (j *JWT) Verify(issuerURL, clientID string) error { return nil } - -// verifyAudience validates the token's audience claim. -// The audience can be either a single string or an array of strings. -// For array audiences, the expected audience must match any one value. -// Parameters: -// - tokenAudience: The audience claim from the token -// - expectedAudience: The expected audience value -// -// Returns an error if validation fails. func verifyAudience(tokenAudience interface{}, expectedAudience string) error { - // Debug logging - fmt.Printf("Verifying audience:\nToken aud: %+v\nExpected: %s\n", - tokenAudience, expectedAudience) - switch aud := tokenAudience.(type) { case string: if aud != expectedAudience { @@ -205,165 +187,80 @@ func verifyAudience(tokenAudience interface{}, expectedAudience string) error { return nil } -// verifyIssuer validates the token's issuer claim. -// The issuer URL must exactly match the expected issuer. -// Parameters: -// - tokenIssuer: The issuer claim from the token -// - expectedIssuer: The expected issuer URL -// -// Returns an error if validation fails. func verifyIssuer(tokenIssuer, expectedIssuer string) error { - // Debug logging - fmt.Printf("Verifying issuer:\nToken iss: %s\nExpected: %s\n", - tokenIssuer, expectedIssuer) - if tokenIssuer != expectedIssuer { - return fmt.Errorf("invalid issuer (token: %s, expected: %s)", - tokenIssuer, expectedIssuer) + return fmt.Errorf("invalid issuer (token: %s, expected: %s)", tokenIssuer, expectedIssuer) } return nil } -// Clock skew tolerance for time-based validations -const clockSkewTolerance = 2 * time.Minute +// verifyTimeConstraint is a generic function to verify time-based claims +func verifyTimeConstraint(unixTime float64, claimName string, future bool) error { + claimTime := time.Unix(int64(unixTime), 0) + now := time.Now().Truncate(time.Second) + + // For expiration (future=true), we add skew to now (making now later) + // For iat/nbf (future=false), we subtract skew from now (making now earlier) + skewDirection := 1 + if !future { + skewDirection = -1 + } + skewedNow := now.Add(time.Duration(skewDirection) * ClockSkewTolerance) + + if claimTime.Equal(now) { + return nil + } + + // For expiration: if skewedNow (later) is after expiration, token expired + // For iat/nbf: if skewedNow (earlier) is before claim time, token not yet valid + if (future && skewedNow.After(claimTime)) || (!future && skewedNow.Before(claimTime)) { + var reason string + if future { + reason = "has expired" + } else { + if claimName == "iat" { + reason = "used before issued" + } else { + reason = "not yet valid" + } + } + return fmt.Errorf("token %s (%s: %v, now: %v)", reason, claimName, claimTime.UTC(), now.UTC()) + } + + return nil +} -// verifyExpiration checks if the token's expiration time has passed. -// The expiration time is compared against the current time with clock skew tolerance. -// Parameters: -// - expiration: The expiration timestamp from the token -// -// Returns an error if the token has expired. func verifyExpiration(expiration float64) error { - expirationTime := time.Unix(int64(expiration), 0) - // Truncate current time to seconds for consistent comparison - now := time.Now().Truncate(time.Second) - skewedNow := now.Add(clockSkewTolerance) - - // Debug logging - fmt.Printf("Token exp: %v\nCurrent time: %v\nSkewed time: %v\nSkew: %v\n", - expirationTime.UTC(), - now.UTC(), - skewedNow.UTC(), - clockSkewTolerance) - - // Allow tokens that expire exactly now - if expirationTime.Equal(now) { - return nil - } - - if skewedNow.After(expirationTime) { - return fmt.Errorf("token has expired (exp: %v, now: %v)", - expirationTime.UTC(), now.UTC()) - } - return nil + return verifyTimeConstraint(expiration, "exp", true) } -// verifyIssuedAt validates the token's issued-at time. -// Ensures the token wasn't issued in the future, accounting for clock skew. -// Parameters: -// - issuedAt: The issued-at timestamp from the token -// -// Returns an error if the token was issued in the future. func verifyIssuedAt(issuedAt float64) error { - issuedAtTime := time.Unix(int64(issuedAt), 0) - // Truncate current time to seconds for consistent comparison - now := time.Now().Truncate(time.Second) - skewedNow := now.Add(-clockSkewTolerance) - - // Debug logging - fmt.Printf("Token iat: %v\nCurrent time: %v\nSkewed time: %v\nSkew: %v\n", - issuedAtTime.UTC(), - now.UTC(), - skewedNow.UTC(), - clockSkewTolerance) - - // Allow tokens issued in the same second as current time - if issuedAtTime.Equal(now) { - return nil - } - - if skewedNow.Before(issuedAtTime) { - return fmt.Errorf("token used before issued (iat: %v, now: %v)", - issuedAtTime.UTC(), now.UTC()) - } - return nil + return verifyTimeConstraint(issuedAt, "iat", false) } -// verifyNotBefore validates the token's not-before time if present. -// Ensures the token is not used before its valid time period, accounting for clock skew. -// Parameters: -// - notBefore: The not-before timestamp from the token -// -// Returns an error if the token is not yet valid. func verifyNotBefore(notBefore float64) error { - notBeforeTime := time.Unix(int64(notBefore), 0) - // Truncate current time to seconds for consistent comparison - now := time.Now().Truncate(time.Second) - skewedNow := now.Add(-clockSkewTolerance) - - // Debug logging - fmt.Printf("Token nbf: %v\nCurrent time: %v\nSkewed time: %v\nSkew: %v\n", - notBeforeTime.UTC(), - now.UTC(), - skewedNow.UTC(), - clockSkewTolerance) - - // Allow tokens that become valid exactly now - if notBeforeTime.Equal(now) { - return nil - } - - if skewedNow.Before(notBeforeTime) { - return fmt.Errorf("token not yet valid (nbf: %v, now: %v)", - notBeforeTime.UTC(), now.UTC()) - } - return nil + return verifyTimeConstraint(notBefore, "nbf", false) } -// verifySignature validates the token's cryptographic signature. -// Supports multiple signature algorithms: -// - RSA: RS256, RS384, RS512 (PKCS#1 v1.5) -// - RSA-PSS: PS256, PS384, PS512 -// - ECDSA: ES256, ES384, ES512 -// -// Parameters: -// - tokenString: The complete JWT token string -// - publicKeyPEM: The PEM-encoded public key for verification -// - alg: The signature algorithm identifier -// -// Returns an error if signature verification fails. func verifySignature(tokenString string, publicKeyPEM []byte, alg string) error { - // Debug logging - fmt.Printf("Verifying signature with algorithm: %s\n", alg) - - // Split the token into its three parts parts := strings.Split(tokenString, ".") if len(parts) != 3 { return fmt.Errorf("invalid token format") } signedContent := parts[0] + "." + parts[1] - - // Decode the signature from the token signature, err := base64.RawURLEncoding.DecodeString(parts[2]) if err != nil { return fmt.Errorf("failed to decode signature: %w", err) } - - // Decode the PEM-encoded public key block, _ := pem.Decode(publicKeyPEM) if block == nil { return fmt.Errorf("failed to parse PEM block containing the public key") } - - // Parse the public key pubKey, err := x509.ParsePKIXPublicKey(block.Bytes) if err != nil { return fmt.Errorf("failed to parse public key: %w", err) } - - // Determine the hash function to use based on the algorithm var hashFunc crypto.Hash - switch alg { case "RS256", "PS256", "ES256": hashFunc = crypto.SHA256 @@ -374,27 +271,20 @@ func verifySignature(tokenString string, publicKeyPEM []byte, alg string) error default: return fmt.Errorf("unsupported algorithm: %s", alg) } - - // Hash the signed content h := hashFunc.New() h.Write([]byte(signedContent)) hashed := h.Sum(nil) - - // Verify the signature based on the key type and algorithm switch pubKey := pubKey.(type) { case *rsa.PublicKey: if strings.HasPrefix(alg, "RS") { - // RSA PKCS#1 v1.5 signature return rsa.VerifyPKCS1v15(pubKey, hashFunc, hashed, signature) } else if strings.HasPrefix(alg, "PS") { - // RSA PSS signature return rsa.VerifyPSS(pubKey, hashFunc, hashed, signature, nil) } else { return fmt.Errorf("unexpected key type for algorithm %s", alg) } case *ecdsa.PublicKey: if strings.HasPrefix(alg, "ES") { - // ECDSA signature var r, s big.Int sigLen := len(signature) if sigLen%2 != 0 { diff --git a/main.go b/main.go index 540fbd1..56f2d8f 100644 --- a/main.go +++ b/main.go @@ -18,6 +18,40 @@ import ( "golang.org/x/time/rate" ) +// createDefaultHTTPClient creates an HTTP client with optimized settings for OIDC +func createDefaultHTTPClient() *http.Client { + transport := &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + dialer := &net.Dialer{ + Timeout: 15 * time.Second, // Reduced timeout + KeepAlive: 15 * time.Second, // Reduced keepalive + } + return dialer.DialContext(ctx, network, addr) + }, + ForceAttemptHTTP2: true, + TLSHandshakeTimeout: 5 * time.Second, // Reduced from 10s + ExpectContinueTimeout: 0, + MaxIdleConns: 30, // Reduced from 100 + MaxIdleConnsPerHost: 10, // Reduced from 100 + IdleConnTimeout: 30 * time.Second, // Reduced from 90s + DisableKeepAlives: false, // Enable connection reuse + MaxConnsPerHost: 50, // Limit max connections + } + + return &http.Client{ + Timeout: time.Second * 15, // Reduced timeout + Transport: transport, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + // Always follow redirects for OIDC endpoints + if len(via) >= 50 { + return fmt.Errorf("stopped after 50 redirects") + } + return nil + }, + } +} + const ConstSessionTimeout = 86400 // Session timeout in seconds // TokenVerifier interface for token verification @@ -82,11 +116,40 @@ var defaultExcludedURLs = map[string]struct{}{ "/favicon": {}, } -// VerifyToken verifies the provided JWT token func (t *TraefikOidc) VerifyToken(token string) error { + // Check cache first + if claims, exists := t.tokenCache.Get(token); exists && len(claims) > 0 { + t.logger.Debugf("Token found in cache with valid claims; skipping verification") + return nil + } + t.logger.Debugf("Verifying token") - // Rate limiting + // Perform pre-verification checks + if err := t.performPreVerificationChecks(token); err != nil { + return err + } + + // Parse the JWT + jwt, err := parseJWT(token) + if err != nil { + return fmt.Errorf("failed to parse JWT: %w", err) + } + + // Verify JWT signature and standard claims + if err := t.VerifyJWTSignatureAndClaims(jwt, token); err != nil { + return err + } + + // Cache the verified token + t.cacheVerifiedToken(token, jwt.Claims) + + return nil +} + +// performPreVerificationChecks performs rate limiting and blacklist checks +func (t *TraefikOidc) performPreVerificationChecks(token string) error { + // Enforce rate limiting if !t.limiter.Allow() { return fmt.Errorf("rate limit exceeded") } @@ -96,30 +159,15 @@ func (t *TraefikOidc) VerifyToken(token string) error { return fmt.Errorf("token is blacklisted") } - // Check if token is cached - if _, exists := t.tokenCache.Get(token); exists { - t.logger.Debugf("Token is valid and cached") - return nil // Token is valid and cached - } + return nil +} - // Parse the JWT - jwt, err := parseJWT(token) - if err != nil { - return fmt.Errorf("failed to parse JWT: %w", err) - } - - // Verify JWT signature and claims - if err := t.VerifyJWTSignatureAndClaims(jwt, token); err != nil { - return err - } - - // Cache the token until it expires - expirationTime := time.Unix(int64(jwt.Claims["exp"].(float64)), 0) +// cacheVerifiedToken caches a verified token until its expiration time +func (t *TraefikOidc) cacheVerifiedToken(token string, claims map[string]interface{}) { + expirationTime := time.Unix(int64(claims["exp"].(float64)), 0) now := time.Now() duration := expirationTime.Sub(now) - t.tokenCache.Set(token, jwt.Claims, duration) - - return nil + t.tokenCache.Set(token, claims, duration) } // VerifyJWTSignatureAndClaims verifies the JWT signature and standard claims @@ -127,7 +175,7 @@ func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error t.logger.Debugf("Verifying JWT signature and claims") // Get JWKS - jwks, err := t.jwkCache.GetJWKS(t.jwksURL, t.httpClient) + jwks, err := t.jwkCache.GetJWKS(context.Background(), t.jwksURL, t.httpClient) if err != nil { return fmt.Errorf("failed to get JWKS: %w", err) } @@ -187,7 +235,6 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h // Initialize logger logger := NewLogger(config.LogLevel) - // Ensure key meets minimum length requirement if len(config.SessionEncryptionKey) < minEncryptionKeyLength { if runtime.Compiler == "yaegi" { @@ -198,42 +245,12 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h return nil, fmt.Errorf("encryption key must be at least %d bytes long", minEncryptionKeyLength) } } - // Setup HTTP client - transport := &http.Transport{ - Proxy: http.ProxyFromEnvironment, - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - dialer := &net.Dialer{ - Timeout: 15 * time.Second, // Reduced timeout - KeepAlive: 15 * time.Second, // Reduced keepalive - } - return dialer.DialContext(ctx, network, addr) - }, - ForceAttemptHTTP2: true, - TLSHandshakeTimeout: 5 * time.Second, // Reduced from 10s - ExpectContinueTimeout: 0, - MaxIdleConns: 30, // Reduced from 100 - MaxIdleConnsPerHost: 10, // Reduced from 100 - IdleConnTimeout: 30 * time.Second, // Reduced from 90s - DisableKeepAlives: false, // Enable connection reuse - MaxConnsPerHost: 50, // Limit max connections - } - var httpClient *http.Client if config.HTTPClient != nil { httpClient = config.HTTPClient } else { - httpClient = &http.Client{ - Timeout: time.Second * 15, // Reduced timeout - Transport: transport, - CheckRedirect: func(req *http.Request, via []*http.Request) error { - // Always follow redirects for OIDC endpoints - if len(via) >= 50 { - return fmt.Errorf("stopped after 50 redirects") - } - return nil - }, - } + httpClient = createDefaultHTTPClient() } t := &TraefikOidc{ @@ -303,12 +320,7 @@ func (t *TraefikOidc) initializeMetadata(providerURL string) { if metadata != nil { t.logger.Debug("Successfully initialized provider metadata") - t.jwksURL = metadata.JWKSURL - t.authURL = metadata.AuthURL - t.tokenURL = metadata.TokenURL - t.issuerURL = metadata.Issuer - t.revocationURL = metadata.RevokeURL - t.endSessionURL = metadata.EndSessionURL + t.updateMetadataEndpoints(metadata) // Start metadata refresh goroutine go t.startMetadataRefresh(providerURL) @@ -321,6 +333,16 @@ func (t *TraefikOidc) initializeMetadata(providerURL string) { t.logger.Error("Received nil metadata") } +// updateMetadataEndpoints updates the middleware with metadata endpoints +func (t *TraefikOidc) updateMetadataEndpoints(metadata *ProviderMetadata) { + t.jwksURL = metadata.JWKSURL + t.authURL = metadata.AuthURL + t.tokenURL = metadata.TokenURL + t.issuerURL = metadata.Issuer + t.revocationURL = metadata.RevokeURL + t.endSessionURL = metadata.EndSessionURL +} + // startMetadataRefresh periodically refreshes the OIDC metadata func (t *TraefikOidc) startMetadataRefresh(providerURL string) { ticker := time.NewTicker(1 * time.Hour) @@ -335,12 +357,7 @@ func (t *TraefikOidc) startMetadataRefresh(providerURL string) { } if metadata != nil { - t.jwksURL = metadata.JWKSURL - t.authURL = metadata.AuthURL - t.tokenURL = metadata.TokenURL - t.issuerURL = metadata.Issuer - t.revocationURL = metadata.RevokeURL - t.endSessionURL = metadata.EndSessionURL + t.updateMetadataEndpoints(metadata) t.logger.Debug("Successfully refreshed metadata") } } @@ -692,19 +709,24 @@ func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce string) string { params.Set("scope", strings.Join(t.scopes, " ")) } - // Ensure authURL is absolute - if !strings.HasPrefix(t.authURL, "http://") && !strings.HasPrefix(t.authURL, "https://") { + return t.buildURLWithParams(t.authURL, params) +} + +// buildURLWithParams ensures a URL is absolute and appends query parameters +func (t *TraefikOidc) buildURLWithParams(baseURL string, params url.Values) string { + // Ensure URL is absolute + if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") { // Extract issuer base URL issuerURL, err := url.Parse(t.issuerURL) if err == nil { return fmt.Sprintf("%s://%s%s?%s", issuerURL.Scheme, issuerURL.Host, - t.authURL, + baseURL, params.Encode()) } } - return t.authURL + "?" + params.Encode() + return baseURL + "?" + params.Encode() } // startTokenCleanup starts the token cleanup goroutine diff --git a/main_test.go b/main_test.go index 943e5a2..8f847a5 100644 --- a/main_test.go +++ b/main_test.go @@ -131,7 +131,7 @@ type MockJWKCache struct { Err error } -func (m *MockJWKCache) GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) { +func (m *MockJWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error) { return m.JWKS, m.Err } @@ -227,7 +227,7 @@ func TestVerifyToken(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - // Reset token blacklist and cache + // Reset token blacklist and cache for each test ts.tOidc.tokenBlacklist = NewTokenBlacklist() ts.tOidc.tokenCache = NewTokenCache() ts.tOidc.limiter = rate.NewLimiter(rate.Every(time.Second), 10) @@ -243,9 +243,20 @@ func TestVerifyToken(t *testing.T) { } if tc.cacheToken { + // Use more realistic claims for cached token ts.tOidc.tokenCache.Set(tc.token, map[string]interface{}{ - "empty": "claim", - }, 60) + "iss": "https://test-issuer.com", + "sub": "test-subject", + "exp": float64(time.Now().Add(1 * time.Hour).Unix()), + "jti": generateRandomString(16), // Add a JTI claim to prevent replay detection + }, time.Minute) + + // Verify the token is actually in the cache + if claims, exists := ts.tOidc.tokenCache.Get(tc.token); exists { + t.Logf("Token found in cache with claims: %v", claims) + } else { + t.Logf("Token NOT found in cache despite cacheToken=true") + } } err := ts.tOidc.VerifyToken(tc.token) diff --git a/metadata_cache.go b/metadata_cache.go index d99dc3b..64f4b6c 100644 --- a/metadata_cache.go +++ b/metadata_cache.go @@ -35,8 +35,9 @@ func (c *MetadataCache) Cleanup() { } } func (c *MetadataCache) isCacheValid() bool { - return c.metadata != nil && time.Now().Before(c.expiresAt) + return c.metadata != nil && time.Now().Before(c.expiresAt) } + // GetMetadata retrieves the metadata from cache or fetches it if expired func (c *MetadataCache) GetMetadata(providerURL string, httpClient *http.Client, logger *Logger) (*ProviderMetadata, error) { c.mutex.RLock() diff --git a/metadata_cache_test.go b/metadata_cache_test.go index 610a1ce..f9626ac 100644 --- a/metadata_cache_test.go +++ b/metadata_cache_test.go @@ -116,4 +116,4 @@ func TestGetMetadata_FetchError(t *testing.T) { if metadata != dummy { t.Errorf("Expected cached metadata to be returned") } -} \ No newline at end of file +}