mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-05 22:44:17 +00:00
Add support for different signing algorithms
This commit is contained in:
@@ -1,6 +1,8 @@
|
|||||||
package traefikoidc
|
package traefikoidc
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/elliptic"
|
||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
@@ -20,6 +22,9 @@ type JWK struct {
|
|||||||
N string `json:"n"`
|
N string `json:"n"`
|
||||||
E string `json:"e"`
|
E string `json:"e"`
|
||||||
Alg string `json:"alg"`
|
Alg string `json:"alg"`
|
||||||
|
Crv string `json:"crv"`
|
||||||
|
X string `json:"x"`
|
||||||
|
Y string `json:"y"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type JWKSet struct {
|
type JWKSet struct {
|
||||||
@@ -77,13 +82,6 @@ func fetchJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) {
|
|||||||
return &jwks, nil
|
return &jwks, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func verifyNonce(tokenNonce, expectedNonce string) error {
|
|
||||||
if tokenNonce != expectedNonce {
|
|
||||||
return fmt.Errorf("invalid nonce")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func verifyAudience(tokenAudience, expectedAudience string) error {
|
func verifyAudience(tokenAudience, expectedAudience string) error {
|
||||||
if tokenAudience != expectedAudience {
|
if tokenAudience != expectedAudience {
|
||||||
return fmt.Errorf("invalid audience")
|
return fmt.Errorf("invalid audience")
|
||||||
@@ -91,17 +89,6 @@ func verifyAudience(tokenAudience, expectedAudience string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func verifyTokenTimes(issuedAt, expiration int64, allowedClockSkew time.Duration) error {
|
|
||||||
now := time.Now().Unix()
|
|
||||||
if now < issuedAt-int64(allowedClockSkew.Seconds()) {
|
|
||||||
return fmt.Errorf("token used before issued")
|
|
||||||
}
|
|
||||||
if now > expiration+int64(allowedClockSkew.Seconds()) {
|
|
||||||
return fmt.Errorf("token is expired")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func verifyIssuer(tokenIssuer, expectedIssuer string) error {
|
func verifyIssuer(tokenIssuer, expectedIssuer string) error {
|
||||||
if tokenIssuer != expectedIssuer {
|
if tokenIssuer != expectedIssuer {
|
||||||
return fmt.Errorf("invalid issuer")
|
return fmt.Errorf("invalid issuer")
|
||||||
@@ -109,17 +96,18 @@ func verifyIssuer(tokenIssuer, expectedIssuer string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateClaims(claims map[string]interface{}) error {
|
func jwkToPEM(jwk *JWK) ([]byte, error) {
|
||||||
requiredClaims := []string{"sub", "iss", "aud", "exp", "iat"}
|
switch jwk.Kty {
|
||||||
for _, claim := range requiredClaims {
|
case "RSA":
|
||||||
if _, ok := claims[claim]; !ok {
|
return rsaJWKToPEM(jwk)
|
||||||
return fmt.Errorf("missing required claim: %s", claim)
|
case "EC":
|
||||||
}
|
return ecJWKToPEM(jwk)
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unsupported key type: %s", jwk.Kty)
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func jwkToPEM(jwk *JWK) ([]byte, error) {
|
func rsaJWKToPEM(jwk *JWK) ([]byte, error) {
|
||||||
n, err := base64.RawURLEncoding.DecodeString(jwk.N)
|
n, err := base64.RawURLEncoding.DecodeString(jwk.N)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to decode JWK 'n' parameter: %w", err)
|
return nil, fmt.Errorf("failed to decode JWK 'n' parameter: %w", err)
|
||||||
@@ -146,3 +134,45 @@ func jwkToPEM(jwk *JWK) ([]byte, error) {
|
|||||||
|
|
||||||
return publicKeyPEM, nil
|
return publicKeyPEM, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ecJWKToPEM(jwk *JWK) ([]byte, error) {
|
||||||
|
x, err := base64.RawURLEncoding.DecodeString(jwk.X)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to decode JWK 'x' parameter: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
y, err := base64.RawURLEncoding.DecodeString(jwk.Y)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to decode JWK 'y' parameter: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var curve elliptic.Curve
|
||||||
|
switch jwk.Crv {
|
||||||
|
case "P-256":
|
||||||
|
curve = elliptic.P256()
|
||||||
|
case "P-384":
|
||||||
|
curve = elliptic.P384()
|
||||||
|
case "P-521":
|
||||||
|
curve = elliptic.P521()
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unsupported elliptic curve: %s", jwk.Crv)
|
||||||
|
}
|
||||||
|
|
||||||
|
publicKey := &ecdsa.PublicKey{
|
||||||
|
Curve: curve,
|
||||||
|
X: new(big.Int).SetBytes(x),
|
||||||
|
Y: new(big.Int).SetBytes(y),
|
||||||
|
}
|
||||||
|
|
||||||
|
publicKeyBytes, err := x509.MarshalPKIXPublicKey(publicKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal public key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
publicKeyPEM := pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: "PUBLIC KEY",
|
||||||
|
Bytes: publicKeyBytes,
|
||||||
|
})
|
||||||
|
|
||||||
|
return publicKeyPEM, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,18 +2,16 @@ package traefikoidc
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto"
|
"crypto"
|
||||||
|
"crypto/ecdsa"
|
||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
"crypto/sha256"
|
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"math/big"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.org/x/sync/errgroup"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type JWT struct {
|
type JWT struct {
|
||||||
@@ -75,69 +73,63 @@ func verifyExpiration(expiration float64) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TraefikOidc) verifySignatureConcurrently(token string, publicKeys map[string][]byte) error {
|
func verifySignature(signedContent string, signature []byte, publicKeyPEM []byte, alg string) error {
|
||||||
parts := strings.Split(token, ".")
|
|
||||||
if len(parts) != 3 {
|
|
||||||
return fmt.Errorf("invalid token format")
|
|
||||||
}
|
|
||||||
|
|
||||||
signedContent := parts[0] + "." + parts[1]
|
|
||||||
signature, err := base64.RawURLEncoding.DecodeString(parts[2])
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to decode signature: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var eg errgroup.Group
|
|
||||||
var mu sync.Mutex
|
|
||||||
var verificationSuccess bool
|
|
||||||
|
|
||||||
for _, publicKeyPEM := range publicKeys {
|
|
||||||
publicKeyPEM := publicKeyPEM // Create a new variable for the goroutine
|
|
||||||
eg.Go(func() error {
|
|
||||||
err := verifySignature(signedContent, signature, publicKeyPEM)
|
|
||||||
if err == nil {
|
|
||||||
mu.Lock()
|
|
||||||
verificationSuccess = true
|
|
||||||
mu.Unlock()
|
|
||||||
}
|
|
||||||
return nil // Always return nil to continue checking other keys
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := eg.Wait(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !verificationSuccess {
|
|
||||||
return fmt.Errorf("signature verification failed for all keys")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func verifySignature(signedContent string, signature []byte, publicKeyPEM []byte) error {
|
|
||||||
block, _ := pem.Decode(publicKeyPEM)
|
block, _ := pem.Decode(publicKeyPEM)
|
||||||
if block == nil {
|
if block == nil {
|
||||||
return fmt.Errorf("failed to parse PEM block containing the public key")
|
return fmt.Errorf("failed to parse PEM block containing the public key")
|
||||||
}
|
}
|
||||||
|
|
||||||
pub, err := x509.ParsePKIXPublicKey(block.Bytes)
|
pubKey, err := x509.ParsePKIXPublicKey(block.Bytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to parse public key: %w", err)
|
return fmt.Errorf("failed to parse public key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
rsaPublicKey, ok := pub.(*rsa.PublicKey)
|
var hashFunc crypto.Hash
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("not an RSA public key")
|
switch alg {
|
||||||
|
case "RS256", "PS256", "ES256":
|
||||||
|
hashFunc = crypto.SHA256
|
||||||
|
case "RS384", "PS384", "ES384":
|
||||||
|
hashFunc = crypto.SHA384
|
||||||
|
case "RS512", "PS512", "ES512":
|
||||||
|
hashFunc = crypto.SHA512
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unsupported algorithm: %s", alg)
|
||||||
}
|
}
|
||||||
|
|
||||||
hash := sha256.Sum256([]byte(signedContent))
|
h := hashFunc.New()
|
||||||
err = rsa.VerifyPKCS1v15(rsaPublicKey, crypto.SHA256, hash[:], signature)
|
h.Write([]byte(signedContent))
|
||||||
if err != nil {
|
hashed := h.Sum(nil)
|
||||||
return fmt.Errorf("invalid token signature: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
switch pub := pubKey.(type) {
|
||||||
|
case *ecdsa.PublicKey:
|
||||||
|
if strings.HasPrefix(alg, "ES") {
|
||||||
|
// ECDSA signature handling
|
||||||
|
keyBytes := (pub.Params().BitSize + 7) / 8
|
||||||
|
if len(signature) != 2*keyBytes {
|
||||||
|
return fmt.Errorf("invalid signature length: expected %d bytes, got %d bytes", 2*keyBytes, len(signature))
|
||||||
|
}
|
||||||
|
r := new(big.Int).SetBytes(signature[:keyBytes])
|
||||||
|
s := new(big.Int).SetBytes(signature[keyBytes:])
|
||||||
|
|
||||||
|
if ecdsa.Verify(pub, hashed, r, s) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf("invalid ECDSA signature")
|
||||||
|
}
|
||||||
|
return fmt.Errorf("algorithm %s is not compatible with ECDSA public key", alg)
|
||||||
|
case *rsa.PublicKey:
|
||||||
|
if strings.HasPrefix(alg, "RS") {
|
||||||
|
err := rsa.VerifyPKCS1v15(pub, hashFunc, hashed, signature)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("RSA signature verification failed: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf("algorithm %s is not compatible with RSA public key", alg)
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unsupported public key type: %T", pub)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func verifyIssuedAt(issuedAt float64) error {
|
func verifyIssuedAt(issuedAt float64) error {
|
||||||
@@ -162,20 +154,3 @@ func decodeSegment(seg string) (map[string]interface{}, error) {
|
|||||||
|
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TraefikOidc) verifyAndCacheToken(token string) error {
|
|
||||||
return t.tokenVerifier.VerifyToken(token)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *TraefikOidc) verifyJWTSignatureAndClaims(jwt *JWT, token string) error {
|
|
||||||
return t.jwtVerifier.VerifyJWTSignatureAndClaims(jwt, token)
|
|
||||||
}
|
|
||||||
|
|
||||||
func getPublicKeyPEM(jwks *JWKSet, kid string) ([]byte, error) {
|
|
||||||
for _, key := range jwks.Keys {
|
|
||||||
if key.Kid == kid {
|
|
||||||
return jwkToPEM(&key)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("unable to find matching public key")
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package traefikoidc
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@@ -98,6 +99,8 @@ func (t *TraefikOidc) VerifyToken(token string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error {
|
func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error {
|
||||||
|
t.logger.Debugf("Verifying JWT. Header: %+v", jwt.Header)
|
||||||
|
|
||||||
jwks, err := t.jwkCache.GetJWKS(t.jwksURL, t.httpClient)
|
jwks, err := t.jwkCache.GetJWKS(t.jwksURL, t.httpClient)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get JWKS: %w", err)
|
return fmt.Errorf("failed to get JWKS: %w", err)
|
||||||
@@ -107,27 +110,57 @@ func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error
|
|||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("missing key ID in token header")
|
return fmt.Errorf("missing key ID in token header")
|
||||||
}
|
}
|
||||||
|
t.logger.Debugf("Token kid: %s", kid)
|
||||||
|
|
||||||
publicKeys := make(map[string][]byte)
|
alg, ok := jwt.Header["alg"].(string)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("missing algorithm in token header")
|
||||||
|
}
|
||||||
|
t.logger.Debugf("Token alg: %s", alg)
|
||||||
|
|
||||||
|
var matchingKey *JWK
|
||||||
for _, key := range jwks.Keys {
|
for _, key := range jwks.Keys {
|
||||||
if key.Kid == kid {
|
if key.Kid == kid {
|
||||||
publicKeyPEM, err := jwkToPEM(&key)
|
matchingKey = &key
|
||||||
if err != nil {
|
break
|
||||||
return err
|
|
||||||
}
|
|
||||||
publicKeys[key.Kid] = publicKeyPEM
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(publicKeys) == 0 {
|
if matchingKey == nil {
|
||||||
return fmt.Errorf("no matching public keys found")
|
return fmt.Errorf("no matching public key found for kid: %s", kid)
|
||||||
|
}
|
||||||
|
t.logger.Debugf("Matching key found. Type: %s, Algorithm: %s", matchingKey.Kty, matchingKey.Alg)
|
||||||
|
|
||||||
|
publicKeyPEM, err := jwkToPEM(matchingKey)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to convert JWK to PEM: %w", err)
|
||||||
|
}
|
||||||
|
t.logger.Debugf("Public key PEM generated. Length: %d", len(publicKeyPEM))
|
||||||
|
|
||||||
|
parts := strings.Split(token, ".")
|
||||||
|
if len(parts) != 3 {
|
||||||
|
return fmt.Errorf("invalid token format")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := t.verifySignatureConcurrently(token, publicKeys); err != nil {
|
signedContent := parts[0] + "." + parts[1]
|
||||||
|
signature, err := base64.RawURLEncoding.DecodeString(parts[2])
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to decode signature: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := verifySignature(signedContent, signature, publicKeyPEM, alg); err != nil {
|
||||||
|
t.logger.Errorf("Signature verification failed: %v", err)
|
||||||
return fmt.Errorf("signature verification failed: %w", err)
|
return fmt.Errorf("signature verification failed: %w", err)
|
||||||
}
|
}
|
||||||
|
t.logger.Debug("Signature verified successfully")
|
||||||
|
|
||||||
return jwt.Verify(t.issuerURL, t.clientID)
|
// Verify standard claims
|
||||||
|
if err := jwt.Verify(t.issuerURL, t.clientID); err != nil {
|
||||||
|
return fmt.Errorf("standard claim verification failed: %w", err)
|
||||||
|
}
|
||||||
|
t.logger.Debug("Standard claims verified successfully")
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(ctx context.Context, next http.Handler, config *Config, name string) (http.Handler, error) {
|
func New(ctx context.Context, next http.Handler, config *Config, name string) (http.Handler, error) {
|
||||||
|
|||||||
+3
-2
@@ -247,8 +247,9 @@ func (suite *TraefikOidcTestSuite) TestBuildAuthURL() {
|
|||||||
|
|
||||||
func (suite *TraefikOidcTestSuite) TestJWKToPEM() {
|
func (suite *TraefikOidcTestSuite) TestJWKToPEM() {
|
||||||
jwk := &JWK{
|
jwk := &JWK{
|
||||||
N: base64.RawURLEncoding.EncodeToString(big.NewInt(12345).Bytes()),
|
Kty: "RSA", // Set the key type to RSA
|
||||||
E: base64.RawURLEncoding.EncodeToString(big.NewInt(65537).Bytes()),
|
N: base64.RawURLEncoding.EncodeToString(big.NewInt(12345).Bytes()),
|
||||||
|
E: base64.RawURLEncoding.EncodeToString(big.NewInt(65537).Bytes()),
|
||||||
}
|
}
|
||||||
pem, err := jwkToPEM(jwk)
|
pem, err := jwkToPEM(jwk)
|
||||||
suite.Require().NoError(err)
|
suite.Require().NoError(err)
|
||||||
|
|||||||
Reference in New Issue
Block a user