mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-05 22:44:17 +00:00
9d52f1b018
- [x] Reorganize golangci-lint configuration with documented disable reasons - [x] Simplify errcheck and revive linter rules with targeted exclusions - [x] Pre-compile regex patterns in input_validation.go for performance - [x] Fix type assertions in memory_shard.go and resp.go with safety checks - [x] Replace string comparison with EqualFold for case-insensitive matching - [x] Fix loop variable captures in jwk.go and logout.go - [x] Change high goroutine log level from Info to Debug in autocleanup.go - [x] Replace deprecated "cancelled" spelling with "canceled" throughout - [x] Add nolint annotations for intentional unused parameters - [x] Improve comment formatting for deprecated functions - [x] Fix comment spelling: "marshalling" → "marshaling" - [x] Refactor provider warnings formatting in internal/providers/warnings.go - [x] Simplify metrics summary building in internal/recovery/metrics.go - [x] Pre-allocate slice in error_recovery.go GetDegradedServices - [x] Refactor context cancellation checks in redis.go
592 lines
18 KiB
Go
592 lines
18 KiB
Go
package traefikoidc
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto"
|
|
"crypto/ecdsa"
|
|
"crypto/rsa"
|
|
"crypto/x509"
|
|
"encoding/base64"
|
|
"encoding/pem"
|
|
"fmt"
|
|
"math/big"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/lukaszraczylo/traefikoidc/internal/pool"
|
|
)
|
|
|
|
// Replay attack protection cache using sharded design for reduced lock contention.
|
|
// This cache tracks JWT IDs (jti claims) to prevent token reuse attacks.
|
|
// Under high load (500+ req/sec), the sharded design reduces contention significantly.
|
|
var (
|
|
// replayCacheMu protects access to the replay cache instance (only used for initialization)
|
|
replayCacheMu sync.RWMutex
|
|
// replayCache stores JWT IDs with expiration to prevent replay attacks (legacy interface)
|
|
replayCache CacheInterface
|
|
// shardedReplayCache is the new high-performance sharded cache for replay detection
|
|
shardedReplayCache *ShardedCache
|
|
// replayCacheOnce ensures the replay cache is initialized only once
|
|
replayCacheOnce sync.Once
|
|
// replayCacheCleanupWG waits for cleanup goroutine to finish
|
|
replayCacheCleanupWG sync.WaitGroup
|
|
// replayCacheCancel cancels the cleanup context
|
|
replayCacheCancel context.CancelFunc
|
|
// replayCacheCleanupMu protects cleanup operations
|
|
replayCacheCleanupMu sync.Mutex
|
|
)
|
|
|
|
// initReplayCache initializes the JWT replay protection cache with bounded size.
|
|
// Uses a sharded cache design with 64 shards for reduced lock contention under high load.
|
|
// The cache is bounded to 10,000 entries to prevent unbounded memory growth.
|
|
// This function uses sync.Once to ensure thread-safe single initialization.
|
|
func initReplayCache() {
|
|
replayCacheOnce.Do(func() {
|
|
// Hold mutex during initialization to synchronize with cleanup goroutine
|
|
replayCacheMu.Lock()
|
|
defer replayCacheMu.Unlock()
|
|
|
|
// Create sharded cache with 64 shards for reduced contention
|
|
// Under 500 req/sec, this reduces lock contention by ~64x compared to single mutex
|
|
shardedReplayCache = NewShardedCache(64, 10000)
|
|
|
|
// Also initialize legacy cache for backward compatibility
|
|
replayCache = NewCache()
|
|
replayCache.SetMaxSize(10000)
|
|
})
|
|
}
|
|
|
|
// cleanupReplayCache performs graceful shutdown of the replay cache system.
|
|
// It cancels the cleanup context, waits for background goroutines to finish,
|
|
// and properly closes the cache to ensure proper cleanup during shutdown.
|
|
func cleanupReplayCache() {
|
|
replayCacheCleanupMu.Lock()
|
|
shouldWait := replayCacheCancel != nil
|
|
if replayCacheCancel != nil {
|
|
replayCacheCancel()
|
|
replayCacheCancel = nil
|
|
}
|
|
replayCacheCleanupMu.Unlock()
|
|
|
|
// Only wait if there was a cleanup routine running
|
|
if shouldWait {
|
|
replayCacheCleanupWG.Wait()
|
|
}
|
|
|
|
replayCacheMu.Lock()
|
|
defer replayCacheMu.Unlock()
|
|
|
|
// Clear sharded cache
|
|
if shardedReplayCache != nil {
|
|
shardedReplayCache.Clear()
|
|
shardedReplayCache = nil
|
|
}
|
|
|
|
// Clear legacy cache
|
|
if replayCache != nil {
|
|
replayCache.Close()
|
|
replayCache = nil
|
|
}
|
|
|
|
replayCacheOnce = sync.Once{}
|
|
}
|
|
|
|
// getReplayCacheStats returns statistics about the replay cache state.
|
|
// Returns:
|
|
// - size: Current number of entries in the cache
|
|
// - maxSize: Maximum allowed entries (10,000)
|
|
func getReplayCacheStats() (size int, maxSize int) {
|
|
// Use sharded cache if available (no mutex needed due to internal sharding)
|
|
if shardedReplayCache != nil {
|
|
return shardedReplayCache.Size(), 10000
|
|
}
|
|
|
|
// Fall back to legacy cache
|
|
replayCacheMu.RLock()
|
|
defer replayCacheMu.RUnlock()
|
|
|
|
if replayCache == nil {
|
|
return 0, 10000
|
|
}
|
|
|
|
return 0, 10000
|
|
}
|
|
|
|
// startReplayCacheCleanup starts a background goroutine for periodic cache maintenance.
|
|
// The goroutine runs every 5 minutes to clean expired entries and log cache statistics.
|
|
// Uses the global task registry with circuit breaker pattern to prevent duplicate tasks.
|
|
// Parameters:
|
|
// - ctx: Parent context for cancellation
|
|
// - logger: Logger for debug output (can be nil)
|
|
func startReplayCacheCleanup(_ context.Context, logger *Logger) {
|
|
registry := GetGlobalTaskRegistry()
|
|
|
|
// Define the cleanup task function
|
|
cleanupFunc := func() {
|
|
// Use mutex to safely access cache pointers - this prevents race with initReplayCache
|
|
replayCacheMu.RLock()
|
|
shardedCache := shardedReplayCache
|
|
legacyCache := replayCache
|
|
replayCacheMu.RUnlock()
|
|
|
|
// Only proceed if caches have been initialized
|
|
if shardedCache == nil && legacyCache == nil {
|
|
return
|
|
}
|
|
|
|
size, maxSize := getReplayCacheStats()
|
|
if logger != nil {
|
|
logger.Debugf("Replay cache stats: size=%d, maxSize=%d", size, maxSize)
|
|
}
|
|
|
|
// Clean up sharded cache
|
|
if shardedCache != nil {
|
|
shardedCache.Cleanup()
|
|
}
|
|
|
|
// Also clean up legacy cache for backward compatibility
|
|
if legacyCache != nil {
|
|
legacyCache.Cleanup()
|
|
}
|
|
}
|
|
|
|
// Create or get singleton cleanup task
|
|
task, err := registry.CreateSingletonTask(
|
|
"replay-cache-cleanup",
|
|
5*time.Minute,
|
|
cleanupFunc,
|
|
logger,
|
|
&replayCacheCleanupWG,
|
|
)
|
|
|
|
if err != nil {
|
|
if logger != nil {
|
|
logger.Debugf("Replay cache cleanup task already exists or circuit breaker limit reached: %v (this is expected with multiple instances)", err)
|
|
}
|
|
return
|
|
}
|
|
|
|
// Start the task
|
|
task.Start()
|
|
|
|
if logger != nil {
|
|
logger.Debug("Started replay cache cleanup task with circuit breaker protection")
|
|
}
|
|
}
|
|
|
|
// ClockSkewToleranceFuture defines the maximum allowable clock skew for future time validation.
|
|
// Tokens are considered valid for an additional 2 minutes past their expiration time.
|
|
var ClockSkewToleranceFuture = 2 * time.Minute
|
|
|
|
// ClockSkewTolerancePast defines the maximum allowable clock skew for past time validation.
|
|
// Tokens are considered valid if issued up to 10 seconds in the future.
|
|
var ClockSkewTolerancePast = 10 * time.Second
|
|
|
|
// ClockSkewTolerance is an alias for ClockSkewToleranceFuture for backward compatibility.
|
|
var ClockSkewTolerance = ClockSkewToleranceFuture
|
|
|
|
// JWT represents a parsed JSON Web Token with its constituent parts.
|
|
// It provides a structured representation of JWT components
|
|
// for validation and processing within the OIDC middleware.
|
|
type JWT struct {
|
|
// Header contains the JWT header claims (alg, typ, kid, etc.)
|
|
Header map[string]interface{}
|
|
// Claims contains the JWT payload claims (iss, sub, aud, exp, etc.)
|
|
Claims map[string]interface{}
|
|
// Token is the original JWT token string
|
|
Token string
|
|
// Signature contains the decoded JWT signature bytes
|
|
Signature []byte
|
|
}
|
|
|
|
// parseJWT parses a JWT token string into its constituent parts.
|
|
// It decodes the base64url-encoded header, claims, and signature components
|
|
// and unmarshals the JSON data into structured maps. Uses memory pools
|
|
// for efficient memory allocation during parsing.
|
|
// Parameters:
|
|
// - tokenString: The JWT token string to parse
|
|
//
|
|
// Returns:
|
|
// - *JWT: Parsed JWT structure with header, claims, and signature
|
|
// - An error if the token format is invalid or decoding/unmarshaling fails
|
|
func parseJWT(tokenString string) (*JWT, error) {
|
|
parts := strings.Split(tokenString, ".")
|
|
if len(parts) != 3 {
|
|
return nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
|
|
}
|
|
|
|
pm := pool.Get()
|
|
jwtBuf := pm.GetJWTBuffer()
|
|
defer pm.PutJWTBuffer(jwtBuf)
|
|
|
|
jwt := &JWT{
|
|
Token: tokenString,
|
|
}
|
|
|
|
headerLen := base64.RawURLEncoding.DecodedLen(len(parts[0]))
|
|
if headerLen > cap(jwtBuf.Header) {
|
|
jwtBuf.Header = make([]byte, headerLen)
|
|
} else {
|
|
jwtBuf.Header = jwtBuf.Header[:headerLen]
|
|
}
|
|
|
|
n, err := base64.RawURLEncoding.Decode(jwtBuf.Header, []byte(parts[0]))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid JWT format: failed to decode header: %v", err)
|
|
}
|
|
headerBytes := jwtBuf.Header[:n]
|
|
|
|
decoder := pm.GetJSONDecoder(bytes.NewReader(headerBytes))
|
|
defer pm.PutJSONDecoder(decoder)
|
|
if err := decoder.Decode(&jwt.Header); err != nil {
|
|
return nil, fmt.Errorf("invalid JWT format: failed to unmarshal header: %v", err)
|
|
}
|
|
|
|
if jwt.Header == nil {
|
|
return nil, fmt.Errorf("invalid JWT format: header is nil after unmarshaling")
|
|
}
|
|
|
|
claimsLen := base64.RawURLEncoding.DecodedLen(len(parts[1]))
|
|
if claimsLen > cap(jwtBuf.Payload) {
|
|
jwtBuf.Payload = make([]byte, claimsLen)
|
|
} else {
|
|
jwtBuf.Payload = jwtBuf.Payload[:claimsLen]
|
|
}
|
|
|
|
n, err = base64.RawURLEncoding.Decode(jwtBuf.Payload, []byte(parts[1]))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid JWT format: failed to decode claims: %v", err)
|
|
}
|
|
claimsBytes := jwtBuf.Payload[:n]
|
|
|
|
decoder2 := pm.GetJSONDecoder(bytes.NewReader(claimsBytes))
|
|
defer pm.PutJSONDecoder(decoder2)
|
|
if err := decoder2.Decode(&jwt.Claims); err != nil {
|
|
return nil, fmt.Errorf("invalid JWT format: failed to unmarshal claims: %v", err)
|
|
}
|
|
|
|
if jwt.Claims == nil {
|
|
return nil, fmt.Errorf("invalid JWT format: claims is nil after unmarshaling")
|
|
}
|
|
|
|
sigLen := base64.RawURLEncoding.DecodedLen(len(parts[2]))
|
|
if sigLen > cap(jwtBuf.Signature) {
|
|
jwtBuf.Signature = make([]byte, sigLen)
|
|
} else {
|
|
jwtBuf.Signature = jwtBuf.Signature[:sigLen]
|
|
}
|
|
|
|
n, err = base64.RawURLEncoding.Decode(jwtBuf.Signature, []byte(parts[2]))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid JWT format: failed to decode signature: %v", err)
|
|
}
|
|
|
|
// Reuse the signature buffer if it's large enough, otherwise allocate
|
|
if cap(jwtBuf.Signature) >= n {
|
|
jwt.Signature = jwtBuf.Signature[:n:n] // Use slice trick to prevent aliasing
|
|
} else {
|
|
jwt.Signature = make([]byte, n)
|
|
copy(jwt.Signature, jwtBuf.Signature[:n])
|
|
}
|
|
|
|
return jwt, nil
|
|
}
|
|
|
|
// Verify performs comprehensive JWT token validation according to OIDC specifications.
|
|
// It validates the token signature algorithm, issuer, audience, expiration, issued-at time,
|
|
// not-before time (if present), and prevents replay attacks using JTI claims.
|
|
// Parameters:
|
|
// - issuerURL: Expected issuer URL to validate against
|
|
// - expectedAudience: Expected audience to validate against (can be clientID or custom audience)
|
|
// - skipReplayCheck: Optional parameter to skip replay attack protection
|
|
//
|
|
// Returns:
|
|
// - An error describing the first validation failure encountered
|
|
func (j *JWT) Verify(issuerURL, expectedAudience string, skipReplayCheck ...bool) error {
|
|
alg, ok := j.Header["alg"].(string)
|
|
if !ok {
|
|
return fmt.Errorf("missing 'alg' header")
|
|
}
|
|
supportedAlgs := map[string]bool{
|
|
"RS256": true, "RS384": true, "RS512": true,
|
|
"PS256": true, "PS384": true, "PS512": true,
|
|
"ES256": true, "ES384": true, "ES512": true,
|
|
}
|
|
if !supportedAlgs[alg] {
|
|
return fmt.Errorf("unsupported algorithm: %s", alg)
|
|
}
|
|
|
|
claims := j.Claims
|
|
|
|
iss, ok := claims["iss"].(string)
|
|
if !ok {
|
|
return fmt.Errorf("missing 'iss' claim")
|
|
}
|
|
if err := verifyIssuer(iss, issuerURL); err != nil {
|
|
return err
|
|
}
|
|
|
|
aud, ok := claims["aud"]
|
|
if !ok {
|
|
return fmt.Errorf("missing 'aud' claim")
|
|
}
|
|
if err := verifyAudience(aud, expectedAudience); err != nil {
|
|
return err
|
|
}
|
|
|
|
exp, ok := claims["exp"].(float64)
|
|
if !ok {
|
|
return fmt.Errorf("missing or invalid 'exp' claim")
|
|
}
|
|
if err := verifyExpiration(exp); err != nil {
|
|
return err
|
|
}
|
|
|
|
iat, ok := claims["iat"].(float64)
|
|
if !ok {
|
|
return fmt.Errorf("missing or invalid 'iat' claim")
|
|
}
|
|
if err := verifyIssuedAt(iat); err != nil {
|
|
return err
|
|
}
|
|
|
|
if nbf, ok := claims["nbf"].(float64); ok {
|
|
if err := verifyNotBefore(nbf); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
shouldSkipReplay := len(skipReplayCheck) > 0 && skipReplayCheck[0]
|
|
|
|
jtiValue, jtiOk := claims["jti"].(string)
|
|
|
|
if jtiOk && !shouldSkipReplay && jtiValue != "" {
|
|
initReplayCache()
|
|
|
|
// Use sharded cache for replay detection - no global mutex needed
|
|
// This reduces lock contention by ~64x under high load
|
|
if shardedReplayCache != nil {
|
|
if shardedReplayCache.Exists(jtiValue) {
|
|
return fmt.Errorf("token replay detected (jti: %s)", jtiValue)
|
|
}
|
|
|
|
expFloat, ok := claims["exp"].(float64)
|
|
var expTime time.Time
|
|
if ok {
|
|
expTime = time.Unix(int64(expFloat), 0)
|
|
} else {
|
|
expTime = time.Now().Add(10 * time.Minute)
|
|
}
|
|
|
|
duration := time.Until(expTime)
|
|
if duration > 0 {
|
|
shardedReplayCache.Set(jtiValue, true, duration)
|
|
}
|
|
} else {
|
|
// Fall back to legacy cache with mutex (should rarely happen)
|
|
replayCacheMu.RLock()
|
|
_, exists := replayCache.Get(jtiValue)
|
|
replayCacheMu.RUnlock()
|
|
|
|
if exists {
|
|
return fmt.Errorf("token replay detected (jti: %s)", jtiValue)
|
|
}
|
|
|
|
expFloat, ok := claims["exp"].(float64)
|
|
var expTime time.Time
|
|
if ok {
|
|
expTime = time.Unix(int64(expFloat), 0)
|
|
} else {
|
|
expTime = time.Now().Add(10 * time.Minute)
|
|
}
|
|
|
|
duration := time.Until(expTime)
|
|
if duration > 0 {
|
|
replayCacheMu.Lock()
|
|
if replayCache != nil {
|
|
replayCache.Set(jtiValue, true, duration)
|
|
}
|
|
replayCacheMu.Unlock()
|
|
}
|
|
}
|
|
}
|
|
|
|
sub, ok := claims["sub"].(string)
|
|
if !ok || sub == "" {
|
|
return fmt.Errorf("missing or empty 'sub' claim")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// verifyAudience validates the JWT audience claim against the expected client ID.
|
|
// The audience claim can be either a single string or an array of strings.
|
|
// Parameters:
|
|
// - tokenAudience: The audience claim from the JWT (string or []interface{})
|
|
// - expectedAudience: The expected audience value (typically the OAuth client ID)
|
|
//
|
|
// Returns:
|
|
// - An error if the claim type is invalid or the expected audience is not present
|
|
func verifyAudience(tokenAudience interface{}, expectedAudience string) error {
|
|
switch aud := tokenAudience.(type) {
|
|
case string:
|
|
if aud != expectedAudience {
|
|
return fmt.Errorf("invalid audience")
|
|
}
|
|
case []interface{}:
|
|
found := false
|
|
for _, v := range aud {
|
|
if str, ok := v.(string); ok && str == expectedAudience {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
return fmt.Errorf("invalid audience")
|
|
}
|
|
default:
|
|
return fmt.Errorf("invalid 'aud' claim type")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// verifyIssuer validates the JWT issuer claim against the expected issuer URL.
|
|
// Parameters:
|
|
// - tokenIssuer: The issuer claim from the JWT
|
|
// - expectedIssuer: The expected issuer URL from OIDC configuration
|
|
//
|
|
// Returns:
|
|
// - An error if the issuers do not match
|
|
func verifyIssuer(tokenIssuer, expectedIssuer string) error {
|
|
if tokenIssuer != expectedIssuer {
|
|
return fmt.Errorf("invalid issuer (token: %s, expected: %s)", tokenIssuer, expectedIssuer)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// verifyTimeConstraint validates time-based JWT claims with clock skew tolerance.
|
|
// It handles both future constraints (exp) and past constraints (iat, nbf).
|
|
// Parameters:
|
|
// - unixTime: The Unix timestamp from the JWT claim
|
|
// - claimName: Name of the claim being validated (for error messages)
|
|
// - future: If true, validates against future tolerance; if false, against past tolerance
|
|
//
|
|
// Returns:
|
|
// - An error describing the failure (e.g., "token has expired", "token used before issued")
|
|
func verifyTimeConstraint(unixTime float64, claimName string, future bool) error {
|
|
claimTime := time.Unix(int64(unixTime), 0)
|
|
now := time.Now()
|
|
|
|
var err error
|
|
if future {
|
|
allowedExpiry := claimTime.Add(ClockSkewToleranceFuture)
|
|
if now.After(allowedExpiry) {
|
|
err = fmt.Errorf("token has expired (exp: %v, now: %v, allowed_until: %v)", claimTime.UTC(), now.UTC(), allowedExpiry.UTC())
|
|
}
|
|
} else {
|
|
allowedStart := claimTime.Add(-ClockSkewTolerancePast)
|
|
if now.Before(allowedStart) {
|
|
reason := "not yet valid"
|
|
if claimName == "iat" {
|
|
reason = "used before issued"
|
|
}
|
|
err = fmt.Errorf("token %s (%s: %v, now: %v, allowed_from: %v)", reason, claimName, claimTime.UTC(), now.UTC(), allowedStart.UTC())
|
|
}
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
// verifyExpiration validates the JWT expiration time (exp claim) with clock skew tolerance.
|
|
// It calls verifyTimeConstraint with future=true.
|
|
func verifyExpiration(expiration float64) error {
|
|
return verifyTimeConstraint(expiration, "exp", true)
|
|
}
|
|
|
|
// verifyIssuedAt validates the JWT issued-at time (iat claim) with clock skew tolerance.
|
|
// It calls verifyTimeConstraint with future=false.
|
|
func verifyIssuedAt(issuedAt float64) error {
|
|
return verifyTimeConstraint(issuedAt, "iat", false)
|
|
}
|
|
|
|
// verifyNotBefore validates the JWT not-before time (nbf claim) with clock skew tolerance.
|
|
// It calls verifyTimeConstraint with future=false.
|
|
func verifyNotBefore(notBefore float64) error {
|
|
return verifyTimeConstraint(notBefore, "nbf", false)
|
|
}
|
|
|
|
// verifySignature verifies the JWT signature using the provided public key.
|
|
// Supports RSA (RS256/384/512, PS256/384/512) and ECDSA (ES256/384/512) algorithms.
|
|
// Parameters:
|
|
// - tokenString: The complete JWT token string
|
|
// - publicKeyPEM: The public key in PEM format
|
|
// - alg: The signing algorithm specified in the JWT header
|
|
//
|
|
// Returns:
|
|
// - An error if the key parsing fails, the algorithm is unsupported,
|
|
// or the signature verification fails
|
|
func verifySignature(tokenString string, publicKeyPEM []byte, alg string) error {
|
|
parts := strings.Split(tokenString, ".")
|
|
if len(parts) != 3 {
|
|
return fmt.Errorf("invalid token format")
|
|
}
|
|
signedContent := parts[0] + "." + parts[1]
|
|
signature, err := base64.RawURLEncoding.DecodeString(parts[2])
|
|
if err != nil {
|
|
return fmt.Errorf("failed to decode signature: %w", err)
|
|
}
|
|
block, _ := pem.Decode(publicKeyPEM)
|
|
if block == nil {
|
|
return fmt.Errorf("failed to parse PEM block containing the public key")
|
|
}
|
|
pubKey, err := x509.ParsePKIXPublicKey(block.Bytes)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to parse public key: %w", err)
|
|
}
|
|
var hashFunc crypto.Hash
|
|
switch alg {
|
|
case "RS256", "PS256", "ES256":
|
|
hashFunc = crypto.SHA256
|
|
case "RS384", "PS384", "ES384":
|
|
hashFunc = crypto.SHA384
|
|
case "RS512", "PS512", "ES512":
|
|
hashFunc = crypto.SHA512
|
|
default:
|
|
return fmt.Errorf("unsupported algorithm: %s", alg)
|
|
}
|
|
h := hashFunc.New()
|
|
h.Write([]byte(signedContent))
|
|
hashed := h.Sum(nil)
|
|
switch pubKey := pubKey.(type) {
|
|
case *rsa.PublicKey:
|
|
if strings.HasPrefix(alg, "RS") {
|
|
return rsa.VerifyPKCS1v15(pubKey, hashFunc, hashed, signature)
|
|
} else if strings.HasPrefix(alg, "PS") {
|
|
return rsa.VerifyPSS(pubKey, hashFunc, hashed, signature, nil)
|
|
} else {
|
|
return fmt.Errorf("unexpected key type for algorithm %s", alg)
|
|
}
|
|
case *ecdsa.PublicKey:
|
|
if strings.HasPrefix(alg, "ES") {
|
|
var r, s big.Int
|
|
sigLen := len(signature)
|
|
if sigLen%2 != 0 {
|
|
return fmt.Errorf("invalid ECDSA signature length")
|
|
}
|
|
r.SetBytes(signature[:sigLen/2])
|
|
s.SetBytes(signature[sigLen/2:])
|
|
if ecdsa.Verify(pubKey, hashed, &r, &s) {
|
|
return nil
|
|
} else {
|
|
return fmt.Errorf("invalid ECDSA signature")
|
|
}
|
|
} else {
|
|
return fmt.Errorf("unexpected key type for algorithm %s", alg)
|
|
}
|
|
default:
|
|
return fmt.Errorf("unsupported public key type: %T", pubKey)
|
|
}
|
|
}
|