Add support for different signing algorithms

This commit is contained in:
2024-10-02 22:04:40 +01:00
parent a7d42de0a4
commit 9ff6779caa
4 changed files with 149 additions and 110 deletions
+56 -26
View File
@@ -1,6 +1,8 @@
package traefikoidc package traefikoidc
import ( import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rsa" "crypto/rsa"
"crypto/x509" "crypto/x509"
"encoding/base64" "encoding/base64"
@@ -20,6 +22,9 @@ type JWK struct {
N string `json:"n"` N string `json:"n"`
E string `json:"e"` E string `json:"e"`
Alg string `json:"alg"` Alg string `json:"alg"`
Crv string `json:"crv"`
X string `json:"x"`
Y string `json:"y"`
} }
type JWKSet struct { type JWKSet struct {
@@ -77,13 +82,6 @@ func fetchJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) {
return &jwks, nil return &jwks, nil
} }
func verifyNonce(tokenNonce, expectedNonce string) error {
if tokenNonce != expectedNonce {
return fmt.Errorf("invalid nonce")
}
return nil
}
func verifyAudience(tokenAudience, expectedAudience string) error { func verifyAudience(tokenAudience, expectedAudience string) error {
if tokenAudience != expectedAudience { if tokenAudience != expectedAudience {
return fmt.Errorf("invalid audience") return fmt.Errorf("invalid audience")
@@ -91,17 +89,6 @@ func verifyAudience(tokenAudience, expectedAudience string) error {
return nil return nil
} }
func verifyTokenTimes(issuedAt, expiration int64, allowedClockSkew time.Duration) error {
now := time.Now().Unix()
if now < issuedAt-int64(allowedClockSkew.Seconds()) {
return fmt.Errorf("token used before issued")
}
if now > expiration+int64(allowedClockSkew.Seconds()) {
return fmt.Errorf("token is expired")
}
return nil
}
func verifyIssuer(tokenIssuer, expectedIssuer string) error { func verifyIssuer(tokenIssuer, expectedIssuer string) error {
if tokenIssuer != expectedIssuer { if tokenIssuer != expectedIssuer {
return fmt.Errorf("invalid issuer") return fmt.Errorf("invalid issuer")
@@ -109,17 +96,18 @@ func verifyIssuer(tokenIssuer, expectedIssuer string) error {
return nil return nil
} }
func validateClaims(claims map[string]interface{}) error { func jwkToPEM(jwk *JWK) ([]byte, error) {
requiredClaims := []string{"sub", "iss", "aud", "exp", "iat"} switch jwk.Kty {
for _, claim := range requiredClaims { case "RSA":
if _, ok := claims[claim]; !ok { return rsaJWKToPEM(jwk)
return fmt.Errorf("missing required claim: %s", claim) case "EC":
} return ecJWKToPEM(jwk)
default:
return nil, fmt.Errorf("unsupported key type: %s", jwk.Kty)
} }
return nil
} }
func jwkToPEM(jwk *JWK) ([]byte, error) { func rsaJWKToPEM(jwk *JWK) ([]byte, error) {
n, err := base64.RawURLEncoding.DecodeString(jwk.N) n, err := base64.RawURLEncoding.DecodeString(jwk.N)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to decode JWK 'n' parameter: %w", err) return nil, fmt.Errorf("failed to decode JWK 'n' parameter: %w", err)
@@ -146,3 +134,45 @@ func jwkToPEM(jwk *JWK) ([]byte, error) {
return publicKeyPEM, nil return publicKeyPEM, nil
} }
func ecJWKToPEM(jwk *JWK) ([]byte, error) {
x, err := base64.RawURLEncoding.DecodeString(jwk.X)
if err != nil {
return nil, fmt.Errorf("failed to decode JWK 'x' parameter: %w", err)
}
y, err := base64.RawURLEncoding.DecodeString(jwk.Y)
if err != nil {
return nil, fmt.Errorf("failed to decode JWK 'y' parameter: %w", err)
}
var curve elliptic.Curve
switch jwk.Crv {
case "P-256":
curve = elliptic.P256()
case "P-384":
curve = elliptic.P384()
case "P-521":
curve = elliptic.P521()
default:
return nil, fmt.Errorf("unsupported elliptic curve: %s", jwk.Crv)
}
publicKey := &ecdsa.PublicKey{
Curve: curve,
X: new(big.Int).SetBytes(x),
Y: new(big.Int).SetBytes(y),
}
publicKeyBytes, err := x509.MarshalPKIXPublicKey(publicKey)
if err != nil {
return nil, fmt.Errorf("failed to marshal public key: %w", err)
}
publicKeyPEM := pem.EncodeToMemory(&pem.Block{
Type: "PUBLIC KEY",
Bytes: publicKeyBytes,
})
return publicKeyPEM, nil
}
+47 -72
View File
@@ -2,18 +2,16 @@ package traefikoidc
import ( import (
"crypto" "crypto"
"crypto/ecdsa"
"crypto/rsa" "crypto/rsa"
"crypto/sha256"
"crypto/x509" "crypto/x509"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"math/big"
"strings" "strings"
"sync"
"time" "time"
"golang.org/x/sync/errgroup"
) )
type JWT struct { type JWT struct {
@@ -75,69 +73,63 @@ func verifyExpiration(expiration float64) error {
return nil return nil
} }
func (t *TraefikOidc) verifySignatureConcurrently(token string, publicKeys map[string][]byte) error { func verifySignature(signedContent string, signature []byte, publicKeyPEM []byte, alg string) error {
parts := strings.Split(token, ".")
if len(parts) != 3 {
return fmt.Errorf("invalid token format")
}
signedContent := parts[0] + "." + parts[1]
signature, err := base64.RawURLEncoding.DecodeString(parts[2])
if err != nil {
return fmt.Errorf("failed to decode signature: %w", err)
}
var eg errgroup.Group
var mu sync.Mutex
var verificationSuccess bool
for _, publicKeyPEM := range publicKeys {
publicKeyPEM := publicKeyPEM // Create a new variable for the goroutine
eg.Go(func() error {
err := verifySignature(signedContent, signature, publicKeyPEM)
if err == nil {
mu.Lock()
verificationSuccess = true
mu.Unlock()
}
return nil // Always return nil to continue checking other keys
})
}
if err := eg.Wait(); err != nil {
return err
}
if !verificationSuccess {
return fmt.Errorf("signature verification failed for all keys")
}
return nil
}
func verifySignature(signedContent string, signature []byte, publicKeyPEM []byte) error {
block, _ := pem.Decode(publicKeyPEM) block, _ := pem.Decode(publicKeyPEM)
if block == nil { if block == nil {
return fmt.Errorf("failed to parse PEM block containing the public key") return fmt.Errorf("failed to parse PEM block containing the public key")
} }
pub, err := x509.ParsePKIXPublicKey(block.Bytes) pubKey, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil { if err != nil {
return fmt.Errorf("failed to parse public key: %w", err) return fmt.Errorf("failed to parse public key: %w", err)
} }
rsaPublicKey, ok := pub.(*rsa.PublicKey) var hashFunc crypto.Hash
if !ok {
return fmt.Errorf("not an RSA public key") switch alg {
case "RS256", "PS256", "ES256":
hashFunc = crypto.SHA256
case "RS384", "PS384", "ES384":
hashFunc = crypto.SHA384
case "RS512", "PS512", "ES512":
hashFunc = crypto.SHA512
default:
return fmt.Errorf("unsupported algorithm: %s", alg)
} }
hash := sha256.Sum256([]byte(signedContent)) h := hashFunc.New()
err = rsa.VerifyPKCS1v15(rsaPublicKey, crypto.SHA256, hash[:], signature) h.Write([]byte(signedContent))
if err != nil { hashed := h.Sum(nil)
return fmt.Errorf("invalid token signature: %w", err)
}
return nil switch pub := pubKey.(type) {
case *ecdsa.PublicKey:
if strings.HasPrefix(alg, "ES") {
// ECDSA signature handling
keyBytes := (pub.Params().BitSize + 7) / 8
if len(signature) != 2*keyBytes {
return fmt.Errorf("invalid signature length: expected %d bytes, got %d bytes", 2*keyBytes, len(signature))
}
r := new(big.Int).SetBytes(signature[:keyBytes])
s := new(big.Int).SetBytes(signature[keyBytes:])
if ecdsa.Verify(pub, hashed, r, s) {
return nil
}
return fmt.Errorf("invalid ECDSA signature")
}
return fmt.Errorf("algorithm %s is not compatible with ECDSA public key", alg)
case *rsa.PublicKey:
if strings.HasPrefix(alg, "RS") {
err := rsa.VerifyPKCS1v15(pub, hashFunc, hashed, signature)
if err != nil {
return fmt.Errorf("RSA signature verification failed: %w", err)
}
return nil
}
return fmt.Errorf("algorithm %s is not compatible with RSA public key", alg)
default:
return fmt.Errorf("unsupported public key type: %T", pub)
}
} }
func verifyIssuedAt(issuedAt float64) error { func verifyIssuedAt(issuedAt float64) error {
@@ -162,20 +154,3 @@ func decodeSegment(seg string) (map[string]interface{}, error) {
return result, nil return result, nil
} }
func (t *TraefikOidc) verifyAndCacheToken(token string) error {
return t.tokenVerifier.VerifyToken(token)
}
func (t *TraefikOidc) verifyJWTSignatureAndClaims(jwt *JWT, token string) error {
return t.jwtVerifier.VerifyJWTSignatureAndClaims(jwt, token)
}
func getPublicKeyPEM(jwks *JWKSet, kid string) ([]byte, error) {
for _, key := range jwks.Keys {
if key.Kid == kid {
return jwkToPEM(&key)
}
}
return nil, fmt.Errorf("unable to find matching public key")
}
+43 -10
View File
@@ -2,6 +2,7 @@ package traefikoidc
import ( import (
"context" "context"
"encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
@@ -98,6 +99,8 @@ func (t *TraefikOidc) VerifyToken(token string) error {
} }
func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error { func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error {
t.logger.Debugf("Verifying JWT. Header: %+v", jwt.Header)
jwks, err := t.jwkCache.GetJWKS(t.jwksURL, t.httpClient) jwks, err := t.jwkCache.GetJWKS(t.jwksURL, t.httpClient)
if err != nil { if err != nil {
return fmt.Errorf("failed to get JWKS: %w", err) return fmt.Errorf("failed to get JWKS: %w", err)
@@ -107,27 +110,57 @@ func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error
if !ok { if !ok {
return fmt.Errorf("missing key ID in token header") return fmt.Errorf("missing key ID in token header")
} }
t.logger.Debugf("Token kid: %s", kid)
publicKeys := make(map[string][]byte) alg, ok := jwt.Header["alg"].(string)
if !ok {
return fmt.Errorf("missing algorithm in token header")
}
t.logger.Debugf("Token alg: %s", alg)
var matchingKey *JWK
for _, key := range jwks.Keys { for _, key := range jwks.Keys {
if key.Kid == kid { if key.Kid == kid {
publicKeyPEM, err := jwkToPEM(&key) matchingKey = &key
if err != nil { break
return err
}
publicKeys[key.Kid] = publicKeyPEM
} }
} }
if len(publicKeys) == 0 { if matchingKey == nil {
return fmt.Errorf("no matching public keys found") return fmt.Errorf("no matching public key found for kid: %s", kid)
}
t.logger.Debugf("Matching key found. Type: %s, Algorithm: %s", matchingKey.Kty, matchingKey.Alg)
publicKeyPEM, err := jwkToPEM(matchingKey)
if err != nil {
return fmt.Errorf("failed to convert JWK to PEM: %w", err)
}
t.logger.Debugf("Public key PEM generated. Length: %d", len(publicKeyPEM))
parts := strings.Split(token, ".")
if len(parts) != 3 {
return fmt.Errorf("invalid token format")
} }
if err := t.verifySignatureConcurrently(token, publicKeys); err != nil { signedContent := parts[0] + "." + parts[1]
signature, err := base64.RawURLEncoding.DecodeString(parts[2])
if err != nil {
return fmt.Errorf("failed to decode signature: %w", err)
}
if err := verifySignature(signedContent, signature, publicKeyPEM, alg); err != nil {
t.logger.Errorf("Signature verification failed: %v", err)
return fmt.Errorf("signature verification failed: %w", err) return fmt.Errorf("signature verification failed: %w", err)
} }
t.logger.Debug("Signature verified successfully")
return jwt.Verify(t.issuerURL, t.clientID) // Verify standard claims
if err := jwt.Verify(t.issuerURL, t.clientID); err != nil {
return fmt.Errorf("standard claim verification failed: %w", err)
}
t.logger.Debug("Standard claims verified successfully")
return nil
} }
func New(ctx context.Context, next http.Handler, config *Config, name string) (http.Handler, error) { func New(ctx context.Context, next http.Handler, config *Config, name string) (http.Handler, error) {
+3 -2
View File
@@ -247,8 +247,9 @@ func (suite *TraefikOidcTestSuite) TestBuildAuthURL() {
func (suite *TraefikOidcTestSuite) TestJWKToPEM() { func (suite *TraefikOidcTestSuite) TestJWKToPEM() {
jwk := &JWK{ jwk := &JWK{
N: base64.RawURLEncoding.EncodeToString(big.NewInt(12345).Bytes()), Kty: "RSA", // Set the key type to RSA
E: base64.RawURLEncoding.EncodeToString(big.NewInt(65537).Bytes()), N: base64.RawURLEncoding.EncodeToString(big.NewInt(12345).Bytes()),
E: base64.RawURLEncoding.EncodeToString(big.NewInt(65537).Bytes()),
} }
pem, err := jwkToPEM(jwk) pem, err := jwkToPEM(jwk)
suite.Require().NoError(err) suite.Require().NoError(err)