Add support for different signing algorithms

This commit is contained in:
2024-10-02 22:04:40 +01:00
parent a7d42de0a4
commit 9ff6779caa
4 changed files with 149 additions and 110 deletions
+56 -26
View File
@@ -1,6 +1,8 @@
package traefikoidc
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rsa"
"crypto/x509"
"encoding/base64"
@@ -20,6 +22,9 @@ type JWK struct {
N string `json:"n"`
E string `json:"e"`
Alg string `json:"alg"`
Crv string `json:"crv"`
X string `json:"x"`
Y string `json:"y"`
}
type JWKSet struct {
@@ -77,13 +82,6 @@ func fetchJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) {
return &jwks, nil
}
func verifyNonce(tokenNonce, expectedNonce string) error {
if tokenNonce != expectedNonce {
return fmt.Errorf("invalid nonce")
}
return nil
}
func verifyAudience(tokenAudience, expectedAudience string) error {
if tokenAudience != expectedAudience {
return fmt.Errorf("invalid audience")
@@ -91,17 +89,6 @@ func verifyAudience(tokenAudience, expectedAudience string) error {
return nil
}
func verifyTokenTimes(issuedAt, expiration int64, allowedClockSkew time.Duration) error {
now := time.Now().Unix()
if now < issuedAt-int64(allowedClockSkew.Seconds()) {
return fmt.Errorf("token used before issued")
}
if now > expiration+int64(allowedClockSkew.Seconds()) {
return fmt.Errorf("token is expired")
}
return nil
}
func verifyIssuer(tokenIssuer, expectedIssuer string) error {
if tokenIssuer != expectedIssuer {
return fmt.Errorf("invalid issuer")
@@ -109,17 +96,18 @@ func verifyIssuer(tokenIssuer, expectedIssuer string) error {
return nil
}
func validateClaims(claims map[string]interface{}) error {
requiredClaims := []string{"sub", "iss", "aud", "exp", "iat"}
for _, claim := range requiredClaims {
if _, ok := claims[claim]; !ok {
return fmt.Errorf("missing required claim: %s", claim)
}
func jwkToPEM(jwk *JWK) ([]byte, error) {
switch jwk.Kty {
case "RSA":
return rsaJWKToPEM(jwk)
case "EC":
return ecJWKToPEM(jwk)
default:
return nil, fmt.Errorf("unsupported key type: %s", jwk.Kty)
}
return nil
}
func jwkToPEM(jwk *JWK) ([]byte, error) {
func rsaJWKToPEM(jwk *JWK) ([]byte, error) {
n, err := base64.RawURLEncoding.DecodeString(jwk.N)
if err != nil {
return nil, fmt.Errorf("failed to decode JWK 'n' parameter: %w", err)
@@ -146,3 +134,45 @@ func jwkToPEM(jwk *JWK) ([]byte, error) {
return publicKeyPEM, nil
}
func ecJWKToPEM(jwk *JWK) ([]byte, error) {
x, err := base64.RawURLEncoding.DecodeString(jwk.X)
if err != nil {
return nil, fmt.Errorf("failed to decode JWK 'x' parameter: %w", err)
}
y, err := base64.RawURLEncoding.DecodeString(jwk.Y)
if err != nil {
return nil, fmt.Errorf("failed to decode JWK 'y' parameter: %w", err)
}
var curve elliptic.Curve
switch jwk.Crv {
case "P-256":
curve = elliptic.P256()
case "P-384":
curve = elliptic.P384()
case "P-521":
curve = elliptic.P521()
default:
return nil, fmt.Errorf("unsupported elliptic curve: %s", jwk.Crv)
}
publicKey := &ecdsa.PublicKey{
Curve: curve,
X: new(big.Int).SetBytes(x),
Y: new(big.Int).SetBytes(y),
}
publicKeyBytes, err := x509.MarshalPKIXPublicKey(publicKey)
if err != nil {
return nil, fmt.Errorf("failed to marshal public key: %w", err)
}
publicKeyPEM := pem.EncodeToMemory(&pem.Block{
Type: "PUBLIC KEY",
Bytes: publicKeyBytes,
})
return publicKeyPEM, nil
}
+47 -72
View File
@@ -2,18 +2,16 @@ package traefikoidc
import (
"crypto"
"crypto/ecdsa"
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"fmt"
"math/big"
"strings"
"sync"
"time"
"golang.org/x/sync/errgroup"
)
type JWT struct {
@@ -75,69 +73,63 @@ func verifyExpiration(expiration float64) error {
return nil
}
func (t *TraefikOidc) verifySignatureConcurrently(token string, publicKeys map[string][]byte) error {
parts := strings.Split(token, ".")
if len(parts) != 3 {
return fmt.Errorf("invalid token format")
}
signedContent := parts[0] + "." + parts[1]
signature, err := base64.RawURLEncoding.DecodeString(parts[2])
if err != nil {
return fmt.Errorf("failed to decode signature: %w", err)
}
var eg errgroup.Group
var mu sync.Mutex
var verificationSuccess bool
for _, publicKeyPEM := range publicKeys {
publicKeyPEM := publicKeyPEM // Create a new variable for the goroutine
eg.Go(func() error {
err := verifySignature(signedContent, signature, publicKeyPEM)
if err == nil {
mu.Lock()
verificationSuccess = true
mu.Unlock()
}
return nil // Always return nil to continue checking other keys
})
}
if err := eg.Wait(); err != nil {
return err
}
if !verificationSuccess {
return fmt.Errorf("signature verification failed for all keys")
}
return nil
}
func verifySignature(signedContent string, signature []byte, publicKeyPEM []byte) error {
func verifySignature(signedContent string, signature []byte, publicKeyPEM []byte, alg string) error {
block, _ := pem.Decode(publicKeyPEM)
if block == nil {
return fmt.Errorf("failed to parse PEM block containing the public key")
}
pub, err := x509.ParsePKIXPublicKey(block.Bytes)
pubKey, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return fmt.Errorf("failed to parse public key: %w", err)
}
rsaPublicKey, ok := pub.(*rsa.PublicKey)
if !ok {
return fmt.Errorf("not an RSA public key")
var hashFunc crypto.Hash
switch alg {
case "RS256", "PS256", "ES256":
hashFunc = crypto.SHA256
case "RS384", "PS384", "ES384":
hashFunc = crypto.SHA384
case "RS512", "PS512", "ES512":
hashFunc = crypto.SHA512
default:
return fmt.Errorf("unsupported algorithm: %s", alg)
}
hash := sha256.Sum256([]byte(signedContent))
err = rsa.VerifyPKCS1v15(rsaPublicKey, crypto.SHA256, hash[:], signature)
if err != nil {
return fmt.Errorf("invalid token signature: %w", err)
}
h := hashFunc.New()
h.Write([]byte(signedContent))
hashed := h.Sum(nil)
return nil
switch pub := pubKey.(type) {
case *ecdsa.PublicKey:
if strings.HasPrefix(alg, "ES") {
// ECDSA signature handling
keyBytes := (pub.Params().BitSize + 7) / 8
if len(signature) != 2*keyBytes {
return fmt.Errorf("invalid signature length: expected %d bytes, got %d bytes", 2*keyBytes, len(signature))
}
r := new(big.Int).SetBytes(signature[:keyBytes])
s := new(big.Int).SetBytes(signature[keyBytes:])
if ecdsa.Verify(pub, hashed, r, s) {
return nil
}
return fmt.Errorf("invalid ECDSA signature")
}
return fmt.Errorf("algorithm %s is not compatible with ECDSA public key", alg)
case *rsa.PublicKey:
if strings.HasPrefix(alg, "RS") {
err := rsa.VerifyPKCS1v15(pub, hashFunc, hashed, signature)
if err != nil {
return fmt.Errorf("RSA signature verification failed: %w", err)
}
return nil
}
return fmt.Errorf("algorithm %s is not compatible with RSA public key", alg)
default:
return fmt.Errorf("unsupported public key type: %T", pub)
}
}
func verifyIssuedAt(issuedAt float64) error {
@@ -162,20 +154,3 @@ func decodeSegment(seg string) (map[string]interface{}, error) {
return result, nil
}
func (t *TraefikOidc) verifyAndCacheToken(token string) error {
return t.tokenVerifier.VerifyToken(token)
}
func (t *TraefikOidc) verifyJWTSignatureAndClaims(jwt *JWT, token string) error {
return t.jwtVerifier.VerifyJWTSignatureAndClaims(jwt, token)
}
func getPublicKeyPEM(jwks *JWKSet, kid string) ([]byte, error) {
for _, key := range jwks.Keys {
if key.Kid == kid {
return jwkToPEM(&key)
}
}
return nil, fmt.Errorf("unable to find matching public key")
}
+43 -10
View File
@@ -2,6 +2,7 @@ package traefikoidc
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
@@ -98,6 +99,8 @@ func (t *TraefikOidc) VerifyToken(token string) error {
}
func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error {
t.logger.Debugf("Verifying JWT. Header: %+v", jwt.Header)
jwks, err := t.jwkCache.GetJWKS(t.jwksURL, t.httpClient)
if err != nil {
return fmt.Errorf("failed to get JWKS: %w", err)
@@ -107,27 +110,57 @@ func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error
if !ok {
return fmt.Errorf("missing key ID in token header")
}
t.logger.Debugf("Token kid: %s", kid)
publicKeys := make(map[string][]byte)
alg, ok := jwt.Header["alg"].(string)
if !ok {
return fmt.Errorf("missing algorithm in token header")
}
t.logger.Debugf("Token alg: %s", alg)
var matchingKey *JWK
for _, key := range jwks.Keys {
if key.Kid == kid {
publicKeyPEM, err := jwkToPEM(&key)
if err != nil {
return err
}
publicKeys[key.Kid] = publicKeyPEM
matchingKey = &key
break
}
}
if len(publicKeys) == 0 {
return fmt.Errorf("no matching public keys found")
if matchingKey == nil {
return fmt.Errorf("no matching public key found for kid: %s", kid)
}
t.logger.Debugf("Matching key found. Type: %s, Algorithm: %s", matchingKey.Kty, matchingKey.Alg)
publicKeyPEM, err := jwkToPEM(matchingKey)
if err != nil {
return fmt.Errorf("failed to convert JWK to PEM: %w", err)
}
t.logger.Debugf("Public key PEM generated. Length: %d", len(publicKeyPEM))
parts := strings.Split(token, ".")
if len(parts) != 3 {
return fmt.Errorf("invalid token format")
}
if err := t.verifySignatureConcurrently(token, publicKeys); err != nil {
signedContent := parts[0] + "." + parts[1]
signature, err := base64.RawURLEncoding.DecodeString(parts[2])
if err != nil {
return fmt.Errorf("failed to decode signature: %w", err)
}
if err := verifySignature(signedContent, signature, publicKeyPEM, alg); err != nil {
t.logger.Errorf("Signature verification failed: %v", err)
return fmt.Errorf("signature verification failed: %w", err)
}
t.logger.Debug("Signature verified successfully")
return jwt.Verify(t.issuerURL, t.clientID)
// Verify standard claims
if err := jwt.Verify(t.issuerURL, t.clientID); err != nil {
return fmt.Errorf("standard claim verification failed: %w", err)
}
t.logger.Debug("Standard claims verified successfully")
return nil
}
func New(ctx context.Context, next http.Handler, config *Config, name string) (http.Handler, error) {
+3 -2
View File
@@ -247,8 +247,9 @@ func (suite *TraefikOidcTestSuite) TestBuildAuthURL() {
func (suite *TraefikOidcTestSuite) TestJWKToPEM() {
jwk := &JWK{
N: base64.RawURLEncoding.EncodeToString(big.NewInt(12345).Bytes()),
E: base64.RawURLEncoding.EncodeToString(big.NewInt(65537).Bytes()),
Kty: "RSA", // Set the key type to RSA
N: base64.RawURLEncoding.EncodeToString(big.NewInt(12345).Bytes()),
E: base64.RawURLEncoding.EncodeToString(big.NewInt(65537).Bytes()),
}
pem, err := jwkToPEM(jwk)
suite.Require().NoError(err)