mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-05 22:44:17 +00:00
417 lines
12 KiB
Go
417 lines
12 KiB
Go
package traefikoidc
|
|
|
|
import (
|
|
"crypto"
|
|
"crypto/ecdsa"
|
|
"crypto/rsa"
|
|
"math/big"
|
|
"strings"
|
|
|
|
"crypto/x509"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"encoding/pem"
|
|
"fmt"
|
|
|
|
"time"
|
|
)
|
|
|
|
// JWT represents a JSON Web Token as defined in RFC 7519.
|
|
// It contains the three parts of a JWT: header, claims (payload),
|
|
// and signature, along with the original token string.
|
|
type JWT struct {
|
|
// Header contains the token metadata (algorithm, key ID, etc.)
|
|
Header map[string]interface{}
|
|
|
|
// Claims contains the token claims (subject, expiration, etc.)
|
|
Claims map[string]interface{}
|
|
|
|
// Signature contains the raw signature bytes
|
|
Signature []byte
|
|
|
|
// Token is the original JWT string
|
|
Token string
|
|
}
|
|
|
|
// parseJWT parses a JWT token string into a JWT struct.
|
|
// It validates the token format and decodes the three parts
|
|
// (header, claims, signature) using base64url decoding.
|
|
// Parameters:
|
|
// - tokenString: The raw JWT token string
|
|
//
|
|
// Returns:
|
|
// - A parsed JWT struct
|
|
// - An error if the token format is invalid or parsing fails
|
|
func parseJWT(tokenString string) (*JWT, error) {
|
|
parts := strings.Split(tokenString, ".")
|
|
if len(parts) != 3 {
|
|
return nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
|
|
}
|
|
|
|
jwt := &JWT{
|
|
Token: tokenString,
|
|
}
|
|
|
|
// Decode and unmarshal the header
|
|
headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0])
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid JWT format: failed to decode header: %v", err)
|
|
}
|
|
if err := json.Unmarshal(headerBytes, &jwt.Header); err != nil {
|
|
return nil, fmt.Errorf("invalid JWT format: failed to unmarshal header: %v", err)
|
|
}
|
|
|
|
// Decode and unmarshal the claims
|
|
claimsBytes, err := base64.RawURLEncoding.DecodeString(parts[1])
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid JWT format: failed to decode claims: %v", err)
|
|
}
|
|
if err := json.Unmarshal(claimsBytes, &jwt.Claims); err != nil {
|
|
return nil, fmt.Errorf("invalid JWT format: failed to unmarshal claims: %v", err)
|
|
}
|
|
|
|
// Decode the signature
|
|
signatureBytes, err := base64.RawURLEncoding.DecodeString(parts[2])
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid JWT format: failed to decode signature: %v", err)
|
|
}
|
|
jwt.Signature = signatureBytes
|
|
|
|
return jwt, nil
|
|
}
|
|
|
|
// Verify validates the standard JWT claims as defined in RFC 7519.
|
|
// It checks:
|
|
// - issuer (iss) matches the expected issuer URL
|
|
// - audience (aud) includes the client ID
|
|
// - expiration time (exp) is in the future (with clock skew tolerance)
|
|
// - issued at time (iat) is in the past (with clock skew tolerance)
|
|
// - not before time (nbf) is in the past (with clock skew tolerance)
|
|
// - subject (sub) is present and not empty
|
|
// - algorithm matches expected value to prevent algorithm switching attacks
|
|
//
|
|
// Returns an error if any validation fails.
|
|
func (j *JWT) Verify(issuerURL, clientID string) error {
|
|
// Debug logging of validation parameters
|
|
fmt.Printf("Validating token against:\nIssuer: %s\nClient ID: %s\n", issuerURL, clientID)
|
|
// Debug logging of token header
|
|
fmt.Printf("Token header: %+v\n", j.Header)
|
|
|
|
// Validate algorithm to prevent algorithm switching attacks
|
|
alg, ok := j.Header["alg"].(string)
|
|
if !ok {
|
|
return fmt.Errorf("missing 'alg' header")
|
|
}
|
|
// List of supported algorithms - should match those in verifySignature
|
|
supportedAlgs := map[string]bool{
|
|
"RS256": true, "RS384": true, "RS512": true,
|
|
"PS256": true, "PS384": true, "PS512": true,
|
|
"ES256": true, "ES384": true, "ES512": true,
|
|
}
|
|
if !supportedAlgs[alg] {
|
|
return fmt.Errorf("unsupported algorithm: %s", alg)
|
|
}
|
|
|
|
claims := j.Claims
|
|
|
|
// Debug logging of all claims
|
|
fmt.Printf("Token claims: %+v\n", claims)
|
|
|
|
iss, ok := claims["iss"].(string)
|
|
if !ok {
|
|
return fmt.Errorf("missing 'iss' claim")
|
|
}
|
|
if err := verifyIssuer(iss, issuerURL); err != nil {
|
|
return err
|
|
}
|
|
|
|
aud, ok := claims["aud"]
|
|
if !ok {
|
|
return fmt.Errorf("missing 'aud' claim")
|
|
}
|
|
if err := verifyAudience(aud, clientID); err != nil {
|
|
return err
|
|
}
|
|
|
|
exp, ok := claims["exp"].(float64)
|
|
if !ok {
|
|
return fmt.Errorf("missing or invalid 'exp' claim")
|
|
}
|
|
if err := verifyExpiration(exp); err != nil {
|
|
return err
|
|
}
|
|
|
|
iat, ok := claims["iat"].(float64)
|
|
if !ok {
|
|
return fmt.Errorf("missing or invalid 'iat' claim")
|
|
}
|
|
if err := verifyIssuedAt(iat); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Validate nbf (not before) claim if present
|
|
if nbf, ok := claims["nbf"].(float64); ok {
|
|
if err := verifyNotBefore(nbf); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
// Validate jti (JWT ID) claim if present
|
|
if jti, ok := claims["jti"].(string); ok {
|
|
// Could add replay detection here if needed
|
|
_ = jti
|
|
}
|
|
|
|
sub, ok := claims["sub"].(string)
|
|
if !ok || sub == "" {
|
|
return fmt.Errorf("missing or empty 'sub' claim")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// verifyAudience validates the token's audience claim.
|
|
// The audience can be either a single string or an array of strings.
|
|
// For array audiences, the expected audience must match any one value.
|
|
// Parameters:
|
|
// - tokenAudience: The audience claim from the token
|
|
// - expectedAudience: The expected audience value
|
|
//
|
|
// Returns an error if validation fails.
|
|
func verifyAudience(tokenAudience interface{}, expectedAudience string) error {
|
|
// Debug logging
|
|
fmt.Printf("Verifying audience:\nToken aud: %+v\nExpected: %s\n",
|
|
tokenAudience, expectedAudience)
|
|
|
|
switch aud := tokenAudience.(type) {
|
|
case string:
|
|
if aud != expectedAudience {
|
|
return fmt.Errorf("invalid audience")
|
|
}
|
|
case []interface{}:
|
|
found := false
|
|
for _, v := range aud {
|
|
if str, ok := v.(string); ok && str == expectedAudience {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
return fmt.Errorf("invalid audience")
|
|
}
|
|
default:
|
|
return fmt.Errorf("invalid 'aud' claim type")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// verifyIssuer validates the token's issuer claim.
|
|
// The issuer URL must exactly match the expected issuer.
|
|
// Parameters:
|
|
// - tokenIssuer: The issuer claim from the token
|
|
// - expectedIssuer: The expected issuer URL
|
|
//
|
|
// Returns an error if validation fails.
|
|
func verifyIssuer(tokenIssuer, expectedIssuer string) error {
|
|
// Debug logging
|
|
fmt.Printf("Verifying issuer:\nToken iss: %s\nExpected: %s\n",
|
|
tokenIssuer, expectedIssuer)
|
|
|
|
if tokenIssuer != expectedIssuer {
|
|
return fmt.Errorf("invalid issuer (token: %s, expected: %s)",
|
|
tokenIssuer, expectedIssuer)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Clock skew tolerance for time-based validations
|
|
const clockSkewTolerance = 2 * time.Minute
|
|
|
|
// verifyExpiration checks if the token's expiration time has passed.
|
|
// The expiration time is compared against the current time with clock skew tolerance.
|
|
// Parameters:
|
|
// - expiration: The expiration timestamp from the token
|
|
//
|
|
// Returns an error if the token has expired.
|
|
func verifyExpiration(expiration float64) error {
|
|
expirationTime := time.Unix(int64(expiration), 0)
|
|
// Truncate current time to seconds for consistent comparison
|
|
now := time.Now().Truncate(time.Second)
|
|
skewedNow := now.Add(clockSkewTolerance)
|
|
|
|
// Debug logging
|
|
fmt.Printf("Token exp: %v\nCurrent time: %v\nSkewed time: %v\nSkew: %v\n",
|
|
expirationTime.UTC(),
|
|
now.UTC(),
|
|
skewedNow.UTC(),
|
|
clockSkewTolerance)
|
|
|
|
// Allow tokens that expire exactly now
|
|
if expirationTime.Equal(now) {
|
|
return nil
|
|
}
|
|
|
|
if skewedNow.After(expirationTime) {
|
|
return fmt.Errorf("token has expired (exp: %v, now: %v)",
|
|
expirationTime.UTC(), now.UTC())
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// verifyIssuedAt validates the token's issued-at time.
|
|
// Ensures the token wasn't issued in the future, accounting for clock skew.
|
|
// Parameters:
|
|
// - issuedAt: The issued-at timestamp from the token
|
|
//
|
|
// Returns an error if the token was issued in the future.
|
|
func verifyIssuedAt(issuedAt float64) error {
|
|
issuedAtTime := time.Unix(int64(issuedAt), 0)
|
|
// Truncate current time to seconds for consistent comparison
|
|
now := time.Now().Truncate(time.Second)
|
|
skewedNow := now.Add(-clockSkewTolerance)
|
|
|
|
// Debug logging
|
|
fmt.Printf("Token iat: %v\nCurrent time: %v\nSkewed time: %v\nSkew: %v\n",
|
|
issuedAtTime.UTC(),
|
|
now.UTC(),
|
|
skewedNow.UTC(),
|
|
clockSkewTolerance)
|
|
|
|
// Allow tokens issued in the same second as current time
|
|
if issuedAtTime.Equal(now) {
|
|
return nil
|
|
}
|
|
|
|
if skewedNow.Before(issuedAtTime) {
|
|
return fmt.Errorf("token used before issued (iat: %v, now: %v)",
|
|
issuedAtTime.UTC(), now.UTC())
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// verifyNotBefore validates the token's not-before time if present.
|
|
// Ensures the token is not used before its valid time period, accounting for clock skew.
|
|
// Parameters:
|
|
// - notBefore: The not-before timestamp from the token
|
|
//
|
|
// Returns an error if the token is not yet valid.
|
|
func verifyNotBefore(notBefore float64) error {
|
|
notBeforeTime := time.Unix(int64(notBefore), 0)
|
|
// Truncate current time to seconds for consistent comparison
|
|
now := time.Now().Truncate(time.Second)
|
|
skewedNow := now.Add(-clockSkewTolerance)
|
|
|
|
// Debug logging
|
|
fmt.Printf("Token nbf: %v\nCurrent time: %v\nSkewed time: %v\nSkew: %v\n",
|
|
notBeforeTime.UTC(),
|
|
now.UTC(),
|
|
skewedNow.UTC(),
|
|
clockSkewTolerance)
|
|
|
|
// Allow tokens that become valid exactly now
|
|
if notBeforeTime.Equal(now) {
|
|
return nil
|
|
}
|
|
|
|
if skewedNow.Before(notBeforeTime) {
|
|
return fmt.Errorf("token not yet valid (nbf: %v, now: %v)",
|
|
notBeforeTime.UTC(), now.UTC())
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// verifySignature validates the token's cryptographic signature.
|
|
// Supports multiple signature algorithms:
|
|
// - RSA: RS256, RS384, RS512 (PKCS#1 v1.5)
|
|
// - RSA-PSS: PS256, PS384, PS512
|
|
// - ECDSA: ES256, ES384, ES512
|
|
//
|
|
// Parameters:
|
|
// - tokenString: The complete JWT token string
|
|
// - publicKeyPEM: The PEM-encoded public key for verification
|
|
// - alg: The signature algorithm identifier
|
|
//
|
|
// Returns an error if signature verification fails.
|
|
func verifySignature(tokenString string, publicKeyPEM []byte, alg string) error {
|
|
// Debug logging
|
|
fmt.Printf("Verifying signature with algorithm: %s\n", alg)
|
|
|
|
// Split the token into its three parts
|
|
parts := strings.Split(tokenString, ".")
|
|
if len(parts) != 3 {
|
|
return fmt.Errorf("invalid token format")
|
|
}
|
|
signedContent := parts[0] + "." + parts[1]
|
|
|
|
// Decode the signature from the token
|
|
signature, err := base64.RawURLEncoding.DecodeString(parts[2])
|
|
if err != nil {
|
|
return fmt.Errorf("failed to decode signature: %w", err)
|
|
}
|
|
|
|
// Decode the PEM-encoded public key
|
|
block, _ := pem.Decode(publicKeyPEM)
|
|
if block == nil {
|
|
return fmt.Errorf("failed to parse PEM block containing the public key")
|
|
}
|
|
|
|
// Parse the public key
|
|
pubKey, err := x509.ParsePKIXPublicKey(block.Bytes)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to parse public key: %w", err)
|
|
}
|
|
|
|
// Determine the hash function to use based on the algorithm
|
|
var hashFunc crypto.Hash
|
|
|
|
switch alg {
|
|
case "RS256", "PS256", "ES256":
|
|
hashFunc = crypto.SHA256
|
|
case "RS384", "PS384", "ES384":
|
|
hashFunc = crypto.SHA384
|
|
case "RS512", "PS512", "ES512":
|
|
hashFunc = crypto.SHA512
|
|
default:
|
|
return fmt.Errorf("unsupported algorithm: %s", alg)
|
|
}
|
|
|
|
// Hash the signed content
|
|
h := hashFunc.New()
|
|
h.Write([]byte(signedContent))
|
|
hashed := h.Sum(nil)
|
|
|
|
// Verify the signature based on the key type and algorithm
|
|
switch pubKey := pubKey.(type) {
|
|
case *rsa.PublicKey:
|
|
if strings.HasPrefix(alg, "RS") {
|
|
// RSA PKCS#1 v1.5 signature
|
|
return rsa.VerifyPKCS1v15(pubKey, hashFunc, hashed, signature)
|
|
} else if strings.HasPrefix(alg, "PS") {
|
|
// RSA PSS signature
|
|
return rsa.VerifyPSS(pubKey, hashFunc, hashed, signature, nil)
|
|
} else {
|
|
return fmt.Errorf("unexpected key type for algorithm %s", alg)
|
|
}
|
|
case *ecdsa.PublicKey:
|
|
if strings.HasPrefix(alg, "ES") {
|
|
// ECDSA signature
|
|
var r, s big.Int
|
|
sigLen := len(signature)
|
|
if sigLen%2 != 0 {
|
|
return fmt.Errorf("invalid ECDSA signature length")
|
|
}
|
|
r.SetBytes(signature[:sigLen/2])
|
|
s.SetBytes(signature[sigLen/2:])
|
|
if ecdsa.Verify(pubKey, hashed, &r, &s) {
|
|
return nil
|
|
} else {
|
|
return fmt.Errorf("invalid ECDSA signature")
|
|
}
|
|
} else {
|
|
return fmt.Errorf("unexpected key type for algorithm %s", alg)
|
|
}
|
|
default:
|
|
return fmt.Errorf("unsupported public key type: %T", pubKey)
|
|
}
|
|
}
|