Add support for more algorithms.

This commit is contained in:
2024-10-07 16:01:07 +01:00
parent 345c0c4a11
commit dc4c4824cd
2 changed files with 56 additions and 41 deletions
+10 -6
View File
@@ -117,14 +117,18 @@ func verifyIssuer(tokenIssuer, expectedIssuer string) error {
}
func jwkToPEM(jwk *JWK) ([]byte, error) {
switch jwk.Kty {
case "RSA":
return rsaJWKToPEM(jwk)
case "EC":
return ecJWKToPEM(jwk)
default:
converter, ok := jwkConverters[jwk.Kty]
if !ok {
return nil, fmt.Errorf("unsupported key type: %s", jwk.Kty)
}
return converter(jwk)
}
type jwkToPEMConverter func(*JWK) ([]byte, error)
var jwkConverters = map[string]jwkToPEMConverter{
"RSA": rsaJWKToPEM,
"EC": ecJWKToPEM,
}
func rsaJWKToPEM(jwk *JWK) ([]byte, error) {
+46 -35
View File
@@ -114,52 +114,63 @@ func verifySignature(signedContent string, signature []byte, publicKeyPEM []byte
return fmt.Errorf("failed to parse public key: %w", err)
}
var hashFunc crypto.Hash
var hash crypto.Hash
var verifyFunc func(publicKey interface{}, hashed []byte, signature []byte, hash crypto.Hash) error
switch alg {
case "RS256", "PS256", "ES256":
hashFunc = crypto.SHA256
case "RS384", "PS384", "ES384":
hashFunc = crypto.SHA384
case "RS512", "PS512", "ES512":
hashFunc = crypto.SHA512
case "RS256", "RS384", "RS512":
hash = crypto.SHA256 // SHA384 and SHA512 are used for RS384 and RS512 respectively.
verifyFunc = rsaVerifyPKCS1v15
case "PS256", "PS384", "PS512":
hash = crypto.SHA256 // SHA384 and SHA512 are used for PS384 and PS512 respectively.
verifyFunc = rsaVerifyPSS
case "ES256", "ES384", "ES512":
hash = crypto.SHA256 // SHA384 and SHA512 are used for ES384 and ES512 respectively.
verifyFunc = ecdsaVerify
default:
return fmt.Errorf("unsupported algorithm: %s", alg)
}
h := hashFunc.New()
h := hash.New()
h.Write([]byte(signedContent))
hashed := h.Sum(nil)
switch pub := pubKey.(type) {
case *ecdsa.PublicKey:
if strings.HasPrefix(alg, "ES") {
// ECDSA signature handling
keyBytes := (pub.Params().BitSize + 7) / 8
if len(signature) != 2*keyBytes {
return fmt.Errorf("invalid signature length: expected %d bytes, got %d bytes", 2*keyBytes, len(signature))
}
r := new(big.Int).SetBytes(signature[:keyBytes])
s := new(big.Int).SetBytes(signature[keyBytes:])
return verifyFunc(pubKey, hashed, signature, hash)
}
if ecdsa.Verify(pub, hashed, r, s) {
return nil
}
return fmt.Errorf("invalid ECDSA signature")
}
return fmt.Errorf("algorithm %s is not compatible with ECDSA public key", alg)
case *rsa.PublicKey:
if strings.HasPrefix(alg, "RS") {
err := rsa.VerifyPKCS1v15(pub, hashFunc, hashed, signature)
if err != nil {
return fmt.Errorf("RSA signature verification failed: %w", err)
}
return nil
}
return fmt.Errorf("algorithm %s is not compatible with RSA public key", alg)
default:
return fmt.Errorf("unsupported public key type: %T", pub)
func rsaVerifyPKCS1v15(publicKey interface{}, hashed []byte, signature []byte, hash crypto.Hash) error {
pubKey, ok := publicKey.(*rsa.PublicKey)
if !ok {
return fmt.Errorf("invalid public key type for RSA: %T", publicKey)
}
return rsa.VerifyPKCS1v15(pubKey, hash, hashed, signature)
}
func rsaVerifyPSS(publicKey interface{}, hashed []byte, signature []byte, hash crypto.Hash) error {
pubKey, ok := publicKey.(*rsa.PublicKey)
if !ok {
return fmt.Errorf("invalid public key type for RSA: %T", publicKey)
}
opts := &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash}
return rsa.VerifyPSS(pubKey, crypto.SHA256, hashed, signature, opts)
}
func ecdsaVerify(publicKey interface{}, hashed []byte, signature []byte, hash crypto.Hash) error {
pubKey, ok := publicKey.(*ecdsa.PublicKey)
if !ok {
return fmt.Errorf("invalid public key type for ECDSA: %T", publicKey)
}
keyBytes := (pubKey.Params().BitSize + 7) / 8
if len(signature) != 2*keyBytes {
return fmt.Errorf("invalid signature length for ECDSA: expected %d bytes, got %d bytes", 2*keyBytes, len(signature))
}
r := new(big.Int).SetBytes(signature[:keyBytes])
s := new(big.Int).SetBytes(signature[keyBytes:])
if ecdsa.Verify(pubKey, hashed, r, s) {
return nil
}
return fmt.Errorf("invalid ECDSA signature")
}
func verifyIssuedAt(issuedAt float64) error {