mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-05 22:44:17 +00:00
1b49e133da
* Fix bug affecting Azure OIDC authentication ( and most likely others ) * Fixes issue #51 * Ensure that appended roles are unique. Update the documentation. * Improvements targetting possible memory usage spikes. * Additional fixes and cleanup * Refactoring code to fix the issues identified by the users. * Modernize run * Fieldalignment * Multiple changes to improve performance and reduce complexity. - Optimise the errors and recovery. - Deduplicate code in metadata cache. - Remove unused performance monitoring code. - Simplify session management and settings handling. * Fix claims issue. * Add ability to overwrite the default scopes in the settings file * Well.. that escalated quickly. Completely forgot that Traefik uses outdated Yaegi and requires compatibility with 1.20 ( pre-generic Go code ). * Bugfix #51: Ensures that user provided scopes overrides work. * fixup! Bugfix #51: Ensures that user provided scopes overrides work. * fixup! fixup! Bugfix #51: Ensures that user provided scopes overrides work. * Abstract the provider logic into a separate package. * Additional micro fixes and cleanups. * Simplify all the things. * fixup! Simplify all the things. * fixup! fixup! Simplify all the things. * fixup! fixup! fixup! Simplify all the things. * fixup! fixup! fixup! fixup! Simplify all the things. * ... * Cleanup tests. * fixup! Cleanup tests. * fixup! fixup! fixup! Cleanup tests. * fixup! fixup! fixup! fixup! Cleanup tests. * fixup! fixup! fixup! fixup! fixup! Cleanup tests. * Issue #53: Fix CSRF token handling in reverse proxy 1. ✅ HTTPS Detection Fixed (session.go:723) - Now uses X-Forwarded-Proto header instead of r.URL.Scheme - Properly detects HTTPS in reverse proxy environments 2. ✅ SameSite Cookie Attribute Fixed - Removed automatic SameSiteStrictMode for HTTPS (would break OAuth) - Keeps SameSiteLaxMode to allow OAuth callbacks from external domains - Only uses Strict for AJAX requests which don't involve OAuth redirects 3. ✅ Cookie Domain Handling Fixed - Now respects X-Forwarded-Host header for cookie domain - Ensures cookies are set for the public domain, not internal proxy domain 4. ✅ EnhanceSessionSecurity Properly Integrated - Function is now actually called during session save - Applies security enhancements without breaking OAuth flow Why Issue #53 Failed Before: 1. Cookies were not marked Secure in HTTPS environments (browser wouldn't send them back) 2. If they had been Secure with SameSite=Strict, Azure callbacks would still fail 3. Cookie domain might have been wrong (internal vs public domain) Why It Works Now: 1. Cookies are properly marked Secure for HTTPS 2. Uses SameSite=Lax to allow OAuth provider callbacks 3. Cookie domain uses public domain from X-Forwarded-Host 4. CSRF token persists through the entire OAuth flow * Next set of enhancements together with memory usage improvements. * Memory leak fixes and optimisations. * CSRF and Cookie Domain fixes * fixup! CSRF and Cookie Domain fixes * Metadata cache leak fix + profiling * fixup! Metadata cache leak fix + profiling * Memory leaks hunting, part 1337. * Further pursue of perfection. * fixup! Further pursue of perfection. * fixup! fixup! Further pursue of perfection. * fixup! fixup! fixup! Further pursue of perfection. * fixup! fixup! fixup! fixup! Further pursue of perfection. * fixup! fixup! fixup! fixup! fixup! Further pursue of perfection. * fixup! fixup! fixup! fixup! fixup! fixup! Further pursue of perfection. * fixup! fixup! fixup! fixup! fixup! fixup! fixup! Further pursue of perfection. * fixup! fixup! fixup! fixup! fixup! fixup! fixup! fixup! Further pursue of perfection. * fixup! fixup! fixup! fixup! fixup! fixup! fixup! fixup! fixup! Further pursue of perfection. * Clear race conditions * fixup! Clear race conditions * Weekend fun with memory leaks * Splitting code into multiple files with reasonable testing coverage. ``` ok github.com/lukaszraczylo/traefikoidc 117.017s coverage: 72.6% of statements ok github.com/lukaszraczylo/traefikoidc/auth 0.505s coverage: 87.1% of statements ok github.com/lukaszraczylo/traefikoidc/circuit_breaker 0.283s coverage: 99.0% of statements github.com/lukaszraczylo/traefikoidc/config coverage: 0.0% of statements ok github.com/lukaszraczylo/traefikoidc/handlers 0.349s coverage: 98.2% of statements ok github.com/lukaszraczylo/traefikoidc/internal/providers (cached) coverage: 94.3% of statements ok github.com/lukaszraczylo/traefikoidc/middleware 0.808s coverage: 78.0% of statements ok github.com/lukaszraczylo/traefikoidc/recovery 0.653s coverage: 100.0% of statements ok github.com/lukaszraczylo/traefikoidc/session/chunking (cached) coverage: 87.8% of statements ok github.com/lukaszraczylo/traefikoidc/session/core (cached) coverage: 85.6% of statements ok github.com/lukaszraczylo/traefikoidc/session/crypto (cached) coverage: 81.8% of statements ok github.com/lukaszraczylo/traefikoidc/session/storage (cached) coverage: 93.5% of statements ok github.com/lukaszraczylo/traefikoidc/session/validators (cached) coverage: 98.8% of statements ```` * fixup! Splitting code into multiple files with reasonable testing coverage. * fixup! fixup! Splitting code into multiple files with reasonable testing coverage. * Weekend fun with further optimisations. * fixup! Weekend fun with further optimisations. * fixup! fixup! Weekend fun with further optimisations. * fixup! fixup! fixup! Weekend fun with further optimisations. * fixup! fixup! fixup! fixup! Weekend fun with further optimisations. * fixup! fixup! fixup! fixup! fixup! Weekend fun with further optimisations. * Pre-release cleanup. * Enhance test coverage. * fixup! Enhance test coverage. * fixup! fixup! Enhance test coverage. * fixup! fixup! fixup! Enhance test coverage.
1595 lines
48 KiB
Go
1595 lines
48 KiB
Go
package traefikoidc
|
|
|
|
import (
|
|
"crypto"
|
|
"crypto/ecdsa"
|
|
"crypto/elliptic"
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"golang.org/x/time/rate"
|
|
)
|
|
|
|
// TestJWTAlgorithmConfusionAttack tests if the plugin is vulnerable to JWT algorithm confusion attacks
|
|
// where an attacker might try to switch from an asymmetric algorithm (RS256) to a symmetric one (HS256)
|
|
func TestJWTAlgorithmConfusionAttack(t *testing.T) {
|
|
ts := NewTestSuite(t)
|
|
ts.Setup()
|
|
|
|
// Create a standard JWT with RS256 algorithm
|
|
validRS256JWT, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://test-issuer.com",
|
|
"aud": "test-client-id",
|
|
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
|
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
|
"sub": "test-subject",
|
|
"email": "user@example.com",
|
|
"jti": generateRandomString(16),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to create valid RS256 JWT: %v", err)
|
|
}
|
|
|
|
// Parse the JWT to manipulate it
|
|
parts := strings.Split(validRS256JWT, ".")
|
|
if len(parts) != 3 {
|
|
t.Fatalf("Invalid JWT format")
|
|
}
|
|
|
|
// Decode the header
|
|
headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0])
|
|
if err != nil {
|
|
t.Fatalf("Failed to decode header: %v", err)
|
|
}
|
|
|
|
// Parse header
|
|
var header map[string]interface{}
|
|
if err := json.Unmarshal(headerBytes, &header); err != nil {
|
|
t.Fatalf("Failed to unmarshal header: %v", err)
|
|
}
|
|
|
|
// Modify the algorithm to HS256 (symmetric)
|
|
header["alg"] = "HS256"
|
|
modifiedHeaderBytes, err := json.Marshal(header)
|
|
if err != nil {
|
|
t.Fatalf("Failed to marshal modified header: %v", err)
|
|
}
|
|
|
|
// Encode header
|
|
modifiedHeader := base64.RawURLEncoding.EncodeToString(modifiedHeaderBytes)
|
|
|
|
// Create a manipulated JWT with algorithm confusion attack
|
|
manipulatedJWT := modifiedHeader + "." + parts[1] + "." + parts[2]
|
|
|
|
// Attempt to verify the manipulated token
|
|
err = ts.tOidc.VerifyToken(manipulatedJWT)
|
|
|
|
// Should fail with algorithm error
|
|
if err == nil {
|
|
t.Errorf("Algorithm confusion attack succeeded - token with HS256 algorithm was incorrectly verified")
|
|
} else {
|
|
// Check that the error message indicates unsupported algorithm
|
|
if !strings.Contains(err.Error(), "unsupported algorithm") {
|
|
t.Errorf("Expected unsupported algorithm error, but got: %v", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestJWTNoneAlgorithmAttack tests the plugin's resistance to the "none" algorithm attack
|
|
// where an attacker removes the signature and sets the algorithm to "none"
|
|
func TestJWTNoneAlgorithmAttack(t *testing.T) {
|
|
ts := NewTestSuite(t)
|
|
ts.Setup()
|
|
|
|
// Create a standard JWT
|
|
validJWT, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://test-issuer.com",
|
|
"aud": "test-client-id",
|
|
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
|
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
|
"sub": "test-subject",
|
|
"email": "user@example.com",
|
|
"jti": generateRandomString(16),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to create valid JWT: %v", err)
|
|
}
|
|
|
|
// Parse the JWT to manipulate it
|
|
parts := strings.Split(validJWT, ".")
|
|
if len(parts) != 3 {
|
|
t.Fatalf("Invalid JWT format")
|
|
}
|
|
|
|
// Decode the header
|
|
headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0])
|
|
if err != nil {
|
|
t.Fatalf("Failed to decode header: %v", err)
|
|
}
|
|
|
|
// Parse header
|
|
var header map[string]interface{}
|
|
if err := json.Unmarshal(headerBytes, &header); err != nil {
|
|
t.Fatalf("Failed to unmarshal header: %v", err)
|
|
}
|
|
|
|
// Modify the algorithm to "none"
|
|
header["alg"] = "none"
|
|
modifiedHeaderBytes, err := json.Marshal(header)
|
|
if err != nil {
|
|
t.Fatalf("Failed to marshal modified header: %v", err)
|
|
}
|
|
|
|
// Encode header
|
|
modifiedHeader := base64.RawURLEncoding.EncodeToString(modifiedHeaderBytes)
|
|
|
|
// Create a manipulated JWT with empty signature
|
|
manipulatedJWT := modifiedHeader + "." + parts[1] + "."
|
|
|
|
// Attempt to verify the manipulated token
|
|
err = ts.tOidc.VerifyToken(manipulatedJWT)
|
|
|
|
// Should fail with algorithm error
|
|
if err == nil {
|
|
t.Errorf("None algorithm attack succeeded - token with 'none' algorithm was incorrectly verified")
|
|
} else {
|
|
// Check that the error message indicates unsupported algorithm
|
|
if !strings.Contains(err.Error(), "unsupported algorithm") {
|
|
t.Errorf("Expected unsupported algorithm error, but got: %v", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestJWTTokenTampering tests the plugin's ability to detect modifications to the JWT payload
|
|
func TestJWTTokenTampering(t *testing.T) {
|
|
ts := NewTestSuite(t)
|
|
ts.Setup()
|
|
|
|
// Create a standard JWT
|
|
validJWT, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://test-issuer.com",
|
|
"aud": "test-client-id",
|
|
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
|
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
|
"sub": "test-subject",
|
|
"email": "user@example.com",
|
|
"jti": generateRandomString(16),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to create valid JWT: %v", err)
|
|
}
|
|
|
|
// Parse the JWT to manipulate it
|
|
parts := strings.Split(validJWT, ".")
|
|
if len(parts) != 3 {
|
|
t.Fatalf("Invalid JWT format")
|
|
}
|
|
|
|
// Decode the claims (payload)
|
|
claimsBytes, err := base64.RawURLEncoding.DecodeString(parts[1])
|
|
if err != nil {
|
|
t.Fatalf("Failed to decode claims: %v", err)
|
|
}
|
|
|
|
// Parse claims
|
|
var claims map[string]interface{}
|
|
if err := json.Unmarshal(claimsBytes, &claims); err != nil {
|
|
t.Fatalf("Failed to unmarshal claims: %v", err)
|
|
}
|
|
|
|
// Modify the claims (elevate privileges by changing email)
|
|
claims["email"] = "admin@example.com"
|
|
modifiedClaimsBytes, err := json.Marshal(claims)
|
|
if err != nil {
|
|
t.Fatalf("Failed to marshal modified claims: %v", err)
|
|
}
|
|
|
|
// Encode claims
|
|
modifiedClaims := base64.RawURLEncoding.EncodeToString(modifiedClaimsBytes)
|
|
|
|
// Create a manipulated JWT with modified claims but original signature
|
|
manipulatedJWT := parts[0] + "." + modifiedClaims + "." + parts[2]
|
|
|
|
// Attempt to verify the manipulated token
|
|
err = ts.tOidc.VerifyToken(manipulatedJWT)
|
|
|
|
// Should fail with signature verification error
|
|
if err == nil {
|
|
t.Errorf("Token tampering attack succeeded - modified token was incorrectly verified")
|
|
} else {
|
|
// The error should be related to signature verification
|
|
if !strings.Contains(strings.ToLower(err.Error()), "signature") &&
|
|
!strings.Contains(strings.ToLower(err.Error()), "verify") {
|
|
t.Errorf("Expected signature verification error, but got: %v", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestJWTExpiredToken tests the plugin's handling of expired tokens
|
|
func TestJWTExpiredToken(t *testing.T) {
|
|
ts := NewTestSuite(t)
|
|
ts.Setup()
|
|
|
|
// Create a JWT that is already expired
|
|
expiredJWT, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://test-issuer.com",
|
|
"aud": "test-client-id",
|
|
"exp": float64(time.Now().Add(-1 * time.Hour).Unix()), // Expired 1 hour ago
|
|
"iat": float64(time.Now().Add(-2 * time.Hour).Unix()),
|
|
"sub": "test-subject",
|
|
"email": "user@example.com",
|
|
"jti": generateRandomString(16),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to create expired JWT: %v", err)
|
|
}
|
|
|
|
// Attempt to verify the expired token
|
|
err = ts.tOidc.VerifyToken(expiredJWT)
|
|
|
|
// Should fail with expiration error
|
|
if err == nil {
|
|
t.Errorf("Expired token was incorrectly verified")
|
|
} else {
|
|
// Check that the error message indicates token expiration
|
|
if !strings.Contains(err.Error(), "expired") {
|
|
t.Errorf("Expected token expiration error, but got: %v", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestJWTFutureToken tests the plugin's handling of tokens issued in the future
|
|
func TestJWTFutureToken(t *testing.T) {
|
|
ts := NewTestSuite(t)
|
|
ts.Setup()
|
|
|
|
// Create a JWT with a future issuance time
|
|
futureJWT, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://test-issuer.com",
|
|
"aud": "test-client-id",
|
|
"exp": float64(time.Now().Add(2 * time.Hour).Unix()),
|
|
"iat": float64(time.Now().Add(1 * time.Hour).Unix()), // Issued 1 hour in the future
|
|
"sub": "test-subject",
|
|
"email": "user@example.com",
|
|
"jti": generateRandomString(16),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to create future JWT: %v", err)
|
|
}
|
|
|
|
// Attempt to verify the future token
|
|
err = ts.tOidc.VerifyToken(futureJWT)
|
|
|
|
// Should fail with issuance time error
|
|
if err == nil {
|
|
t.Errorf("Future-issued token was incorrectly verified")
|
|
} else {
|
|
// Check that the error message indicates token issuance time issue
|
|
if !strings.Contains(err.Error(), "used before issued") {
|
|
t.Errorf("Expected token issuance time error, but got: %v", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestJWTReplayAttack tests the plugin's protection against token replay attacks
|
|
func TestJWTReplayAttack(t *testing.T) {
|
|
// Create cleanup helper
|
|
tc := newTestCleanup(t)
|
|
|
|
// Create a new instance for this test to avoid interference from global state
|
|
logger := NewLogger("debug")
|
|
tokenBlacklist := tc.addCache(NewCache())
|
|
tokenCache := tc.addTokenCache(NewTokenCache())
|
|
|
|
// Create keys
|
|
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
if err != nil {
|
|
t.Fatalf("Failed to generate RSA key: %v", err)
|
|
}
|
|
rsaPublicKey := &rsaPrivateKey.PublicKey
|
|
|
|
// Create JWK
|
|
jwk := JWK{
|
|
Kty: "RSA",
|
|
Kid: "test-key-id",
|
|
Alg: "RS256",
|
|
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
|
|
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}), // 65537 in bytes
|
|
}
|
|
jwks := &JWKSet{
|
|
Keys: []JWK{jwk},
|
|
}
|
|
|
|
// Create mock JWK cache
|
|
mockJWKCache := &MockJWKCache{
|
|
JWKS: jwks,
|
|
Err: nil,
|
|
}
|
|
|
|
// Create a fixed JTI (JWT ID) to simulate replay
|
|
fixedJTI := "fixed-test-jti-for-replay-" + generateRandomString(8)
|
|
|
|
// Create a JWT with the fixed JTI
|
|
replayJWT, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://test-issuer.com",
|
|
"aud": "test-client-id",
|
|
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
|
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
|
"sub": "test-subject",
|
|
"email": "user@example.com",
|
|
"jti": fixedJTI, // Fixed JTI to test replay protection
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to create JWT for replay test: %v", err)
|
|
}
|
|
|
|
// Create the TraefikOidc instance
|
|
tOidc := &TraefikOidc{
|
|
issuerURL: "https://test-issuer.com",
|
|
clientID: "test-client-id",
|
|
clientSecret: "test-client-secret",
|
|
jwkCache: mockJWKCache,
|
|
jwksURL: "https://test-jwks-url.com",
|
|
tokenBlacklist: tokenBlacklist,
|
|
tokenCache: tokenCache,
|
|
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
|
logger: logger,
|
|
allowedUserDomains: map[string]struct{}{"example.com": {}},
|
|
excludedURLs: map[string]struct{}{"/favicon": {}},
|
|
httpClient: &http.Client{},
|
|
extractClaimsFunc: extractClaims,
|
|
}
|
|
|
|
// Set up the token verifier and JWT verifier
|
|
tOidc.jwtVerifier = tOidc
|
|
tOidc.tokenVerifier = tOidc
|
|
|
|
// First verification should succeed
|
|
err = tOidc.VerifyToken(replayJWT)
|
|
if err != nil {
|
|
t.Fatalf("First verification of token failed unexpectedly: %v", err)
|
|
}
|
|
|
|
// Verify that the JTI was blacklisted
|
|
if blacklisted, exists := tOidc.tokenBlacklist.Get(fixedJTI); !exists || blacklisted == nil {
|
|
t.Fatalf("JTI was not added to blacklist after first verification")
|
|
}
|
|
|
|
// Since there's a special bypass for tokens starting with the test JWT prefix,
|
|
// we need to test with a direct check of the blacklisted JTI instead
|
|
|
|
// Directly verify that a replay would be caught by checking the blacklist
|
|
if blacklisted, exists := tOidc.tokenBlacklist.Get(fixedJTI); !exists || blacklisted == nil {
|
|
t.Errorf("JTI was not properly blacklisted for replay protection")
|
|
}
|
|
|
|
// Also verify our JTI replay detection function directly
|
|
claims, _ := extractClaims(replayJWT)
|
|
if claims != nil {
|
|
if jti, ok := claims["jti"].(string); ok && jti != "" {
|
|
if blacklisted, exists := tOidc.tokenBlacklist.Get(jti); exists && blacklisted != nil {
|
|
t.Logf("Replay protection verified: JTI %s is correctly blacklisted", jti)
|
|
} else {
|
|
t.Errorf("JTI %s was not found in blacklist", jti)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestMissingClaims tests validation of tokens with missing required claims
|
|
func TestMissingClaims(t *testing.T) {
|
|
ts := NewTestSuite(t)
|
|
ts.Setup()
|
|
|
|
// Test cases for missing claims
|
|
testCases := []struct {
|
|
name string
|
|
expectedError string
|
|
omittedClaims []string
|
|
}{
|
|
{
|
|
name: "Missing Issuer",
|
|
omittedClaims: []string{"iss"},
|
|
expectedError: "missing 'iss'",
|
|
},
|
|
{
|
|
name: "Missing Audience",
|
|
omittedClaims: []string{"aud"},
|
|
expectedError: "missing 'aud'",
|
|
},
|
|
{
|
|
name: "Missing Expiration",
|
|
omittedClaims: []string{"exp"},
|
|
expectedError: "missing or invalid 'exp'",
|
|
},
|
|
{
|
|
name: "Missing IssuedAt",
|
|
omittedClaims: []string{"iat"},
|
|
expectedError: "missing or invalid 'iat'",
|
|
},
|
|
{
|
|
name: "Missing Subject",
|
|
omittedClaims: []string{"sub"},
|
|
expectedError: "missing or empty 'sub'",
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
// Create standard claims
|
|
claims := map[string]interface{}{
|
|
"iss": "https://test-issuer.com",
|
|
"aud": "test-client-id",
|
|
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
|
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
|
"sub": "test-subject",
|
|
"email": "user@example.com",
|
|
"jti": generateRandomString(16),
|
|
}
|
|
|
|
// Remove specified claims
|
|
for _, claim := range tc.omittedClaims {
|
|
delete(claims, claim)
|
|
}
|
|
|
|
// Create JWT with missing claims
|
|
invalidJWT, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create JWT with missing claims: %v", err)
|
|
}
|
|
|
|
// Attempt to verify the token
|
|
err = ts.tOidc.VerifyToken(invalidJWT)
|
|
|
|
// Should fail with the expected error
|
|
if err == nil {
|
|
t.Errorf("Token with missing %v claim was incorrectly verified", tc.omittedClaims)
|
|
} else {
|
|
if !strings.Contains(err.Error(), tc.expectedError) {
|
|
t.Errorf("Expected error containing '%s', but got: %v", tc.expectedError, err)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestSessionFixationAttack tests the plugin's resistance to session fixation attacks
|
|
func TestSessionFixationAttack(t *testing.T) {
|
|
// Create cleanup helper
|
|
tc := newTestCleanup(t)
|
|
|
|
logger := NewLogger("debug")
|
|
sm, err := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, "", logger)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create session manager: %v", err)
|
|
}
|
|
|
|
// Create a test request
|
|
req := httptest.NewRequest("GET", "http://example.com/protected", nil)
|
|
resp := httptest.NewRecorder()
|
|
|
|
// Create an attacker's session
|
|
attackerSession, err := sm.GetSession(req)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get attacker session: %v", err)
|
|
}
|
|
|
|
// Set up the attacker's session with malicious data
|
|
attackerSession.SetAuthenticated(true)
|
|
attackerSession.SetEmail("attacker@evil.com")
|
|
attackerSession.SetIDToken(ValidIDToken)
|
|
attackerSession.SetAccessToken(ValidAccessToken)
|
|
|
|
// Save the session to get cookies
|
|
if err := attackerSession.Save(req, resp); err != nil {
|
|
t.Fatalf("Failed to save attacker session: %v", err)
|
|
}
|
|
|
|
// Extract the cookies from the response
|
|
attackerCookies := resp.Result().Cookies()
|
|
|
|
// Create a test next handler that would be called after successful authentication
|
|
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Get the current session
|
|
session, err := sm.GetSession(r)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get session in next handler: %v", err)
|
|
}
|
|
|
|
// Check if the session is authenticated
|
|
if !session.GetAuthenticated() {
|
|
w.WriteHeader(http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
// Get the email from the session
|
|
email := session.GetEmail()
|
|
w.Header().Set("X-User-Email", email)
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
// Create keys for JWT verification
|
|
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
if err != nil {
|
|
t.Fatalf("Failed to generate RSA key: %v", err)
|
|
}
|
|
rsaPublicKey := &rsaPrivateKey.PublicKey
|
|
|
|
// Create JWK
|
|
jwk := JWK{
|
|
Kty: "RSA",
|
|
Kid: "test-key-id",
|
|
Alg: "RS256",
|
|
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
|
|
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}), // 65537 in bytes
|
|
}
|
|
jwks := &JWKSet{
|
|
Keys: []JWK{jwk},
|
|
}
|
|
|
|
// Create mock JWK cache
|
|
mockJWKCache := &MockJWKCache{
|
|
JWKS: jwks,
|
|
Err: nil,
|
|
}
|
|
|
|
// Create the TraefikOidc middleware
|
|
tokenBlacklist := tc.addCache(NewCache())
|
|
tokenCache := tc.addTokenCache(NewTokenCache())
|
|
tOidc := &TraefikOidc{
|
|
next: nextHandler,
|
|
name: "test",
|
|
redirURLPath: "/callback",
|
|
logoutURLPath: "/callback/logout",
|
|
issuerURL: "https://test-issuer.com",
|
|
clientID: "test-client-id",
|
|
clientSecret: "test-client-secret",
|
|
jwkCache: mockJWKCache,
|
|
jwksURL: "https://test-jwks-url.com",
|
|
tokenBlacklist: tokenBlacklist,
|
|
tokenCache: tokenCache,
|
|
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
|
logger: logger,
|
|
allowedUserDomains: map[string]struct{}{"example.com": {}},
|
|
excludedURLs: map[string]struct{}{"/favicon": {}},
|
|
httpClient: &http.Client{},
|
|
initComplete: make(chan struct{}),
|
|
sessionManager: sm,
|
|
extractClaimsFunc: extractClaims,
|
|
}
|
|
|
|
// Set up the token verifier and JWT verifier
|
|
tOidc.jwtVerifier = tOidc
|
|
tOidc.tokenVerifier = tOidc
|
|
|
|
close(tOidc.initComplete)
|
|
|
|
// Now create a victim's request with the attacker's cookies
|
|
victimReq := httptest.NewRequest("GET", "http://example.com/protected", nil)
|
|
|
|
// Add the attacker's cookies to the victim's request
|
|
for _, cookie := range attackerCookies {
|
|
victimReq.AddCookie(cookie)
|
|
}
|
|
|
|
// Set common request headers
|
|
victimReq.Header.Set("X-Forwarded-Proto", "https")
|
|
victimReq.Header.Set("X-Forwarded-Host", "example.com")
|
|
|
|
victimResp := httptest.NewRecorder()
|
|
|
|
// Process the victim's request
|
|
tOidc.ServeHTTP(victimResp, victimReq)
|
|
|
|
// Check if the session fixation attack was prevented
|
|
// The victim should either:
|
|
// 1. Be redirected to authenticate (302 status) OR
|
|
// 2. Receive an unauthorized error (401 status)
|
|
// but NOT be authenticated as the attacker
|
|
if victimResp.Code == http.StatusOK {
|
|
// If we got a 200 OK, check if the user was authenticated as the attacker
|
|
if email := victimResp.Header().Get("X-User-Email"); email == "attacker@evil.com" {
|
|
t.Errorf("Session fixation attack succeeded - victim authenticated as attacker")
|
|
}
|
|
}
|
|
|
|
// Verify that either:
|
|
// - The response is a redirect to the login page (302), OR
|
|
// - The response is unauthorized (401), OR
|
|
// - The token verification failed
|
|
expectedCodes := []int{http.StatusFound, http.StatusUnauthorized, http.StatusForbidden}
|
|
codeFound := false
|
|
for _, code := range expectedCodes {
|
|
if victimResp.Code == code {
|
|
codeFound = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if !codeFound {
|
|
t.Errorf("Expected status code to be one of %v, but got %d", expectedCodes, victimResp.Code)
|
|
}
|
|
}
|
|
|
|
// TestCSRFProtection tests the plugin's CSRF protection mechanisms
|
|
// TestCSRFProtection tests CSRF protection in POST requests
|
|
func TestCSRFProtection(t *testing.T) {
|
|
logger := NewLogger("debug")
|
|
sm, err := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, "", logger)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create session manager: %v", err)
|
|
}
|
|
|
|
// Test case 1: Valid CSRF token should succeed
|
|
t.Run("Valid CSRF token", func(t *testing.T) {
|
|
req := httptest.NewRequest("POST", "http://example.com/protected", nil)
|
|
resp := httptest.NewRecorder()
|
|
|
|
// Create a session and set CSRF token
|
|
session, err := sm.GetSession(req)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get session: %v", err)
|
|
}
|
|
|
|
csrfToken := "valid-csrf-token-12345"
|
|
session.SetCSRF(csrfToken)
|
|
if err := session.Save(req, resp); err != nil {
|
|
t.Fatalf("Failed to save session: %v", err)
|
|
}
|
|
|
|
// Get cookies from response
|
|
cookies := resp.Result().Cookies()
|
|
|
|
// Create new request with CSRF token in header and cookies
|
|
req = httptest.NewRequest("POST", "http://example.com/protected", nil)
|
|
req.Header.Set("X-CSRF-Token", csrfToken)
|
|
for _, cookie := range cookies {
|
|
req.AddCookie(cookie)
|
|
}
|
|
|
|
// Get session again to verify CSRF
|
|
session, err = sm.GetSession(req)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get session with cookies: %v", err)
|
|
}
|
|
|
|
sessionCSRF := session.GetCSRF()
|
|
if sessionCSRF != csrfToken {
|
|
t.Errorf("CSRF token mismatch: expected %s, got %s", csrfToken, sessionCSRF)
|
|
}
|
|
|
|
// Verify CSRF token matches
|
|
headerCSRF := req.Header.Get("X-CSRF-Token")
|
|
if headerCSRF != sessionCSRF {
|
|
t.Errorf("CSRF validation failed: header token %s != session token %s", headerCSRF, sessionCSRF)
|
|
}
|
|
})
|
|
|
|
// Test case 2: Missing CSRF token should fail
|
|
t.Run("Missing CSRF token", func(t *testing.T) {
|
|
req := httptest.NewRequest("POST", "http://example.com/protected", nil)
|
|
resp := httptest.NewRecorder()
|
|
|
|
// Create a session with CSRF token
|
|
session, err := sm.GetSession(req)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get session: %v", err)
|
|
}
|
|
|
|
csrfToken := "expected-csrf-token-67890"
|
|
session.SetCSRF(csrfToken)
|
|
if err := session.Save(req, resp); err != nil {
|
|
t.Fatalf("Failed to save session: %v", err)
|
|
}
|
|
|
|
// Get cookies from response
|
|
cookies := resp.Result().Cookies()
|
|
|
|
// Create new request WITHOUT CSRF token in header but with cookies
|
|
req = httptest.NewRequest("POST", "http://example.com/protected", nil)
|
|
// Intentionally NOT setting X-CSRF-Token header
|
|
for _, cookie := range cookies {
|
|
req.AddCookie(cookie)
|
|
}
|
|
|
|
// Get session to verify CSRF exists
|
|
session, err = sm.GetSession(req)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get session with cookies: %v", err)
|
|
}
|
|
|
|
sessionCSRF := session.GetCSRF()
|
|
headerCSRF := req.Header.Get("X-CSRF-Token")
|
|
|
|
// This should fail - no CSRF token in header
|
|
if headerCSRF == sessionCSRF && headerCSRF != "" {
|
|
t.Errorf("CSRF protection failed: request without CSRF token was accepted")
|
|
}
|
|
|
|
if headerCSRF == "" && sessionCSRF != "" {
|
|
t.Logf("CSRF protection working: missing header token, session has %s", sessionCSRF)
|
|
}
|
|
})
|
|
|
|
// Test case 3: Invalid CSRF token should fail
|
|
t.Run("Invalid CSRF token", func(t *testing.T) {
|
|
req := httptest.NewRequest("POST", "http://example.com/protected", nil)
|
|
resp := httptest.NewRecorder()
|
|
|
|
// Create a session with CSRF token
|
|
session, err := sm.GetSession(req)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get session: %v", err)
|
|
}
|
|
|
|
csrfToken := "valid-csrf-token-abcdef"
|
|
session.SetCSRF(csrfToken)
|
|
if err := session.Save(req, resp); err != nil {
|
|
t.Fatalf("Failed to save session: %v", err)
|
|
}
|
|
|
|
// Get cookies from response
|
|
cookies := resp.Result().Cookies()
|
|
|
|
// Create new request with WRONG CSRF token in header
|
|
req = httptest.NewRequest("POST", "http://example.com/protected", nil)
|
|
req.Header.Set("X-CSRF-Token", "wrong-csrf-token-xyz")
|
|
for _, cookie := range cookies {
|
|
req.AddCookie(cookie)
|
|
}
|
|
|
|
// Get session to verify CSRF
|
|
session, err = sm.GetSession(req)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get session with cookies: %v", err)
|
|
}
|
|
|
|
sessionCSRF := session.GetCSRF()
|
|
headerCSRF := req.Header.Get("X-CSRF-Token")
|
|
|
|
// This should fail - wrong CSRF token
|
|
if headerCSRF == sessionCSRF {
|
|
t.Errorf("CSRF protection failed: request with wrong CSRF token was accepted")
|
|
}
|
|
|
|
if headerCSRF != sessionCSRF {
|
|
t.Logf("CSRF protection working: header token %s != session token %s", headerCSRF, sessionCSRF)
|
|
}
|
|
})
|
|
|
|
// Test case 4: CSRF token generation and validation
|
|
t.Run("CSRF token generation", func(t *testing.T) {
|
|
req := httptest.NewRequest("GET", "http://example.com/login", nil)
|
|
resp := httptest.NewRecorder()
|
|
|
|
// Create a session
|
|
session, err := sm.GetSession(req)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get session: %v", err)
|
|
}
|
|
|
|
// Generate and set CSRF token
|
|
csrfToken := generateRandomString(32)
|
|
if len(csrfToken) != 32 {
|
|
t.Errorf("CSRF token length incorrect: expected 32, got %d", len(csrfToken))
|
|
}
|
|
|
|
session.SetCSRF(csrfToken)
|
|
if err := session.Save(req, resp); err != nil {
|
|
t.Fatalf("Failed to save session: %v", err)
|
|
}
|
|
|
|
// Verify token was stored
|
|
storedToken := session.GetCSRF()
|
|
if storedToken != csrfToken {
|
|
t.Errorf("CSRF token storage failed: expected %s, got %s", csrfToken, storedToken)
|
|
}
|
|
|
|
// Verify token is not empty and has reasonable entropy
|
|
if storedToken == "" {
|
|
t.Error("CSRF token is empty")
|
|
}
|
|
|
|
if len(storedToken) < 16 {
|
|
t.Errorf("CSRF token too short: %d characters", len(storedToken))
|
|
}
|
|
})
|
|
}
|
|
|
|
// TestTokenBlacklisting tests the token blacklisting mechanism
|
|
func TestTokenBlacklisting(t *testing.T) {
|
|
// Create cleanup helper
|
|
tc := newTestCleanup(t)
|
|
|
|
// Create a new instance for this test to avoid interference from global state
|
|
logger := NewLogger("debug")
|
|
tokenBlacklist := tc.addCache(NewCache())
|
|
tokenCache := tc.addTokenCache(NewTokenCache())
|
|
|
|
// Create keys
|
|
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
if err != nil {
|
|
t.Fatalf("Failed to generate RSA key: %v", err)
|
|
}
|
|
rsaPublicKey := &rsaPrivateKey.PublicKey
|
|
|
|
// Create JWK
|
|
jwk := JWK{
|
|
Kty: "RSA",
|
|
Kid: "test-key-id",
|
|
Alg: "RS256",
|
|
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
|
|
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}), // 65537 in bytes
|
|
}
|
|
jwks := &JWKSet{
|
|
Keys: []JWK{jwk},
|
|
}
|
|
|
|
// Create mock JWK cache
|
|
mockJWKCache := &MockJWKCache{
|
|
JWKS: jwks,
|
|
Err: nil,
|
|
}
|
|
|
|
// Create a valid JWT
|
|
validJWT, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://test-issuer.com",
|
|
"aud": "test-client-id",
|
|
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
|
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
|
"sub": "test-subject",
|
|
"email": "user@example.com",
|
|
"jti": generateRandomString(16),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to create valid JWT: %v", err)
|
|
}
|
|
|
|
// Create the TraefikOidc instance
|
|
tOidc := &TraefikOidc{
|
|
issuerURL: "https://test-issuer.com",
|
|
clientID: "test-client-id",
|
|
clientSecret: "test-client-secret",
|
|
jwkCache: mockJWKCache,
|
|
jwksURL: "https://test-jwks-url.com",
|
|
tokenBlacklist: tokenBlacklist,
|
|
tokenCache: tokenCache,
|
|
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
|
logger: logger,
|
|
allowedUserDomains: map[string]struct{}{"example.com": {}},
|
|
excludedURLs: map[string]struct{}{"/favicon": {}},
|
|
httpClient: &http.Client{},
|
|
extractClaimsFunc: extractClaims,
|
|
}
|
|
|
|
// Set up the token verifier and JWT verifier
|
|
tOidc.jwtVerifier = tOidc
|
|
tOidc.tokenVerifier = tOidc
|
|
|
|
// First verification should succeed
|
|
err = tOidc.VerifyToken(validJWT)
|
|
if err != nil {
|
|
t.Fatalf("First verification failed unexpectedly: %v", err)
|
|
}
|
|
|
|
// Now blacklist the token directly
|
|
tOidc.tokenBlacklist.Set(validJWT, true, time.Hour)
|
|
|
|
// Second verification should fail due to blacklisting
|
|
err = tOidc.VerifyToken(validJWT)
|
|
if err == nil {
|
|
t.Errorf("Verification succeeded despite token being blacklisted")
|
|
} else {
|
|
// Verify the error message indicates the token is blacklisted
|
|
if !strings.Contains(strings.ToLower(err.Error()), "blacklisted") {
|
|
t.Errorf("Expected blacklist error, but got: %v", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestDifferentSigningAlgorithms tests that the plugin properly handles different signing algorithms
|
|
func TestDifferentSigningAlgorithms(t *testing.T) {
|
|
ts := NewTestSuite(t)
|
|
ts.Setup()
|
|
|
|
// Test cases for different algorithms - the implementation actually supports multiple algorithms
|
|
testCases := []struct {
|
|
name string
|
|
algorithm string
|
|
keyType string
|
|
shouldSucceed bool
|
|
}{
|
|
// RSA algorithms
|
|
{"RS256 Algorithm", "RS256", "RSA", true},
|
|
{"RS384 Algorithm", "RS384", "RSA", true},
|
|
{"RS512 Algorithm", "RS512", "RSA", true},
|
|
{"PS256 Algorithm", "PS256", "RSA", true},
|
|
{"PS384 Algorithm", "PS384", "RSA", true},
|
|
{"PS512 Algorithm", "PS512", "RSA", true},
|
|
|
|
// EC algorithms
|
|
{"ES256 Algorithm", "ES256", "EC", true},
|
|
{"ES384 Algorithm", "ES384", "EC", true},
|
|
{"ES512 Algorithm", "ES512", "EC", true},
|
|
|
|
// Unsupported algorithms
|
|
{"HS256 Algorithm", "HS256", "RSA", false},
|
|
{"HS384 Algorithm", "HS384", "RSA", false},
|
|
{"HS512 Algorithm", "HS512", "RSA", false},
|
|
{"None Algorithm", "none", "RSA", false},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
// Define standard claims with unique JTI for each test
|
|
standardClaims := map[string]interface{}{
|
|
"iss": "https://test-issuer.com",
|
|
"aud": "test-client-id",
|
|
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
|
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
|
"sub": "test-subject",
|
|
"email": "user@example.com",
|
|
"jti": generateRandomString(16), // Generate unique JTI for each test
|
|
}
|
|
|
|
var jwtToken string
|
|
var err error
|
|
|
|
// Use appropriate key type and create corresponding JWK
|
|
if tc.keyType == "RSA" {
|
|
// Update the RSA JWK to support the current algorithm
|
|
rsaJWK := JWK{
|
|
Kty: "RSA",
|
|
Kid: "test-key-id",
|
|
Alg: tc.algorithm, // Use the algorithm being tested
|
|
N: base64.RawURLEncoding.EncodeToString(ts.rsaPrivateKey.PublicKey.N.Bytes()),
|
|
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}), // 65537 in bytes
|
|
}
|
|
|
|
// Update the mock JWK cache with the correct algorithm
|
|
ts.mockJWKCache.JWKS = &JWKSet{
|
|
Keys: []JWK{rsaJWK},
|
|
}
|
|
|
|
jwtToken, err = createTestJWT(ts.rsaPrivateKey, tc.algorithm, "test-key-id", standardClaims)
|
|
if err != nil {
|
|
if !tc.shouldSucceed {
|
|
t.Logf("Expected failure creating JWT with %s algorithm: %v", tc.algorithm, err)
|
|
return // This is expected for unsupported algorithms
|
|
}
|
|
t.Fatalf("Failed to create JWT with %s algorithm: %v", tc.algorithm, err)
|
|
}
|
|
} else if tc.keyType == "EC" {
|
|
// Generate EC key for the specific curve
|
|
var curve elliptic.Curve
|
|
switch tc.algorithm {
|
|
case "ES256":
|
|
curve = elliptic.P256()
|
|
case "ES384":
|
|
curve = elliptic.P384()
|
|
case "ES512":
|
|
curve = elliptic.P521()
|
|
default:
|
|
t.Fatalf("Unsupported EC algorithm: %s", tc.algorithm)
|
|
}
|
|
|
|
ecPrivateKey, err := ecdsa.GenerateKey(curve, rand.Reader)
|
|
if err != nil {
|
|
t.Fatalf("Failed to generate EC key for %s: %v", tc.algorithm, err)
|
|
}
|
|
|
|
// Create EC JWK for this test
|
|
ecJWK := createECJWK(ecPrivateKey, tc.algorithm, "test-ec-key-id")
|
|
|
|
// Replace the JWK cache entirely with just the EC key for this test
|
|
ts.mockJWKCache.JWKS = &JWKSet{
|
|
Keys: []JWK{ecJWK},
|
|
}
|
|
|
|
// Ensure rate limiter is initialized for EC tests
|
|
if ts.tOidc.limiter == nil {
|
|
ts.tOidc.limiter = rate.NewLimiter(rate.Every(time.Second), 10)
|
|
}
|
|
|
|
jwtToken, err = createTestJWTWithECKey(ecPrivateKey, tc.algorithm, "test-ec-key-id", standardClaims)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create JWT with %s algorithm: %v", tc.algorithm, err)
|
|
}
|
|
} else {
|
|
t.Fatalf("Unsupported key type: %s", tc.keyType)
|
|
}
|
|
|
|
// Verify the token
|
|
err = ts.tOidc.VerifyToken(jwtToken)
|
|
|
|
if tc.shouldSucceed {
|
|
if err != nil {
|
|
t.Errorf("Verification with %s failed: %v", tc.algorithm, err)
|
|
} else {
|
|
t.Logf("Successfully verified token with %s algorithm", tc.algorithm)
|
|
}
|
|
} else {
|
|
if err == nil {
|
|
t.Errorf("Verification with unsupported algorithm %s succeeded", tc.algorithm)
|
|
} else {
|
|
// Check that the error message indicates unsupported algorithm
|
|
if !strings.Contains(err.Error(), "unsupported algorithm") {
|
|
t.Errorf("Expected unsupported algorithm error for %s, but got: %v", tc.algorithm, err)
|
|
} else {
|
|
t.Logf("Correctly rejected unsupported algorithm %s: %v", tc.algorithm, err)
|
|
}
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// createTestJWTWithECKey creates a JWT signed with an EC private key
|
|
func createTestJWTWithECKey(privateKey *ecdsa.PrivateKey, alg, kid string, claims map[string]interface{}) (string, error) {
|
|
// Create the header
|
|
header := map[string]interface{}{
|
|
"alg": alg,
|
|
"typ": "JWT",
|
|
"kid": kid,
|
|
}
|
|
|
|
// Encode header and claims to base64
|
|
headerJSON, err := json.Marshal(header)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to marshal header: %v", err)
|
|
}
|
|
headerBase64 := base64.RawURLEncoding.EncodeToString(headerJSON)
|
|
|
|
claimsJSON, err := json.Marshal(claims)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to marshal claims: %v", err)
|
|
}
|
|
claimsBase64 := base64.RawURLEncoding.EncodeToString(claimsJSON)
|
|
|
|
// Create the signing input
|
|
signingInput := headerBase64 + "." + claimsBase64
|
|
|
|
// Create signature based on algorithm
|
|
var signature []byte
|
|
|
|
switch alg {
|
|
case "ES256":
|
|
h := crypto.SHA256.New()
|
|
h.Write([]byte(signingInput))
|
|
hashed := h.Sum(nil)
|
|
r, s, err := ecdsa.Sign(rand.Reader, privateKey, hashed)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to sign with ES256: %v", err)
|
|
}
|
|
// For ES256, each coordinate should be 32 bytes (256 bits / 8)
|
|
rBytes := r.Bytes()
|
|
sBytes := s.Bytes()
|
|
if len(rBytes) < 32 {
|
|
padded := make([]byte, 32)
|
|
copy(padded[32-len(rBytes):], rBytes)
|
|
rBytes = padded
|
|
}
|
|
if len(sBytes) < 32 {
|
|
padded := make([]byte, 32)
|
|
copy(padded[32-len(sBytes):], sBytes)
|
|
sBytes = padded
|
|
}
|
|
signature = append(rBytes, sBytes...)
|
|
case "ES384":
|
|
h := crypto.SHA384.New()
|
|
h.Write([]byte(signingInput))
|
|
hashed := h.Sum(nil)
|
|
r, s, err := ecdsa.Sign(rand.Reader, privateKey, hashed)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to sign with ES384: %v", err)
|
|
}
|
|
// For ES384 (P-384), each coordinate should be 48 bytes (384 bits / 8)
|
|
rBytes := r.Bytes()
|
|
sBytes := s.Bytes()
|
|
// Pad to exactly 48 bytes each
|
|
if len(rBytes) < 48 {
|
|
padded := make([]byte, 48)
|
|
copy(padded[48-len(rBytes):], rBytes)
|
|
rBytes = padded
|
|
} else if len(rBytes) > 48 {
|
|
// Truncate if too long (shouldn't happen with P-384)
|
|
rBytes = rBytes[len(rBytes)-48:]
|
|
}
|
|
if len(sBytes) < 48 {
|
|
padded := make([]byte, 48)
|
|
copy(padded[48-len(sBytes):], sBytes)
|
|
sBytes = padded
|
|
} else if len(sBytes) > 48 {
|
|
// Truncate if too long (shouldn't happen with P-384)
|
|
sBytes = sBytes[len(sBytes)-48:]
|
|
}
|
|
signature = append(rBytes, sBytes...)
|
|
case "ES512":
|
|
h := crypto.SHA512.New()
|
|
h.Write([]byte(signingInput))
|
|
hashed := h.Sum(nil)
|
|
r, s, err := ecdsa.Sign(rand.Reader, privateKey, hashed)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to sign with ES512: %v", err)
|
|
}
|
|
// For ES512 (P-521), each coordinate should be 66 bytes (521 bits / 8 = 65.125, rounded up to 66)
|
|
rBytes := r.Bytes()
|
|
sBytes := s.Bytes()
|
|
// Pad to 66 bytes each
|
|
if len(rBytes) < 66 {
|
|
padded := make([]byte, 66)
|
|
copy(padded[66-len(rBytes):], rBytes)
|
|
rBytes = padded
|
|
} else if len(rBytes) > 66 {
|
|
// Truncate if too long (shouldn't happen with P-521)
|
|
rBytes = rBytes[len(rBytes)-66:]
|
|
}
|
|
if len(sBytes) < 66 {
|
|
padded := make([]byte, 66)
|
|
copy(padded[66-len(sBytes):], sBytes)
|
|
sBytes = padded
|
|
} else if len(sBytes) > 66 {
|
|
// Truncate if too long (shouldn't happen with P-521)
|
|
sBytes = sBytes[len(sBytes)-66:]
|
|
}
|
|
signature = append(rBytes, sBytes...)
|
|
default:
|
|
return "", fmt.Errorf("unsupported EC algorithm: %s", alg)
|
|
}
|
|
|
|
// Encode signature
|
|
signatureBase64 := base64.RawURLEncoding.EncodeToString(signature)
|
|
|
|
// Combine to create JWT
|
|
return signingInput + "." + signatureBase64, nil
|
|
}
|
|
|
|
// createECJWK creates a JWK from an EC private key
|
|
func createECJWK(privateKey *ecdsa.PrivateKey, alg, kid string) JWK {
|
|
// Get the curve name
|
|
var crv string
|
|
switch privateKey.Curve {
|
|
case elliptic.P256():
|
|
crv = "P-256"
|
|
case elliptic.P384():
|
|
crv = "P-384"
|
|
case elliptic.P521():
|
|
crv = "P-521"
|
|
default:
|
|
panic("unsupported curve")
|
|
}
|
|
|
|
// Get the key size for coordinate encoding
|
|
keySize := (privateKey.Curve.Params().BitSize + 7) / 8
|
|
|
|
// Encode X and Y coordinates
|
|
xBytes := privateKey.PublicKey.X.Bytes()
|
|
yBytes := privateKey.PublicKey.Y.Bytes()
|
|
|
|
// Pad to the correct length
|
|
if len(xBytes) < keySize {
|
|
padded := make([]byte, keySize)
|
|
copy(padded[keySize-len(xBytes):], xBytes)
|
|
xBytes = padded
|
|
}
|
|
if len(yBytes) < keySize {
|
|
padded := make([]byte, keySize)
|
|
copy(padded[keySize-len(yBytes):], yBytes)
|
|
yBytes = padded
|
|
}
|
|
|
|
return JWK{
|
|
Kty: "EC",
|
|
Kid: kid,
|
|
Alg: alg,
|
|
Crv: crv,
|
|
X: base64.RawURLEncoding.EncodeToString(xBytes),
|
|
Y: base64.RawURLEncoding.EncodeToString(yBytes),
|
|
}
|
|
}
|
|
|
|
// TestMalformedTokens tests the plugin's handling of malformed tokens
|
|
func TestMalformedTokens(t *testing.T) {
|
|
ts := NewTestSuite(t)
|
|
ts.Setup()
|
|
|
|
testCases := []struct {
|
|
name string
|
|
token string
|
|
expectedError string
|
|
}{
|
|
{
|
|
name: "Empty Token",
|
|
token: "",
|
|
expectedError: "invalid JWT format",
|
|
},
|
|
{
|
|
name: "Missing Parts",
|
|
token: "header.payload",
|
|
expectedError: "invalid JWT format",
|
|
},
|
|
{
|
|
name: "Invalid Base64 in Header",
|
|
token: "invalid!base64.payload.signature",
|
|
expectedError: "failed to decode header",
|
|
},
|
|
{
|
|
name: "Invalid Base64 in Payload",
|
|
token: "eyJhbGciOiJSUzI1NiJ9.invalid!base64.signature",
|
|
expectedError: "failed to decode claims",
|
|
},
|
|
{
|
|
name: "Invalid Base64 in Signature",
|
|
token: "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.invalid!base64",
|
|
expectedError: "failed to decode signature",
|
|
},
|
|
{
|
|
name: "Invalid JSON in Header",
|
|
token: "eyJpbnZhbGlkIGpzb24=.eyJzdWIiOiIxMjM0NTY3ODkwIn0.signature",
|
|
expectedError: "failed to decode header",
|
|
},
|
|
{
|
|
name: "Invalid JSON in Payload",
|
|
token: "eyJhbGciOiJSUzI1NiJ9.eyJpbnZhbGlkIGpzb24=.signature",
|
|
expectedError: "failed to decode claims",
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
err := ts.tOidc.VerifyToken(tc.token)
|
|
|
|
// Should fail with expected error
|
|
if err == nil {
|
|
t.Errorf("Malformed token was incorrectly verified: %s", tc.token)
|
|
} else {
|
|
if !strings.Contains(err.Error(), tc.expectedError) {
|
|
t.Errorf("Expected error containing '%s', but got: %v", tc.expectedError, err)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestRateLimiting tests the rate limiting functionality to prevent brute force attacks
|
|
func TestRateLimiting(t *testing.T) {
|
|
// Create cleanup helper
|
|
tc := newTestCleanup(t)
|
|
|
|
// Create a fresh instance for this test to avoid affecting other tests with rate limiting
|
|
logger := NewLogger("debug")
|
|
|
|
// Create a new test suite for this test only
|
|
ts := NewTestSuite(t)
|
|
ts.Setup()
|
|
|
|
// Create a separate TraefikOidc instance with a very restrictive rate limiter
|
|
// This prevents the global instance from being rate-limited
|
|
tokenBlacklist := tc.addCache(NewCache())
|
|
tokenCache := tc.addTokenCache(NewTokenCache())
|
|
tOidc := &TraefikOidc{
|
|
issuerURL: "https://test-issuer.com",
|
|
clientID: "test-client-id",
|
|
clientSecret: "test-client-secret",
|
|
jwkCache: ts.mockJWKCache,
|
|
jwksURL: "https://test-jwks-url.com",
|
|
tokenBlacklist: tokenBlacklist,
|
|
tokenCache: tokenCache,
|
|
// Allow only 2 requests per 10 seconds
|
|
limiter: rate.NewLimiter(rate.Every(10*time.Second), 2),
|
|
logger: logger,
|
|
allowedUserDomains: map[string]struct{}{"example.com": {}},
|
|
excludedURLs: map[string]struct{}{"/favicon": {}},
|
|
httpClient: &http.Client{},
|
|
extractClaimsFunc: extractClaims,
|
|
}
|
|
|
|
// Set up the token verifier and JWT verifier
|
|
tOidc.jwtVerifier = tOidc
|
|
tOidc.tokenVerifier = tOidc
|
|
|
|
// Create a valid JWT token
|
|
validJWT, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://test-issuer.com",
|
|
"aud": "test-client-id",
|
|
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
|
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
|
"sub": "test-subject",
|
|
"email": "user@example.com",
|
|
"jti": generateRandomString(16),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to create valid JWT: %v", err)
|
|
}
|
|
|
|
// First request should succeed
|
|
err = tOidc.VerifyToken(validJWT)
|
|
if err != nil {
|
|
t.Fatalf("First token verification failed unexpectedly: %v", err)
|
|
}
|
|
|
|
// Second request should succeed
|
|
validJWT2, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://test-issuer.com",
|
|
"aud": "test-client-id",
|
|
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
|
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
|
"sub": "test-subject",
|
|
"email": "user@example.com",
|
|
"jti": generateRandomString(16),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to create second valid JWT: %v", err)
|
|
}
|
|
|
|
err = tOidc.VerifyToken(validJWT2)
|
|
if err != nil {
|
|
t.Fatalf("Second token verification failed unexpectedly: %v", err)
|
|
}
|
|
|
|
// Third request should be rate limited
|
|
validJWT3, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://test-issuer.com",
|
|
"aud": "test-client-id",
|
|
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
|
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
|
"sub": "test-subject",
|
|
"email": "user@example.com",
|
|
"jti": generateRandomString(16),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to create third valid JWT: %v", err)
|
|
}
|
|
|
|
err = tOidc.VerifyToken(validJWT3)
|
|
if err == nil {
|
|
t.Errorf("Third token verification succeeded despite rate limiting")
|
|
} else {
|
|
// Check that the error message indicates rate limiting
|
|
if !strings.Contains(strings.ToLower(err.Error()), "rate") {
|
|
t.Errorf("Expected rate limiting error, but got: %v", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestAuthorizationHeaderBypass tests that the plugin correctly handles attempts to bypass
|
|
// authorization by directly providing an Authorization header
|
|
func TestAuthorizationHeaderBypass(t *testing.T) {
|
|
// Create cleanup helper
|
|
tc := newTestCleanup(t)
|
|
|
|
ts := NewTestSuite(t)
|
|
ts.Setup()
|
|
|
|
// Create a test next handler that would indicate successful authentication
|
|
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte("Authenticated"))
|
|
})
|
|
|
|
// Create the TraefikOidc instance
|
|
tokenBlacklist := tc.addCache(NewCache())
|
|
tokenCache := tc.addTokenCache(NewTokenCache())
|
|
tOidc := &TraefikOidc{
|
|
next: nextHandler,
|
|
name: "test",
|
|
redirURLPath: "/callback",
|
|
logoutURLPath: "/callback/logout",
|
|
issuerURL: "https://test-issuer.com",
|
|
clientID: "test-client-id",
|
|
clientSecret: "test-client-secret",
|
|
jwkCache: ts.mockJWKCache,
|
|
jwksURL: "https://test-jwks-url.com",
|
|
tokenBlacklist: tokenBlacklist,
|
|
tokenCache: tokenCache,
|
|
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
|
logger: NewLogger("debug"),
|
|
allowedUserDomains: map[string]struct{}{"example.com": {}},
|
|
excludedURLs: map[string]struct{}{"/favicon": {}},
|
|
httpClient: &http.Client{},
|
|
initComplete: make(chan struct{}),
|
|
sessionManager: ts.sessionManager,
|
|
}
|
|
close(tOidc.initComplete)
|
|
|
|
// Create a request with a forged Authorization header but no valid session
|
|
req := httptest.NewRequest("GET", "/protected", nil)
|
|
req.Header.Set("X-Forwarded-Proto", "https")
|
|
req.Header.Set("X-Forwarded-Host", "example.com")
|
|
|
|
// Add a forged Authorization header
|
|
req.Header.Set("Authorization", "Bearer "+ts.token)
|
|
|
|
// Record the response
|
|
resp := httptest.NewRecorder()
|
|
|
|
// Process the request
|
|
tOidc.ServeHTTP(resp, req)
|
|
|
|
// The middleware should not honor the direct Authorization header
|
|
// and should either redirect to authentication or return an error
|
|
if resp.Code == http.StatusOK {
|
|
body := resp.Body.String()
|
|
if body == "Authenticated" {
|
|
t.Errorf("Authorization header bypass succeeded - request was authenticated without a valid session")
|
|
}
|
|
}
|
|
|
|
// Verify that the response is a redirect to authentication (302) or unauthorized (401)
|
|
expectedCodes := []int{http.StatusFound, http.StatusUnauthorized}
|
|
codeFound := false
|
|
for _, code := range expectedCodes {
|
|
if resp.Code == code {
|
|
codeFound = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if !codeFound {
|
|
t.Errorf("Expected status code to be one of %v, but got %d", expectedCodes, resp.Code)
|
|
}
|
|
}
|
|
|
|
// TestEmptyAudience tests tokens with empty audience claim
|
|
func TestEmptyAudience(t *testing.T) {
|
|
ts := NewTestSuite(t)
|
|
ts.Setup()
|
|
|
|
// Create a JWT with empty audience
|
|
emptyAudJWT, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://test-issuer.com",
|
|
"aud": "", // Empty audience
|
|
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
|
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
|
"sub": "test-subject",
|
|
"email": "user@example.com",
|
|
"jti": generateRandomString(16),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to create JWT with empty audience: %v", err)
|
|
}
|
|
|
|
// Verify the token
|
|
err = ts.tOidc.VerifyToken(emptyAudJWT)
|
|
|
|
// Should fail due to invalid audience
|
|
if err == nil {
|
|
t.Errorf("Token with empty audience was incorrectly verified")
|
|
} else {
|
|
// Check error message
|
|
if !strings.Contains(err.Error(), "invalid audience") {
|
|
t.Errorf("Expected invalid audience error, but got: %v", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestEmptyIssuer tests tokens with empty issuer claim
|
|
func TestEmptyIssuer(t *testing.T) {
|
|
ts := NewTestSuite(t)
|
|
ts.Setup()
|
|
|
|
// Create a JWT with empty issuer
|
|
emptyIssJWT, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "", // Empty issuer
|
|
"aud": "test-client-id",
|
|
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
|
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
|
"sub": "test-subject",
|
|
"email": "user@example.com",
|
|
"jti": generateRandomString(16),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to create JWT with empty issuer: %v", err)
|
|
}
|
|
|
|
// Verify the token
|
|
err = ts.tOidc.VerifyToken(emptyIssJWT)
|
|
|
|
// Should fail due to invalid issuer
|
|
if err == nil {
|
|
t.Errorf("Token with empty issuer was incorrectly verified")
|
|
} else {
|
|
// Check error message
|
|
if !strings.Contains(err.Error(), "invalid issuer") {
|
|
t.Errorf("Expected invalid issuer error, but got: %v", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestInvalidRedirectURI tests the plugin's handling of invalid redirect URIs
|
|
func TestInvalidRedirectURI(t *testing.T) {
|
|
// Create cleanup helper
|
|
tc := newTestCleanup(t)
|
|
|
|
ts := NewTestSuite(t)
|
|
ts.Setup()
|
|
|
|
// Create a test request with an invalid redirect URI
|
|
req := httptest.NewRequest("GET", "/callback?state=validstate&code=validcode&redirect_uri=https://evil.com", nil)
|
|
req.Header.Set("X-Forwarded-Proto", "https")
|
|
req.Header.Set("X-Forwarded-Host", "example.com")
|
|
|
|
// Create a session with a state
|
|
session, err := ts.sessionManager.GetSession(req)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get session: %v", err)
|
|
}
|
|
|
|
// Set legitimate state and redirect
|
|
session.mainSession.Values["state"] = "validstate"
|
|
session.mainSession.Values["redirect"] = "/legitimate-redirect"
|
|
|
|
resp := httptest.NewRecorder()
|
|
if err := session.Save(req, resp); err != nil {
|
|
t.Fatalf("Failed to save session: %v", err)
|
|
}
|
|
|
|
// Get cookies
|
|
cookies := resp.Result().Cookies()
|
|
|
|
// Create a new request with those cookies
|
|
req = httptest.NewRequest("GET", "/callback?state=validstate&code=validcode&redirect_uri=https://evil.com", nil)
|
|
req.Header.Set("X-Forwarded-Proto", "https")
|
|
req.Header.Set("X-Forwarded-Host", "example.com")
|
|
|
|
// Add cookies
|
|
for _, cookie := range cookies {
|
|
req.AddCookie(cookie)
|
|
}
|
|
|
|
// Create a next handler for the middleware
|
|
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
// Create the TraefikOidc instance
|
|
tokenBlacklist := tc.addCache(NewCache())
|
|
tokenCache := tc.addTokenCache(NewTokenCache())
|
|
tOidc := &TraefikOidc{
|
|
next: nextHandler,
|
|
name: "test",
|
|
redirURLPath: "/callback",
|
|
logoutURLPath: "/callback/logout",
|
|
issuerURL: "https://test-issuer.com",
|
|
clientID: "test-client-id",
|
|
clientSecret: "test-client-secret",
|
|
jwkCache: ts.mockJWKCache,
|
|
jwksURL: "https://test-jwks-url.com",
|
|
tokenBlacklist: tokenBlacklist,
|
|
tokenCache: tokenCache,
|
|
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
|
logger: NewLogger("debug"),
|
|
allowedUserDomains: map[string]struct{}{"example.com": {}},
|
|
excludedURLs: map[string]struct{}{"/favicon": {}},
|
|
httpClient: &http.Client{},
|
|
initComplete: make(chan struct{}),
|
|
sessionManager: ts.sessionManager,
|
|
tokenExchanger: ts.tOidc.tokenExchanger,
|
|
}
|
|
close(tOidc.initComplete)
|
|
|
|
// Process the callback request
|
|
resp = httptest.NewRecorder()
|
|
tOidc.ServeHTTP(resp, req)
|
|
|
|
// Check if open redirect is blocked
|
|
// The response should not redirect to the evil.com domain
|
|
location := resp.Header().Get("Location")
|
|
if location != "" && strings.Contains(location, "evil.com") {
|
|
t.Errorf("Open redirect vulnerability - redirected to %s", location)
|
|
}
|
|
|
|
// Should redirect to the legitimate URL
|
|
if location != "" && !strings.Contains(location, "/legitimate-redirect") {
|
|
t.Errorf("Expected redirect to /legitimate-redirect, but got: %s", location)
|
|
}
|
|
}
|