mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-05 22:44:17 +00:00
Add support for different signing algorithms
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
@@ -20,6 +22,9 @@ type JWK struct {
|
||||
N string `json:"n"`
|
||||
E string `json:"e"`
|
||||
Alg string `json:"alg"`
|
||||
Crv string `json:"crv"`
|
||||
X string `json:"x"`
|
||||
Y string `json:"y"`
|
||||
}
|
||||
|
||||
type JWKSet struct {
|
||||
@@ -77,13 +82,6 @@ func fetchJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) {
|
||||
return &jwks, nil
|
||||
}
|
||||
|
||||
func verifyNonce(tokenNonce, expectedNonce string) error {
|
||||
if tokenNonce != expectedNonce {
|
||||
return fmt.Errorf("invalid nonce")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func verifyAudience(tokenAudience, expectedAudience string) error {
|
||||
if tokenAudience != expectedAudience {
|
||||
return fmt.Errorf("invalid audience")
|
||||
@@ -91,17 +89,6 @@ func verifyAudience(tokenAudience, expectedAudience string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func verifyTokenTimes(issuedAt, expiration int64, allowedClockSkew time.Duration) error {
|
||||
now := time.Now().Unix()
|
||||
if now < issuedAt-int64(allowedClockSkew.Seconds()) {
|
||||
return fmt.Errorf("token used before issued")
|
||||
}
|
||||
if now > expiration+int64(allowedClockSkew.Seconds()) {
|
||||
return fmt.Errorf("token is expired")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func verifyIssuer(tokenIssuer, expectedIssuer string) error {
|
||||
if tokenIssuer != expectedIssuer {
|
||||
return fmt.Errorf("invalid issuer")
|
||||
@@ -109,17 +96,18 @@ func verifyIssuer(tokenIssuer, expectedIssuer string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateClaims(claims map[string]interface{}) error {
|
||||
requiredClaims := []string{"sub", "iss", "aud", "exp", "iat"}
|
||||
for _, claim := range requiredClaims {
|
||||
if _, ok := claims[claim]; !ok {
|
||||
return fmt.Errorf("missing required claim: %s", claim)
|
||||
}
|
||||
func jwkToPEM(jwk *JWK) ([]byte, error) {
|
||||
switch jwk.Kty {
|
||||
case "RSA":
|
||||
return rsaJWKToPEM(jwk)
|
||||
case "EC":
|
||||
return ecJWKToPEM(jwk)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported key type: %s", jwk.Kty)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func jwkToPEM(jwk *JWK) ([]byte, error) {
|
||||
func rsaJWKToPEM(jwk *JWK) ([]byte, error) {
|
||||
n, err := base64.RawURLEncoding.DecodeString(jwk.N)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode JWK 'n' parameter: %w", err)
|
||||
@@ -146,3 +134,45 @@ func jwkToPEM(jwk *JWK) ([]byte, error) {
|
||||
|
||||
return publicKeyPEM, nil
|
||||
}
|
||||
|
||||
func ecJWKToPEM(jwk *JWK) ([]byte, error) {
|
||||
x, err := base64.RawURLEncoding.DecodeString(jwk.X)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode JWK 'x' parameter: %w", err)
|
||||
}
|
||||
|
||||
y, err := base64.RawURLEncoding.DecodeString(jwk.Y)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode JWK 'y' parameter: %w", err)
|
||||
}
|
||||
|
||||
var curve elliptic.Curve
|
||||
switch jwk.Crv {
|
||||
case "P-256":
|
||||
curve = elliptic.P256()
|
||||
case "P-384":
|
||||
curve = elliptic.P384()
|
||||
case "P-521":
|
||||
curve = elliptic.P521()
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported elliptic curve: %s", jwk.Crv)
|
||||
}
|
||||
|
||||
publicKey := &ecdsa.PublicKey{
|
||||
Curve: curve,
|
||||
X: new(big.Int).SetBytes(x),
|
||||
Y: new(big.Int).SetBytes(y),
|
||||
}
|
||||
|
||||
publicKeyBytes, err := x509.MarshalPKIXPublicKey(publicKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal public key: %w", err)
|
||||
}
|
||||
|
||||
publicKeyPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "PUBLIC KEY",
|
||||
Bytes: publicKeyBytes,
|
||||
})
|
||||
|
||||
return publicKeyPEM, nil
|
||||
}
|
||||
|
||||
@@ -2,18 +2,16 @@ package traefikoidc
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
type JWT struct {
|
||||
@@ -75,69 +73,63 @@ func verifyExpiration(expiration float64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *TraefikOidc) verifySignatureConcurrently(token string, publicKeys map[string][]byte) error {
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
return fmt.Errorf("invalid token format")
|
||||
}
|
||||
|
||||
signedContent := parts[0] + "." + parts[1]
|
||||
signature, err := base64.RawURLEncoding.DecodeString(parts[2])
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode signature: %w", err)
|
||||
}
|
||||
|
||||
var eg errgroup.Group
|
||||
var mu sync.Mutex
|
||||
var verificationSuccess bool
|
||||
|
||||
for _, publicKeyPEM := range publicKeys {
|
||||
publicKeyPEM := publicKeyPEM // Create a new variable for the goroutine
|
||||
eg.Go(func() error {
|
||||
err := verifySignature(signedContent, signature, publicKeyPEM)
|
||||
if err == nil {
|
||||
mu.Lock()
|
||||
verificationSuccess = true
|
||||
mu.Unlock()
|
||||
}
|
||||
return nil // Always return nil to continue checking other keys
|
||||
})
|
||||
}
|
||||
|
||||
if err := eg.Wait(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !verificationSuccess {
|
||||
return fmt.Errorf("signature verification failed for all keys")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func verifySignature(signedContent string, signature []byte, publicKeyPEM []byte) error {
|
||||
func verifySignature(signedContent string, signature []byte, publicKeyPEM []byte, alg string) error {
|
||||
block, _ := pem.Decode(publicKeyPEM)
|
||||
if block == nil {
|
||||
return fmt.Errorf("failed to parse PEM block containing the public key")
|
||||
}
|
||||
|
||||
pub, err := x509.ParsePKIXPublicKey(block.Bytes)
|
||||
pubKey, err := x509.ParsePKIXPublicKey(block.Bytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse public key: %w", err)
|
||||
}
|
||||
|
||||
rsaPublicKey, ok := pub.(*rsa.PublicKey)
|
||||
if !ok {
|
||||
return fmt.Errorf("not an RSA public key")
|
||||
var hashFunc crypto.Hash
|
||||
|
||||
switch alg {
|
||||
case "RS256", "PS256", "ES256":
|
||||
hashFunc = crypto.SHA256
|
||||
case "RS384", "PS384", "ES384":
|
||||
hashFunc = crypto.SHA384
|
||||
case "RS512", "PS512", "ES512":
|
||||
hashFunc = crypto.SHA512
|
||||
default:
|
||||
return fmt.Errorf("unsupported algorithm: %s", alg)
|
||||
}
|
||||
|
||||
hash := sha256.Sum256([]byte(signedContent))
|
||||
err = rsa.VerifyPKCS1v15(rsaPublicKey, crypto.SHA256, hash[:], signature)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid token signature: %w", err)
|
||||
}
|
||||
h := hashFunc.New()
|
||||
h.Write([]byte(signedContent))
|
||||
hashed := h.Sum(nil)
|
||||
|
||||
return nil
|
||||
switch pub := pubKey.(type) {
|
||||
case *ecdsa.PublicKey:
|
||||
if strings.HasPrefix(alg, "ES") {
|
||||
// ECDSA signature handling
|
||||
keyBytes := (pub.Params().BitSize + 7) / 8
|
||||
if len(signature) != 2*keyBytes {
|
||||
return fmt.Errorf("invalid signature length: expected %d bytes, got %d bytes", 2*keyBytes, len(signature))
|
||||
}
|
||||
r := new(big.Int).SetBytes(signature[:keyBytes])
|
||||
s := new(big.Int).SetBytes(signature[keyBytes:])
|
||||
|
||||
if ecdsa.Verify(pub, hashed, r, s) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("invalid ECDSA signature")
|
||||
}
|
||||
return fmt.Errorf("algorithm %s is not compatible with ECDSA public key", alg)
|
||||
case *rsa.PublicKey:
|
||||
if strings.HasPrefix(alg, "RS") {
|
||||
err := rsa.VerifyPKCS1v15(pub, hashFunc, hashed, signature)
|
||||
if err != nil {
|
||||
return fmt.Errorf("RSA signature verification failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("algorithm %s is not compatible with RSA public key", alg)
|
||||
default:
|
||||
return fmt.Errorf("unsupported public key type: %T", pub)
|
||||
}
|
||||
}
|
||||
|
||||
func verifyIssuedAt(issuedAt float64) error {
|
||||
@@ -162,20 +154,3 @@ func decodeSegment(seg string) (map[string]interface{}, error) {
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (t *TraefikOidc) verifyAndCacheToken(token string) error {
|
||||
return t.tokenVerifier.VerifyToken(token)
|
||||
}
|
||||
|
||||
func (t *TraefikOidc) verifyJWTSignatureAndClaims(jwt *JWT, token string) error {
|
||||
return t.jwtVerifier.VerifyJWTSignatureAndClaims(jwt, token)
|
||||
}
|
||||
|
||||
func getPublicKeyPEM(jwks *JWKSet, kid string) ([]byte, error) {
|
||||
for _, key := range jwks.Keys {
|
||||
if key.Kid == kid {
|
||||
return jwkToPEM(&key)
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("unable to find matching public key")
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -98,6 +99,8 @@ func (t *TraefikOidc) VerifyToken(token string) error {
|
||||
}
|
||||
|
||||
func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error {
|
||||
t.logger.Debugf("Verifying JWT. Header: %+v", jwt.Header)
|
||||
|
||||
jwks, err := t.jwkCache.GetJWKS(t.jwksURL, t.httpClient)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get JWKS: %w", err)
|
||||
@@ -107,27 +110,57 @@ func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error
|
||||
if !ok {
|
||||
return fmt.Errorf("missing key ID in token header")
|
||||
}
|
||||
t.logger.Debugf("Token kid: %s", kid)
|
||||
|
||||
publicKeys := make(map[string][]byte)
|
||||
alg, ok := jwt.Header["alg"].(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("missing algorithm in token header")
|
||||
}
|
||||
t.logger.Debugf("Token alg: %s", alg)
|
||||
|
||||
var matchingKey *JWK
|
||||
for _, key := range jwks.Keys {
|
||||
if key.Kid == kid {
|
||||
publicKeyPEM, err := jwkToPEM(&key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
publicKeys[key.Kid] = publicKeyPEM
|
||||
matchingKey = &key
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(publicKeys) == 0 {
|
||||
return fmt.Errorf("no matching public keys found")
|
||||
if matchingKey == nil {
|
||||
return fmt.Errorf("no matching public key found for kid: %s", kid)
|
||||
}
|
||||
t.logger.Debugf("Matching key found. Type: %s, Algorithm: %s", matchingKey.Kty, matchingKey.Alg)
|
||||
|
||||
publicKeyPEM, err := jwkToPEM(matchingKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to convert JWK to PEM: %w", err)
|
||||
}
|
||||
t.logger.Debugf("Public key PEM generated. Length: %d", len(publicKeyPEM))
|
||||
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
return fmt.Errorf("invalid token format")
|
||||
}
|
||||
|
||||
if err := t.verifySignatureConcurrently(token, publicKeys); err != nil {
|
||||
signedContent := parts[0] + "." + parts[1]
|
||||
signature, err := base64.RawURLEncoding.DecodeString(parts[2])
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode signature: %w", err)
|
||||
}
|
||||
|
||||
if err := verifySignature(signedContent, signature, publicKeyPEM, alg); err != nil {
|
||||
t.logger.Errorf("Signature verification failed: %v", err)
|
||||
return fmt.Errorf("signature verification failed: %w", err)
|
||||
}
|
||||
t.logger.Debug("Signature verified successfully")
|
||||
|
||||
return jwt.Verify(t.issuerURL, t.clientID)
|
||||
// Verify standard claims
|
||||
if err := jwt.Verify(t.issuerURL, t.clientID); err != nil {
|
||||
return fmt.Errorf("standard claim verification failed: %w", err)
|
||||
}
|
||||
t.logger.Debug("Standard claims verified successfully")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func New(ctx context.Context, next http.Handler, config *Config, name string) (http.Handler, error) {
|
||||
|
||||
+3
-2
@@ -247,8 +247,9 @@ func (suite *TraefikOidcTestSuite) TestBuildAuthURL() {
|
||||
|
||||
func (suite *TraefikOidcTestSuite) TestJWKToPEM() {
|
||||
jwk := &JWK{
|
||||
N: base64.RawURLEncoding.EncodeToString(big.NewInt(12345).Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString(big.NewInt(65537).Bytes()),
|
||||
Kty: "RSA", // Set the key type to RSA
|
||||
N: base64.RawURLEncoding.EncodeToString(big.NewInt(12345).Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString(big.NewInt(65537).Bytes()),
|
||||
}
|
||||
pem, err := jwkToPEM(jwk)
|
||||
suite.Require().NoError(err)
|
||||
|
||||
Reference in New Issue
Block a user