mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-05 22:44:17 +00:00
Cleanup the codebase, DRY and abstract functions, increase the test coverage.
This commit is contained in:
+1
-1
@@ -15,4 +15,4 @@ func autoCleanupRoutine(interval time.Duration, stop <-chan struct{}, cleanup fu
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+1
-1
@@ -19,4 +19,4 @@ func TestAutoCleanupRoutine(t *testing.T) {
|
||||
if atomic.LoadInt32(&counter) < 3 {
|
||||
t.Errorf("Expected cleanup to be called at least 3 times, got %d", counter)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,92 +1,51 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rsa"
|
||||
"math/big"
|
||||
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// JWK represents a JSON Web Key as defined in RFC 7517.
|
||||
// It contains the cryptographic key information used for token verification.
|
||||
type JWK struct {
|
||||
// Kty is the key type (e.g., "RSA", "EC")
|
||||
Kty string `json:"kty"`
|
||||
|
||||
// Kid is the unique key identifier
|
||||
Kid string `json:"kid"`
|
||||
|
||||
// Use specifies the intended use of the key (e.g., "sig" for signature)
|
||||
Use string `json:"use"`
|
||||
|
||||
// N is the modulus for RSA keys
|
||||
N string `json:"n"`
|
||||
|
||||
// E is the exponent for RSA keys
|
||||
E string `json:"e"`
|
||||
|
||||
// Alg is the algorithm intended for use with the key
|
||||
N string `json:"n"`
|
||||
E string `json:"e"`
|
||||
Alg string `json:"alg"`
|
||||
|
||||
// Crv is the curve for EC keys (e.g., "P-256", "P-384", "P-521")
|
||||
Crv string `json:"crv"`
|
||||
|
||||
// X is the x-coordinate for EC keys
|
||||
X string `json:"x"`
|
||||
|
||||
// Y is the y-coordinate for EC keys
|
||||
Y string `json:"y"`
|
||||
X string `json:"x"`
|
||||
Y string `json:"y"`
|
||||
}
|
||||
|
||||
// JWKSet represents a set of JSON Web Keys as returned by the JWKS endpoint.
|
||||
// OIDC providers typically expose multiple keys to support key rotation.
|
||||
type JWKSet struct {
|
||||
// Keys is the array of JSON Web Keys
|
||||
Keys []JWK `json:"keys"`
|
||||
}
|
||||
|
||||
// JWKCache provides a thread-safe caching mechanism for JWK sets.
|
||||
// It caches the keys for a configurable duration to reduce load on the OIDC provider
|
||||
// while ensuring keys are refreshed periodically to handle key rotation.
|
||||
type JWKCache struct {
|
||||
// jwks holds the cached set of JSON Web Keys
|
||||
jwks *JWKSet
|
||||
|
||||
// expiresAt is the timestamp when the cached keys should be refreshed
|
||||
jwks *JWKSet
|
||||
expiresAt time.Time
|
||||
|
||||
// mutex protects concurrent access to the cache
|
||||
mutex sync.RWMutex
|
||||
mutex sync.RWMutex
|
||||
// CacheLifetime is configurable to determine how long the JWKS is cached.
|
||||
CacheLifetime time.Duration
|
||||
}
|
||||
|
||||
// JWKCacheInterface defines the interface for JWK caching operations.
|
||||
// This interface allows for different caching implementations while
|
||||
// maintaining consistent behavior in the token verification process.
|
||||
type JWKCacheInterface interface {
|
||||
GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error)
|
||||
Cleanup() // Add Cleanup method to the interface
|
||||
GetJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error)
|
||||
Cleanup()
|
||||
}
|
||||
|
||||
// GetJWKS retrieves the JSON Web Key Set, either from cache or by fetching it
|
||||
// from the OIDC provider. It implements a thread-safe double-checked locking
|
||||
// pattern to prevent multiple simultaneous fetches of the same keys.
|
||||
// Parameters:
|
||||
// - jwksURL: The URL of the JWKS endpoint
|
||||
// - httpClient: The HTTP client to use for fetching keys
|
||||
//
|
||||
// Returns:
|
||||
// - The JSON Web Key Set
|
||||
// - An error if the keys cannot be retrieved or parsed
|
||||
func (c *JWKCache) GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) {
|
||||
func (c *JWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error) {
|
||||
c.mutex.RLock()
|
||||
if c.jwks != nil && time.Now().Before(c.expiresAt) {
|
||||
defer c.mutex.RUnlock()
|
||||
@@ -96,23 +55,25 @@ func (c *JWKCache) GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, er
|
||||
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
if c.jwks != nil && time.Now().Before(c.expiresAt) {
|
||||
return c.jwks, nil
|
||||
}
|
||||
|
||||
jwks, err := fetchJWKS(jwksURL, httpClient)
|
||||
jwks, err := fetchJWKS(ctx, jwksURL, httpClient)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.jwks = jwks
|
||||
c.expiresAt = time.Now().Add(1 * time.Hour)
|
||||
lifetime := c.CacheLifetime
|
||||
if lifetime == 0 {
|
||||
lifetime = 1 * time.Hour
|
||||
}
|
||||
c.expiresAt = time.Now().Add(lifetime)
|
||||
|
||||
return jwks, nil
|
||||
}
|
||||
|
||||
// Cleanup removes expired JWKs from the cache.
|
||||
func (c *JWKCache) Cleanup() {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
@@ -123,17 +84,14 @@ func (c *JWKCache) Cleanup() {
|
||||
}
|
||||
}
|
||||
|
||||
// fetchJWKS retrieves the JSON Web Key Set from the OIDC provider's JWKS endpoint.
|
||||
// It handles HTTP communication and JSON parsing of the response.
|
||||
// Parameters:
|
||||
// - jwksURL: The URL of the JWKS endpoint
|
||||
// - httpClient: The HTTP client to use for the request
|
||||
//
|
||||
// Returns:
|
||||
// - The parsed JSON Web Key Set
|
||||
// - An error if the request fails or the response is invalid
|
||||
func fetchJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) {
|
||||
resp, err := httpClient.Get(jwksURL)
|
||||
func fetchJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error) {
|
||||
// Create a request with context to enforce timeout
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", jwksURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create JWKS request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch JWKS: %w", err)
|
||||
}
|
||||
@@ -151,9 +109,6 @@ func fetchJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) {
|
||||
return &jwks, nil
|
||||
}
|
||||
|
||||
// jwkToPEM converts a JSON Web Key to PEM format for use with standard
|
||||
// cryptographic functions. It supports both RSA and EC keys, delegating
|
||||
// to the appropriate converter based on the key type.
|
||||
func jwkToPEM(jwk *JWK) ([]byte, error) {
|
||||
converter, ok := jwkConverters[jwk.Kty]
|
||||
if !ok {
|
||||
@@ -169,9 +124,6 @@ var jwkConverters = map[string]jwkToPEMConverter{
|
||||
"EC": ecJWKToPEM,
|
||||
}
|
||||
|
||||
// rsaJWKToPEM converts an RSA JSON Web Key to PEM format.
|
||||
// It handles base64url decoding of the modulus and exponent,
|
||||
// constructs an RSA public key, and encodes it in PEM format.
|
||||
func rsaJWKToPEM(jwk *JWK) ([]byte, error) {
|
||||
nBytes, err := base64.RawURLEncoding.DecodeString(jwk.N)
|
||||
if err != nil {
|
||||
@@ -203,10 +155,6 @@ func rsaJWKToPEM(jwk *JWK) ([]byte, error) {
|
||||
return pubKeyPEM, nil
|
||||
}
|
||||
|
||||
// ecJWKToPEM converts an EC (Elliptic Curve) JSON Web Key to PEM format.
|
||||
// It supports the P-256, P-384, and P-521 curves as defined in the
|
||||
// OIDC specification, decoding the x and y coordinates and encoding
|
||||
// the resulting public key in PEM format.
|
||||
func ecJWKToPEM(jwk *JWK) ([]byte, error) {
|
||||
xBytes, err := base64.RawURLEncoding.DecodeString(jwk.X)
|
||||
if err != nil {
|
||||
|
||||
@@ -4,44 +4,41 @@ import (
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rsa"
|
||||
"math/big"
|
||||
"strings"
|
||||
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
|
||||
"math/big"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var replayCacheMu sync.Mutex
|
||||
var replayCache = make(map[string]time.Time)
|
||||
|
||||
func cleanupReplayCache() {
|
||||
now := time.Now()
|
||||
for token, expiry := range replayCache {
|
||||
if expiry.Before(now) {
|
||||
delete(replayCache, token)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ClockSkewTolerance is configurable to adjust time-based validations.
|
||||
var ClockSkewTolerance = 2 * time.Minute
|
||||
|
||||
// JWT represents a JSON Web Token as defined in RFC 7519.
|
||||
// It contains the three parts of a JWT: header, claims (payload),
|
||||
// and signature, along with the original token string.
|
||||
type JWT struct {
|
||||
// Header contains the token metadata (algorithm, key ID, etc.)
|
||||
Header map[string]interface{}
|
||||
|
||||
// Claims contains the token claims (subject, expiration, etc.)
|
||||
Claims map[string]interface{}
|
||||
|
||||
// Signature contains the raw signature bytes
|
||||
Header map[string]interface{}
|
||||
Claims map[string]interface{}
|
||||
Signature []byte
|
||||
|
||||
// Token is the original JWT string
|
||||
Token string
|
||||
Token string
|
||||
}
|
||||
|
||||
// parseJWT parses a JWT token string into a JWT struct.
|
||||
// It validates the token format and decodes the three parts
|
||||
// (header, claims, signature) using base64url decoding.
|
||||
// Parameters:
|
||||
// - tokenString: The raw JWT token string
|
||||
//
|
||||
// Returns:
|
||||
// - A parsed JWT struct
|
||||
// - An error if the token format is invalid or parsing fails
|
||||
func parseJWT(tokenString string) (*JWT, error) {
|
||||
parts := strings.Split(tokenString, ".")
|
||||
if len(parts) != 3 {
|
||||
@@ -52,7 +49,6 @@ func parseJWT(tokenString string) (*JWT, error) {
|
||||
Token: tokenString,
|
||||
}
|
||||
|
||||
// Decode and unmarshal the header
|
||||
headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid JWT format: failed to decode header: %v", err)
|
||||
@@ -61,7 +57,6 @@ func parseJWT(tokenString string) (*JWT, error) {
|
||||
return nil, fmt.Errorf("invalid JWT format: failed to unmarshal header: %v", err)
|
||||
}
|
||||
|
||||
// Decode and unmarshal the claims
|
||||
claimsBytes, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid JWT format: failed to decode claims: %v", err)
|
||||
@@ -70,7 +65,6 @@ func parseJWT(tokenString string) (*JWT, error) {
|
||||
return nil, fmt.Errorf("invalid JWT format: failed to unmarshal claims: %v", err)
|
||||
}
|
||||
|
||||
// Decode the signature
|
||||
signatureBytes, err := base64.RawURLEncoding.DecodeString(parts[2])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid JWT format: failed to decode signature: %v", err)
|
||||
@@ -81,28 +75,13 @@ func parseJWT(tokenString string) (*JWT, error) {
|
||||
}
|
||||
|
||||
// Verify validates the standard JWT claims as defined in RFC 7519.
|
||||
// It checks:
|
||||
// - issuer (iss) matches the expected issuer URL
|
||||
// - audience (aud) includes the client ID
|
||||
// - expiration time (exp) is in the future (with clock skew tolerance)
|
||||
// - issued at time (iat) is in the past (with clock skew tolerance)
|
||||
// - not before time (nbf) is in the past (with clock skew tolerance)
|
||||
// - subject (sub) is present and not empty
|
||||
// - algorithm matches expected value to prevent algorithm switching attacks
|
||||
//
|
||||
// Returns an error if any validation fails.
|
||||
// Verify validates the standard JWT claims as defined in RFC 7519.
|
||||
func (j *JWT) Verify(issuerURL, clientID string) error {
|
||||
// Debug logging of validation parameters
|
||||
fmt.Printf("Validating token against:\nIssuer: %s\nClient ID: %s\n", issuerURL, clientID)
|
||||
// Debug logging of token header
|
||||
fmt.Printf("Token header: %+v\n", j.Header)
|
||||
|
||||
// Validate algorithm to prevent algorithm switching attacks
|
||||
alg, ok := j.Header["alg"].(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("missing 'alg' header")
|
||||
}
|
||||
// List of supported algorithms - should match those in verifySignature
|
||||
supportedAlgs := map[string]bool{
|
||||
"RS256": true, "RS384": true, "RS512": true,
|
||||
"PS256": true, "PS384": true, "PS512": true,
|
||||
@@ -114,9 +93,6 @@ func (j *JWT) Verify(issuerURL, clientID string) error {
|
||||
|
||||
claims := j.Claims
|
||||
|
||||
// Debug logging of all claims
|
||||
fmt.Printf("Token claims: %+v\n", claims)
|
||||
|
||||
iss, ok := claims["iss"].(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("missing 'iss' claim")
|
||||
@@ -149,17 +125,36 @@ func (j *JWT) Verify(issuerURL, clientID string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// Validate nbf (not before) claim if present
|
||||
if nbf, ok := claims["nbf"].(float64); ok {
|
||||
if err := verifyNotBefore(nbf); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Validate jti (JWT ID) claim if present
|
||||
// Implement replay protection by checking the jti (JWT ID)
|
||||
if jti, ok := claims["jti"].(string); ok {
|
||||
// Could add replay detection here if needed
|
||||
_ = jti
|
||||
// Skip replay detection for tokens that are being verified from the cache
|
||||
if j.Token == "" {
|
||||
// This is a parsed JWT without the original token string,
|
||||
// which means it's likely from a cached token verification
|
||||
return nil
|
||||
}
|
||||
|
||||
replayCacheMu.Lock()
|
||||
cleanupReplayCache()
|
||||
if _, exists := replayCache[jti]; exists {
|
||||
replayCacheMu.Unlock()
|
||||
return fmt.Errorf("token replay detected")
|
||||
}
|
||||
expFloat, ok := claims["exp"].(float64)
|
||||
var expTime time.Time
|
||||
if ok {
|
||||
expTime = time.Unix(int64(expFloat), 0)
|
||||
} else {
|
||||
expTime = time.Now().Add(10 * time.Minute)
|
||||
}
|
||||
replayCache[jti] = expTime
|
||||
replayCacheMu.Unlock()
|
||||
}
|
||||
|
||||
sub, ok := claims["sub"].(string)
|
||||
@@ -169,20 +164,7 @@ func (j *JWT) Verify(issuerURL, clientID string) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifyAudience validates the token's audience claim.
|
||||
// The audience can be either a single string or an array of strings.
|
||||
// For array audiences, the expected audience must match any one value.
|
||||
// Parameters:
|
||||
// - tokenAudience: The audience claim from the token
|
||||
// - expectedAudience: The expected audience value
|
||||
//
|
||||
// Returns an error if validation fails.
|
||||
func verifyAudience(tokenAudience interface{}, expectedAudience string) error {
|
||||
// Debug logging
|
||||
fmt.Printf("Verifying audience:\nToken aud: %+v\nExpected: %s\n",
|
||||
tokenAudience, expectedAudience)
|
||||
|
||||
switch aud := tokenAudience.(type) {
|
||||
case string:
|
||||
if aud != expectedAudience {
|
||||
@@ -205,165 +187,80 @@ func verifyAudience(tokenAudience interface{}, expectedAudience string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifyIssuer validates the token's issuer claim.
|
||||
// The issuer URL must exactly match the expected issuer.
|
||||
// Parameters:
|
||||
// - tokenIssuer: The issuer claim from the token
|
||||
// - expectedIssuer: The expected issuer URL
|
||||
//
|
||||
// Returns an error if validation fails.
|
||||
func verifyIssuer(tokenIssuer, expectedIssuer string) error {
|
||||
// Debug logging
|
||||
fmt.Printf("Verifying issuer:\nToken iss: %s\nExpected: %s\n",
|
||||
tokenIssuer, expectedIssuer)
|
||||
|
||||
if tokenIssuer != expectedIssuer {
|
||||
return fmt.Errorf("invalid issuer (token: %s, expected: %s)",
|
||||
tokenIssuer, expectedIssuer)
|
||||
return fmt.Errorf("invalid issuer (token: %s, expected: %s)", tokenIssuer, expectedIssuer)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Clock skew tolerance for time-based validations
|
||||
const clockSkewTolerance = 2 * time.Minute
|
||||
// verifyTimeConstraint is a generic function to verify time-based claims
|
||||
func verifyTimeConstraint(unixTime float64, claimName string, future bool) error {
|
||||
claimTime := time.Unix(int64(unixTime), 0)
|
||||
now := time.Now().Truncate(time.Second)
|
||||
|
||||
// For expiration (future=true), we add skew to now (making now later)
|
||||
// For iat/nbf (future=false), we subtract skew from now (making now earlier)
|
||||
skewDirection := 1
|
||||
if !future {
|
||||
skewDirection = -1
|
||||
}
|
||||
skewedNow := now.Add(time.Duration(skewDirection) * ClockSkewTolerance)
|
||||
|
||||
if claimTime.Equal(now) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// For expiration: if skewedNow (later) is after expiration, token expired
|
||||
// For iat/nbf: if skewedNow (earlier) is before claim time, token not yet valid
|
||||
if (future && skewedNow.After(claimTime)) || (!future && skewedNow.Before(claimTime)) {
|
||||
var reason string
|
||||
if future {
|
||||
reason = "has expired"
|
||||
} else {
|
||||
if claimName == "iat" {
|
||||
reason = "used before issued"
|
||||
} else {
|
||||
reason = "not yet valid"
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("token %s (%s: %v, now: %v)", reason, claimName, claimTime.UTC(), now.UTC())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifyExpiration checks if the token's expiration time has passed.
|
||||
// The expiration time is compared against the current time with clock skew tolerance.
|
||||
// Parameters:
|
||||
// - expiration: The expiration timestamp from the token
|
||||
//
|
||||
// Returns an error if the token has expired.
|
||||
func verifyExpiration(expiration float64) error {
|
||||
expirationTime := time.Unix(int64(expiration), 0)
|
||||
// Truncate current time to seconds for consistent comparison
|
||||
now := time.Now().Truncate(time.Second)
|
||||
skewedNow := now.Add(clockSkewTolerance)
|
||||
|
||||
// Debug logging
|
||||
fmt.Printf("Token exp: %v\nCurrent time: %v\nSkewed time: %v\nSkew: %v\n",
|
||||
expirationTime.UTC(),
|
||||
now.UTC(),
|
||||
skewedNow.UTC(),
|
||||
clockSkewTolerance)
|
||||
|
||||
// Allow tokens that expire exactly now
|
||||
if expirationTime.Equal(now) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if skewedNow.After(expirationTime) {
|
||||
return fmt.Errorf("token has expired (exp: %v, now: %v)",
|
||||
expirationTime.UTC(), now.UTC())
|
||||
}
|
||||
return nil
|
||||
return verifyTimeConstraint(expiration, "exp", true)
|
||||
}
|
||||
|
||||
// verifyIssuedAt validates the token's issued-at time.
|
||||
// Ensures the token wasn't issued in the future, accounting for clock skew.
|
||||
// Parameters:
|
||||
// - issuedAt: The issued-at timestamp from the token
|
||||
//
|
||||
// Returns an error if the token was issued in the future.
|
||||
func verifyIssuedAt(issuedAt float64) error {
|
||||
issuedAtTime := time.Unix(int64(issuedAt), 0)
|
||||
// Truncate current time to seconds for consistent comparison
|
||||
now := time.Now().Truncate(time.Second)
|
||||
skewedNow := now.Add(-clockSkewTolerance)
|
||||
|
||||
// Debug logging
|
||||
fmt.Printf("Token iat: %v\nCurrent time: %v\nSkewed time: %v\nSkew: %v\n",
|
||||
issuedAtTime.UTC(),
|
||||
now.UTC(),
|
||||
skewedNow.UTC(),
|
||||
clockSkewTolerance)
|
||||
|
||||
// Allow tokens issued in the same second as current time
|
||||
if issuedAtTime.Equal(now) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if skewedNow.Before(issuedAtTime) {
|
||||
return fmt.Errorf("token used before issued (iat: %v, now: %v)",
|
||||
issuedAtTime.UTC(), now.UTC())
|
||||
}
|
||||
return nil
|
||||
return verifyTimeConstraint(issuedAt, "iat", false)
|
||||
}
|
||||
|
||||
// verifyNotBefore validates the token's not-before time if present.
|
||||
// Ensures the token is not used before its valid time period, accounting for clock skew.
|
||||
// Parameters:
|
||||
// - notBefore: The not-before timestamp from the token
|
||||
//
|
||||
// Returns an error if the token is not yet valid.
|
||||
func verifyNotBefore(notBefore float64) error {
|
||||
notBeforeTime := time.Unix(int64(notBefore), 0)
|
||||
// Truncate current time to seconds for consistent comparison
|
||||
now := time.Now().Truncate(time.Second)
|
||||
skewedNow := now.Add(-clockSkewTolerance)
|
||||
|
||||
// Debug logging
|
||||
fmt.Printf("Token nbf: %v\nCurrent time: %v\nSkewed time: %v\nSkew: %v\n",
|
||||
notBeforeTime.UTC(),
|
||||
now.UTC(),
|
||||
skewedNow.UTC(),
|
||||
clockSkewTolerance)
|
||||
|
||||
// Allow tokens that become valid exactly now
|
||||
if notBeforeTime.Equal(now) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if skewedNow.Before(notBeforeTime) {
|
||||
return fmt.Errorf("token not yet valid (nbf: %v, now: %v)",
|
||||
notBeforeTime.UTC(), now.UTC())
|
||||
}
|
||||
return nil
|
||||
return verifyTimeConstraint(notBefore, "nbf", false)
|
||||
}
|
||||
|
||||
// verifySignature validates the token's cryptographic signature.
|
||||
// Supports multiple signature algorithms:
|
||||
// - RSA: RS256, RS384, RS512 (PKCS#1 v1.5)
|
||||
// - RSA-PSS: PS256, PS384, PS512
|
||||
// - ECDSA: ES256, ES384, ES512
|
||||
//
|
||||
// Parameters:
|
||||
// - tokenString: The complete JWT token string
|
||||
// - publicKeyPEM: The PEM-encoded public key for verification
|
||||
// - alg: The signature algorithm identifier
|
||||
//
|
||||
// Returns an error if signature verification fails.
|
||||
func verifySignature(tokenString string, publicKeyPEM []byte, alg string) error {
|
||||
// Debug logging
|
||||
fmt.Printf("Verifying signature with algorithm: %s\n", alg)
|
||||
|
||||
// Split the token into its three parts
|
||||
parts := strings.Split(tokenString, ".")
|
||||
if len(parts) != 3 {
|
||||
return fmt.Errorf("invalid token format")
|
||||
}
|
||||
signedContent := parts[0] + "." + parts[1]
|
||||
|
||||
// Decode the signature from the token
|
||||
signature, err := base64.RawURLEncoding.DecodeString(parts[2])
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode signature: %w", err)
|
||||
}
|
||||
|
||||
// Decode the PEM-encoded public key
|
||||
block, _ := pem.Decode(publicKeyPEM)
|
||||
if block == nil {
|
||||
return fmt.Errorf("failed to parse PEM block containing the public key")
|
||||
}
|
||||
|
||||
// Parse the public key
|
||||
pubKey, err := x509.ParsePKIXPublicKey(block.Bytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse public key: %w", err)
|
||||
}
|
||||
|
||||
// Determine the hash function to use based on the algorithm
|
||||
var hashFunc crypto.Hash
|
||||
|
||||
switch alg {
|
||||
case "RS256", "PS256", "ES256":
|
||||
hashFunc = crypto.SHA256
|
||||
@@ -374,27 +271,20 @@ func verifySignature(tokenString string, publicKeyPEM []byte, alg string) error
|
||||
default:
|
||||
return fmt.Errorf("unsupported algorithm: %s", alg)
|
||||
}
|
||||
|
||||
// Hash the signed content
|
||||
h := hashFunc.New()
|
||||
h.Write([]byte(signedContent))
|
||||
hashed := h.Sum(nil)
|
||||
|
||||
// Verify the signature based on the key type and algorithm
|
||||
switch pubKey := pubKey.(type) {
|
||||
case *rsa.PublicKey:
|
||||
if strings.HasPrefix(alg, "RS") {
|
||||
// RSA PKCS#1 v1.5 signature
|
||||
return rsa.VerifyPKCS1v15(pubKey, hashFunc, hashed, signature)
|
||||
} else if strings.HasPrefix(alg, "PS") {
|
||||
// RSA PSS signature
|
||||
return rsa.VerifyPSS(pubKey, hashFunc, hashed, signature, nil)
|
||||
} else {
|
||||
return fmt.Errorf("unexpected key type for algorithm %s", alg)
|
||||
}
|
||||
case *ecdsa.PublicKey:
|
||||
if strings.HasPrefix(alg, "ES") {
|
||||
// ECDSA signature
|
||||
var r, s big.Int
|
||||
sigLen := len(signature)
|
||||
if sigLen%2 != 0 {
|
||||
|
||||
@@ -18,6 +18,40 @@ import (
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// createDefaultHTTPClient creates an HTTP client with optimized settings for OIDC
|
||||
func createDefaultHTTPClient() *http.Client {
|
||||
transport := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
dialer := &net.Dialer{
|
||||
Timeout: 15 * time.Second, // Reduced timeout
|
||||
KeepAlive: 15 * time.Second, // Reduced keepalive
|
||||
}
|
||||
return dialer.DialContext(ctx, network, addr)
|
||||
},
|
||||
ForceAttemptHTTP2: true,
|
||||
TLSHandshakeTimeout: 5 * time.Second, // Reduced from 10s
|
||||
ExpectContinueTimeout: 0,
|
||||
MaxIdleConns: 30, // Reduced from 100
|
||||
MaxIdleConnsPerHost: 10, // Reduced from 100
|
||||
IdleConnTimeout: 30 * time.Second, // Reduced from 90s
|
||||
DisableKeepAlives: false, // Enable connection reuse
|
||||
MaxConnsPerHost: 50, // Limit max connections
|
||||
}
|
||||
|
||||
return &http.Client{
|
||||
Timeout: time.Second * 15, // Reduced timeout
|
||||
Transport: transport,
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
// Always follow redirects for OIDC endpoints
|
||||
if len(via) >= 50 {
|
||||
return fmt.Errorf("stopped after 50 redirects")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
const ConstSessionTimeout = 86400 // Session timeout in seconds
|
||||
|
||||
// TokenVerifier interface for token verification
|
||||
@@ -82,11 +116,40 @@ var defaultExcludedURLs = map[string]struct{}{
|
||||
"/favicon": {},
|
||||
}
|
||||
|
||||
// VerifyToken verifies the provided JWT token
|
||||
func (t *TraefikOidc) VerifyToken(token string) error {
|
||||
// Check cache first
|
||||
if claims, exists := t.tokenCache.Get(token); exists && len(claims) > 0 {
|
||||
t.logger.Debugf("Token found in cache with valid claims; skipping verification")
|
||||
return nil
|
||||
}
|
||||
|
||||
t.logger.Debugf("Verifying token")
|
||||
|
||||
// Rate limiting
|
||||
// Perform pre-verification checks
|
||||
if err := t.performPreVerificationChecks(token); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Parse the JWT
|
||||
jwt, err := parseJWT(token)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse JWT: %w", err)
|
||||
}
|
||||
|
||||
// Verify JWT signature and standard claims
|
||||
if err := t.VerifyJWTSignatureAndClaims(jwt, token); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Cache the verified token
|
||||
t.cacheVerifiedToken(token, jwt.Claims)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// performPreVerificationChecks performs rate limiting and blacklist checks
|
||||
func (t *TraefikOidc) performPreVerificationChecks(token string) error {
|
||||
// Enforce rate limiting
|
||||
if !t.limiter.Allow() {
|
||||
return fmt.Errorf("rate limit exceeded")
|
||||
}
|
||||
@@ -96,30 +159,15 @@ func (t *TraefikOidc) VerifyToken(token string) error {
|
||||
return fmt.Errorf("token is blacklisted")
|
||||
}
|
||||
|
||||
// Check if token is cached
|
||||
if _, exists := t.tokenCache.Get(token); exists {
|
||||
t.logger.Debugf("Token is valid and cached")
|
||||
return nil // Token is valid and cached
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Parse the JWT
|
||||
jwt, err := parseJWT(token)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse JWT: %w", err)
|
||||
}
|
||||
|
||||
// Verify JWT signature and claims
|
||||
if err := t.VerifyJWTSignatureAndClaims(jwt, token); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Cache the token until it expires
|
||||
expirationTime := time.Unix(int64(jwt.Claims["exp"].(float64)), 0)
|
||||
// cacheVerifiedToken caches a verified token until its expiration time
|
||||
func (t *TraefikOidc) cacheVerifiedToken(token string, claims map[string]interface{}) {
|
||||
expirationTime := time.Unix(int64(claims["exp"].(float64)), 0)
|
||||
now := time.Now()
|
||||
duration := expirationTime.Sub(now)
|
||||
t.tokenCache.Set(token, jwt.Claims, duration)
|
||||
|
||||
return nil
|
||||
t.tokenCache.Set(token, claims, duration)
|
||||
}
|
||||
|
||||
// VerifyJWTSignatureAndClaims verifies the JWT signature and standard claims
|
||||
@@ -127,7 +175,7 @@ func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error
|
||||
t.logger.Debugf("Verifying JWT signature and claims")
|
||||
|
||||
// Get JWKS
|
||||
jwks, err := t.jwkCache.GetJWKS(t.jwksURL, t.httpClient)
|
||||
jwks, err := t.jwkCache.GetJWKS(context.Background(), t.jwksURL, t.httpClient)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get JWKS: %w", err)
|
||||
}
|
||||
@@ -187,7 +235,6 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
|
||||
|
||||
// Initialize logger
|
||||
logger := NewLogger(config.LogLevel)
|
||||
|
||||
// Ensure key meets minimum length requirement
|
||||
if len(config.SessionEncryptionKey) < minEncryptionKeyLength {
|
||||
if runtime.Compiler == "yaegi" {
|
||||
@@ -198,42 +245,12 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
|
||||
return nil, fmt.Errorf("encryption key must be at least %d bytes long", minEncryptionKeyLength)
|
||||
}
|
||||
}
|
||||
|
||||
// Setup HTTP client
|
||||
transport := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
dialer := &net.Dialer{
|
||||
Timeout: 15 * time.Second, // Reduced timeout
|
||||
KeepAlive: 15 * time.Second, // Reduced keepalive
|
||||
}
|
||||
return dialer.DialContext(ctx, network, addr)
|
||||
},
|
||||
ForceAttemptHTTP2: true,
|
||||
TLSHandshakeTimeout: 5 * time.Second, // Reduced from 10s
|
||||
ExpectContinueTimeout: 0,
|
||||
MaxIdleConns: 30, // Reduced from 100
|
||||
MaxIdleConnsPerHost: 10, // Reduced from 100
|
||||
IdleConnTimeout: 30 * time.Second, // Reduced from 90s
|
||||
DisableKeepAlives: false, // Enable connection reuse
|
||||
MaxConnsPerHost: 50, // Limit max connections
|
||||
}
|
||||
|
||||
var httpClient *http.Client
|
||||
if config.HTTPClient != nil {
|
||||
httpClient = config.HTTPClient
|
||||
} else {
|
||||
httpClient = &http.Client{
|
||||
Timeout: time.Second * 15, // Reduced timeout
|
||||
Transport: transport,
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
// Always follow redirects for OIDC endpoints
|
||||
if len(via) >= 50 {
|
||||
return fmt.Errorf("stopped after 50 redirects")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
httpClient = createDefaultHTTPClient()
|
||||
}
|
||||
|
||||
t := &TraefikOidc{
|
||||
@@ -303,12 +320,7 @@ func (t *TraefikOidc) initializeMetadata(providerURL string) {
|
||||
|
||||
if metadata != nil {
|
||||
t.logger.Debug("Successfully initialized provider metadata")
|
||||
t.jwksURL = metadata.JWKSURL
|
||||
t.authURL = metadata.AuthURL
|
||||
t.tokenURL = metadata.TokenURL
|
||||
t.issuerURL = metadata.Issuer
|
||||
t.revocationURL = metadata.RevokeURL
|
||||
t.endSessionURL = metadata.EndSessionURL
|
||||
t.updateMetadataEndpoints(metadata)
|
||||
|
||||
// Start metadata refresh goroutine
|
||||
go t.startMetadataRefresh(providerURL)
|
||||
@@ -321,6 +333,16 @@ func (t *TraefikOidc) initializeMetadata(providerURL string) {
|
||||
t.logger.Error("Received nil metadata")
|
||||
}
|
||||
|
||||
// updateMetadataEndpoints updates the middleware with metadata endpoints
|
||||
func (t *TraefikOidc) updateMetadataEndpoints(metadata *ProviderMetadata) {
|
||||
t.jwksURL = metadata.JWKSURL
|
||||
t.authURL = metadata.AuthURL
|
||||
t.tokenURL = metadata.TokenURL
|
||||
t.issuerURL = metadata.Issuer
|
||||
t.revocationURL = metadata.RevokeURL
|
||||
t.endSessionURL = metadata.EndSessionURL
|
||||
}
|
||||
|
||||
// startMetadataRefresh periodically refreshes the OIDC metadata
|
||||
func (t *TraefikOidc) startMetadataRefresh(providerURL string) {
|
||||
ticker := time.NewTicker(1 * time.Hour)
|
||||
@@ -335,12 +357,7 @@ func (t *TraefikOidc) startMetadataRefresh(providerURL string) {
|
||||
}
|
||||
|
||||
if metadata != nil {
|
||||
t.jwksURL = metadata.JWKSURL
|
||||
t.authURL = metadata.AuthURL
|
||||
t.tokenURL = metadata.TokenURL
|
||||
t.issuerURL = metadata.Issuer
|
||||
t.revocationURL = metadata.RevokeURL
|
||||
t.endSessionURL = metadata.EndSessionURL
|
||||
t.updateMetadataEndpoints(metadata)
|
||||
t.logger.Debug("Successfully refreshed metadata")
|
||||
}
|
||||
}
|
||||
@@ -692,19 +709,24 @@ func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce string) string {
|
||||
params.Set("scope", strings.Join(t.scopes, " "))
|
||||
}
|
||||
|
||||
// Ensure authURL is absolute
|
||||
if !strings.HasPrefix(t.authURL, "http://") && !strings.HasPrefix(t.authURL, "https://") {
|
||||
return t.buildURLWithParams(t.authURL, params)
|
||||
}
|
||||
|
||||
// buildURLWithParams ensures a URL is absolute and appends query parameters
|
||||
func (t *TraefikOidc) buildURLWithParams(baseURL string, params url.Values) string {
|
||||
// Ensure URL is absolute
|
||||
if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
|
||||
// Extract issuer base URL
|
||||
issuerURL, err := url.Parse(t.issuerURL)
|
||||
if err == nil {
|
||||
return fmt.Sprintf("%s://%s%s?%s",
|
||||
issuerURL.Scheme,
|
||||
issuerURL.Host,
|
||||
t.authURL,
|
||||
baseURL,
|
||||
params.Encode())
|
||||
}
|
||||
}
|
||||
return t.authURL + "?" + params.Encode()
|
||||
return baseURL + "?" + params.Encode()
|
||||
}
|
||||
|
||||
// startTokenCleanup starts the token cleanup goroutine
|
||||
|
||||
+15
-4
@@ -131,7 +131,7 @@ type MockJWKCache struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
func (m *MockJWKCache) GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) {
|
||||
func (m *MockJWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error) {
|
||||
return m.JWKS, m.Err
|
||||
}
|
||||
|
||||
@@ -227,7 +227,7 @@ func TestVerifyToken(t *testing.T) {
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Reset token blacklist and cache
|
||||
// Reset token blacklist and cache for each test
|
||||
ts.tOidc.tokenBlacklist = NewTokenBlacklist()
|
||||
ts.tOidc.tokenCache = NewTokenCache()
|
||||
ts.tOidc.limiter = rate.NewLimiter(rate.Every(time.Second), 10)
|
||||
@@ -243,9 +243,20 @@ func TestVerifyToken(t *testing.T) {
|
||||
}
|
||||
|
||||
if tc.cacheToken {
|
||||
// Use more realistic claims for cached token
|
||||
ts.tOidc.tokenCache.Set(tc.token, map[string]interface{}{
|
||||
"empty": "claim",
|
||||
}, 60)
|
||||
"iss": "https://test-issuer.com",
|
||||
"sub": "test-subject",
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"jti": generateRandomString(16), // Add a JTI claim to prevent replay detection
|
||||
}, time.Minute)
|
||||
|
||||
// Verify the token is actually in the cache
|
||||
if claims, exists := ts.tOidc.tokenCache.Get(tc.token); exists {
|
||||
t.Logf("Token found in cache with claims: %v", claims)
|
||||
} else {
|
||||
t.Logf("Token NOT found in cache despite cacheToken=true")
|
||||
}
|
||||
}
|
||||
|
||||
err := ts.tOidc.VerifyToken(tc.token)
|
||||
|
||||
+2
-1
@@ -35,8 +35,9 @@ func (c *MetadataCache) Cleanup() {
|
||||
}
|
||||
}
|
||||
func (c *MetadataCache) isCacheValid() bool {
|
||||
return c.metadata != nil && time.Now().Before(c.expiresAt)
|
||||
return c.metadata != nil && time.Now().Before(c.expiresAt)
|
||||
}
|
||||
|
||||
// GetMetadata retrieves the metadata from cache or fetches it if expired
|
||||
func (c *MetadataCache) GetMetadata(providerURL string, httpClient *http.Client, logger *Logger) (*ProviderMetadata, error) {
|
||||
c.mutex.RLock()
|
||||
|
||||
@@ -116,4 +116,4 @@ func TestGetMetadata_FetchError(t *testing.T) {
|
||||
if metadata != dummy {
|
||||
t.Errorf("Expected cached metadata to be returned")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user