From 4baf3fbefd06661745707066b92f96a5b4bdc168 Mon Sep 17 00:00:00 2001 From: Lukasz Raczylo Date: Wed, 24 Jul 2024 23:53:41 +0100 Subject: [PATCH] Optimise the JWT token cache / creation and verification. --- jwt.go | 218 +++++++++++++++++++++++++++++++++++--------------------- main.go | 51 +------------ 2 files changed, 138 insertions(+), 131 deletions(-) diff --git a/jwt.go b/jwt.go index 2b1d4eb..9279140 100644 --- a/jwt.go +++ b/jwt.go @@ -14,119 +14,175 @@ import ( ) type JWT struct { - Header map[string]interface{} - Claims map[string]interface{} - Signature string + Header map[string]interface{} + Claims map[string]interface{} + Signature string } func parseJWT(token string) (*JWT, error) { - parts := strings.Split(token, ".") - if len(parts) != 3 { - return nil, fmt.Errorf("invalid token format") - } + parts := strings.Split(token, ".") + if len(parts) != 3 { + return nil, fmt.Errorf("invalid token format") + } - header, err := decodeSegment(parts[0]) - if err != nil { - return nil, fmt.Errorf("failed to decode header: %w", err) - } + header, err := decodeSegment(parts[0]) + if err != nil { + return nil, fmt.Errorf("failed to decode header: %w", err) + } - claims, err := decodeSegment(parts[1]) - if err != nil { - return nil, fmt.Errorf("failed to decode claims: %w", err) - } + claims, err := decodeSegment(parts[1]) + if err != nil { + return nil, fmt.Errorf("failed to decode claims: %w", err) + } - return &JWT{ - Header: header, - Claims: claims, - Signature: parts[2], - }, nil + return &JWT{ + Header: header, + Claims: claims, + Signature: parts[2], + }, nil } func (j *JWT) Verify(issuerURL, clientID string) error { - claims := j.Claims + claims := j.Claims - if err := verifyIssuer(claims["iss"].(string), issuerURL); err != nil { - return err - } + if err := verifyIssuer(claims["iss"].(string), issuerURL); err != nil { + return err + } - if err := verifyAudience(claims["aud"].(string), clientID); err != nil { - return err - } + if err := verifyAudience(claims["aud"].(string), clientID); err != nil { + return err + } - if err := verifyExpiration(claims["exp"].(float64)); err != nil { - return err - } + if err := verifyExpiration(claims["exp"].(float64)); err != nil { + return err + } - if err := verifyIssuedAt(claims["iat"].(float64)); err != nil { - return err - } + if err := verifyIssuedAt(claims["iat"].(float64)); err != nil { + return err + } - return nil + return nil } func verifyExpiration(expiration float64) error { - expirationTime := time.Unix(int64(expiration), 0) - if time.Now().After(expirationTime) { - return fmt.Errorf("token has expired") - } - return nil + expirationTime := time.Unix(int64(expiration), 0) + if time.Now().After(expirationTime) { + return fmt.Errorf("token has expired") + } + return nil } func verifySignature(token string, publicKeyPEM []byte) error { - parts := strings.Split(token, ".") - if len(parts) != 3 { - return fmt.Errorf("invalid token format") - } + parts := strings.Split(token, ".") + if len(parts) != 3 { + return fmt.Errorf("invalid token format") + } - block, _ := pem.Decode(publicKeyPEM) - if block == nil { - return fmt.Errorf("failed to parse PEM block containing the public key") - } + block, _ := pem.Decode(publicKeyPEM) + if block == nil { + return fmt.Errorf("failed to parse PEM block containing the public key") + } - pub, err := x509.ParsePKIXPublicKey(block.Bytes) - if err != nil { - return fmt.Errorf("failed to parse public key: %w", err) - } + pub, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + return fmt.Errorf("failed to parse public key: %w", err) + } - rsaPublicKey, ok := pub.(*rsa.PublicKey) - if !ok { - return fmt.Errorf("not an RSA public key") - } + rsaPublicKey, ok := pub.(*rsa.PublicKey) + if !ok { + return fmt.Errorf("not an RSA public key") + } - signedContent := parts[0] + "." + parts[1] - signature, err := base64.RawURLEncoding.DecodeString(parts[2]) - if err != nil { - return fmt.Errorf("failed to decode signature: %w", err) - } + signedContent := parts[0] + "." + parts[1] + signature, err := base64.RawURLEncoding.DecodeString(parts[2]) + if err != nil { + return fmt.Errorf("failed to decode signature: %w", err) + } - hash := sha256.Sum256([]byte(signedContent)) - err = rsa.VerifyPKCS1v15(rsaPublicKey, crypto.SHA256, hash[:], signature) - if err != nil { - return fmt.Errorf("invalid token signature: %w", err) - } + hash := sha256.Sum256([]byte(signedContent)) + err = rsa.VerifyPKCS1v15(rsaPublicKey, crypto.SHA256, hash[:], signature) + if err != nil { + return fmt.Errorf("invalid token signature: %w", err) + } - return nil + return nil } func verifyIssuedAt(issuedAt float64) error { - issuedAtTime := time.Unix(int64(issuedAt), 0) - if time.Now().Before(issuedAtTime) { - return fmt.Errorf("token used before issued") - } - return nil + issuedAtTime := time.Unix(int64(issuedAt), 0) + if time.Now().Before(issuedAtTime) { + return fmt.Errorf("token used before issued") + } + return nil } func decodeSegment(seg string) (map[string]interface{}, error) { - data, err := base64.RawURLEncoding.DecodeString(seg) - if err != nil { - return nil, fmt.Errorf("failed to decode segment: %w", err) - } + data, err := base64.RawURLEncoding.DecodeString(seg) + if err != nil { + return nil, fmt.Errorf("failed to decode segment: %w", err) + } - var result map[string]interface{} - err = json.Unmarshal(data, &result) - if err != nil { - return nil, fmt.Errorf("failed to unmarshal segment: %w", err) - } + var result map[string]interface{} + err = json.Unmarshal(data, &result) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal segment: %w", err) + } - return result, nil + return result, nil +} + +func (t *TraefikOidc) verifyAndCacheToken(token string) error { + if !t.limiter.Allow() { + return fmt.Errorf("rate limit exceeded") + } + + if _, exists := t.tokenCache.Get(token); exists { + return nil // Token is valid and cached + } + + jwt, err := parseJWT(token) + if err != nil { + return fmt.Errorf("failed to parse JWT: %w", err) + } + + if err := t.verifyJWTSignatureAndClaims(jwt, token); err != nil { + return err + } + + expirationTime := time.Unix(int64(jwt.Claims["exp"].(float64)), 0) + t.tokenCache.Set(token, expirationTime) + + return nil +} + +func (t *TraefikOidc) verifyJWTSignatureAndClaims(jwt *JWT, token string) error { + jwks, err := t.jwkCache.GetJWKS(t.jwksURL, t.httpClient) + if err != nil { + return fmt.Errorf("failed to get JWKS: %w", err) + } + + kid, ok := jwt.Header["kid"].(string) + if !ok { + return fmt.Errorf("missing key ID in token header") + } + + publicKeyPEM, err := getPublicKeyPEM(jwks, kid) + if err != nil { + return err + } + + if err := verifySignature(token, publicKeyPEM); err != nil { + return fmt.Errorf("signature verification failed: %w", err) + } + + return jwt.Verify(t.issuerURL, t.clientID) +} + +func getPublicKeyPEM(jwks *JWKSet, kid string) ([]byte, error) { + for _, key := range jwks.Keys { + if key.Kid == kid { + return jwkToPEM(&key) + } + } + return nil, fmt.Errorf("unable to find matching public key") } diff --git a/main.go b/main.go index 27d571e..e8a684a 100644 --- a/main.go +++ b/main.go @@ -202,56 +202,7 @@ func (t *TraefikOidc) initiateAuthentication(rw http.ResponseWriter, req *http.R } func (t *TraefikOidc) verifyToken(token string) error { - if !t.limiter.Allow() { - return fmt.Errorf("rate limit exceeded") - } - - if _, exists := t.tokenCache.Get(token); exists { - return nil // Token is valid and cached - } - - jwt, err := parseJWT(token) - if err != nil { - return fmt.Errorf("failed to parse JWT: %w", err) - } - - jwks, err := t.jwkCache.GetJWKS(t.jwksURL, t.httpClient) - if err != nil { - return fmt.Errorf("failed to get JWKS: %w", err) - } - - kid, ok := jwt.Header["kid"].(string) - if !ok { - return fmt.Errorf("missing key ID in token header") - } - - var publicKeyPEM []byte - for _, key := range jwks.Keys { - if key.Kid == kid { - publicKeyPEM, err = jwkToPEM(&key) - if err != nil { - return fmt.Errorf("failed to convert JWK to PEM: %w", err) - } - break - } - } - - if publicKeyPEM == nil { - return fmt.Errorf("unable to find matching public key") - } - - if err := verifySignature(token, publicKeyPEM); err != nil { - return fmt.Errorf("signature verification failed: %w", err) - } - - if err := jwt.Verify(t.issuerURL, t.clientID); err != nil { - return fmt.Errorf("JWT verification failed: %w", err) - } - - expirationTime := time.Unix(int64(jwt.Claims["exp"].(float64)), 0) - t.tokenCache.Set(token, expirationTime) - - return nil + return t.verifyAndCacheToken(token) } func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce string) string {