mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-05 22:44:17 +00:00
1b49e133da
* Fix bug affecting Azure OIDC authentication ( and most likely others ) * Fixes issue #51 * Ensure that appended roles are unique. Update the documentation. * Improvements targetting possible memory usage spikes. * Additional fixes and cleanup * Refactoring code to fix the issues identified by the users. * Modernize run * Fieldalignment * Multiple changes to improve performance and reduce complexity. - Optimise the errors and recovery. - Deduplicate code in metadata cache. - Remove unused performance monitoring code. - Simplify session management and settings handling. * Fix claims issue. * Add ability to overwrite the default scopes in the settings file * Well.. that escalated quickly. Completely forgot that Traefik uses outdated Yaegi and requires compatibility with 1.20 ( pre-generic Go code ). * Bugfix #51: Ensures that user provided scopes overrides work. * fixup! Bugfix #51: Ensures that user provided scopes overrides work. * fixup! fixup! Bugfix #51: Ensures that user provided scopes overrides work. * Abstract the provider logic into a separate package. * Additional micro fixes and cleanups. * Simplify all the things. * fixup! Simplify all the things. * fixup! fixup! Simplify all the things. * fixup! fixup! fixup! Simplify all the things. * fixup! fixup! fixup! fixup! Simplify all the things. * ... * Cleanup tests. * fixup! Cleanup tests. * fixup! fixup! fixup! Cleanup tests. * fixup! fixup! fixup! fixup! Cleanup tests. * fixup! fixup! fixup! fixup! fixup! Cleanup tests. * Issue #53: Fix CSRF token handling in reverse proxy 1. ✅ HTTPS Detection Fixed (session.go:723) - Now uses X-Forwarded-Proto header instead of r.URL.Scheme - Properly detects HTTPS in reverse proxy environments 2. ✅ SameSite Cookie Attribute Fixed - Removed automatic SameSiteStrictMode for HTTPS (would break OAuth) - Keeps SameSiteLaxMode to allow OAuth callbacks from external domains - Only uses Strict for AJAX requests which don't involve OAuth redirects 3. ✅ Cookie Domain Handling Fixed - Now respects X-Forwarded-Host header for cookie domain - Ensures cookies are set for the public domain, not internal proxy domain 4. ✅ EnhanceSessionSecurity Properly Integrated - Function is now actually called during session save - Applies security enhancements without breaking OAuth flow Why Issue #53 Failed Before: 1. Cookies were not marked Secure in HTTPS environments (browser wouldn't send them back) 2. If they had been Secure with SameSite=Strict, Azure callbacks would still fail 3. Cookie domain might have been wrong (internal vs public domain) Why It Works Now: 1. Cookies are properly marked Secure for HTTPS 2. Uses SameSite=Lax to allow OAuth provider callbacks 3. Cookie domain uses public domain from X-Forwarded-Host 4. CSRF token persists through the entire OAuth flow * Next set of enhancements together with memory usage improvements. * Memory leak fixes and optimisations. * CSRF and Cookie Domain fixes * fixup! CSRF and Cookie Domain fixes * Metadata cache leak fix + profiling * fixup! Metadata cache leak fix + profiling * Memory leaks hunting, part 1337. * Further pursue of perfection. * fixup! Further pursue of perfection. * fixup! fixup! Further pursue of perfection. * fixup! fixup! fixup! Further pursue of perfection. * fixup! fixup! fixup! fixup! Further pursue of perfection. * fixup! fixup! fixup! fixup! fixup! Further pursue of perfection. * fixup! fixup! fixup! fixup! fixup! fixup! Further pursue of perfection. * fixup! fixup! fixup! fixup! fixup! fixup! fixup! Further pursue of perfection. * fixup! fixup! fixup! fixup! fixup! fixup! fixup! fixup! Further pursue of perfection. * fixup! fixup! fixup! fixup! fixup! fixup! fixup! fixup! fixup! Further pursue of perfection. * Clear race conditions * fixup! Clear race conditions * Weekend fun with memory leaks * Splitting code into multiple files with reasonable testing coverage. ``` ok github.com/lukaszraczylo/traefikoidc 117.017s coverage: 72.6% of statements ok github.com/lukaszraczylo/traefikoidc/auth 0.505s coverage: 87.1% of statements ok github.com/lukaszraczylo/traefikoidc/circuit_breaker 0.283s coverage: 99.0% of statements github.com/lukaszraczylo/traefikoidc/config coverage: 0.0% of statements ok github.com/lukaszraczylo/traefikoidc/handlers 0.349s coverage: 98.2% of statements ok github.com/lukaszraczylo/traefikoidc/internal/providers (cached) coverage: 94.3% of statements ok github.com/lukaszraczylo/traefikoidc/middleware 0.808s coverage: 78.0% of statements ok github.com/lukaszraczylo/traefikoidc/recovery 0.653s coverage: 100.0% of statements ok github.com/lukaszraczylo/traefikoidc/session/chunking (cached) coverage: 87.8% of statements ok github.com/lukaszraczylo/traefikoidc/session/core (cached) coverage: 85.6% of statements ok github.com/lukaszraczylo/traefikoidc/session/crypto (cached) coverage: 81.8% of statements ok github.com/lukaszraczylo/traefikoidc/session/storage (cached) coverage: 93.5% of statements ok github.com/lukaszraczylo/traefikoidc/session/validators (cached) coverage: 98.8% of statements ```` * fixup! Splitting code into multiple files with reasonable testing coverage. * fixup! fixup! Splitting code into multiple files with reasonable testing coverage. * Weekend fun with further optimisations. * fixup! Weekend fun with further optimisations. * fixup! fixup! Weekend fun with further optimisations. * fixup! fixup! fixup! Weekend fun with further optimisations. * fixup! fixup! fixup! fixup! Weekend fun with further optimisations. * fixup! fixup! fixup! fixup! fixup! Weekend fun with further optimisations. * Pre-release cleanup. * Enhance test coverage. * fixup! Enhance test coverage. * fixup! fixup! Enhance test coverage. * fixup! fixup! fixup! Enhance test coverage.
517 lines
16 KiB
Go
517 lines
16 KiB
Go
package traefikoidc
|
|
|
|
import (
|
|
"context"
|
|
"crypto"
|
|
"crypto/ecdsa"
|
|
"crypto/rsa"
|
|
"crypto/x509"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"encoding/pem"
|
|
"fmt"
|
|
"math/big"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
// Replay attack protection cache and synchronization primitives.
|
|
// This cache tracks JWT IDs (jti claims) to prevent token reuse attacks.
|
|
var (
|
|
// replayCacheMu protects access to the replay cache instance
|
|
replayCacheMu sync.RWMutex
|
|
// replayCache stores JWT IDs with expiration to prevent replay attacks
|
|
replayCache CacheInterface
|
|
// replayCacheOnce ensures the replay cache is initialized only once
|
|
replayCacheOnce sync.Once
|
|
// replayCacheCleanupWG waits for cleanup goroutine to finish
|
|
replayCacheCleanupWG sync.WaitGroup
|
|
// replayCacheCancel cancels the cleanup context
|
|
replayCacheCancel context.CancelFunc
|
|
// replayCacheCleanupMu protects cleanup operations
|
|
replayCacheCleanupMu sync.Mutex
|
|
)
|
|
|
|
// initReplayCache initializes the JWT replay protection cache with bounded size.
|
|
// The cache is bounded to 10,000 entries to prevent unbounded memory growth.
|
|
// This function uses sync.Once to ensure thread-safe single initialization.
|
|
func initReplayCache() {
|
|
replayCacheOnce.Do(func() {
|
|
replayCache = NewCache()
|
|
replayCache.SetMaxSize(10000)
|
|
})
|
|
}
|
|
|
|
// cleanupReplayCache performs graceful shutdown of the replay cache system.
|
|
// It cancels the cleanup context, waits for background goroutines to finish,
|
|
// and properly closes the cache to ensure proper cleanup during shutdown.
|
|
func cleanupReplayCache() {
|
|
replayCacheCleanupMu.Lock()
|
|
shouldWait := replayCacheCancel != nil
|
|
if replayCacheCancel != nil {
|
|
replayCacheCancel()
|
|
replayCacheCancel = nil
|
|
}
|
|
replayCacheCleanupMu.Unlock()
|
|
|
|
// Only wait if there was a cleanup routine running
|
|
if shouldWait {
|
|
replayCacheCleanupWG.Wait()
|
|
}
|
|
|
|
replayCacheMu.Lock()
|
|
defer replayCacheMu.Unlock()
|
|
|
|
if replayCache != nil {
|
|
replayCache.Close()
|
|
replayCache = nil
|
|
replayCacheOnce = sync.Once{}
|
|
}
|
|
}
|
|
|
|
// getReplayCacheStats returns statistics about the replay cache state.
|
|
// Returns:
|
|
// - size: Current number of entries in the cache (currently always 0 due to interface limitations)
|
|
// - maxSize: Maximum allowed entries (10,000)
|
|
func getReplayCacheStats() (size int, maxSize int) {
|
|
replayCacheMu.RLock()
|
|
defer replayCacheMu.RUnlock()
|
|
|
|
if replayCache == nil {
|
|
return 0, 10000
|
|
}
|
|
|
|
return 0, 10000
|
|
}
|
|
|
|
// startReplayCacheCleanup starts a background goroutine for periodic cache maintenance.
|
|
// The goroutine runs every 5 minutes to clean expired entries and log cache statistics.
|
|
// Uses the global task registry with circuit breaker pattern to prevent duplicate tasks.
|
|
// Parameters:
|
|
// - ctx: Parent context for cancellation
|
|
// - logger: Logger for debug output (can be nil)
|
|
func startReplayCacheCleanup(ctx context.Context, logger *Logger) {
|
|
registry := GetGlobalTaskRegistry()
|
|
|
|
// Define the cleanup task function
|
|
cleanupFunc := func() {
|
|
size, maxSize := getReplayCacheStats()
|
|
if logger != nil {
|
|
logger.Debugf("Replay cache stats: size=%d, maxSize=%d", size, maxSize)
|
|
}
|
|
|
|
replayCacheMu.RLock()
|
|
if replayCache != nil {
|
|
replayCache.Cleanup()
|
|
}
|
|
replayCacheMu.RUnlock()
|
|
}
|
|
|
|
// Create or get singleton cleanup task
|
|
task, err := registry.CreateSingletonTask(
|
|
"replay-cache-cleanup",
|
|
5*time.Minute,
|
|
cleanupFunc,
|
|
logger,
|
|
&replayCacheCleanupWG,
|
|
)
|
|
|
|
if err != nil {
|
|
if logger != nil {
|
|
logger.Debugf("Replay cache cleanup task already exists or circuit breaker limit reached: %v (this is expected with multiple instances)", err)
|
|
}
|
|
return
|
|
}
|
|
|
|
// Start the task
|
|
task.Start()
|
|
|
|
if logger != nil {
|
|
logger.Debug("Started replay cache cleanup task with circuit breaker protection")
|
|
}
|
|
}
|
|
|
|
// ClockSkewToleranceFuture defines the maximum allowable clock skew for future time validation.
|
|
// Tokens are considered valid for an additional 2 minutes past their expiration time.
|
|
var ClockSkewToleranceFuture = 2 * time.Minute
|
|
|
|
// ClockSkewTolerancePast defines the maximum allowable clock skew for past time validation.
|
|
// Tokens are considered valid if issued up to 10 seconds in the future.
|
|
var ClockSkewTolerancePast = 10 * time.Second
|
|
|
|
// ClockSkewTolerance is an alias for ClockSkewToleranceFuture for backward compatibility.
|
|
var ClockSkewTolerance = ClockSkewToleranceFuture
|
|
|
|
// JWT represents a parsed JSON Web Token with its constituent parts.
|
|
// It provides a structured representation of JWT components
|
|
// for validation and processing within the OIDC middleware.
|
|
type JWT struct {
|
|
// Header contains the JWT header claims (alg, typ, kid, etc.)
|
|
Header map[string]interface{}
|
|
// Claims contains the JWT payload claims (iss, sub, aud, exp, etc.)
|
|
Claims map[string]interface{}
|
|
// Token is the original JWT token string
|
|
Token string
|
|
// Signature contains the decoded JWT signature bytes
|
|
Signature []byte
|
|
}
|
|
|
|
// parseJWT parses a JWT token string into its constituent parts.
|
|
// It decodes the base64url-encoded header, claims, and signature components
|
|
// and unmarshals the JSON data into structured maps. Uses memory pools
|
|
// for efficient memory allocation during parsing.
|
|
// Parameters:
|
|
// - tokenString: The JWT token string to parse
|
|
//
|
|
// Returns:
|
|
// - *JWT: Parsed JWT structure with header, claims, and signature
|
|
// - An error if the token format is invalid or decoding/unmarshaling fails
|
|
func parseJWT(tokenString string) (*JWT, error) {
|
|
parts := strings.Split(tokenString, ".")
|
|
if len(parts) != 3 {
|
|
return nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
|
|
}
|
|
|
|
pools := GetGlobalMemoryPools()
|
|
jwtBuf := pools.GetJWTParsingBuffer()
|
|
defer pools.PutJWTParsingBuffer(jwtBuf)
|
|
|
|
jwt := &JWT{
|
|
Token: tokenString,
|
|
}
|
|
|
|
headerLen := base64.RawURLEncoding.DecodedLen(len(parts[0]))
|
|
if headerLen > cap(jwtBuf.HeaderBuf) {
|
|
jwtBuf.HeaderBuf = make([]byte, headerLen)
|
|
} else {
|
|
jwtBuf.HeaderBuf = jwtBuf.HeaderBuf[:headerLen]
|
|
}
|
|
|
|
n, err := base64.RawURLEncoding.Decode(jwtBuf.HeaderBuf, []byte(parts[0]))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid JWT format: failed to decode header: %v", err)
|
|
}
|
|
headerBytes := jwtBuf.HeaderBuf[:n]
|
|
|
|
if err := json.Unmarshal(headerBytes, &jwt.Header); err != nil {
|
|
return nil, fmt.Errorf("invalid JWT format: failed to unmarshal header: %v", err)
|
|
}
|
|
|
|
if jwt.Header == nil {
|
|
return nil, fmt.Errorf("invalid JWT format: header is nil after unmarshaling")
|
|
}
|
|
|
|
claimsLen := base64.RawURLEncoding.DecodedLen(len(parts[1]))
|
|
if claimsLen > cap(jwtBuf.PayloadBuf) {
|
|
jwtBuf.PayloadBuf = make([]byte, claimsLen)
|
|
} else {
|
|
jwtBuf.PayloadBuf = jwtBuf.PayloadBuf[:claimsLen]
|
|
}
|
|
|
|
n, err = base64.RawURLEncoding.Decode(jwtBuf.PayloadBuf, []byte(parts[1]))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid JWT format: failed to decode claims: %v", err)
|
|
}
|
|
claimsBytes := jwtBuf.PayloadBuf[:n]
|
|
|
|
if err := json.Unmarshal(claimsBytes, &jwt.Claims); err != nil {
|
|
return nil, fmt.Errorf("invalid JWT format: failed to unmarshal claims: %v", err)
|
|
}
|
|
|
|
if jwt.Claims == nil {
|
|
return nil, fmt.Errorf("invalid JWT format: claims is nil after unmarshaling")
|
|
}
|
|
|
|
sigLen := base64.RawURLEncoding.DecodedLen(len(parts[2]))
|
|
if sigLen > cap(jwtBuf.SignatureBuf) {
|
|
jwtBuf.SignatureBuf = make([]byte, sigLen)
|
|
} else {
|
|
jwtBuf.SignatureBuf = jwtBuf.SignatureBuf[:sigLen]
|
|
}
|
|
|
|
n, err = base64.RawURLEncoding.Decode(jwtBuf.SignatureBuf, []byte(parts[2]))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid JWT format: failed to decode signature: %v", err)
|
|
}
|
|
|
|
jwt.Signature = make([]byte, n)
|
|
copy(jwt.Signature, jwtBuf.SignatureBuf[:n])
|
|
|
|
return jwt, nil
|
|
}
|
|
|
|
// Verify performs comprehensive JWT token validation according to OIDC specifications.
|
|
// It validates the token signature algorithm, issuer, audience, expiration, issued-at time,
|
|
// not-before time (if present), and prevents replay attacks using JTI claims.
|
|
// Parameters:
|
|
// - issuerURL: Expected issuer URL to validate against
|
|
// - clientID: Expected audience (client ID) to validate against
|
|
// - skipReplayCheck: Optional parameter to skip replay attack protection
|
|
//
|
|
// Returns:
|
|
// - An error describing the first validation failure encountered
|
|
func (j *JWT) Verify(issuerURL, clientID string, skipReplayCheck ...bool) error {
|
|
alg, ok := j.Header["alg"].(string)
|
|
if !ok {
|
|
return fmt.Errorf("missing 'alg' header")
|
|
}
|
|
supportedAlgs := map[string]bool{
|
|
"RS256": true, "RS384": true, "RS512": true,
|
|
"PS256": true, "PS384": true, "PS512": true,
|
|
"ES256": true, "ES384": true, "ES512": true,
|
|
}
|
|
if !supportedAlgs[alg] {
|
|
return fmt.Errorf("unsupported algorithm: %s", alg)
|
|
}
|
|
|
|
claims := j.Claims
|
|
|
|
iss, ok := claims["iss"].(string)
|
|
if !ok {
|
|
return fmt.Errorf("missing 'iss' claim")
|
|
}
|
|
if err := verifyIssuer(iss, issuerURL); err != nil {
|
|
return err
|
|
}
|
|
|
|
aud, ok := claims["aud"]
|
|
if !ok {
|
|
return fmt.Errorf("missing 'aud' claim")
|
|
}
|
|
if err := verifyAudience(aud, clientID); err != nil {
|
|
return err
|
|
}
|
|
|
|
exp, ok := claims["exp"].(float64)
|
|
if !ok {
|
|
return fmt.Errorf("missing or invalid 'exp' claim")
|
|
}
|
|
if err := verifyExpiration(exp); err != nil {
|
|
return err
|
|
}
|
|
|
|
iat, ok := claims["iat"].(float64)
|
|
if !ok {
|
|
return fmt.Errorf("missing or invalid 'iat' claim")
|
|
}
|
|
if err := verifyIssuedAt(iat); err != nil {
|
|
return err
|
|
}
|
|
|
|
if nbf, ok := claims["nbf"].(float64); ok {
|
|
if err := verifyNotBefore(nbf); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
shouldSkipReplay := len(skipReplayCheck) > 0 && skipReplayCheck[0]
|
|
|
|
jtiValue, jtiOk := claims["jti"].(string)
|
|
|
|
if jtiOk && !shouldSkipReplay && jtiValue != "" {
|
|
initReplayCache()
|
|
|
|
replayCacheMu.RLock()
|
|
_, exists := replayCache.Get(jtiValue)
|
|
replayCacheMu.RUnlock()
|
|
|
|
if exists {
|
|
return fmt.Errorf("token replay detected (jti: %s)", jtiValue)
|
|
}
|
|
|
|
expFloat, ok := claims["exp"].(float64)
|
|
var expTime time.Time
|
|
if ok {
|
|
expTime = time.Unix(int64(expFloat), 0)
|
|
} else {
|
|
expTime = time.Now().Add(10 * time.Minute)
|
|
}
|
|
|
|
duration := time.Until(expTime)
|
|
if duration > 0 {
|
|
replayCacheMu.Lock()
|
|
if replayCache != nil {
|
|
replayCache.Set(jtiValue, true, duration)
|
|
}
|
|
replayCacheMu.Unlock()
|
|
}
|
|
}
|
|
|
|
sub, ok := claims["sub"].(string)
|
|
if !ok || sub == "" {
|
|
return fmt.Errorf("missing or empty 'sub' claim")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// verifyAudience validates the JWT audience claim against the expected client ID.
|
|
// The audience claim can be either a single string or an array of strings.
|
|
// Parameters:
|
|
// - tokenAudience: The audience claim from the JWT (string or []interface{})
|
|
// - expectedAudience: The expected audience value (typically the OAuth client ID)
|
|
//
|
|
// Returns:
|
|
// - An error if the claim type is invalid or the expected audience is not present
|
|
func verifyAudience(tokenAudience interface{}, expectedAudience string) error {
|
|
switch aud := tokenAudience.(type) {
|
|
case string:
|
|
if aud != expectedAudience {
|
|
return fmt.Errorf("invalid audience")
|
|
}
|
|
case []interface{}:
|
|
found := false
|
|
for _, v := range aud {
|
|
if str, ok := v.(string); ok && str == expectedAudience {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
return fmt.Errorf("invalid audience")
|
|
}
|
|
default:
|
|
return fmt.Errorf("invalid 'aud' claim type")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// verifyIssuer validates the JWT issuer claim against the expected issuer URL.
|
|
// Parameters:
|
|
// - tokenIssuer: The issuer claim from the JWT
|
|
// - expectedIssuer: The expected issuer URL from OIDC configuration
|
|
//
|
|
// Returns:
|
|
// - An error if the issuers do not match
|
|
func verifyIssuer(tokenIssuer, expectedIssuer string) error {
|
|
if tokenIssuer != expectedIssuer {
|
|
return fmt.Errorf("invalid issuer (token: %s, expected: %s)", tokenIssuer, expectedIssuer)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// verifyTimeConstraint validates time-based JWT claims with clock skew tolerance.
|
|
// It handles both future constraints (exp) and past constraints (iat, nbf).
|
|
// Parameters:
|
|
// - unixTime: The Unix timestamp from the JWT claim
|
|
// - claimName: Name of the claim being validated (for error messages)
|
|
// - future: If true, validates against future tolerance; if false, against past tolerance
|
|
//
|
|
// Returns:
|
|
// - An error describing the failure (e.g., "token has expired", "token used before issued")
|
|
func verifyTimeConstraint(unixTime float64, claimName string, future bool) error {
|
|
claimTime := time.Unix(int64(unixTime), 0)
|
|
now := time.Now()
|
|
|
|
var err error
|
|
if future {
|
|
allowedExpiry := claimTime.Add(ClockSkewToleranceFuture)
|
|
if now.After(allowedExpiry) {
|
|
err = fmt.Errorf("token has expired (exp: %v, now: %v, allowed_until: %v)", claimTime.UTC(), now.UTC(), allowedExpiry.UTC())
|
|
}
|
|
} else {
|
|
allowedStart := claimTime.Add(-ClockSkewTolerancePast)
|
|
if now.Before(allowedStart) {
|
|
reason := "not yet valid"
|
|
if claimName == "iat" {
|
|
reason = "used before issued"
|
|
}
|
|
err = fmt.Errorf("token %s (%s: %v, now: %v, allowed_from: %v)", reason, claimName, claimTime.UTC(), now.UTC(), allowedStart.UTC())
|
|
}
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
// verifyExpiration validates the JWT expiration time (exp claim) with clock skew tolerance.
|
|
// It calls verifyTimeConstraint with future=true.
|
|
func verifyExpiration(expiration float64) error {
|
|
return verifyTimeConstraint(expiration, "exp", true)
|
|
}
|
|
|
|
// verifyIssuedAt validates the JWT issued-at time (iat claim) with clock skew tolerance.
|
|
// It calls verifyTimeConstraint with future=false.
|
|
func verifyIssuedAt(issuedAt float64) error {
|
|
return verifyTimeConstraint(issuedAt, "iat", false)
|
|
}
|
|
|
|
// verifyNotBefore validates the JWT not-before time (nbf claim) with clock skew tolerance.
|
|
// It calls verifyTimeConstraint with future=false.
|
|
func verifyNotBefore(notBefore float64) error {
|
|
return verifyTimeConstraint(notBefore, "nbf", false)
|
|
}
|
|
|
|
// verifySignature verifies the JWT signature using the provided public key.
|
|
// Supports RSA (RS256/384/512, PS256/384/512) and ECDSA (ES256/384/512) algorithms.
|
|
// Parameters:
|
|
// - tokenString: The complete JWT token string
|
|
// - publicKeyPEM: The public key in PEM format
|
|
// - alg: The signing algorithm specified in the JWT header
|
|
//
|
|
// Returns:
|
|
// - An error if the key parsing fails, the algorithm is unsupported,
|
|
// or the signature verification fails
|
|
func verifySignature(tokenString string, publicKeyPEM []byte, alg string) error {
|
|
parts := strings.Split(tokenString, ".")
|
|
if len(parts) != 3 {
|
|
return fmt.Errorf("invalid token format")
|
|
}
|
|
signedContent := parts[0] + "." + parts[1]
|
|
signature, err := base64.RawURLEncoding.DecodeString(parts[2])
|
|
if err != nil {
|
|
return fmt.Errorf("failed to decode signature: %w", err)
|
|
}
|
|
block, _ := pem.Decode(publicKeyPEM)
|
|
if block == nil {
|
|
return fmt.Errorf("failed to parse PEM block containing the public key")
|
|
}
|
|
pubKey, err := x509.ParsePKIXPublicKey(block.Bytes)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to parse public key: %w", err)
|
|
}
|
|
var hashFunc crypto.Hash
|
|
switch alg {
|
|
case "RS256", "PS256", "ES256":
|
|
hashFunc = crypto.SHA256
|
|
case "RS384", "PS384", "ES384":
|
|
hashFunc = crypto.SHA384
|
|
case "RS512", "PS512", "ES512":
|
|
hashFunc = crypto.SHA512
|
|
default:
|
|
return fmt.Errorf("unsupported algorithm: %s", alg)
|
|
}
|
|
h := hashFunc.New()
|
|
h.Write([]byte(signedContent))
|
|
hashed := h.Sum(nil)
|
|
switch pubKey := pubKey.(type) {
|
|
case *rsa.PublicKey:
|
|
if strings.HasPrefix(alg, "RS") {
|
|
return rsa.VerifyPKCS1v15(pubKey, hashFunc, hashed, signature)
|
|
} else if strings.HasPrefix(alg, "PS") {
|
|
return rsa.VerifyPSS(pubKey, hashFunc, hashed, signature, nil)
|
|
} else {
|
|
return fmt.Errorf("unexpected key type for algorithm %s", alg)
|
|
}
|
|
case *ecdsa.PublicKey:
|
|
if strings.HasPrefix(alg, "ES") {
|
|
var r, s big.Int
|
|
sigLen := len(signature)
|
|
if sigLen%2 != 0 {
|
|
return fmt.Errorf("invalid ECDSA signature length")
|
|
}
|
|
r.SetBytes(signature[:sigLen/2])
|
|
s.SetBytes(signature[sigLen/2:])
|
|
if ecdsa.Verify(pubKey, hashed, &r, &s) {
|
|
return nil
|
|
} else {
|
|
return fmt.Errorf("invalid ECDSA signature")
|
|
}
|
|
} else {
|
|
return fmt.Errorf("unexpected key type for algorithm %s", alg)
|
|
}
|
|
default:
|
|
return fmt.Errorf("unsupported public key type: %T", pubKey)
|
|
}
|
|
}
|