diff --git a/jwt.go b/jwt.go index a04dfef..feae914 100644 --- a/jwt.go +++ b/jwt.go @@ -4,7 +4,7 @@ import ( "crypto" "crypto/ecdsa" "crypto/rsa" - "crypto/sha256" + "math/big" "strings" "crypto/x509" @@ -157,37 +157,82 @@ func verifyIssuedAt(issuedAt float64) error { return nil } -// verifySignature verifies the token signature +// verifySignature verifies the token signature using the provided public key and algorithm func verifySignature(tokenString string, publicKeyPEM []byte, alg string) error { + // Split the token into its three parts parts := strings.Split(tokenString, ".") + if len(parts) != 3 { + return fmt.Errorf("invalid token format") + } signedContent := parts[0] + "." + parts[1] + + // Decode the signature from the token signature, err := base64.RawURLEncoding.DecodeString(parts[2]) if err != nil { return fmt.Errorf("failed to decode signature: %w", err) } + // Decode the PEM-encoded public key block, _ := pem.Decode(publicKeyPEM) if block == nil { return fmt.Errorf("failed to parse PEM block containing the public key") } + // Parse the public key pubKey, err := x509.ParsePKIXPublicKey(block.Bytes) if err != nil { return fmt.Errorf("failed to parse public key: %w", err) } - h := sha256.New() + // Determine the hash function to use based on the algorithm + var hashFunc crypto.Hash + + switch alg { + case "RS256", "PS256", "ES256": + hashFunc = crypto.SHA256 + case "RS384", "PS384", "ES384": + hashFunc = crypto.SHA384 + case "RS512", "PS512", "ES512": + hashFunc = crypto.SHA512 + default: + return fmt.Errorf("unsupported algorithm: %s", alg) + } + + // Hash the signed content + h := hashFunc.New() h.Write([]byte(signedContent)) hashed := h.Sum(nil) + // Verify the signature based on the key type and algorithm switch pubKey := pubKey.(type) { case *rsa.PublicKey: - return rsa.VerifyPKCS1v15(pubKey, crypto.SHA256, hashed, signature) - case *ecdsa.PublicKey: - if !ecdsa.VerifyASN1(pubKey, hashed, signature) { - return fmt.Errorf("invalid ECDSA signature") + if strings.HasPrefix(alg, "RS") { + // RSA PKCS#1 v1.5 signature + return rsa.VerifyPKCS1v15(pubKey, hashFunc, hashed, signature) + } else if strings.HasPrefix(alg, "PS") { + // RSA PSS signature + return rsa.VerifyPSS(pubKey, hashFunc, hashed, signature, nil) + } else { + return fmt.Errorf("unexpected key type for algorithm %s", alg) + } + case *ecdsa.PublicKey: + if strings.HasPrefix(alg, "ES") { + // ECDSA signature + var r, s big.Int + sigLen := len(signature) + if sigLen%2 != 0 { + return fmt.Errorf("invalid ECDSA signature length") + } + r.SetBytes(signature[:sigLen/2]) + s.SetBytes(signature[sigLen/2:]) + if ecdsa.Verify(pubKey, hashed, &r, &s) { + return nil + } else { + return fmt.Errorf("invalid ECDSA signature") + } + } else { + return fmt.Errorf("unexpected key type for algorithm %s", alg) } - return nil default: return fmt.Errorf("unsupported public key type: %T", pubKey) }