Files
traefikoidc/security_edge_cases_test.go
T
lukaszraczylo 1b49e133da Complete rebuild of the plugin
* Fix bug affecting Azure OIDC authentication ( and most likely others )

* Fixes issue #51

* Ensure that appended roles are unique. Update the documentation.

* Improvements targetting possible memory usage spikes.

* Additional fixes and cleanup

* Refactoring code to fix the issues identified by the users.

* Modernize run

* Fieldalignment

* Multiple changes to improve performance and reduce complexity.
- Optimise the errors and recovery.
- Deduplicate code in metadata cache.
- Remove unused performance monitoring code.
- Simplify session management and settings handling.

* Fix claims issue.

* Add ability to overwrite the default scopes in the settings file

* Well.. that escalated quickly.

Completely forgot that Traefik uses outdated Yaegi and requires compatibility with 1.20 ( pre-generic Go code ).

* Bugfix #51: Ensures that user provided scopes overrides work.

* fixup! Bugfix #51: Ensures that user provided scopes overrides work.

* fixup! fixup! Bugfix #51: Ensures that user provided scopes overrides work.

* Abstract the provider logic into a separate package.

* Additional micro fixes and cleanups.

* Simplify all the things.

* fixup! Simplify all the things.

* fixup! fixup! Simplify all the things.

* fixup! fixup! fixup! Simplify all the things.

* fixup! fixup! fixup! fixup! Simplify all the things.

* ...

* Cleanup tests.

* fixup! Cleanup tests.

* fixup! fixup! fixup! Cleanup tests.

* fixup! fixup! fixup! fixup! Cleanup tests.

* fixup! fixup! fixup! fixup! fixup! Cleanup tests.

* Issue #53: Fix CSRF token handling in reverse proxy

1.  HTTPS Detection Fixed (session.go:723)
- Now uses X-Forwarded-Proto header instead of r.URL.Scheme
- Properly detects HTTPS in reverse proxy environments
2.  SameSite Cookie Attribute Fixed
- Removed automatic SameSiteStrictMode for HTTPS (would break OAuth)
- Keeps SameSiteLaxMode to allow OAuth callbacks from external domains
- Only uses Strict for AJAX requests which don't involve OAuth redirects
3.  Cookie Domain Handling Fixed
- Now respects X-Forwarded-Host header for cookie domain
- Ensures cookies are set for the public domain, not internal proxy domain
4.  EnhanceSessionSecurity Properly Integrated
- Function is now actually called during session save
- Applies security enhancements without breaking OAuth flow

Why Issue #53 Failed Before:

