mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-05 22:44:17 +00:00
9cbca4c4fb
patch-release
The refresh path in token_manager.go hardcoded the "email" claim when
extracting the user identifier from a refreshed ID token, ignoring the
configured userIdentifierClaim. Keycloak users without an email claim
(using sub or another identifier) were kicked out on refresh even
though their initial login worked.
The callback path (auth_flow.go:226-239) already honored
userIdentifierClaim with "sub" fallback; PR #100 (commit a316a98)
added that support but missed the refresh path.
Mirror the callback logic in refreshToken so both paths behave the same.
Cleanup: rename Get/SetEmail to Get/SetUserIdentifier on SessionData
to match the actual semantics. The slot already stored the configured
identifier (email, sub, oid, upn, preferred_username), only the API
name was misleading. Storage key "email" → "user_identifier" and
combinedSessionPayload field E (json:"e") → Ui (json:"ui").
Compat note: existing user sessions invalidate on upgrade — every active
user re-authenticates once after deploying this change.
1601 lines
49 KiB
Go
1601 lines
49 KiB
Go
package traefikoidc
|
|
|
|
import (
|
|
"crypto"
|
|
"crypto/ecdsa"
|
|
"crypto/elliptic"
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"golang.org/x/time/rate"
|
|
)
|
|
|
|
// TestJWTAlgorithmConfusionAttack tests if the plugin is vulnerable to JWT algorithm confusion attacks
|
|
// where an attacker might try to switch from an asymmetric algorithm (RS256) to a symmetric one (HS256)
|
|
func TestJWTAlgorithmConfusionAttack(t *testing.T) {
|
|
ts := NewTestSuite(t)
|
|
ts.Setup()
|
|
|
|
// Create a standard JWT with RS256 algorithm
|
|
validRS256JWT, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://test-issuer.com",
|
|
"aud": "test-client-id",
|
|
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
|
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
|
"sub": "test-subject",
|
|
"email": "user@example.com",
|
|
"jti": generateRandomString(16),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to create valid RS256 JWT: %v", err)
|
|
}
|
|
|
|
// Parse the JWT to manipulate it
|
|
parts := strings.Split(validRS256JWT, ".")
|
|
if len(parts) != 3 {
|
|
t.Fatalf("Invalid JWT format")
|
|
}
|
|
|
|
// Decode the header
|
|
headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0])
|
|
if err != nil {
|
|
t.Fatalf("Failed to decode header: %v", err)
|
|
}
|
|
|
|
// Parse header
|
|
var header map[string]interface{}
|
|
if err := json.Unmarshal(headerBytes, &header); err != nil {
|
|
t.Fatalf("Failed to unmarshal header: %v", err)
|
|
}
|
|
|
|
// Modify the algorithm to HS256 (symmetric)
|
|
header["alg"] = "HS256"
|
|
modifiedHeaderBytes, err := json.Marshal(header)
|
|
if err != nil {
|
|
t.Fatalf("Failed to marshal modified header: %v", err)
|
|
}
|
|
|
|
// Encode header
|
|
modifiedHeader := base64.RawURLEncoding.EncodeToString(modifiedHeaderBytes)
|
|
|
|
// Create a manipulated JWT with algorithm confusion attack
|
|
manipulatedJWT := modifiedHeader + "." + parts[1] + "." + parts[2]
|
|
|
|
// Attempt to verify the manipulated token
|
|
err = ts.tOidc.VerifyToken(manipulatedJWT)
|
|
|
|
// Should fail with algorithm error
|
|
if err == nil {
|
|
t.Errorf("Algorithm confusion attack succeeded - token with HS256 algorithm was incorrectly verified")
|
|
} else {
|
|
// Check that the error message indicates unsupported algorithm
|
|
if !strings.Contains(err.Error(), "unsupported algorithm") {
|
|
t.Errorf("Expected unsupported algorithm error, but got: %v", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestJWTNoneAlgorithmAttack tests the plugin's resistance to the "none" algorithm attack
|
|
// where an attacker removes the signature and sets the algorithm to "none"
|
|
func TestJWTNoneAlgorithmAttack(t *testing.T) {
|
|
ts := NewTestSuite(t)
|
|
ts.Setup()
|
|
|
|
// Create a standard JWT
|
|
validJWT, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://test-issuer.com",
|
|
"aud": "test-client-id",
|
|
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
|
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
|
"sub": "test-subject",
|
|
"email": "user@example.com",
|
|
"jti": generateRandomString(16),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to create valid JWT: %v", err)
|
|
}
|
|
|
|
// Parse the JWT to manipulate it
|
|
parts := strings.Split(validJWT, ".")
|
|
if len(parts) != 3 {
|
|
t.Fatalf("Invalid JWT format")
|
|
}
|
|
|
|
// Decode the header
|
|
headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0])
|
|
if err != nil {
|
|
t.Fatalf("Failed to decode header: %v", err)
|
|
}
|
|
|
|
// Parse header
|
|
var header map[string]interface{}
|
|
if err := json.Unmarshal(headerBytes, &header); err != nil {
|
|
t.Fatalf("Failed to unmarshal header: %v", err)
|
|
}
|
|
|
|
// Modify the algorithm to "none"
|
|
header["alg"] = "none"
|
|
modifiedHeaderBytes, err := json.Marshal(header)
|
|
if err != nil {
|
|
t.Fatalf("Failed to marshal modified header: %v", err)
|
|
}
|
|
|
|
// Encode header
|
|
modifiedHeader := base64.RawURLEncoding.EncodeToString(modifiedHeaderBytes)
|
|
|
|
// Create a manipulated JWT with empty signature
|
|
manipulatedJWT := modifiedHeader + "." + parts[1] + "."
|
|
|
|
// Attempt to verify the manipulated token
|
|
err = ts.tOidc.VerifyToken(manipulatedJWT)
|
|
|
|
// Should fail with algorithm error
|
|
if err == nil {
|
|
t.Errorf("None algorithm attack succeeded - token with 'none' algorithm was incorrectly verified")
|
|
} else {
|
|
// Check that the error message indicates unsupported algorithm
|
|
if !strings.Contains(err.Error(), "unsupported algorithm") {
|
|
t.Errorf("Expected unsupported algorithm error, but got: %v", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestJWTTokenTampering tests the plugin's ability to detect modifications to the JWT payload
|
|
func TestJWTTokenTampering(t *testing.T) {
|
|
ts := NewTestSuite(t)
|
|
ts.Setup()
|
|
|
|
// Create a standard JWT
|
|
validJWT, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://test-issuer.com",
|
|
"aud": "test-client-id",
|
|
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
|
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
|
"sub": "test-subject",
|
|
"email": "user@example.com",
|
|
"jti": generateRandomString(16),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to create valid JWT: %v", err)
|
|
}
|
|
|
|
// Parse the JWT to manipulate it
|
|
parts := strings.Split(validJWT, ".")
|
|
if len(parts) != 3 {
|
|
t.Fatalf("Invalid JWT format")
|
|
}
|
|
|
|
// Decode the claims (payload)
|
|
claimsBytes, err := base64.RawURLEncoding.DecodeString(parts[1])
|
|
if err != nil {
|
|
t.Fatalf("Failed to decode claims: %v", err)
|
|
}
|
|
|
|
// Parse claims
|
|
var claims map[string]interface{}
|
|
if err := json.Unmarshal(claimsBytes, &claims); err != nil {
|
|
t.Fatalf("Failed to unmarshal claims: %v", err)
|
|
}
|
|
|
|
// Modify the claims (elevate privileges by changing email)
|
|
claims["email"] = "admin@example.com"
|
|
modifiedClaimsBytes, err := json.Marshal(claims)
|
|
if err != nil {
|
|
t.Fatalf("Failed to marshal modified claims: %v", err)
|
|
}
|
|
|
|
// Encode claims
|
|
modifiedClaims := base64.RawURLEncoding.EncodeToString(modifiedClaimsBytes)
|
|
|
|
// Create a manipulated JWT with modified claims but original signature
|
|
manipulatedJWT := parts[0] + "." + modifiedClaims + "." + parts[2]
|
|
|
|
// Attempt to verify the manipulated token
|
|
err = ts.tOidc.VerifyToken(manipulatedJWT)
|
|
|
|
// Should fail with signature verification error
|
|
if err == nil {
|
|
t.Errorf("Token tampering attack succeeded - modified token was incorrectly verified")
|
|
} else {
|
|
// The error should be related to signature verification
|
|
if !strings.Contains(strings.ToLower(err.Error()), "signature") &&
|
|
!strings.Contains(strings.ToLower(err.Error()), "verify") {
|
|
t.Errorf("Expected signature verification error, but got: %v", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestJWTExpiredToken tests the plugin's handling of expired tokens
|
|
func TestJWTExpiredToken(t *testing.T) {
|
|
ts := NewTestSuite(t)
|
|
ts.Setup()
|
|
|
|
// Create a JWT that is already expired
|
|
expiredJWT, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://test-issuer.com",
|
|
"aud": "test-client-id",
|
|
"exp": float64(time.Now().Add(-1 * time.Hour).Unix()), // Expired 1 hour ago
|
|
"iat": float64(time.Now().Add(-2 * time.Hour).Unix()),
|
|
"sub": "test-subject",
|
|
"email": "user@example.com",
|
|
"jti": generateRandomString(16),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to create expired JWT: %v", err)
|
|
}
|
|
|
|
// Attempt to verify the expired token
|
|
err = ts.tOidc.VerifyToken(expiredJWT)
|
|
|
|
// Should fail with expiration error
|
|
if err == nil {
|
|
t.Errorf("Expired token was incorrectly verified")
|
|
} else {
|
|
// Check that the error message indicates token expiration
|
|
if !strings.Contains(err.Error(), "expired") {
|
|
t.Errorf("Expected token expiration error, but got: %v", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestJWTFutureToken tests the plugin's handling of tokens issued in the future
|
|
func TestJWTFutureToken(t *testing.T) {
|
|
ts := NewTestSuite(t)
|
|
ts.Setup()
|
|
|
|
// Create a JWT with a future issuance time
|
|
futureJWT, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://test-issuer.com",
|
|
"aud": "test-client-id",
|
|
"exp": float64(time.Now().Add(2 * time.Hour).Unix()),
|
|
"iat": float64(time.Now().Add(1 * time.Hour).Unix()), // Issued 1 hour in the future
|
|
"sub": "test-subject",
|
|
"email": "user@example.com",
|
|
"jti": generateRandomString(16),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to create future JWT: %v", err)
|
|
}
|
|
|
|
// Attempt to verify the future token
|
|
err = ts.tOidc.VerifyToken(futureJWT)
|
|
|
|
// Should fail with issuance time error
|
|
if err == nil {
|
|
t.Errorf("Future-issued token was incorrectly verified")
|
|
} else {
|
|
// Check that the error message indicates token issuance time issue
|
|
if !strings.Contains(err.Error(), "used before issued") {
|
|
t.Errorf("Expected token issuance time error, but got: %v", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestJWTReplayAttack tests the plugin's protection against token replay attacks
|
|
func TestJWTReplayAttack(t *testing.T) {
|
|
// Create cleanup helper
|
|
tc := newTestCleanup(t)
|
|
|
|
// Create a new instance for this test to avoid interference from global state
|
|
logger := NewLogger("debug")
|
|
tokenBlacklist := tc.addCache(NewCache())
|
|
tokenCache := tc.addTokenCache(NewTokenCache())
|
|
|
|
// Create keys
|
|
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
if err != nil {
|
|
t.Fatalf("Failed to generate RSA key: %v", err)
|
|
}
|
|
rsaPublicKey := &rsaPrivateKey.PublicKey
|
|
|
|
// Create JWK
|
|
jwk := JWK{
|
|
Kty: "RSA",
|
|
Kid: "test-key-id",
|
|
Alg: "RS256",
|
|
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
|
|
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}), // 65537 in bytes
|
|
}
|
|
jwks := &JWKSet{
|
|
Keys: []JWK{jwk},
|
|
}
|
|
|
|
// Create mock JWK cache
|
|
mockJWKCache := &MockJWKCache{
|
|
JWKS: jwks,
|
|
Err: nil,
|
|
}
|
|
|
|
// Create a fixed JTI (JWT ID) to simulate replay
|
|
fixedJTI := "fixed-test-jti-for-replay-" + generateRandomString(8)
|
|
|
|
// Create a JWT with the fixed JTI
|
|
replayJWT, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://test-issuer.com",
|
|
"aud": "test-client-id",
|
|
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
|
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
|
"sub": "test-subject",
|
|
"email": "user@example.com",
|
|
"jti": fixedJTI, // Fixed JTI to test replay protection
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to create JWT for replay test: %v", err)
|
|
}
|
|
|
|
// Create the TraefikOidc instance
|
|
tOidc := &TraefikOidc{
|
|
issuerURL: "https://test-issuer.com",
|
|
clientID: "test-client-id",
|
|
audience: "test-client-id",
|
|
clientSecret: "test-client-secret",
|
|
jwkCache: mockJWKCache,
|
|
jwksURL: "https://test-jwks-url.com",
|
|
tokenBlacklist: tokenBlacklist,
|
|
tokenCache: tokenCache,
|
|
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
|
logger: logger,
|
|
allowedUserDomains: map[string]struct{}{"example.com": {}},
|
|
excludedURLs: map[string]struct{}{"/favicon": {}},
|
|
httpClient: &http.Client{},
|
|
extractClaimsFunc: extractClaims,
|
|
}
|
|
|
|
// Set up the token verifier and JWT verifier
|
|
tOidc.jwtVerifier = tOidc
|
|
tOidc.tokenVerifier = tOidc
|
|
|
|
// First verification should succeed
|
|
err = tOidc.VerifyToken(replayJWT)
|
|
if err != nil {
|
|
t.Fatalf("First verification of token failed unexpectedly: %v", err)
|
|
}
|
|
|
|
// Verify that the JTI was blacklisted
|
|
if blacklisted, exists := tOidc.tokenBlacklist.Get(fixedJTI); !exists || blacklisted == nil {
|
|
t.Fatalf("JTI was not added to blacklist after first verification")
|
|
}
|
|
|
|
// Since there's a special bypass for tokens starting with the test JWT prefix,
|
|
// we need to test with a direct check of the blacklisted JTI instead
|
|
|
|
// Directly verify that a replay would be caught by checking the blacklist
|
|
if blacklisted, exists := tOidc.tokenBlacklist.Get(fixedJTI); !exists || blacklisted == nil {
|
|
t.Errorf("JTI was not properly blacklisted for replay protection")
|
|
}
|
|
|
|
// Also verify our JTI replay detection function directly
|
|
claims, _ := extractClaims(replayJWT)
|
|
if claims != nil {
|
|
if jti, ok := claims["jti"].(string); ok && jti != "" {
|
|
if blacklisted, exists := tOidc.tokenBlacklist.Get(jti); exists && blacklisted != nil {
|
|
t.Logf("Replay protection verified: JTI %s is correctly blacklisted", jti)
|
|
} else {
|
|
t.Errorf("JTI %s was not found in blacklist", jti)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestMissingClaims tests validation of tokens with missing required claims
|
|
func TestMissingClaims(t *testing.T) {
|
|
ts := NewTestSuite(t)
|
|
ts.Setup()
|
|
|
|
// Test cases for missing claims
|
|
testCases := []struct {
|
|
name string
|
|
expectedError string
|
|
omittedClaims []string
|
|
}{
|
|
{
|
|
name: "Missing Issuer",
|
|
omittedClaims: []string{"iss"},
|
|
expectedError: "missing 'iss'",
|
|
},
|
|
{
|
|
name: "Missing Audience",
|
|
omittedClaims: []string{"aud"},
|
|
expectedError: "missing 'aud'",
|
|
},
|
|
{
|
|
name: "Missing Expiration",
|
|
omittedClaims: []string{"exp"},
|
|
expectedError: "missing or invalid 'exp'",
|
|
},
|
|
{
|
|
name: "Missing IssuedAt",
|
|
omittedClaims: []string{"iat"},
|
|
expectedError: "missing or invalid 'iat'",
|
|
},
|
|
{
|
|
name: "Missing Subject",
|
|
omittedClaims: []string{"sub"},
|
|
expectedError: "missing or empty 'sub'",
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
// Create standard claims
|
|
claims := map[string]interface{}{
|
|
"iss": "https://test-issuer.com",
|
|
"aud": "test-client-id",
|
|
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
|
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
|
"sub": "test-subject",
|
|
"email": "user@example.com",
|
|
"jti": generateRandomString(16),
|
|
}
|
|
|
|
// Remove specified claims
|
|
for _, claim := range tc.omittedClaims {
|
|
delete(claims, claim)
|
|
}
|
|
|
|
// Create JWT with missing claims
|
|
invalidJWT, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create JWT with missing claims: %v", err)
|
|
}
|
|
|
|
// Attempt to verify the token
|
|
err = ts.tOidc.VerifyToken(invalidJWT)
|
|
|
|
// Should fail with the expected error
|
|
if err == nil {
|
|
t.Errorf("Token with missing %v claim was incorrectly verified", tc.omittedClaims)
|
|
} else {
|
|
if !strings.Contains(err.Error(), tc.expectedError) {
|
|
t.Errorf("Expected error containing '%s', but got: %v", tc.expectedError, err)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestSessionFixationAttack tests the plugin's resistance to session fixation attacks
|
|
func TestSessionFixationAttack(t *testing.T) {
|
|
// Create cleanup helper
|
|
tc := newTestCleanup(t)
|
|
|
|
logger := NewLogger("debug")
|
|
sm, err := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, "", "", 0, logger)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create session manager: %v", err)
|
|
}
|
|
|
|
// Create a test request
|
|
req := httptest.NewRequest("GET", "http://example.com/protected", nil)
|
|
resp := httptest.NewRecorder()
|
|
|
|
// Create an attacker's session
|
|
attackerSession, err := sm.GetSession(req)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get attacker session: %v", err)
|
|
}
|
|
|
|
// Set up the attacker's session with malicious data
|
|
attackerSession.SetAuthenticated(true)
|
|
attackerSession.SetUserIdentifier("attacker@evil.com")
|
|
attackerSession.SetIDToken(ValidIDToken)
|
|
attackerSession.SetAccessToken(ValidAccessToken)
|
|
|
|
// Save the session to get cookies
|
|
if err := attackerSession.Save(req, resp); err != nil {
|
|
t.Fatalf("Failed to save attacker session: %v", err)
|
|
}
|
|
|
|
// Extract the cookies from the response
|
|
attackerCookies := resp.Result().Cookies()
|
|
|
|
// Create a test next handler that would be called after successful authentication
|
|
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Get the current session
|
|
session, err := sm.GetSession(r)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get session in next handler: %v", err)
|
|
}
|
|
|
|
// Check if the session is authenticated
|
|
if !session.GetAuthenticated() {
|
|
w.WriteHeader(http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
// Get the email from the session
|
|
email := session.GetUserIdentifier()
|
|
w.Header().Set("X-User-Email", email)
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
// Create keys for JWT verification
|
|
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
if err != nil {
|
|
t.Fatalf("Failed to generate RSA key: %v", err)
|
|
}
|
|
rsaPublicKey := &rsaPrivateKey.PublicKey
|
|
|
|
// Create JWK
|
|
jwk := JWK{
|
|
Kty: "RSA",
|
|
Kid: "test-key-id",
|
|
Alg: "RS256",
|
|
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
|
|
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}), // 65537 in bytes
|
|
}
|
|
jwks := &JWKSet{
|
|
Keys: []JWK{jwk},
|
|
}
|
|
|
|
// Create mock JWK cache
|
|
mockJWKCache := &MockJWKCache{
|
|
JWKS: jwks,
|
|
Err: nil,
|
|
}
|
|
|
|
// Create the TraefikOidc middleware
|
|
tokenBlacklist := tc.addCache(NewCache())
|
|
tokenCache := tc.addTokenCache(NewTokenCache())
|
|
tOidc := &TraefikOidc{
|
|
next: nextHandler,
|
|
name: "test",
|
|
redirURLPath: "/callback",
|
|
logoutURLPath: "/callback/logout",
|
|
issuerURL: "https://test-issuer.com",
|
|
clientID: "test-client-id",
|
|
audience: "test-client-id",
|
|
clientSecret: "test-client-secret",
|
|
jwkCache: mockJWKCache,
|
|
jwksURL: "https://test-jwks-url.com",
|
|
tokenBlacklist: tokenBlacklist,
|
|
tokenCache: tokenCache,
|
|
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
|
logger: logger,
|
|
allowedUserDomains: map[string]struct{}{"example.com": {}},
|
|
excludedURLs: map[string]struct{}{"/favicon": {}},
|
|
httpClient: &http.Client{},
|
|
initComplete: make(chan struct{}),
|
|
sessionManager: sm,
|
|
extractClaimsFunc: extractClaims,
|
|
}
|
|
|
|
// Set up the token verifier and JWT verifier
|
|
tOidc.jwtVerifier = tOidc
|
|
tOidc.tokenVerifier = tOidc
|
|
|
|
close(tOidc.initComplete)
|
|
|
|
// Now create a victim's request with the attacker's cookies
|
|
victimReq := httptest.NewRequest("GET", "http://example.com/protected", nil)
|
|
|
|
// Add the attacker's cookies to the victim's request
|
|
for _, cookie := range attackerCookies {
|
|
victimReq.AddCookie(cookie)
|
|
}
|
|
|
|
// Set common request headers
|
|
victimReq.Header.Set("X-Forwarded-Proto", "https")
|
|
victimReq.Header.Set("X-Forwarded-Host", "example.com")
|
|
|
|
victimResp := httptest.NewRecorder()
|
|
|
|
// Process the victim's request
|
|
tOidc.ServeHTTP(victimResp, victimReq)
|
|
|
|
// Check if the session fixation attack was prevented
|
|
// The victim should either:
|
|
// 1. Be redirected to authenticate (302 status) OR
|
|
// 2. Receive an unauthorized error (401 status)
|
|
// but NOT be authenticated as the attacker
|
|
if victimResp.Code == http.StatusOK {
|
|
// If we got a 200 OK, check if the user was authenticated as the attacker
|
|
if email := victimResp.Header().Get("X-User-Email"); email == "attacker@evil.com" {
|
|
t.Errorf("Session fixation attack succeeded - victim authenticated as attacker")
|
|
}
|
|
}
|
|
|
|
// Verify that either:
|
|
// - The response is a redirect to the login page (302), OR
|
|
// - The response is unauthorized (401), OR
|
|
// - The token verification failed
|
|
expectedCodes := []int{http.StatusFound, http.StatusUnauthorized, http.StatusForbidden}
|
|
codeFound := false
|
|
for _, code := range expectedCodes {
|
|
if victimResp.Code == code {
|
|
codeFound = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if !codeFound {
|
|
t.Errorf("Expected status code to be one of %v, but got %d", expectedCodes, victimResp.Code)
|
|
}
|
|
}
|
|
|
|
// TestCSRFProtection tests the plugin's CSRF protection mechanisms
|
|
// TestCSRFProtection tests CSRF protection in POST requests
|
|
func TestCSRFProtection(t *testing.T) {
|
|
logger := NewLogger("debug")
|
|
sm, err := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, "", "", 0, logger)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create session manager: %v", err)
|
|
}
|
|
|
|
// Test case 1: Valid CSRF token should succeed
|
|
t.Run("Valid CSRF token", func(t *testing.T) {
|
|
req := httptest.NewRequest("POST", "http://example.com/protected", nil)
|
|
resp := httptest.NewRecorder()
|
|
|
|
// Create a session and set CSRF token
|
|
session, err := sm.GetSession(req)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get session: %v", err)
|
|
}
|
|
|
|
csrfToken := "valid-csrf-token-12345"
|
|
session.SetCSRF(csrfToken)
|
|
if err := session.Save(req, resp); err != nil {
|
|
t.Fatalf("Failed to save session: %v", err)
|
|
}
|
|
|
|
// Get cookies from response
|
|
cookies := resp.Result().Cookies()
|
|
|
|
// Create new request with CSRF token in header and cookies
|
|
req = httptest.NewRequest("POST", "http://example.com/protected", nil)
|
|
req.Header.Set("X-CSRF-Token", csrfToken)
|
|
for _, cookie := range cookies {
|
|
req.AddCookie(cookie)
|
|
}
|
|
|
|
// Get session again to verify CSRF
|
|
session, err = sm.GetSession(req)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get session with cookies: %v", err)
|
|
}
|
|
|
|
sessionCSRF := session.GetCSRF()
|
|
if sessionCSRF != csrfToken {
|
|
t.Errorf("CSRF token mismatch: expected %s, got %s", csrfToken, sessionCSRF)
|
|
}
|
|
|
|
// Verify CSRF token matches
|
|
headerCSRF := req.Header.Get("X-CSRF-Token")
|
|
if headerCSRF != sessionCSRF {
|
|
t.Errorf("CSRF validation failed: header token %s != session token %s", headerCSRF, sessionCSRF)
|
|
}
|
|
})
|
|
|
|
// Test case 2: Missing CSRF token should fail
|
|
t.Run("Missing CSRF token", func(t *testing.T) {
|
|
req := httptest.NewRequest("POST", "http://example.com/protected", nil)
|
|
resp := httptest.NewRecorder()
|
|
|
|
// Create a session with CSRF token
|
|
session, err := sm.GetSession(req)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get session: %v", err)
|
|
}
|
|
|
|
csrfToken := "expected-csrf-token-67890"
|
|
session.SetCSRF(csrfToken)
|
|
if err := session.Save(req, resp); err != nil {
|
|
t.Fatalf("Failed to save session: %v", err)
|
|
}
|
|
|
|
// Get cookies from response
|
|
cookies := resp.Result().Cookies()
|
|
|
|
// Create new request WITHOUT CSRF token in header but with cookies
|
|
req = httptest.NewRequest("POST", "http://example.com/protected", nil)
|
|
// Intentionally NOT setting X-CSRF-Token header
|
|
for _, cookie := range cookies {
|
|
req.AddCookie(cookie)
|
|
}
|
|
|
|
// Get session to verify CSRF exists
|
|
session, err = sm.GetSession(req)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get session with cookies: %v", err)
|
|
}
|
|
|
|
sessionCSRF := session.GetCSRF()
|
|
headerCSRF := req.Header.Get("X-CSRF-Token")
|
|
|
|
// This should fail - no CSRF token in header
|
|
if headerCSRF == sessionCSRF && headerCSRF != "" {
|
|
t.Errorf("CSRF protection failed: request without CSRF token was accepted")
|
|
}
|
|
|
|
if headerCSRF == "" && sessionCSRF != "" {
|
|
t.Logf("CSRF protection working: missing header token, session has %s", sessionCSRF)
|
|
}
|
|
})
|
|
|
|
// Test case 3: Invalid CSRF token should fail
|
|
t.Run("Invalid CSRF token", func(t *testing.T) {
|
|
req := httptest.NewRequest("POST", "http://example.com/protected", nil)
|
|
resp := httptest.NewRecorder()
|
|
|
|
// Create a session with CSRF token
|
|
session, err := sm.GetSession(req)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get session: %v", err)
|
|
}
|
|
|
|
csrfToken := "valid-csrf-token-abcdef"
|
|
session.SetCSRF(csrfToken)
|
|
if err := session.Save(req, resp); err != nil {
|
|
t.Fatalf("Failed to save session: %v", err)
|
|
}
|
|
|
|
// Get cookies from response
|
|
cookies := resp.Result().Cookies()
|
|
|
|
// Create new request with WRONG CSRF token in header
|
|
req = httptest.NewRequest("POST", "http://example.com/protected", nil)
|
|
req.Header.Set("X-CSRF-Token", "wrong-csrf-token-xyz")
|
|
for _, cookie := range cookies {
|
|
req.AddCookie(cookie)
|
|
}
|
|
|
|
// Get session to verify CSRF
|
|
session, err = sm.GetSession(req)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get session with cookies: %v", err)
|
|
}
|
|
|
|
sessionCSRF := session.GetCSRF()
|
|
headerCSRF := req.Header.Get("X-CSRF-Token")
|
|
|
|
// This should fail - wrong CSRF token
|
|
if headerCSRF == sessionCSRF {
|
|
t.Errorf("CSRF protection failed: request with wrong CSRF token was accepted")
|
|
}
|
|
|
|
if headerCSRF != sessionCSRF {
|
|
t.Logf("CSRF protection working: header token %s != session token %s", headerCSRF, sessionCSRF)
|
|
}
|
|
})
|
|
|
|
// Test case 4: CSRF token generation and validation
|
|
t.Run("CSRF token generation", func(t *testing.T) {
|
|
req := httptest.NewRequest("GET", "http://example.com/login", nil)
|
|
resp := httptest.NewRecorder()
|
|
|
|
// Create a session
|
|
session, err := sm.GetSession(req)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get session: %v", err)
|
|
}
|
|
|
|
// Generate and set CSRF token
|
|
csrfToken := generateRandomString(32)
|
|
if len(csrfToken) != 32 {
|
|
t.Errorf("CSRF token length incorrect: expected 32, got %d", len(csrfToken))
|
|
}
|
|
|
|
session.SetCSRF(csrfToken)
|
|
if err := session.Save(req, resp); err != nil {
|
|
t.Fatalf("Failed to save session: %v", err)
|
|
}
|
|
|
|
// Verify token was stored
|
|
storedToken := session.GetCSRF()
|
|
if storedToken != csrfToken {
|
|
t.Errorf("CSRF token storage failed: expected %s, got %s", csrfToken, storedToken)
|
|
}
|
|
|
|
// Verify token is not empty and has reasonable entropy
|
|
if storedToken == "" {
|
|
t.Error("CSRF token is empty")
|
|
}
|
|
|
|
if len(storedToken) < 16 {
|
|
t.Errorf("CSRF token too short: %d characters", len(storedToken))
|
|
}
|
|
})
|
|
}
|
|
|
|
// TestTokenBlacklisting tests the token blacklisting mechanism
|
|
func TestTokenBlacklisting(t *testing.T) {
|
|
// Create cleanup helper
|
|
tc := newTestCleanup(t)
|
|
|
|
// Create a new instance for this test to avoid interference from global state
|
|
logger := NewLogger("debug")
|
|
tokenBlacklist := tc.addCache(NewCache())
|
|
tokenCache := tc.addTokenCache(NewTokenCache())
|
|
|
|
// Create keys
|
|
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
if err != nil {
|
|
t.Fatalf("Failed to generate RSA key: %v", err)
|
|
}
|
|
rsaPublicKey := &rsaPrivateKey.PublicKey
|
|
|
|
// Create JWK
|
|
jwk := JWK{
|
|
Kty: "RSA",
|
|
Kid: "test-key-id",
|
|
Alg: "RS256",
|
|
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
|
|
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}), // 65537 in bytes
|
|
}
|
|
jwks := &JWKSet{
|
|
Keys: []JWK{jwk},
|
|
}
|
|
|
|
// Create mock JWK cache
|
|
mockJWKCache := &MockJWKCache{
|
|
JWKS: jwks,
|
|
Err: nil,
|
|
}
|
|
|
|
// Create a valid JWT
|
|
validJWT, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://test-issuer.com",
|
|
"aud": "test-client-id",
|
|
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
|
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
|
"sub": "test-subject",
|
|
"email": "user@example.com",
|
|
"jti": generateRandomString(16),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to create valid JWT: %v", err)
|
|
}
|
|
|
|
// Create the TraefikOidc instance
|
|
tOidc := &TraefikOidc{
|
|
issuerURL: "https://test-issuer.com",
|
|
clientID: "test-client-id",
|
|
audience: "test-client-id",
|
|
clientSecret: "test-client-secret",
|
|
jwkCache: mockJWKCache,
|
|
jwksURL: "https://test-jwks-url.com",
|
|
tokenBlacklist: tokenBlacklist,
|
|
tokenCache: tokenCache,
|
|
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
|
logger: logger,
|
|
allowedUserDomains: map[string]struct{}{"example.com": {}},
|
|
excludedURLs: map[string]struct{}{"/favicon": {}},
|
|
httpClient: &http.Client{},
|
|
extractClaimsFunc: extractClaims,
|
|
}
|
|
|
|
// Set up the token verifier and JWT verifier
|
|
tOidc.jwtVerifier = tOidc
|
|
tOidc.tokenVerifier = tOidc
|
|
|
|
// First verification should succeed
|
|
err = tOidc.VerifyToken(validJWT)
|
|
if err != nil {
|
|
t.Fatalf("First verification failed unexpectedly: %v", err)
|
|
}
|
|
|
|
// Now blacklist the token directly
|
|
tOidc.tokenBlacklist.Set(validJWT, true, time.Hour)
|
|
|
|
// Second verification should fail due to blacklisting
|
|
err = tOidc.VerifyToken(validJWT)
|
|
if err == nil {
|
|
t.Errorf("Verification succeeded despite token being blacklisted")
|
|
} else {
|
|
// Verify the error message indicates the token is blacklisted
|
|
if !strings.Contains(strings.ToLower(err.Error()), "blacklisted") {
|
|
t.Errorf("Expected blacklist error, but got: %v", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestDifferentSigningAlgorithms tests that the plugin properly handles different signing algorithms
|
|
func TestDifferentSigningAlgorithms(t *testing.T) {
|
|
ts := NewTestSuite(t)
|
|
ts.Setup()
|
|
|
|
// Test cases for different algorithms - the implementation actually supports multiple algorithms
|
|
testCases := []struct {
|
|
name string
|
|
algorithm string
|
|
keyType string
|
|
shouldSucceed bool
|
|
}{
|
|
// RSA algorithms
|
|
{"RS256 Algorithm", "RS256", "RSA", true},
|
|
{"RS384 Algorithm", "RS384", "RSA", true},
|
|
{"RS512 Algorithm", "RS512", "RSA", true},
|
|
{"PS256 Algorithm", "PS256", "RSA", true},
|
|
{"PS384 Algorithm", "PS384", "RSA", true},
|
|
{"PS512 Algorithm", "PS512", "RSA", true},
|
|
|
|
// EC algorithms
|
|
{"ES256 Algorithm", "ES256", "EC", true},
|
|
{"ES384 Algorithm", "ES384", "EC", true},
|
|
{"ES512 Algorithm", "ES512", "EC", true},
|
|
|
|
// Unsupported algorithms
|
|
{"HS256 Algorithm", "HS256", "RSA", false},
|
|
{"HS384 Algorithm", "HS384", "RSA", false},
|
|
{"HS512 Algorithm", "HS512", "RSA", false},
|
|
{"None Algorithm", "none", "RSA", false},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
// Define standard claims with unique JTI for each test
|
|
standardClaims := map[string]interface{}{
|
|
"iss": "https://test-issuer.com",
|
|
"aud": "test-client-id",
|
|
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
|
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
|
"sub": "test-subject",
|
|
"email": "user@example.com",
|
|
"jti": generateRandomString(16), // Generate unique JTI for each test
|
|
}
|
|
|
|
var jwtToken string
|
|
var err error
|
|
|
|
// Use appropriate key type and create corresponding JWK
|
|
if tc.keyType == "RSA" {
|
|
// Update the RSA JWK to support the current algorithm
|
|
rsaJWK := JWK{
|
|
Kty: "RSA",
|
|
Kid: "test-key-id",
|
|
Alg: tc.algorithm, // Use the algorithm being tested
|
|
N: base64.RawURLEncoding.EncodeToString(ts.rsaPrivateKey.PublicKey.N.Bytes()),
|
|
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}), // 65537 in bytes
|
|
}
|
|
|
|
// Update the mock JWK cache with the correct algorithm
|
|
ts.mockJWKCache.JWKS = &JWKSet{
|
|
Keys: []JWK{rsaJWK},
|
|
}
|
|
|
|
jwtToken, err = createTestJWT(ts.rsaPrivateKey, tc.algorithm, "test-key-id", standardClaims)
|
|
if err != nil {
|
|
if !tc.shouldSucceed {
|
|
t.Logf("Expected failure creating JWT with %s algorithm: %v", tc.algorithm, err)
|
|
return // This is expected for unsupported algorithms
|
|
}
|
|
t.Fatalf("Failed to create JWT with %s algorithm: %v", tc.algorithm, err)
|
|
}
|
|
} else if tc.keyType == "EC" {
|
|
// Generate EC key for the specific curve
|
|
var curve elliptic.Curve
|
|
switch tc.algorithm {
|
|
case "ES256":
|
|
curve = elliptic.P256()
|
|
case "ES384":
|
|
curve = elliptic.P384()
|
|
case "ES512":
|
|
curve = elliptic.P521()
|
|
default:
|
|
t.Fatalf("Unsupported EC algorithm: %s", tc.algorithm)
|
|
}
|
|
|
|
ecPrivateKey, err := ecdsa.GenerateKey(curve, rand.Reader)
|
|
if err != nil {
|
|
t.Fatalf("Failed to generate EC key for %s: %v", tc.algorithm, err)
|
|
}
|
|
|
|
// Create EC JWK for this test
|
|
ecJWK := createECJWK(ecPrivateKey, tc.algorithm, "test-ec-key-id")
|
|
|
|
// Replace the JWK cache entirely with just the EC key for this test
|
|
ts.mockJWKCache.JWKS = &JWKSet{
|
|
Keys: []JWK{ecJWK},
|
|
}
|
|
|
|
// Ensure rate limiter is initialized for EC tests
|
|
if ts.tOidc.limiter == nil {
|
|
ts.tOidc.limiter = rate.NewLimiter(rate.Every(time.Second), 10)
|
|
}
|
|
|
|
jwtToken, err = createTestJWTWithECKey(ecPrivateKey, tc.algorithm, "test-ec-key-id", standardClaims)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create JWT with %s algorithm: %v", tc.algorithm, err)
|
|
}
|
|
} else {
|
|
t.Fatalf("Unsupported key type: %s", tc.keyType)
|
|
}
|
|
|
|
// Verify the token
|
|
err = ts.tOidc.VerifyToken(jwtToken)
|
|
|
|
if tc.shouldSucceed {
|
|
if err != nil {
|
|
t.Errorf("Verification with %s failed: %v", tc.algorithm, err)
|
|
} else {
|
|
t.Logf("Successfully verified token with %s algorithm", tc.algorithm)
|
|
}
|
|
} else {
|
|
if err == nil {
|
|
t.Errorf("Verification with unsupported algorithm %s succeeded", tc.algorithm)
|
|
} else {
|
|
// Check that the error message indicates unsupported algorithm
|
|
if !strings.Contains(err.Error(), "unsupported algorithm") {
|
|
t.Errorf("Expected unsupported algorithm error for %s, but got: %v", tc.algorithm, err)
|
|
} else {
|
|
t.Logf("Correctly rejected unsupported algorithm %s: %v", tc.algorithm, err)
|
|
}
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// createTestJWTWithECKey creates a JWT signed with an EC private key
|
|
func createTestJWTWithECKey(privateKey *ecdsa.PrivateKey, alg, kid string, claims map[string]interface{}) (string, error) {
|
|
// Create the header
|
|
header := map[string]interface{}{
|
|
"alg": alg,
|
|
"typ": "JWT",
|
|
"kid": kid,
|
|
}
|
|
|
|
// Encode header and claims to base64
|
|
headerJSON, err := json.Marshal(header)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to marshal header: %v", err)
|
|
}
|
|
headerBase64 := base64.RawURLEncoding.EncodeToString(headerJSON)
|
|
|
|
claimsJSON, err := json.Marshal(claims)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to marshal claims: %v", err)
|
|
}
|
|
claimsBase64 := base64.RawURLEncoding.EncodeToString(claimsJSON)
|
|
|
|
// Create the signing input
|
|
signingInput := headerBase64 + "." + claimsBase64
|
|
|
|
// Create signature based on algorithm
|
|
var signature []byte
|
|
|
|
switch alg {
|
|
case "ES256":
|
|
h := crypto.SHA256.New()
|
|
h.Write([]byte(signingInput))
|
|
hashed := h.Sum(nil)
|
|
r, s, err := ecdsa.Sign(rand.Reader, privateKey, hashed)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to sign with ES256: %v", err)
|
|
}
|
|
// For ES256, each coordinate should be 32 bytes (256 bits / 8)
|
|
rBytes := r.Bytes()
|
|
sBytes := s.Bytes()
|
|
if len(rBytes) < 32 {
|
|
padded := make([]byte, 32)
|
|
copy(padded[32-len(rBytes):], rBytes)
|
|
rBytes = padded
|
|
}
|
|
if len(sBytes) < 32 {
|
|
padded := make([]byte, 32)
|
|
copy(padded[32-len(sBytes):], sBytes)
|
|
sBytes = padded
|
|
}
|
|
signature = append(rBytes, sBytes...)
|
|
case "ES384":
|
|
h := crypto.SHA384.New()
|
|
h.Write([]byte(signingInput))
|
|
hashed := h.Sum(nil)
|
|
r, s, err := ecdsa.Sign(rand.Reader, privateKey, hashed)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to sign with ES384: %v", err)
|
|
}
|
|
// For ES384 (P-384), each coordinate should be 48 bytes (384 bits / 8)
|
|
rBytes := r.Bytes()
|
|
sBytes := s.Bytes()
|
|
// Pad to exactly 48 bytes each
|
|
if len(rBytes) < 48 {
|
|
padded := make([]byte, 48)
|
|
copy(padded[48-len(rBytes):], rBytes)
|
|
rBytes = padded
|
|
} else if len(rBytes) > 48 {
|
|
// Truncate if too long (shouldn't happen with P-384)
|
|
rBytes = rBytes[len(rBytes)-48:]
|
|
}
|
|
if len(sBytes) < 48 {
|
|
padded := make([]byte, 48)
|
|
copy(padded[48-len(sBytes):], sBytes)
|
|
sBytes = padded
|
|
} else if len(sBytes) > 48 {
|
|
// Truncate if too long (shouldn't happen with P-384)
|
|
sBytes = sBytes[len(sBytes)-48:]
|
|
}
|
|
signature = append(rBytes, sBytes...)
|
|
case "ES512":
|
|
h := crypto.SHA512.New()
|
|
h.Write([]byte(signingInput))
|
|
hashed := h.Sum(nil)
|
|
r, s, err := ecdsa.Sign(rand.Reader, privateKey, hashed)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to sign with ES512: %v", err)
|
|
}
|
|
// For ES512 (P-521), each coordinate should be 66 bytes (521 bits / 8 = 65.125, rounded up to 66)
|
|
rBytes := r.Bytes()
|
|
sBytes := s.Bytes()
|
|
// Pad to 66 bytes each
|
|
if len(rBytes) < 66 {
|
|
padded := make([]byte, 66)
|
|
copy(padded[66-len(rBytes):], rBytes)
|
|
rBytes = padded
|
|
} else if len(rBytes) > 66 {
|
|
// Truncate if too long (shouldn't happen with P-521)
|
|
rBytes = rBytes[len(rBytes)-66:]
|
|
}
|
|
if len(sBytes) < 66 {
|
|
padded := make([]byte, 66)
|
|
copy(padded[66-len(sBytes):], sBytes)
|
|
sBytes = padded
|
|
} else if len(sBytes) > 66 {
|
|
// Truncate if too long (shouldn't happen with P-521)
|
|
sBytes = sBytes[len(sBytes)-66:]
|
|
}
|
|
signature = append(rBytes, sBytes...)
|
|
default:
|
|
return "", fmt.Errorf("unsupported EC algorithm: %s", alg)
|
|
}
|
|
|
|
// Encode signature
|
|
signatureBase64 := base64.RawURLEncoding.EncodeToString(signature)
|
|
|
|
// Combine to create JWT
|
|
return signingInput + "." + signatureBase64, nil
|
|
}
|
|
|
|
// createECJWK creates a JWK from an EC private key
|
|
func createECJWK(privateKey *ecdsa.PrivateKey, alg, kid string) JWK {
|
|
// Get the curve name
|
|
var crv string
|
|
switch privateKey.Curve {
|
|
case elliptic.P256():
|
|
crv = "P-256"
|
|
case elliptic.P384():
|
|
crv = "P-384"
|
|
case elliptic.P521():
|
|
crv = "P-521"
|
|
default:
|
|
panic("unsupported curve")
|
|
}
|
|
|
|
// Get the key size for coordinate encoding
|
|
keySize := (privateKey.Curve.Params().BitSize + 7) / 8
|
|
|
|
// Encode X and Y coordinates
|
|
xBytes := privateKey.PublicKey.X.Bytes()
|
|
yBytes := privateKey.PublicKey.Y.Bytes()
|
|
|
|
// Pad to the correct length
|
|
if len(xBytes) < keySize {
|
|
padded := make([]byte, keySize)
|
|
copy(padded[keySize-len(xBytes):], xBytes)
|
|
xBytes = padded
|
|
}
|
|
if len(yBytes) < keySize {
|
|
padded := make([]byte, keySize)
|
|
copy(padded[keySize-len(yBytes):], yBytes)
|
|
yBytes = padded
|
|
}
|
|
|
|
return JWK{
|
|
Kty: "EC",
|
|
Kid: kid,
|
|
Alg: alg,
|
|
Crv: crv,
|
|
X: base64.RawURLEncoding.EncodeToString(xBytes),
|
|
Y: base64.RawURLEncoding.EncodeToString(yBytes),
|
|
}
|
|
}
|
|
|
|
// TestMalformedTokens tests the plugin's handling of malformed tokens
|
|
func TestMalformedTokens(t *testing.T) {
|
|
ts := NewTestSuite(t)
|
|
ts.Setup()
|
|
|
|
testCases := []struct {
|
|
name string
|
|
token string
|
|
expectedError string
|
|
}{
|
|
{
|
|
name: "Empty Token",
|
|
token: "",
|
|
expectedError: "invalid JWT format",
|
|
},
|
|
{
|
|
name: "Missing Parts",
|
|
token: "header.payload",
|
|
expectedError: "invalid JWT format",
|
|
},
|
|
{
|
|
name: "Invalid Base64 in Header",
|
|
token: "invalid!base64.payload.signature",
|
|
expectedError: "failed to decode header",
|
|
},
|
|
{
|
|
name: "Invalid Base64 in Payload",
|
|
token: "eyJhbGciOiJSUzI1NiJ9.invalid!base64.signature",
|
|
expectedError: "failed to decode claims",
|
|
},
|
|
{
|
|
name: "Invalid Base64 in Signature",
|
|
token: "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.invalid!base64",
|
|
expectedError: "failed to decode signature",
|
|
},
|
|
{
|
|
name: "Invalid JSON in Header",
|
|
token: "eyJpbnZhbGlkIGpzb24=.eyJzdWIiOiIxMjM0NTY3ODkwIn0.signature",
|
|
expectedError: "failed to decode header",
|
|
},
|
|
{
|
|
name: "Invalid JSON in Payload",
|
|
token: "eyJhbGciOiJSUzI1NiJ9.eyJpbnZhbGlkIGpzb24=.signature",
|
|
expectedError: "failed to decode claims",
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
err := ts.tOidc.VerifyToken(tc.token)
|
|
|
|
// Should fail with expected error
|
|
if err == nil {
|
|
t.Errorf("Malformed token was incorrectly verified: %s", tc.token)
|
|
} else {
|
|
if !strings.Contains(err.Error(), tc.expectedError) {
|
|
t.Errorf("Expected error containing '%s', but got: %v", tc.expectedError, err)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestRateLimiting tests the rate limiting functionality to prevent brute force attacks
|
|
func TestRateLimiting(t *testing.T) {
|
|
// Create cleanup helper
|
|
tc := newTestCleanup(t)
|
|
|
|
// Create a fresh instance for this test to avoid affecting other tests with rate limiting
|
|
logger := NewLogger("debug")
|
|
|
|
// Create a new test suite for this test only
|
|
ts := NewTestSuite(t)
|
|
ts.Setup()
|
|
|
|
// Create a separate TraefikOidc instance with a very restrictive rate limiter
|
|
// This prevents the global instance from being rate-limited
|
|
tokenBlacklist := tc.addCache(NewCache())
|
|
tokenCache := tc.addTokenCache(NewTokenCache())
|
|
tOidc := &TraefikOidc{
|
|
issuerURL: "https://test-issuer.com",
|
|
clientID: "test-client-id",
|
|
audience: "test-client-id",
|
|
clientSecret: "test-client-secret",
|
|
jwkCache: ts.mockJWKCache,
|
|
jwksURL: "https://test-jwks-url.com",
|
|
tokenBlacklist: tokenBlacklist,
|
|
tokenCache: tokenCache,
|
|
// Allow only 2 requests per 10 seconds
|
|
limiter: rate.NewLimiter(rate.Every(10*time.Second), 2),
|
|
logger: logger,
|
|
allowedUserDomains: map[string]struct{}{"example.com": {}},
|
|
excludedURLs: map[string]struct{}{"/favicon": {}},
|
|
httpClient: &http.Client{},
|
|
extractClaimsFunc: extractClaims,
|
|
}
|
|
|
|
// Set up the token verifier and JWT verifier
|
|
tOidc.jwtVerifier = tOidc
|
|
tOidc.tokenVerifier = tOidc
|
|
|
|
// Create a valid JWT token
|
|
validJWT, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://test-issuer.com",
|
|
"aud": "test-client-id",
|
|
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
|
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
|
"sub": "test-subject",
|
|
"email": "user@example.com",
|
|
"jti": generateRandomString(16),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to create valid JWT: %v", err)
|
|
}
|
|
|
|
// First request should succeed
|
|
err = tOidc.VerifyToken(validJWT)
|
|
if err != nil {
|
|
t.Fatalf("First token verification failed unexpectedly: %v", err)
|
|
}
|
|
|
|
// Second request should succeed
|
|
validJWT2, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://test-issuer.com",
|
|
"aud": "test-client-id",
|
|
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
|
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
|
"sub": "test-subject",
|
|
"email": "user@example.com",
|
|
"jti": generateRandomString(16),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to create second valid JWT: %v", err)
|
|
}
|
|
|
|
err = tOidc.VerifyToken(validJWT2)
|
|
if err != nil {
|
|
t.Fatalf("Second token verification failed unexpectedly: %v", err)
|
|
}
|
|
|
|
// Third request should be rate limited
|
|
validJWT3, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://test-issuer.com",
|
|
"aud": "test-client-id",
|
|
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
|
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
|
"sub": "test-subject",
|
|
"email": "user@example.com",
|
|
"jti": generateRandomString(16),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to create third valid JWT: %v", err)
|
|
}
|
|
|
|
err = tOidc.VerifyToken(validJWT3)
|
|
if err == nil {
|
|
t.Errorf("Third token verification succeeded despite rate limiting")
|
|
} else {
|
|
// Check that the error message indicates rate limiting
|
|
if !strings.Contains(strings.ToLower(err.Error()), "rate") {
|
|
t.Errorf("Expected rate limiting error, but got: %v", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestAuthorizationHeaderBypass tests that the plugin correctly handles attempts to bypass
|
|
// authorization by directly providing an Authorization header
|
|
func TestAuthorizationHeaderBypass(t *testing.T) {
|
|
// Create cleanup helper
|
|
tc := newTestCleanup(t)
|
|
|
|
ts := NewTestSuite(t)
|
|
ts.Setup()
|
|
|
|
// Create a test next handler that would indicate successful authentication
|
|
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte("Authenticated"))
|
|
})
|
|
|
|
// Create the TraefikOidc instance
|
|
tokenBlacklist := tc.addCache(NewCache())
|
|
tokenCache := tc.addTokenCache(NewTokenCache())
|
|
tOidc := &TraefikOidc{
|
|
next: nextHandler,
|
|
name: "test",
|
|
redirURLPath: "/callback",
|
|
logoutURLPath: "/callback/logout",
|
|
issuerURL: "https://test-issuer.com",
|
|
clientID: "test-client-id",
|
|
audience: "test-client-id",
|
|
clientSecret: "test-client-secret",
|
|
jwkCache: ts.mockJWKCache,
|
|
jwksURL: "https://test-jwks-url.com",
|
|
tokenBlacklist: tokenBlacklist,
|
|
tokenCache: tokenCache,
|
|
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
|
logger: NewLogger("debug"),
|
|
allowedUserDomains: map[string]struct{}{"example.com": {}},
|
|
excludedURLs: map[string]struct{}{"/favicon": {}},
|
|
httpClient: &http.Client{},
|
|
initComplete: make(chan struct{}),
|
|
sessionManager: ts.sessionManager,
|
|
}
|
|
close(tOidc.initComplete)
|
|
|
|
// Create a request with a forged Authorization header but no valid session
|
|
req := httptest.NewRequest("GET", "/protected", nil)
|
|
req.Header.Set("X-Forwarded-Proto", "https")
|
|
req.Header.Set("X-Forwarded-Host", "example.com")
|
|
|
|
// Add a forged Authorization header
|
|
req.Header.Set("Authorization", "Bearer "+ts.token)
|
|
|
|
// Record the response
|
|
resp := httptest.NewRecorder()
|
|
|
|
// Process the request
|
|
tOidc.ServeHTTP(resp, req)
|
|
|
|
// The middleware should not honor the direct Authorization header
|
|
// and should either redirect to authentication or return an error
|
|
if resp.Code == http.StatusOK {
|
|
body := resp.Body.String()
|
|
if body == "Authenticated" {
|
|
t.Errorf("Authorization header bypass succeeded - request was authenticated without a valid session")
|
|
}
|
|
}
|
|
|
|
// Verify that the response is a redirect to authentication (302) or unauthorized (401)
|
|
expectedCodes := []int{http.StatusFound, http.StatusUnauthorized}
|
|
codeFound := false
|
|
for _, code := range expectedCodes {
|
|
if resp.Code == code {
|
|
codeFound = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if !codeFound {
|
|
t.Errorf("Expected status code to be one of %v, but got %d", expectedCodes, resp.Code)
|
|
}
|
|
}
|
|
|
|
// TestEmptyAudience tests tokens with empty audience claim
|
|
func TestEmptyAudience(t *testing.T) {
|
|
ts := NewTestSuite(t)
|
|
ts.Setup()
|
|
|
|
// Create a JWT with empty audience
|
|
emptyAudJWT, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://test-issuer.com",
|
|
"aud": "", // Empty audience
|
|
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
|
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
|
"sub": "test-subject",
|
|
"email": "user@example.com",
|
|
"jti": generateRandomString(16),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to create JWT with empty audience: %v", err)
|
|
}
|
|
|
|
// Verify the token
|
|
err = ts.tOidc.VerifyToken(emptyAudJWT)
|
|
|
|
// Should fail due to invalid audience
|
|
if err == nil {
|
|
t.Errorf("Token with empty audience was incorrectly verified")
|
|
} else {
|
|
// Check error message
|
|
if !strings.Contains(err.Error(), "invalid audience") {
|
|
t.Errorf("Expected invalid audience error, but got: %v", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestEmptyIssuer tests tokens with empty issuer claim
|
|
func TestEmptyIssuer(t *testing.T) {
|
|
ts := NewTestSuite(t)
|
|
ts.Setup()
|
|
|
|
// Create a JWT with empty issuer
|
|
emptyIssJWT, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "", // Empty issuer
|
|
"aud": "test-client-id",
|
|
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
|
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
|
"sub": "test-subject",
|
|
"email": "user@example.com",
|
|
"jti": generateRandomString(16),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to create JWT with empty issuer: %v", err)
|
|
}
|
|
|
|
// Verify the token
|
|
err = ts.tOidc.VerifyToken(emptyIssJWT)
|
|
|
|
// Should fail due to invalid issuer
|
|
if err == nil {
|
|
t.Errorf("Token with empty issuer was incorrectly verified")
|
|
} else {
|
|
// Check error message
|
|
if !strings.Contains(err.Error(), "invalid issuer") {
|
|
t.Errorf("Expected invalid issuer error, but got: %v", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestInvalidRedirectURI tests the plugin's handling of invalid redirect URIs
|
|
func TestInvalidRedirectURI(t *testing.T) {
|
|
// Create cleanup helper
|
|
tc := newTestCleanup(t)
|
|
|
|
ts := NewTestSuite(t)
|
|
ts.Setup()
|
|
|
|
// Create a test request with an invalid redirect URI
|
|
req := httptest.NewRequest("GET", "/callback?state=validstate&code=validcode&redirect_uri=https://evil.com", nil)
|
|
req.Header.Set("X-Forwarded-Proto", "https")
|
|
req.Header.Set("X-Forwarded-Host", "example.com")
|
|
|
|
// Create a session with a state
|
|
session, err := ts.sessionManager.GetSession(req)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get session: %v", err)
|
|
}
|
|
|
|
// Set legitimate state and redirect
|
|
session.mainSession.Values["state"] = "validstate"
|
|
session.mainSession.Values["redirect"] = "/legitimate-redirect"
|
|
|
|
resp := httptest.NewRecorder()
|
|
if err := session.Save(req, resp); err != nil {
|
|
t.Fatalf("Failed to save session: %v", err)
|
|
}
|
|
|
|
// Get cookies
|
|
cookies := resp.Result().Cookies()
|
|
|
|
// Create a new request with those cookies
|
|
req = httptest.NewRequest("GET", "/callback?state=validstate&code=validcode&redirect_uri=https://evil.com", nil)
|
|
req.Header.Set("X-Forwarded-Proto", "https")
|
|
req.Header.Set("X-Forwarded-Host", "example.com")
|
|
|
|
// Add cookies
|
|
for _, cookie := range cookies {
|
|
req.AddCookie(cookie)
|
|
}
|
|
|
|
// Create a next handler for the middleware
|
|
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
// Create the TraefikOidc instance
|
|
tokenBlacklist := tc.addCache(NewCache())
|
|
tokenCache := tc.addTokenCache(NewTokenCache())
|
|
tOidc := &TraefikOidc{
|
|
next: nextHandler,
|
|
name: "test",
|
|
redirURLPath: "/callback",
|
|
logoutURLPath: "/callback/logout",
|
|
issuerURL: "https://test-issuer.com",
|
|
clientID: "test-client-id",
|
|
audience: "test-client-id",
|
|
clientSecret: "test-client-secret",
|
|
jwkCache: ts.mockJWKCache,
|
|
jwksURL: "https://test-jwks-url.com",
|
|
tokenBlacklist: tokenBlacklist,
|
|
tokenCache: tokenCache,
|
|
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
|
logger: NewLogger("debug"),
|
|
allowedUserDomains: map[string]struct{}{"example.com": {}},
|
|
excludedURLs: map[string]struct{}{"/favicon": {}},
|
|
httpClient: &http.Client{},
|
|
initComplete: make(chan struct{}),
|
|
sessionManager: ts.sessionManager,
|
|
tokenExchanger: ts.tOidc.tokenExchanger,
|
|
}
|
|
close(tOidc.initComplete)
|
|
|
|
// Process the callback request
|
|
resp = httptest.NewRecorder()
|
|
tOidc.ServeHTTP(resp, req)
|
|
|
|
// Check if open redirect is blocked
|
|
// The response should not redirect to the evil.com domain
|
|
location := resp.Header().Get("Location")
|
|
if location != "" && strings.Contains(location, "evil.com") {
|
|
t.Errorf("Open redirect vulnerability - redirected to %s", location)
|
|
}
|
|
|
|
// Should redirect to the legitimate URL
|
|
if location != "" && !strings.Contains(location, "/legitimate-redirect") {
|
|
t.Errorf("Expected redirect to /legitimate-redirect, but got: %s", location)
|
|
}
|
|
}
|