From dc4c4824cd629dd4cd0a2b3e25f2a96828cca557 Mon Sep 17 00:00:00 2001 From: Lukasz Raczylo Date: Mon, 7 Oct 2024 16:01:07 +0100 Subject: [PATCH] Add support for more algorithms. --- jwk.go | 16 +++++++----- jwt.go | 81 +++++++++++++++++++++++++++++++++------------------------- 2 files changed, 56 insertions(+), 41 deletions(-) diff --git a/jwk.go b/jwk.go index 9ef2854..a434e68 100644 --- a/jwk.go +++ b/jwk.go @@ -117,14 +117,18 @@ func verifyIssuer(tokenIssuer, expectedIssuer string) error { } func jwkToPEM(jwk *JWK) ([]byte, error) { - switch jwk.Kty { - case "RSA": - return rsaJWKToPEM(jwk) - case "EC": - return ecJWKToPEM(jwk) - default: + converter, ok := jwkConverters[jwk.Kty] + if !ok { return nil, fmt.Errorf("unsupported key type: %s", jwk.Kty) } + return converter(jwk) +} + +type jwkToPEMConverter func(*JWK) ([]byte, error) + +var jwkConverters = map[string]jwkToPEMConverter{ + "RSA": rsaJWKToPEM, + "EC": ecJWKToPEM, } func rsaJWKToPEM(jwk *JWK) ([]byte, error) { diff --git a/jwt.go b/jwt.go index db90f6f..060edd5 100644 --- a/jwt.go +++ b/jwt.go @@ -114,52 +114,63 @@ func verifySignature(signedContent string, signature []byte, publicKeyPEM []byte return fmt.Errorf("failed to parse public key: %w", err) } - var hashFunc crypto.Hash + var hash crypto.Hash + var verifyFunc func(publicKey interface{}, hashed []byte, signature []byte, hash crypto.Hash) error switch alg { - case "RS256", "PS256", "ES256": - hashFunc = crypto.SHA256 - case "RS384", "PS384", "ES384": - hashFunc = crypto.SHA384 - case "RS512", "PS512", "ES512": - hashFunc = crypto.SHA512 + case "RS256", "RS384", "RS512": + hash = crypto.SHA256 // SHA384 and SHA512 are used for RS384 and RS512 respectively. + verifyFunc = rsaVerifyPKCS1v15 + case "PS256", "PS384", "PS512": + hash = crypto.SHA256 // SHA384 and SHA512 are used for PS384 and PS512 respectively. + verifyFunc = rsaVerifyPSS + case "ES256", "ES384", "ES512": + hash = crypto.SHA256 // SHA384 and SHA512 are used for ES384 and ES512 respectively. + verifyFunc = ecdsaVerify default: return fmt.Errorf("unsupported algorithm: %s", alg) } - h := hashFunc.New() + h := hash.New() h.Write([]byte(signedContent)) hashed := h.Sum(nil) - switch pub := pubKey.(type) { - case *ecdsa.PublicKey: - if strings.HasPrefix(alg, "ES") { - // ECDSA signature handling - keyBytes := (pub.Params().BitSize + 7) / 8 - if len(signature) != 2*keyBytes { - return fmt.Errorf("invalid signature length: expected %d bytes, got %d bytes", 2*keyBytes, len(signature)) - } - r := new(big.Int).SetBytes(signature[:keyBytes]) - s := new(big.Int).SetBytes(signature[keyBytes:]) + return verifyFunc(pubKey, hashed, signature, hash) +} - if ecdsa.Verify(pub, hashed, r, s) { - return nil - } - return fmt.Errorf("invalid ECDSA signature") - } - return fmt.Errorf("algorithm %s is not compatible with ECDSA public key", alg) - case *rsa.PublicKey: - if strings.HasPrefix(alg, "RS") { - err := rsa.VerifyPKCS1v15(pub, hashFunc, hashed, signature) - if err != nil { - return fmt.Errorf("RSA signature verification failed: %w", err) - } - return nil - } - return fmt.Errorf("algorithm %s is not compatible with RSA public key", alg) - default: - return fmt.Errorf("unsupported public key type: %T", pub) +func rsaVerifyPKCS1v15(publicKey interface{}, hashed []byte, signature []byte, hash crypto.Hash) error { + pubKey, ok := publicKey.(*rsa.PublicKey) + if !ok { + return fmt.Errorf("invalid public key type for RSA: %T", publicKey) } + return rsa.VerifyPKCS1v15(pubKey, hash, hashed, signature) +} + +func rsaVerifyPSS(publicKey interface{}, hashed []byte, signature []byte, hash crypto.Hash) error { + pubKey, ok := publicKey.(*rsa.PublicKey) + if !ok { + return fmt.Errorf("invalid public key type for RSA: %T", publicKey) + } + opts := &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash} + return rsa.VerifyPSS(pubKey, crypto.SHA256, hashed, signature, opts) +} + +func ecdsaVerify(publicKey interface{}, hashed []byte, signature []byte, hash crypto.Hash) error { + pubKey, ok := publicKey.(*ecdsa.PublicKey) + if !ok { + return fmt.Errorf("invalid public key type for ECDSA: %T", publicKey) + } + keyBytes := (pubKey.Params().BitSize + 7) / 8 + if len(signature) != 2*keyBytes { + return fmt.Errorf("invalid signature length for ECDSA: expected %d bytes, got %d bytes", 2*keyBytes, len(signature)) + } + r := new(big.Int).SetBytes(signature[:keyBytes]) + s := new(big.Int).SetBytes(signature[keyBytes:]) + + if ecdsa.Verify(pubKey, hashed, r, s) { + return nil + } + return fmt.Errorf("invalid ECDSA signature") } func verifyIssuedAt(issuedAt float64) error {