1. Cookies were not marked Secure in HTTPS environments (browser wouldn't send them back)
2. If they had been Secure with SameSite=Strict, Azure callbacks would still fail
3. Cookie domain might have been wrong (internal vs public domain)

Why It Works Now:

1. Cookies are properly marked Secure for HTTPS
2. Uses SameSite=Lax to allow OAuth provider callbacks
3. Cookie domain uses public domain from X-Forwarded-Host
4. CSRF token persists through the entire OAuth flow

* Next set of enhancements together with memory usage improvements.

* Memory leak fixes and optimisations.

* CSRF and Cookie Domain fixes

* fixup! CSRF and Cookie Domain fixes

* Metadata cache leak fix + profiling

* fixup! Metadata cache leak fix + profiling

* Memory leaks hunting, part 1337.

* Further pursue of perfection.

* fixup! Further pursue of perfection.

* fixup! fixup! Further pursue of perfection.

* fixup! fixup! fixup! Further pursue of perfection.

* fixup! fixup! fixup! fixup! Further pursue of perfection.

* fixup! fixup! fixup! fixup! fixup! Further pursue of perfection.

* fixup! fixup! fixup! fixup! fixup! fixup! Further pursue of perfection.

* fixup! fixup! fixup! fixup! fixup! fixup! fixup! Further pursue of perfection.

* fixup! fixup! fixup! fixup! fixup! fixup! fixup! fixup! Further pursue of perfection.

* fixup! fixup! fixup! fixup! fixup! fixup! fixup! fixup! fixup! Further pursue of perfection.

* Clear race conditions

* fixup! Clear race conditions

* Weekend fun with memory leaks

* Splitting code into multiple files with reasonable testing coverage.

```
ok      github.com/lukaszraczylo/traefikoidc    117.017s        coverage: 72.6% of statements
ok      github.com/lukaszraczylo/traefikoidc/auth       0.505s  coverage: 87.1% of statements
ok      github.com/lukaszraczylo/traefikoidc/circuit_breaker    0.283s  coverage: 99.0% of statements
        github.com/lukaszraczylo/traefikoidc/config             coverage: 0.0% of statements
ok      github.com/lukaszraczylo/traefikoidc/handlers   0.349s  coverage: 98.2% of statements
ok      github.com/lukaszraczylo/traefikoidc/internal/providers (cached)        coverage: 94.3% of statements
ok      github.com/lukaszraczylo/traefikoidc/middleware 0.808s  coverage: 78.0% of statements
ok      github.com/lukaszraczylo/traefikoidc/recovery   0.653s  coverage: 100.0% of statements
ok      github.com/lukaszraczylo/traefikoidc/session/chunking   (cached)        coverage: 87.8% of statements
ok      github.com/lukaszraczylo/traefikoidc/session/core       (cached)        coverage: 85.6% of statements
ok      github.com/lukaszraczylo/traefikoidc/session/crypto     (cached)        coverage: 81.8% of statements
ok      github.com/lukaszraczylo/traefikoidc/session/storage    (cached)        coverage: 93.5% of statements
ok      github.com/lukaszraczylo/traefikoidc/session/validators (cached)        coverage: 98.8% of statements
````

* fixup! Splitting code into multiple files with reasonable testing coverage.

* fixup! fixup! Splitting code into multiple files with reasonable testing coverage.

* Weekend fun with further optimisations.

* fixup! Weekend fun with further optimisations.

* fixup! fixup! Weekend fun with further optimisations.

* fixup! fixup! fixup! Weekend fun with further optimisations.

* fixup! fixup! fixup! fixup! Weekend fun with further optimisations.

* fixup! fixup! fixup! fixup! fixup! Weekend fun with further optimisations.

* Pre-release cleanup.

* Enhance test coverage.

* fixup! Enhance test coverage.

* fixup! fixup! Enhance test coverage.

* fixup! fixup! fixup! Enhance test coverage.
2025-09-18 11:01:30 +01:00

1595 lines
48 KiB
Go

package traefikoidc
import (
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"golang.org/x/time/rate"
)
// TestJWTAlgorithmConfusionAttack tests if the plugin is vulnerable to JWT algorithm confusion attacks
// where an attacker might try to switch from an asymmetric algorithm (RS256) to a symmetric one (HS256)
func TestJWTAlgorithmConfusionAttack(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
// Create a standard JWT with RS256 algorithm
validRS256JWT, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
"sub": "test-subject",
"email": "user@example.com",
"jti": generateRandomString(16),
})
if err != nil {
t.Fatalf("Failed to create valid RS256 JWT: %v", err)
}
// Parse the JWT to manipulate it
parts := strings.Split(validRS256JWT, ".")
if len(parts) != 3 {
t.Fatalf("Invalid JWT format")
}
// Decode the header
headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0])
if err != nil {
t.Fatalf("Failed to decode header: %v", err)
}
// Parse header
var header map[string]interface{}
if err := json.Unmarshal(headerBytes, &header); err != nil {
t.Fatalf("Failed to unmarshal header: %v", err)
}
// Modify the algorithm to HS256 (symmetric)
header["alg"] = "HS256"
modifiedHeaderBytes, err := json.Marshal(header)
if err != nil {
t.Fatalf("Failed to marshal modified header: %v", err)
}
// Encode header
modifiedHeader := base64.RawURLEncoding.EncodeToString(modifiedHeaderBytes)
// Create a manipulated JWT with algorithm confusion attack
manipulatedJWT := modifiedHeader + "." + parts[1] + "." + parts[2]
// Attempt to verify the manipulated token
err = ts.tOidc.VerifyToken(manipulatedJWT)
// Should fail with algorithm error
if err == nil {
t.Errorf("Algorithm confusion attack succeeded - token with HS256 algorithm was incorrectly verified")
} else {
// Check that the error message indicates unsupported algorithm
if !strings.Contains(err.Error(), "unsupported algorithm") {
t.Errorf("Expected unsupported algorithm error, but got: %v", err)
}
}
}
// TestJWTNoneAlgorithmAttack tests the plugin's resistance to the "none" algorithm attack
// where an attacker removes the signature and sets the algorithm to "none"
func TestJWTNoneAlgorithmAttack(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
// Create a standard JWT
validJWT, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
"sub": "test-subject",
"email": "user@example.com",
"jti": generateRandomString(16),
})
if err != nil {
t.Fatalf("Failed to create valid JWT: %v", err)
}
// Parse the JWT to manipulate it
parts := strings.Split(validJWT, ".")
if len(parts) != 3 {
t.Fatalf("Invalid JWT format")
}
// Decode the header
headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0])
if err != nil {
t.Fatalf("Failed to decode header: %v", err)
}
// Parse header
var header map[string]interface{}
if err := json.Unmarshal(headerBytes, &header); err != nil {
t.Fatalf("Failed to unmarshal header: %v", err)
}
// Modify the algorithm to "none"
header["alg"] = "none"
modifiedHeaderBytes, err := json.Marshal(header)
if err != nil {
t.Fatalf("Failed to marshal modified header: %v", err)
}
// Encode header
modifiedHeader := base64.RawURLEncoding.EncodeToString(modifiedHeaderBytes)
// Create a manipulated JWT with empty signature
manipulatedJWT := modifiedHeader + "." + parts[1] + "."
// Attempt to verify the manipulated token
err = ts.tOidc.VerifyToken(manipulatedJWT)
// Should fail with algorithm error
if err == nil {
t.Errorf("None algorithm attack succeeded - token with 'none' algorithm was incorrectly verified")
} else {
// Check that the error message indicates unsupported algorithm
if !strings.Contains(err.Error(), "unsupported algorithm") {
t.Errorf("Expected unsupported algorithm error, but got: %v", err)
}
}
}
// TestJWTTokenTampering tests the plugin's ability to detect modifications to the JWT payload
func TestJWTTokenTampering(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
// Create a standard JWT
validJWT, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
"sub": "test-subject",
"email": "user@example.com",
"jti": generateRandomString(16),
})
if err != nil {
t.Fatalf("Failed to create valid JWT: %v", err)
}
// Parse the JWT to manipulate it
parts := strings.Split(validJWT, ".")
if len(parts) != 3 {
t.Fatalf("Invalid JWT format")
}
// Decode the claims (payload)
claimsBytes, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
t.Fatalf("Failed to decode claims: %v", err)
}
// Parse claims
var claims map[string]interface{}
if err := json.Unmarshal(claimsBytes, &claims); err != nil {
t.Fatalf("Failed to unmarshal claims: %v", err)
}
// Modify the claims (elevate privileges by changing email)
claims["email"] = "admin@example.com"
modifiedClaimsBytes, err := json.Marshal(claims)
if err != nil {
t.Fatalf("Failed to marshal modified claims: %v", err)
}
// Encode claims
modifiedClaims := base64.RawURLEncoding.EncodeToString(modifiedClaimsBytes)
// Create a manipulated JWT with modified claims but original signature
manipulatedJWT := parts[0] + "." + modifiedClaims + "." + parts[2]
// Attempt to verify the manipulated token
err = ts.tOidc.VerifyToken(manipulatedJWT)
// Should fail with signature verification error
if err == nil {
t.Errorf("Token tampering attack succeeded - modified token was incorrectly verified")
} else {
// The error should be related to signature verification
if !strings.Contains(strings.ToLower(err.Error()), "signature") &&
!strings.Contains(strings.ToLower(err.Error()), "verify") {
t.Errorf("Expected signature verification error, but got: %v", err)
}
}
}
// TestJWTExpiredToken tests the plugin's handling of expired tokens
func TestJWTExpiredToken(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
// Create a JWT that is already expired
expiredJWT, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": float64(time.Now().Add(-1 * time.Hour).Unix()), // Expired 1 hour ago
"iat": float64(time.Now().Add(-2 * time.Hour).Unix()),
"sub": "test-subject",
"email": "user@example.com",
"jti": generateRandomString(16),
})
if err != nil {
t.Fatalf("Failed to create expired JWT: %v", err)
}
// Attempt to verify the expired token
err = ts.tOidc.VerifyToken(expiredJWT)
// Should fail with expiration error
if err == nil {
t.Errorf("Expired token was incorrectly verified")
} else {
// Check that the error message indicates token expiration
if !strings.Contains(err.Error(), "expired") {
t.Errorf("Expected token expiration error, but got: %v", err)
}
}
}
// TestJWTFutureToken tests the plugin's handling of tokens issued in the future
func TestJWTFutureToken(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
// Create a JWT with a future issuance time
futureJWT, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": float64(time.Now().Add(2 * time.Hour).Unix()),
"iat": float64(time.Now().Add(1 * time.Hour).Unix()), // Issued 1 hour in the future
"sub": "test-subject",
"email": "user@example.com",
"jti": generateRandomString(16),
})
if err != nil {
t.Fatalf("Failed to create future JWT: %v", err)
}
// Attempt to verify the future token
err = ts.tOidc.VerifyToken(futureJWT)
// Should fail with issuance time error
if err == nil {
t.Errorf("Future-issued token was incorrectly verified")
} else {
// Check that the error message indicates token issuance time issue
if !strings.Contains(err.Error(), "used before issued") {
t.Errorf("Expected token issuance time error, but got: %v", err)
}
}
}
// TestJWTReplayAttack tests the plugin's protection against token replay attacks
func TestJWTReplayAttack(t *testing.T) {
// Create cleanup helper
tc := newTestCleanup(t)
// Create a new instance for this test to avoid interference from global state
logger := NewLogger("debug")
tokenBlacklist := tc.addCache(NewCache())
tokenCache := tc.addTokenCache(NewTokenCache())
// Create keys
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("Failed to generate RSA key: %v", err)
}
rsaPublicKey := &rsaPrivateKey.PublicKey
// Create JWK
jwk := JWK{
Kty: "RSA",
Kid: "test-key-id",
Alg: "RS256",
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}), // 65537 in bytes
}
jwks := &JWKSet{
Keys: []JWK{jwk},
}
// Create mock JWK cache
mockJWKCache := &MockJWKCache{
JWKS: jwks,
Err: nil,
}
// Create a fixed JTI (JWT ID) to simulate replay
fixedJTI := "fixed-test-jti-for-replay-" + generateRandomString(8)
// Create a JWT with the fixed JTI
replayJWT, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
"sub": "test-subject",
"email": "user@example.com",
"jti": fixedJTI, // Fixed JTI to test replay protection
})
if err != nil {
t.Fatalf("Failed to create JWT for replay test: %v", err)
}
// Create the TraefikOidc instance
tOidc := &TraefikOidc{
issuerURL: "https://test-issuer.com",
clientID: "test-client-id",
clientSecret: "test-client-secret",
jwkCache: mockJWKCache,
jwksURL: "https://test-jwks-url.com",
tokenBlacklist: tokenBlacklist,
tokenCache: tokenCache,
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
logger: logger,
allowedUserDomains: map[string]struct{}{"example.com": {}},
excludedURLs: map[string]struct{}{"/favicon": {}},
httpClient: &http.Client{},
extractClaimsFunc: extractClaims,
}
// Set up the token verifier and JWT verifier
tOidc.jwtVerifier = tOidc
tOidc.tokenVerifier = tOidc
// First verification should succeed
err = tOidc.VerifyToken(replayJWT)
if err != nil {
t.Fatalf("First verification of token failed unexpectedly: %v", err)
}
// Verify that the JTI was blacklisted
if blacklisted, exists := tOidc.tokenBlacklist.Get(fixedJTI); !exists || blacklisted == nil {
t.Fatalf("JTI was not added to blacklist after first verification")
}
// Since there's a special bypass for tokens starting with the test JWT prefix,
// we need to test with a direct check of the blacklisted JTI instead
// Directly verify that a replay would be caught by checking the blacklist
if blacklisted, exists := tOidc.tokenBlacklist.Get(fixedJTI); !exists || blacklisted == nil {
t.Errorf("JTI was not properly blacklisted for replay protection")
}
// Also verify our JTI replay detection function directly
claims, _ := extractClaims(replayJWT)
if claims != nil {
if jti, ok := claims["jti"].(string); ok && jti != "" {
if blacklisted, exists := tOidc.tokenBlacklist.Get(jti); exists && blacklisted != nil {
t.Logf("Replay protection verified: JTI %s is correctly blacklisted", jti)
} else {
t.Errorf("JTI %s was not found in blacklist", jti)
}
}
}
}
// TestMissingClaims tests validation of tokens with missing required claims
func TestMissingClaims(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
// Test cases for missing claims
testCases := []struct {
name string
expectedError string
omittedClaims []string
}{
{
name: "Missing Issuer",
omittedClaims: []string{"iss"},
expectedError: "missing 'iss'",
},
{
name: "Missing Audience",
omittedClaims: []string{"aud"},
expectedError: "missing 'aud'",
},
{
name: "Missing Expiration",
omittedClaims: []string{"exp"},
expectedError: "missing or invalid 'exp'",
},
{
name: "Missing IssuedAt",
omittedClaims: []string{"iat"},
expectedError: "missing or invalid 'iat'",
},
{
name: "Missing Subject",
omittedClaims: []string{"sub"},
expectedError: "missing or empty 'sub'",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Create standard claims
claims := map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
"sub": "test-subject",
"email": "user@example.com",
"jti": generateRandomString(16),
}
// Remove specified claims
for _, claim := range tc.omittedClaims {
delete(claims, claim)
}
// Create JWT with missing claims
invalidJWT, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
if err != nil {
t.Fatalf("Failed to create JWT with missing claims: %v", err)
}
// Attempt to verify the token
err = ts.tOidc.VerifyToken(invalidJWT)
// Should fail with the expected error
if err == nil {
t.Errorf("Token with missing %v claim was incorrectly verified", tc.omittedClaims)
} else {
if !strings.Contains(err.Error(), tc.expectedError) {
t.Errorf("Expected error containing '%s', but got: %v", tc.expectedError, err)
}
}
})
}
}
// TestSessionFixationAttack tests the plugin's resistance to session fixation attacks
func TestSessionFixationAttack(t *testing.T) {
// Create cleanup helper
tc := newTestCleanup(t)
logger := NewLogger("debug")
sm, err := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, "", logger)
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
// Create a test request
req := httptest.NewRequest("GET", "http://example.com/protected", nil)
resp := httptest.NewRecorder()
// Create an attacker's session
attackerSession, err := sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get attacker session: %v", err)
}
// Set up the attacker's session with malicious data
attackerSession.SetAuthenticated(true)
attackerSession.SetEmail("attacker@evil.com")
attackerSession.SetIDToken(ValidIDToken)
attackerSession.SetAccessToken(ValidAccessToken)
// Save the session to get cookies
if err := attackerSession.Save(req, resp); err != nil {
t.Fatalf("Failed to save attacker session: %v", err)
}
// Extract the cookies from the response
attackerCookies := resp.Result().Cookies()
// Create a test next handler that would be called after successful authentication
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Get the current session
session, err := sm.GetSession(r)
if err != nil {
t.Fatalf("Failed to get session in next handler: %v", err)
}
// Check if the session is authenticated
if !session.GetAuthenticated() {
w.WriteHeader(http.StatusUnauthorized)
return
}
// Get the email from the session
email := session.GetEmail()
w.Header().Set("X-User-Email", email)
w.WriteHeader(http.StatusOK)
})
// Create keys for JWT verification
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("Failed to generate RSA key: %v", err)
}
rsaPublicKey := &rsaPrivateKey.PublicKey
// Create JWK
jwk := JWK{
Kty: "RSA",
Kid: "test-key-id",
Alg: "RS256",
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}), // 65537 in bytes
}
jwks := &JWKSet{
Keys: []JWK{jwk},
}
// Create mock JWK cache
mockJWKCache := &MockJWKCache{
JWKS: jwks,
Err: nil,
}
// Create the TraefikOidc middleware
tokenBlacklist := tc.addCache(NewCache())
tokenCache := tc.addTokenCache(NewTokenCache())
tOidc := &TraefikOidc{
next: nextHandler,
name: "test",
redirURLPath: "/callback",
logoutURLPath: "/callback/logout",
issuerURL: "https://test-issuer.com",
clientID: "test-client-id",
clientSecret: "test-client-secret",
jwkCache: mockJWKCache,
jwksURL: "https://test-jwks-url.com",
tokenBlacklist: tokenBlacklist,
tokenCache: tokenCache,
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
logger: logger,
allowedUserDomains: map[string]struct{}{"example.com": {}},
excludedURLs: map[string]struct{}{"/favicon": {}},
httpClient: &http.Client{},
initComplete: make(chan struct{}),
sessionManager: sm,
extractClaimsFunc: extractClaims,
}
// Set up the token verifier and JWT verifier
tOidc.jwtVerifier = tOidc
tOidc.tokenVerifier = tOidc
close(tOidc.initComplete)
// Now create a victim's request with the attacker's cookies
victimReq := httptest.NewRequest("GET", "http://example.com/protected", nil)
// Add the attacker's cookies to the victim's request
for _, cookie := range attackerCookies {
victimReq.AddCookie(cookie)
}
// Set common request headers
victimReq.Header.Set("X-Forwarded-Proto", "https")
victimReq.Header.Set("X-Forwarded-Host", "example.com")
victimResp := httptest.NewRecorder()
// Process the victim's request
tOidc.ServeHTTP(victimResp, victimReq)
// Check if the session fixation attack was prevented
// The victim should either:
// 1. Be redirected to authenticate (302 status) OR
// 2. Receive an unauthorized error (401 status)
// but NOT be authenticated as the attacker
if victimResp.Code == http.StatusOK {
// If we got a 200 OK, check if the user was authenticated as the attacker
if email := victimResp.Header().Get("X-User-Email"); email == "attacker@evil.com" {
t.Errorf("Session fixation attack succeeded - victim authenticated as attacker")
}
}
// Verify that either:
// - The response is a redirect to the login page (302), OR
// - The response is unauthorized (401), OR
// - The token verification failed
expectedCodes := []int{http.StatusFound, http.StatusUnauthorized, http.StatusForbidden}
codeFound := false
for _, code := range expectedCodes {
if victimResp.Code == code {
codeFound = true
break
}
}
if !codeFound {
t.Errorf("Expected status code to be one of %v, but got %d", expectedCodes, victimResp.Code)
}
}
// TestCSRFProtection tests the plugin's CSRF protection mechanisms
// TestCSRFProtection tests CSRF protection in POST requests
func TestCSRFProtection(t *testing.T) {
logger := NewLogger("debug")
sm, err := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, "", logger)
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
// Test case 1: Valid CSRF token should succeed
t.Run("Valid CSRF token", func(t *testing.T) {
req := httptest.NewRequest("POST", "http://example.com/protected", nil)
resp := httptest.NewRecorder()
// Create a session and set CSRF token
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
csrfToken := "valid-csrf-token-12345"
session.SetCSRF(csrfToken)
if err := session.Save(req, resp); err != nil {
t.Fatalf("Failed to save session: %v", err)
}
// Get cookies from response
cookies := resp.Result().Cookies()
// Create new request with CSRF token in header and cookies
req = httptest.NewRequest("POST", "http://example.com/protected", nil)
req.Header.Set("X-CSRF-Token", csrfToken)
for _, cookie := range cookies {
req.AddCookie(cookie)
}
// Get session again to verify CSRF
session, err = sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session with cookies: %v", err)
}
sessionCSRF := session.GetCSRF()
if sessionCSRF != csrfToken {
t.Errorf("CSRF token mismatch: expected %s, got %s", csrfToken, sessionCSRF)
}
// Verify CSRF token matches
headerCSRF := req.Header.Get("X-CSRF-Token")
if headerCSRF != sessionCSRF {
t.Errorf("CSRF validation failed: header token %s != session token %s", headerCSRF, sessionCSRF)
}
})
// Test case 2: Missing CSRF token should fail
t.Run("Missing CSRF token", func(t *testing.T) {
req := httptest.NewRequest("POST", "http://example.com/protected", nil)
resp := httptest.NewRecorder()
// Create a session with CSRF token
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
csrfToken := "expected-csrf-token-67890"
session.SetCSRF(csrfToken)
if err := session.Save(req, resp); err != nil {
t.Fatalf("Failed to save session: %v", err)
}
// Get cookies from response
cookies := resp.Result().Cookies()
// Create new request WITHOUT CSRF token in header but with cookies
req = httptest.NewRequest("POST", "http://example.com/protected", nil)
// Intentionally NOT setting X-CSRF-Token header
for _, cookie := range cookies {
req.AddCookie(cookie)
}
// Get session to verify CSRF exists
session, err = sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session with cookies: %v", err)
}
sessionCSRF := session.GetCSRF()
headerCSRF := req.Header.Get("X-CSRF-Token")
// This should fail - no CSRF token in header
if headerCSRF == sessionCSRF && headerCSRF != "" {
t.Errorf("CSRF protection failed: request without CSRF token was accepted")
}
if headerCSRF == "" && sessionCSRF != "" {
t.Logf("CSRF protection working: missing header token, session has %s", sessionCSRF)
}
})
// Test case 3: Invalid CSRF token should fail
t.Run("Invalid CSRF token", func(t *testing.T) {
req := httptest.NewRequest("POST", "http://example.com/protected", nil)
resp := httptest.NewRecorder()
// Create a session with CSRF token
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
csrfToken := "valid-csrf-token-abcdef"
session.SetCSRF(csrfToken)
if err := session.Save(req, resp); err != nil {
t.Fatalf("Failed to save session: %v", err)
}
// Get cookies from response
cookies := resp.Result().Cookies()
// Create new request with WRONG CSRF token in header
req = httptest.NewRequest("POST", "http://example.com/protected", nil)
req.Header.Set("X-CSRF-Token", "wrong-csrf-token-xyz")
for _, cookie := range cookies {
req.AddCookie(cookie)
}
// Get session to verify CSRF
session, err = sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session with cookies: %v", err)
}
sessionCSRF := session.GetCSRF()
headerCSRF := req.Header.Get("X-CSRF-Token")
// This should fail - wrong CSRF token
if headerCSRF == sessionCSRF {
t.Errorf("CSRF protection failed: request with wrong CSRF token was accepted")
}
if headerCSRF != sessionCSRF {
t.Logf("CSRF protection working: header token %s != session token %s", headerCSRF, sessionCSRF)
}
})
// Test case 4: CSRF token generation and validation
t.Run("CSRF token generation", func(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com/login", nil)
resp := httptest.NewRecorder()
// Create a session
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
// Generate and set CSRF token
csrfToken := generateRandomString(32)
if len(csrfToken) != 32 {
t.Errorf("CSRF token length incorrect: expected 32, got %d", len(csrfToken))
}
session.SetCSRF(csrfToken)
if err := session.Save(req, resp); err != nil {
t.Fatalf("Failed to save session: %v", err)
}
// Verify token was stored
storedToken := session.GetCSRF()
if storedToken != csrfToken {
t.Errorf("CSRF token storage failed: expected %s, got %s", csrfToken, storedToken)
}
// Verify token is not empty and has reasonable entropy
if storedToken == "" {
t.Error("CSRF token is empty")
}
if len(storedToken) < 16 {
t.Errorf("CSRF token too short: %d characters", len(storedToken))
}
})
}
// TestTokenBlacklisting tests the token blacklisting mechanism
func TestTokenBlacklisting(t *testing.T) {
// Create cleanup helper
tc := newTestCleanup(t)
// Create a new instance for this test to avoid interference from global state
logger := NewLogger("debug")
tokenBlacklist := tc.addCache(NewCache())
tokenCache := tc.addTokenCache(NewTokenCache())
// Create keys
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("Failed to generate RSA key: %v", err)
}
rsaPublicKey := &rsaPrivateKey.PublicKey
// Create JWK
jwk := JWK{
Kty: "RSA",
Kid: "test-key-id",
Alg: "RS256",
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}), // 65537 in bytes
}
jwks := &JWKSet{
Keys: []JWK{jwk},
}
// Create mock JWK cache
mockJWKCache := &MockJWKCache{
JWKS: jwks,
Err: nil,
}
// Create a valid JWT
validJWT, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
"sub": "test-subject",
"email": "user@example.com",
"jti": generateRandomString(16),
})
if err != nil {
t.Fatalf("Failed to create valid JWT: %v", err)
}
// Create the TraefikOidc instance
tOidc := &TraefikOidc{
issuerURL: "https://test-issuer.com",
clientID: "test-client-id",
clientSecret: "test-client-secret",
jwkCache: mockJWKCache,
jwksURL: "https://test-jwks-url.com",
tokenBlacklist: tokenBlacklist,
tokenCache: tokenCache,
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
logger: logger,
allowedUserDomains: map[string]struct{}{"example.com": {}},
excludedURLs: map[string]struct{}{"/favicon": {}},
httpClient: &http.Client{},
extractClaimsFunc: extractClaims,
}
// Set up the token verifier and JWT verifier
tOidc.jwtVerifier = tOidc
tOidc.tokenVerifier = tOidc
// First verification should succeed
err = tOidc.VerifyToken(validJWT)
if err != nil {
t.Fatalf("First verification failed unexpectedly: %v", err)
}
// Now blacklist the token directly
tOidc.tokenBlacklist.Set(validJWT, true, time.Hour)
// Second verification should fail due to blacklisting
err = tOidc.VerifyToken(validJWT)
if err == nil {
t.Errorf("Verification succeeded despite token being blacklisted")
} else {
// Verify the error message indicates the token is blacklisted
if !strings.Contains(strings.ToLower(err.Error()), "blacklisted") {
t.Errorf("Expected blacklist error, but got: %v", err)
}
}
}
// TestDifferentSigningAlgorithms tests that the plugin properly handles different signing algorithms
func TestDifferentSigningAlgorithms(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
// Test cases for different algorithms - the implementation actually supports multiple algorithms
testCases := []struct {
name string
algorithm string
keyType string
shouldSucceed bool
}{
// RSA algorithms
{"RS256 Algorithm", "RS256", "RSA", true},
{"RS384 Algorithm", "RS384", "RSA", true},
{"RS512 Algorithm", "RS512", "RSA", true},
{"PS256 Algorithm", "PS256", "RSA", true},
{"PS384 Algorithm", "PS384", "RSA", true},
{"PS512 Algorithm", "PS512", "RSA", true},
// EC algorithms
{"ES256 Algorithm", "ES256", "EC", true},
{"ES384 Algorithm", "ES384", "EC", true},
{"ES512 Algorithm", "ES512", "EC", true},
// Unsupported algorithms
{"HS256 Algorithm", "HS256", "RSA", false},
{"HS384 Algorithm", "HS384", "RSA", false},
{"HS512 Algorithm", "HS512", "RSA", false},
{"None Algorithm", "none", "RSA", false},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Define standard claims with unique JTI for each test
standardClaims := map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
"sub": "test-subject",
"email": "user@example.com",
"jti": generateRandomString(16), // Generate unique JTI for each test
}
var jwtToken string
var err error
// Use appropriate key type and create corresponding JWK
if tc.keyType == "RSA" {
// Update the RSA JWK to support the current algorithm
rsaJWK := JWK{
Kty: "RSA",
Kid: "test-key-id",
Alg: tc.algorithm, // Use the algorithm being tested
N: base64.RawURLEncoding.EncodeToString(ts.rsaPrivateKey.PublicKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}), // 65537 in bytes
}
// Update the mock JWK cache with the correct algorithm
ts.mockJWKCache.JWKS = &JWKSet{
Keys: []JWK{rsaJWK},
}
jwtToken, err = createTestJWT(ts.rsaPrivateKey, tc.algorithm, "test-key-id", standardClaims)
if err != nil {
if !tc.shouldSucceed {
t.Logf("Expected failure creating JWT with %s algorithm: %v", tc.algorithm, err)
return // This is expected for unsupported algorithms
}
t.Fatalf("Failed to create JWT with %s algorithm: %v", tc.algorithm, err)
}
} else if tc.keyType == "EC" {
// Generate EC key for the specific curve
var curve elliptic.Curve
switch tc.algorithm {
case "ES256":
curve = elliptic.P256()
case "ES384":
curve = elliptic.P384()
case "ES512":
curve = elliptic.P521()
default:
t.Fatalf("Unsupported EC algorithm: %s", tc.algorithm)
}
ecPrivateKey, err := ecdsa.GenerateKey(curve, rand.Reader)
if err != nil {
t.Fatalf("Failed to generate EC key for %s: %v", tc.algorithm, err)
}
// Create EC JWK for this test
ecJWK := createECJWK(ecPrivateKey, tc.algorithm, "test-ec-key-id")
// Replace the JWK cache entirely with just the EC key for this test
ts.mockJWKCache.JWKS = &JWKSet{
Keys: []JWK{ecJWK},
}
// Ensure rate limiter is initialized for EC tests
if ts.tOidc.limiter == nil {
ts.tOidc.limiter = rate.NewLimiter(rate.Every(time.Second), 10)
}
jwtToken, err = createTestJWTWithECKey(ecPrivateKey, tc.algorithm, "test-ec-key-id", standardClaims)
if err != nil {
t.Fatalf("Failed to create JWT with %s algorithm: %v", tc.algorithm, err)
}
} else {
t.Fatalf("Unsupported key type: %s", tc.keyType)
}
// Verify the token
err = ts.tOidc.VerifyToken(jwtToken)
if tc.shouldSucceed {
if err != nil {
t.Errorf("Verification with %s failed: %v", tc.algorithm, err)
} else {
t.Logf("Successfully verified token with %s algorithm", tc.algorithm)
}
} else {
if err == nil {
t.Errorf("Verification with unsupported algorithm %s succeeded", tc.algorithm)
} else {
// Check that the error message indicates unsupported algorithm
if !strings.Contains(err.Error(), "unsupported algorithm") {
t.Errorf("Expected unsupported algorithm error for %s, but got: %v", tc.algorithm, err)
} else {
t.Logf("Correctly rejected unsupported algorithm %s: %v", tc.algorithm, err)
}
}
}
})
}
}
// createTestJWTWithECKey creates a JWT signed with an EC private key
func createTestJWTWithECKey(privateKey *ecdsa.PrivateKey, alg, kid string, claims map[string]interface{}) (string, error) {
// Create the header
header := map[string]interface{}{
"alg": alg,
"typ": "JWT",
"kid": kid,
}
// Encode header and claims to base64
headerJSON, err := json.Marshal(header)
if err != nil {
return "", fmt.Errorf("failed to marshal header: %v", err)
}
headerBase64 := base64.RawURLEncoding.EncodeToString(headerJSON)
claimsJSON, err := json.Marshal(claims)
if err != nil {
return "", fmt.Errorf("failed to marshal claims: %v", err)
}
claimsBase64 := base64.RawURLEncoding.EncodeToString(claimsJSON)
// Create the signing input
signingInput := headerBase64 + "." + claimsBase64
// Create signature based on algorithm
var signature []byte
switch alg {
case "ES256":
h := crypto.SHA256.New()
h.Write([]byte(signingInput))
hashed := h.Sum(nil)
r, s, err := ecdsa.Sign(rand.Reader, privateKey, hashed)
if err != nil {
return "", fmt.Errorf("failed to sign with ES256: %v", err)
}
// For ES256, each coordinate should be 32 bytes (256 bits / 8)
rBytes := r.Bytes()
sBytes := s.Bytes()
if len(rBytes) < 32 {
padded := make([]byte, 32)
copy(padded[32-len(rBytes):], rBytes)
rBytes = padded
}
if len(sBytes) < 32 {
padded := make([]byte, 32)
copy(padded[32-len(sBytes):], sBytes)
sBytes = padded
}
signature = append(rBytes, sBytes...)
case "ES384":
h := crypto.SHA384.New()
h.Write([]byte(signingInput))
hashed := h.Sum(nil)
r, s, err := ecdsa.Sign(rand.Reader, privateKey, hashed)
if err != nil {
return "", fmt.Errorf("failed to sign with ES384: %v", err)
}
// For ES384 (P-384), each coordinate should be 48 bytes (384 bits / 8)
rBytes := r.Bytes()
sBytes := s.Bytes()
// Pad to exactly 48 bytes each
if len(rBytes) < 48 {
padded := make([]byte, 48)
copy(padded[48-len(rBytes):], rBytes)
rBytes = padded
} else if len(rBytes) > 48 {
// Truncate if too long (shouldn't happen with P-384)
rBytes = rBytes[len(rBytes)-48:]
}
if len(sBytes) < 48 {
padded := make([]byte, 48)
copy(padded[48-len(sBytes):], sBytes)
sBytes = padded
} else if len(sBytes) > 48 {
// Truncate if too long (shouldn't happen with P-384)
sBytes = sBytes[len(sBytes)-48:]
}
signature = append(rBytes, sBytes...)
case "ES512":
h := crypto.SHA512.New()
h.Write([]byte(signingInput))
hashed := h.Sum(nil)
r, s, err := ecdsa.Sign(rand.Reader, privateKey, hashed)
if err != nil {
return "", fmt.Errorf("failed to sign with ES512: %v", err)
}
// For ES512 (P-521), each coordinate should be 66 bytes (521 bits / 8 = 65.125, rounded up to 66)
rBytes := r.Bytes()
sBytes := s.Bytes()
// Pad to 66 bytes each
if len(rBytes) < 66 {
padded := make([]byte, 66)
copy(padded[66-len(rBytes):], rBytes)
rBytes = padded
} else if len(rBytes) > 66 {
// Truncate if too long (shouldn't happen with P-521)
rBytes = rBytes[len(rBytes)-66:]
}
if len(sBytes) < 66 {
padded := make([]byte, 66)
copy(padded[66-len(sBytes):], sBytes)
sBytes = padded
} else if len(sBytes) > 66 {
// Truncate if too long (shouldn't happen with P-521)
sBytes = sBytes[len(sBytes)-66:]
}
signature = append(rBytes, sBytes...)
default:
return "", fmt.Errorf("unsupported EC algorithm: %s", alg)
}
// Encode signature
signatureBase64 := base64.RawURLEncoding.EncodeToString(signature)
// Combine to create JWT
return signingInput + "." + signatureBase64, nil
}
// createECJWK creates a JWK from an EC private key
func createECJWK(privateKey *ecdsa.PrivateKey, alg, kid string) JWK {
// Get the curve name
var crv string
switch privateKey.Curve {
case elliptic.P256():
crv = "P-256"
case elliptic.P384():
crv = "P-384"
case elliptic.P521():
crv = "P-521"
default:
panic("unsupported curve")
}
// Get the key size for coordinate encoding
keySize := (privateKey.Curve.Params().BitSize + 7) / 8
// Encode X and Y coordinates
xBytes := privateKey.PublicKey.X.Bytes()
yBytes := privateKey.PublicKey.Y.Bytes()
// Pad to the correct length
if len(xBytes) < keySize {
padded := make([]byte, keySize)
copy(padded[keySize-len(xBytes):], xBytes)
xBytes = padded
}
if len(yBytes) < keySize {
padded := make([]byte, keySize)
copy(padded[keySize-len(yBytes):], yBytes)
yBytes = padded
}
return JWK{
Kty: "EC",
Kid: kid,
Alg: alg,
Crv: crv,
X: base64.RawURLEncoding.EncodeToString(xBytes),
Y: base64.RawURLEncoding.EncodeToString(yBytes),
}
}
// TestMalformedTokens tests the plugin's handling of malformed tokens
func TestMalformedTokens(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
testCases := []struct {
name string
token string
expectedError string
}{
{
name: "Empty Token",
token: "",
expectedError: "invalid JWT format",
},
{
name: "Missing Parts",
token: "header.payload",
expectedError: "invalid JWT format",
},
{
name: "Invalid Base64 in Header",
token: "invalid!base64.payload.signature",
expectedError: "failed to decode header",
},
{
name: "Invalid Base64 in Payload",
token: "eyJhbGciOiJSUzI1NiJ9.invalid!base64.signature",
expectedError: "failed to decode claims",
},
{
name: "Invalid Base64 in Signature",
token: "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.invalid!base64",
expectedError: "failed to decode signature",
},
{
name: "Invalid JSON in Header",
token: "eyJpbnZhbGlkIGpzb24=.eyJzdWIiOiIxMjM0NTY3ODkwIn0.signature",
expectedError: "failed to decode header",
},
{
name: "Invalid JSON in Payload",
token: "eyJhbGciOiJSUzI1NiJ9.eyJpbnZhbGlkIGpzb24=.signature",
expectedError: "failed to decode claims",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := ts.tOidc.VerifyToken(tc.token)
// Should fail with expected error
if err == nil {
t.Errorf("Malformed token was incorrectly verified: %s", tc.token)
} else {
if !strings.Contains(err.Error(), tc.expectedError) {
t.Errorf("Expected error containing '%s', but got: %v", tc.expectedError, err)
}
}
})
}
}
// TestRateLimiting tests the rate limiting functionality to prevent brute force attacks
func TestRateLimiting(t *testing.T) {
// Create cleanup helper
tc := newTestCleanup(t)
// Create a fresh instance for this test to avoid affecting other tests with rate limiting
logger := NewLogger("debug")
// Create a new test suite for this test only
ts := NewTestSuite(t)
ts.Setup()
// Create a separate TraefikOidc instance with a very restrictive rate limiter
// This prevents the global instance from being rate-limited
tokenBlacklist := tc.addCache(NewCache())
tokenCache := tc.addTokenCache(NewTokenCache())
tOidc := &TraefikOidc{
issuerURL: "https://test-issuer.com",
clientID: "test-client-id",
clientSecret: "test-client-secret",
jwkCache: ts.mockJWKCache,
jwksURL: "https://test-jwks-url.com",
tokenBlacklist: tokenBlacklist,
tokenCache: tokenCache,
// Allow only 2 requests per 10 seconds
limiter: rate.NewLimiter(rate.Every(10*time.Second), 2),
logger: logger,
allowedUserDomains: map[string]struct{}{"example.com": {}},
excludedURLs: map[string]struct{}{"/favicon": {}},
httpClient: &http.Client{},
extractClaimsFunc: extractClaims,
}
// Set up the token verifier and JWT verifier
tOidc.jwtVerifier = tOidc
tOidc.tokenVerifier = tOidc
// Create a valid JWT token
validJWT, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
"sub": "test-subject",
"email": "user@example.com",
"jti": generateRandomString(16),
})
if err != nil {
t.Fatalf("Failed to create valid JWT: %v", err)
}
// First request should succeed
err = tOidc.VerifyToken(validJWT)
if err != nil {
t.Fatalf("First token verification failed unexpectedly: %v", err)
}
// Second request should succeed
validJWT2, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
"sub": "test-subject",
"email": "user@example.com",
"jti": generateRandomString(16),
})
if err != nil {
t.Fatalf("Failed to create second valid JWT: %v", err)
}
err = tOidc.VerifyToken(validJWT2)
if err != nil {
t.Fatalf("Second token verification failed unexpectedly: %v", err)
}
// Third request should be rate limited
validJWT3, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
"sub": "test-subject",
"email": "user@example.com",
"jti": generateRandomString(16),
})
if err != nil {
t.Fatalf("Failed to create third valid JWT: %v", err)
}
err = tOidc.VerifyToken(validJWT3)
if err == nil {
t.Errorf("Third token verification succeeded despite rate limiting")
} else {
// Check that the error message indicates rate limiting
if !strings.Contains(strings.ToLower(err.Error()), "rate") {
t.Errorf("Expected rate limiting error, but got: %v", err)
}
}
}
// TestAuthorizationHeaderBypass tests that the plugin correctly handles attempts to bypass
// authorization by directly providing an Authorization header
func TestAuthorizationHeaderBypass(t *testing.T) {
// Create cleanup helper
tc := newTestCleanup(t)
ts := NewTestSuite(t)
ts.Setup()
// Create a test next handler that would indicate successful authentication
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("Authenticated"))
})
// Create the TraefikOidc instance
tokenBlacklist := tc.addCache(NewCache())
tokenCache := tc.addTokenCache(NewTokenCache())
tOidc := &TraefikOidc{
next: nextHandler,
name: "test",
redirURLPath: "/callback",
logoutURLPath: "/callback/logout",
issuerURL: "https://test-issuer.com",
clientID: "test-client-id",
clientSecret: "test-client-secret",
jwkCache: ts.mockJWKCache,
jwksURL: "https://test-jwks-url.com",
tokenBlacklist: tokenBlacklist,
tokenCache: tokenCache,
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
logger: NewLogger("debug"),
allowedUserDomains: map[string]struct{}{"example.com": {}},
excludedURLs: map[string]struct{}{"/favicon": {}},
httpClient: &http.Client{},
initComplete: make(chan struct{}),
sessionManager: ts.sessionManager,
}
close(tOidc.initComplete)
// Create a request with a forged Authorization header but no valid session
req := httptest.NewRequest("GET", "/protected", nil)
req.Header.Set("X-Forwarded-Proto", "https")
req.Header.Set("X-Forwarded-Host", "example.com")
// Add a forged Authorization header
req.Header.Set("Authorization", "Bearer "+ts.token)
// Record the response
resp := httptest.NewRecorder()
// Process the request
tOidc.ServeHTTP(resp, req)
// The middleware should not honor the direct Authorization header
// and should either redirect to authentication or return an error
if resp.Code == http.StatusOK {
body := resp.Body.String()
if body == "Authenticated" {
t.Errorf("Authorization header bypass succeeded - request was authenticated without a valid session")
}
}
// Verify that the response is a redirect to authentication (302) or unauthorized (401)
expectedCodes := []int{http.StatusFound, http.StatusUnauthorized}
codeFound := false
for _, code := range expectedCodes {
if resp.Code == code {
codeFound = true
break
}
}
if !codeFound {
t.Errorf("Expected status code to be one of %v, but got %d", expectedCodes, resp.Code)
}
}
// TestEmptyAudience tests tokens with empty audience claim
func TestEmptyAudience(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
// Create a JWT with empty audience
emptyAudJWT, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "", // Empty audience
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
"sub": "test-subject",
"email": "user@example.com",
"jti": generateRandomString(16),
})
if err != nil {
t.Fatalf("Failed to create JWT with empty audience: %v", err)
}
// Verify the token
err = ts.tOidc.VerifyToken(emptyAudJWT)
// Should fail due to invalid audience
if err == nil {
t.Errorf("Token with empty audience was incorrectly verified")
} else {
// Check error message
if !strings.Contains(err.Error(), "invalid audience") {
t.Errorf("Expected invalid audience error, but got: %v", err)
}
}
}
// TestEmptyIssuer tests tokens with empty issuer claim
func TestEmptyIssuer(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
// Create a JWT with empty issuer
emptyIssJWT, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "", // Empty issuer
"aud": "test-client-id",
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
"sub": "test-subject",
"email": "user@example.com",
"jti": generateRandomString(16),
})
if err != nil {
t.Fatalf("Failed to create JWT with empty issuer: %v", err)
}
// Verify the token
err = ts.tOidc.VerifyToken(emptyIssJWT)
// Should fail due to invalid issuer
if err == nil {
t.Errorf("Token with empty issuer was incorrectly verified")
} else {
// Check error message
if !strings.Contains(err.Error(), "invalid issuer") {
t.Errorf("Expected invalid issuer error, but got: %v", err)
}
}
}
// TestInvalidRedirectURI tests the plugin's handling of invalid redirect URIs
func TestInvalidRedirectURI(t *testing.T) {
// Create cleanup helper
tc := newTestCleanup(t)
ts := NewTestSuite(t)
ts.Setup()
// Create a test request with an invalid redirect URI
req := httptest.NewRequest("GET", "/callback?state=validstate&code=validcode&redirect_uri=https://evil.com", nil)
req.Header.Set("X-Forwarded-Proto", "https")
req.Header.Set("X-Forwarded-Host", "example.com")
// Create a session with a state
session, err := ts.sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
// Set legitimate state and redirect
session.mainSession.Values["state"] = "validstate"
session.mainSession.Values["redirect"] = "/legitimate-redirect"
resp := httptest.NewRecorder()
if err := session.Save(req, resp); err != nil {
t.Fatalf("Failed to save session: %v", err)
}
// Get cookies
cookies := resp.Result().Cookies()
// Create a new request with those cookies
req = httptest.NewRequest("GET", "/callback?state=validstate&code=validcode&redirect_uri=https://evil.com", nil)
req.Header.Set("X-Forwarded-Proto", "https")
req.Header.Set("X-Forwarded-Host", "example.com")
// Add cookies
for _, cookie := range cookies {
req.AddCookie(cookie)
}
// Create a next handler for the middleware
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
// Create the TraefikOidc instance
tokenBlacklist := tc.addCache(NewCache())
tokenCache := tc.addTokenCache(NewTokenCache())
tOidc := &TraefikOidc{
next: nextHandler,
name: "test",
redirURLPath: "/callback",
logoutURLPath: "/callback/logout",
issuerURL: "https://test-issuer.com",
clientID: "test-client-id",
clientSecret: "test-client-secret",
jwkCache: ts.mockJWKCache,
jwksURL: "https://test-jwks-url.com",
tokenBlacklist: tokenBlacklist,
tokenCache: tokenCache,
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
logger: NewLogger("debug"),
allowedUserDomains: map[string]struct{}{"example.com": {}},
excludedURLs: map[string]struct{}{"/favicon": {}},
httpClient: &http.Client{},
initComplete: make(chan struct{}),
sessionManager: ts.sessionManager,
tokenExchanger: ts.tOidc.tokenExchanger,
}
close(tOidc.initComplete)
// Process the callback request
resp = httptest.NewRecorder()
tOidc.ServeHTTP(resp, req)
// Check if open redirect is blocked
// The response should not redirect to the evil.com domain
location := resp.Header().Get("Location")
if location != "" && strings.Contains(location, "evil.com") {
t.Errorf("Open redirect vulnerability - redirected to %s", location)
}
// Should redirect to the legitimate URL
if location != "" && !strings.Contains(location, "/legitimate-redirect") {
t.Errorf("Expected redirect to /legitimate-redirect, but got: %s", location)
}
}