From 9ff6779caaf608f08dd8bc75d7dbf2f0c7487dbb Mon Sep 17 00:00:00 2001 From: Lukasz Raczylo Date: Wed, 2 Oct 2024 22:04:40 +0100 Subject: [PATCH] Add support for different signing algorithms --- jwk.go | 82 ++++++++++++++++++++++++----------- jwt.go | 119 ++++++++++++++++++++------------------------------- main.go | 53 ++++++++++++++++++----- main_test.go | 5 ++- 4 files changed, 149 insertions(+), 110 deletions(-) diff --git a/jwk.go b/jwk.go index 8c0a218..a5e07df 100644 --- a/jwk.go +++ b/jwk.go @@ -1,6 +1,8 @@ package traefikoidc import ( + "crypto/ecdsa" + "crypto/elliptic" "crypto/rsa" "crypto/x509" "encoding/base64" @@ -20,6 +22,9 @@ type JWK struct { N string `json:"n"` E string `json:"e"` Alg string `json:"alg"` + Crv string `json:"crv"` + X string `json:"x"` + Y string `json:"y"` } type JWKSet struct { @@ -77,13 +82,6 @@ func fetchJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) { return &jwks, nil } -func verifyNonce(tokenNonce, expectedNonce string) error { - if tokenNonce != expectedNonce { - return fmt.Errorf("invalid nonce") - } - return nil -} - func verifyAudience(tokenAudience, expectedAudience string) error { if tokenAudience != expectedAudience { return fmt.Errorf("invalid audience") @@ -91,17 +89,6 @@ func verifyAudience(tokenAudience, expectedAudience string) error { return nil } -func verifyTokenTimes(issuedAt, expiration int64, allowedClockSkew time.Duration) error { - now := time.Now().Unix() - if now < issuedAt-int64(allowedClockSkew.Seconds()) { - return fmt.Errorf("token used before issued") - } - if now > expiration+int64(allowedClockSkew.Seconds()) { - return fmt.Errorf("token is expired") - } - return nil -} - func verifyIssuer(tokenIssuer, expectedIssuer string) error { if tokenIssuer != expectedIssuer { return fmt.Errorf("invalid issuer") @@ -109,17 +96,18 @@ func verifyIssuer(tokenIssuer, expectedIssuer string) error { return nil } -func validateClaims(claims map[string]interface{}) error { - requiredClaims := []string{"sub", "iss", "aud", "exp", "iat"} - for _, claim := range requiredClaims { - if _, ok := claims[claim]; !ok { - return fmt.Errorf("missing required claim: %s", claim) - } +func jwkToPEM(jwk *JWK) ([]byte, error) { + switch jwk.Kty { + case "RSA": + return rsaJWKToPEM(jwk) + case "EC": + return ecJWKToPEM(jwk) + default: + return nil, fmt.Errorf("unsupported key type: %s", jwk.Kty) } - return nil } -func jwkToPEM(jwk *JWK) ([]byte, error) { +func rsaJWKToPEM(jwk *JWK) ([]byte, error) { n, err := base64.RawURLEncoding.DecodeString(jwk.N) if err != nil { return nil, fmt.Errorf("failed to decode JWK 'n' parameter: %w", err) @@ -146,3 +134,45 @@ func jwkToPEM(jwk *JWK) ([]byte, error) { return publicKeyPEM, nil } + +func ecJWKToPEM(jwk *JWK) ([]byte, error) { + x, err := base64.RawURLEncoding.DecodeString(jwk.X) + if err != nil { + return nil, fmt.Errorf("failed to decode JWK 'x' parameter: %w", err) + } + + y, err := base64.RawURLEncoding.DecodeString(jwk.Y) + if err != nil { + return nil, fmt.Errorf("failed to decode JWK 'y' parameter: %w", err) + } + + var curve elliptic.Curve + switch jwk.Crv { + case "P-256": + curve = elliptic.P256() + case "P-384": + curve = elliptic.P384() + case "P-521": + curve = elliptic.P521() + default: + return nil, fmt.Errorf("unsupported elliptic curve: %s", jwk.Crv) + } + + publicKey := &ecdsa.PublicKey{ + Curve: curve, + X: new(big.Int).SetBytes(x), + Y: new(big.Int).SetBytes(y), + } + + publicKeyBytes, err := x509.MarshalPKIXPublicKey(publicKey) + if err != nil { + return nil, fmt.Errorf("failed to marshal public key: %w", err) + } + + publicKeyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "PUBLIC KEY", + Bytes: publicKeyBytes, + }) + + return publicKeyPEM, nil +} diff --git a/jwt.go b/jwt.go index 60f15be..c97d71d 100644 --- a/jwt.go +++ b/jwt.go @@ -2,18 +2,16 @@ package traefikoidc import ( "crypto" + "crypto/ecdsa" "crypto/rsa" - "crypto/sha256" "crypto/x509" "encoding/base64" "encoding/json" "encoding/pem" "fmt" + "math/big" "strings" - "sync" "time" - - "golang.org/x/sync/errgroup" ) type JWT struct { @@ -75,69 +73,63 @@ func verifyExpiration(expiration float64) error { return nil } -func (t *TraefikOidc) verifySignatureConcurrently(token string, publicKeys map[string][]byte) error { - parts := strings.Split(token, ".") - if len(parts) != 3 { - return fmt.Errorf("invalid token format") - } - - signedContent := parts[0] + "." + parts[1] - signature, err := base64.RawURLEncoding.DecodeString(parts[2]) - if err != nil { - return fmt.Errorf("failed to decode signature: %w", err) - } - - var eg errgroup.Group - var mu sync.Mutex - var verificationSuccess bool - - for _, publicKeyPEM := range publicKeys { - publicKeyPEM := publicKeyPEM // Create a new variable for the goroutine - eg.Go(func() error { - err := verifySignature(signedContent, signature, publicKeyPEM) - if err == nil { - mu.Lock() - verificationSuccess = true - mu.Unlock() - } - return nil // Always return nil to continue checking other keys - }) - } - - if err := eg.Wait(); err != nil { - return err - } - - if !verificationSuccess { - return fmt.Errorf("signature verification failed for all keys") - } - - return nil -} - -func verifySignature(signedContent string, signature []byte, publicKeyPEM []byte) error { +func verifySignature(signedContent string, signature []byte, publicKeyPEM []byte, alg string) error { block, _ := pem.Decode(publicKeyPEM) if block == nil { return fmt.Errorf("failed to parse PEM block containing the public key") } - pub, err := x509.ParsePKIXPublicKey(block.Bytes) + pubKey, err := x509.ParsePKIXPublicKey(block.Bytes) if err != nil { return fmt.Errorf("failed to parse public key: %w", err) } - rsaPublicKey, ok := pub.(*rsa.PublicKey) - if !ok { - return fmt.Errorf("not an RSA public key") + var hashFunc crypto.Hash + + switch alg { + case "RS256", "PS256", "ES256": + hashFunc = crypto.SHA256 + case "RS384", "PS384", "ES384": + hashFunc = crypto.SHA384 + case "RS512", "PS512", "ES512": + hashFunc = crypto.SHA512 + default: + return fmt.Errorf("unsupported algorithm: %s", alg) } - hash := sha256.Sum256([]byte(signedContent)) - err = rsa.VerifyPKCS1v15(rsaPublicKey, crypto.SHA256, hash[:], signature) - if err != nil { - return fmt.Errorf("invalid token signature: %w", err) - } + h := hashFunc.New() + h.Write([]byte(signedContent)) + hashed := h.Sum(nil) - return nil + switch pub := pubKey.(type) { + case *ecdsa.PublicKey: + if strings.HasPrefix(alg, "ES") { + // ECDSA signature handling + keyBytes := (pub.Params().BitSize + 7) / 8 + if len(signature) != 2*keyBytes { + return fmt.Errorf("invalid signature length: expected %d bytes, got %d bytes", 2*keyBytes, len(signature)) + } + r := new(big.Int).SetBytes(signature[:keyBytes]) + s := new(big.Int).SetBytes(signature[keyBytes:]) + + if ecdsa.Verify(pub, hashed, r, s) { + return nil + } + return fmt.Errorf("invalid ECDSA signature") + } + return fmt.Errorf("algorithm %s is not compatible with ECDSA public key", alg) + case *rsa.PublicKey: + if strings.HasPrefix(alg, "RS") { + err := rsa.VerifyPKCS1v15(pub, hashFunc, hashed, signature) + if err != nil { + return fmt.Errorf("RSA signature verification failed: %w", err) + } + return nil + } + return fmt.Errorf("algorithm %s is not compatible with RSA public key", alg) + default: + return fmt.Errorf("unsupported public key type: %T", pub) + } } func verifyIssuedAt(issuedAt float64) error { @@ -162,20 +154,3 @@ func decodeSegment(seg string) (map[string]interface{}, error) { return result, nil } - -func (t *TraefikOidc) verifyAndCacheToken(token string) error { - return t.tokenVerifier.VerifyToken(token) -} - -func (t *TraefikOidc) verifyJWTSignatureAndClaims(jwt *JWT, token string) error { - return t.jwtVerifier.VerifyJWTSignatureAndClaims(jwt, token) -} - -func getPublicKeyPEM(jwks *JWKSet, kid string) ([]byte, error) { - for _, key := range jwks.Keys { - if key.Kid == kid { - return jwkToPEM(&key) - } - } - return nil, fmt.Errorf("unable to find matching public key") -} diff --git a/main.go b/main.go index 4f1f5b4..b68da30 100644 --- a/main.go +++ b/main.go @@ -2,6 +2,7 @@ package traefikoidc import ( "context" + "encoding/base64" "encoding/json" "fmt" "io" @@ -98,6 +99,8 @@ func (t *TraefikOidc) VerifyToken(token string) error { } func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error { + t.logger.Debugf("Verifying JWT. Header: %+v", jwt.Header) + jwks, err := t.jwkCache.GetJWKS(t.jwksURL, t.httpClient) if err != nil { return fmt.Errorf("failed to get JWKS: %w", err) @@ -107,27 +110,57 @@ func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error if !ok { return fmt.Errorf("missing key ID in token header") } + t.logger.Debugf("Token kid: %s", kid) - publicKeys := make(map[string][]byte) + alg, ok := jwt.Header["alg"].(string) + if !ok { + return fmt.Errorf("missing algorithm in token header") + } + t.logger.Debugf("Token alg: %s", alg) + + var matchingKey *JWK for _, key := range jwks.Keys { if key.Kid == kid { - publicKeyPEM, err := jwkToPEM(&key) - if err != nil { - return err - } - publicKeys[key.Kid] = publicKeyPEM + matchingKey = &key + break } } - if len(publicKeys) == 0 { - return fmt.Errorf("no matching public keys found") + if matchingKey == nil { + return fmt.Errorf("no matching public key found for kid: %s", kid) + } + t.logger.Debugf("Matching key found. Type: %s, Algorithm: %s", matchingKey.Kty, matchingKey.Alg) + + publicKeyPEM, err := jwkToPEM(matchingKey) + if err != nil { + return fmt.Errorf("failed to convert JWK to PEM: %w", err) + } + t.logger.Debugf("Public key PEM generated. Length: %d", len(publicKeyPEM)) + + parts := strings.Split(token, ".") + if len(parts) != 3 { + return fmt.Errorf("invalid token format") } - if err := t.verifySignatureConcurrently(token, publicKeys); err != nil { + signedContent := parts[0] + "." + parts[1] + signature, err := base64.RawURLEncoding.DecodeString(parts[2]) + if err != nil { + return fmt.Errorf("failed to decode signature: %w", err) + } + + if err := verifySignature(signedContent, signature, publicKeyPEM, alg); err != nil { + t.logger.Errorf("Signature verification failed: %v", err) return fmt.Errorf("signature verification failed: %w", err) } + t.logger.Debug("Signature verified successfully") - return jwt.Verify(t.issuerURL, t.clientID) + // Verify standard claims + if err := jwt.Verify(t.issuerURL, t.clientID); err != nil { + return fmt.Errorf("standard claim verification failed: %w", err) + } + t.logger.Debug("Standard claims verified successfully") + + return nil } func New(ctx context.Context, next http.Handler, config *Config, name string) (http.Handler, error) { diff --git a/main_test.go b/main_test.go index 1759821..f09e45d 100644 --- a/main_test.go +++ b/main_test.go @@ -247,8 +247,9 @@ func (suite *TraefikOidcTestSuite) TestBuildAuthURL() { func (suite *TraefikOidcTestSuite) TestJWKToPEM() { jwk := &JWK{ - N: base64.RawURLEncoding.EncodeToString(big.NewInt(12345).Bytes()), - E: base64.RawURLEncoding.EncodeToString(big.NewInt(65537).Bytes()), + Kty: "RSA", // Set the key type to RSA + N: base64.RawURLEncoding.EncodeToString(big.NewInt(12345).Bytes()), + E: base64.RawURLEncoding.EncodeToString(big.NewInt(65537).Bytes()), } pem, err := jwkToPEM(jwk) suite.Require().NoError(err)