Files
traefikoidc/main_test.go
T
lukaszraczylo 9cbca4c4fb fix(refresh): honor userIdentifierClaim in token refresh path (#132)
patch-release

The refresh path in token_manager.go hardcoded the "email" claim when
extracting the user identifier from a refreshed ID token, ignoring the
configured userIdentifierClaim. Keycloak users without an email claim
(using sub or another identifier) were kicked out on refresh even
though their initial login worked.

The callback path (auth_flow.go:226-239) already honored
userIdentifierClaim with "sub" fallback; PR #100 (commit a316a98)
added that support but missed the refresh path.

Mirror the callback logic in refreshToken so both paths behave the same.

Cleanup: rename Get/SetEmail to Get/SetUserIdentifier on SessionData
to match the actual semantics. The slot already stored the configured
identifier (email, sub, oid, upn, preferred_username), only the API
name was misleading. Storage key "email" → "user_identifier" and
combinedSessionPayload field E (json:"e") → Ui (json:"ui").

Compat note: existing user sessions invalidate on upgrade — every active
user re-authenticates once after deploying this change.
2026-05-07 09:21:41 +01:00

4825 lines
155 KiB
Go

package traefikoidc
import (
"context"
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"fmt"
"math/big"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"sync"
"testing"
"time"
"github.com/gorilla/sessions"
"golang.org/x/time/rate"
)
// TestSuite holds common test data and setup
type TestSuite struct {
t *testing.T
rsaPrivateKey *rsa.PrivateKey
rsaPublicKey *rsa.PublicKey
ecPrivateKey *ecdsa.PrivateKey
tOidc *TraefikOidc
mockJWKCache *MockJWKCache
sessionManager *SessionManager
// utf *UnifiedTestFramework // Removed - consolidated test framework
token string
}
// NewTestSuite creates a new test suite with automatic cleanup
func NewTestSuite(t *testing.T) *TestSuite {
ts := &TestSuite{
t: t,
// utf: NewUnifiedTestFramework(t), // Removed
}
return ts
}
// Setup initializes the test suite
func (ts *TestSuite) Setup() {
// Initialize unified test framework if not already done
// Unified test framework removed - using direct cleanup
var err error
ts.rsaPrivateKey, err = rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
ts.t.Fatalf("Failed to generate RSA key: %v", err)
}
ts.rsaPublicKey = &ts.rsaPrivateKey.PublicKey
// Generate EC key for EC key tests
ts.ecPrivateKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
ts.t.Fatalf("Failed to generate EC key: %v", err)
}
// Create a JWK for the RSA public key
jwk := JWK{
Kty: "RSA",
Kid: "test-key-id",
Alg: "RS256",
N: base64.RawURLEncoding.EncodeToString(ts.rsaPublicKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString(bigIntToBytes(big.NewInt(int64(ts.rsaPublicKey.E)))),
}
jwks := &JWKSet{
Keys: []JWK{jwk},
}
// Create a mock JWKCache
ts.mockJWKCache = &MockJWKCache{
JWKS: jwks,
Err: nil,
}
// Create a test JWT token signed with the RSA private key
// Create timestamps with proper clock skew
now := time.Now()
exp := now.Add(1 * time.Hour).Unix()
iat := now.Add(-2 * time.Minute).Unix() // Account for clock skew
nbf := now.Add(-2 * time.Minute).Unix() // Account for clock skew
ts.token, err = createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": exp,
"iat": iat,
"nbf": nbf,
"sub": "test-subject",
"email": "user@example.com",
"nonce": "test-nonce",
"jti": generateRandomString(16),
})
if err != nil {
ts.t.Fatalf("Failed to create test JWT: %v", err)
}
logger := NewLogger("info")
ts.sessionManager, _ = NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, "", "", 0, logger)
// Create WaitGroup for the OIDC instance
goroutineWG := &sync.WaitGroup{}
// Initialize caches properly
tokenBlacklist := NewCache()
tokenCacheInternal := NewCache()
tokenCache := &TokenCache{}
if tokenCache.cache == nil {
// Type assert to get the underlying UniversalCache
if wrapper, ok := tokenCacheInternal.(*CacheInterfaceWrapper); ok {
tokenCache.cache = wrapper.cache
}
}
// Common TraefikOidc instance
ts.tOidc = &TraefikOidc{
issuerURL: "https://test-issuer.com",
clientID: "test-client-id",
audience: "test-client-id",
clientSecret: "test-client-secret",
roleClaimName: "roles", // Set default for backward compatibility
groupClaimName: "groups", // Set default for backward compatibility
userIdentifierClaim: "email", // Set default for backward compatibility
jwkCache: ts.mockJWKCache,
jwksURL: "https://test-jwks-url.com",
revocationURL: "https://revocation-endpoint.com",
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
tokenBlacklist: tokenBlacklist,
tokenCache: tokenCache,
logger: logger,
allowedUserDomains: map[string]struct{}{"example.com": {}},
excludedURLs: map[string]struct{}{"/favicon": {}, "/health": {}},
httpClient: &http.Client{Timeout: 10 * time.Second},
// Explicitly set paths as New() is bypassed
redirURLPath: "/callback", // Assume default callback path for tests
logoutURLPath: "/callback/logout", // Assume default logout path for tests
tokenURL: "https://test-issuer.com/token", // Explicitly set for refresh tests
extractClaimsFunc: extractClaims,
initComplete: make(chan struct{}),
sessionManager: ts.sessionManager,
goroutineWG: goroutineWG,
ctx: context.Background(),
tokenCleanupStopChan: make(chan struct{}),
metadataRefreshStopChan: make(chan struct{}),
}
close(ts.tOidc.initComplete)
ts.tOidc.tokenVerifier = ts.tOidc
ts.tOidc.jwtVerifier = ts.tOidc
// Set default mock exchanger
ts.tOidc.tokenExchanger = &MockTokenExchanger{
ExchangeCodeFunc: func(ctx context.Context, grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error) {
// Default mock behavior for code exchange
return &TokenResponse{
IDToken: ts.token, // Use the valid token from setup
AccessToken: ts.token,
RefreshToken: "default-refresh-token",
ExpiresIn: 3600,
}, nil
},
RefreshTokenFunc: func(refreshToken string) (*TokenResponse, error) {
// Default mock behavior for refresh (can be overridden in tests)
return nil, fmt.Errorf("default mock: refresh not expected")
},
RevokeTokenFunc: func(token, tokenType string) error {
// Default mock behavior for revoke
return nil
},
}
// OIDC instance created
// Register cleanup
ts.t.Cleanup(func() {
if ts.tOidc.tokenBlacklist != nil {
ts.tOidc.tokenBlacklist.Close()
}
if ts.tOidc.tokenCache != nil && ts.tOidc.tokenCache.cache != nil {
ts.tOidc.tokenCache.cache.Close()
}
})
}
// Helper function exchangeCodeForTokenFunc removed as it's unused after refactoring to TokenExchanger interface.
// MockJWKCache implements JWKCacheInterface
type MockJWKCache struct {
Err error
JWKS *JWKSet
mu sync.RWMutex
}
// Close is a no-op for the mock.
func (m *MockJWKCache) Close() {
// No operation needed for the mock.
}
func (m *MockJWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error) {
m.mu.RLock()
defer m.mu.RUnlock()
return m.JWKS, m.Err
}
func (m *MockJWKCache) GetPublicKey(ctx context.Context, jwksURL, kid string, httpClient *http.Client) (crypto.PublicKey, error) {
m.mu.RLock()
defer m.mu.RUnlock()
if m.Err != nil {
return nil, m.Err
}
if m.JWKS == nil {
return nil, fmt.Errorf("JWKS is nil")
}
for i := range m.JWKS.Keys {
k := &m.JWKS.Keys[i]
if k.Kid != kid {
continue
}
switch k.Kty {
case "RSA":
return k.ToRSAPublicKey()
case "EC":
return k.ToECDSAPublicKey()
default:
return nil, fmt.Errorf("unsupported key type: %s", k.Kty)
}
}
return nil, fmt.Errorf("no matching public key found for kid: %s", kid)
}
func (m *MockJWKCache) Cleanup() {
// Mock cleanup is a no-op - we don't want to destroy the mock JWKS data
// Real cleanup is for expired entries, not resetting all data
}
// MockTokenVerifier implements TokenVerifier for testing, allowing interception of VerifyToken calls.
type MockTokenVerifier struct {
VerifyFunc func(token string) error
}
func (m *MockTokenVerifier) VerifyToken(token string) error {
if m.VerifyFunc != nil {
return m.VerifyFunc(token)
}
return fmt.Errorf("VerifyFunc not implemented in mock")
}
// MockTokenExchanger implements TokenExchanger for testing
type MockTokenExchanger struct {
ExchangeCodeFunc func(ctx context.Context, grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error)
RefreshTokenFunc func(refreshToken string) (*TokenResponse, error)
RevokeTokenFunc func(token, tokenType string) error
}
func (m *MockTokenExchanger) ExchangeCodeForToken(ctx context.Context, grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error) {
if m.ExchangeCodeFunc != nil {
return m.ExchangeCodeFunc(ctx, grantType, codeOrToken, redirectURL, codeVerifier)
}
return nil, fmt.Errorf("ExchangeCodeFunc not implemented in mock")
}
func (m *MockTokenExchanger) GetNewTokenWithRefreshToken(refreshToken string) (*TokenResponse, error) {
if m.RefreshTokenFunc != nil {
return m.RefreshTokenFunc(refreshToken)
}
return nil, fmt.Errorf("RefreshTokenFunc not implemented in mock")
}
func (m *MockTokenExchanger) RevokeTokenWithProvider(token, tokenType string) error {
if m.RevokeTokenFunc != nil {
return m.RevokeTokenFunc(token, tokenType)
}
return fmt.Errorf("RevokeTokenFunc not implemented in mock")
}
// Helper function to check if a token is a test token
func isTestToken(token string) bool {
// Parse the token without verification to check if it's a test token
claims, err := extractClaims(token)
if err != nil {
return false
}
// Check if the issuer is our test issuer
if iss, ok := claims["iss"].(string); ok {
return iss == "https://test-issuer.com"
}
// Check if audience is our test client
if aud, ok := claims["aud"].(string); ok {
return aud == "test-client-id"
}
return false
}
// Helper function to create a new valid token for refresh tests using test suite
func (ts *TestSuite) createNewValidToken() string {
now := time.Now()
exp := now.Add(1 * time.Hour).Unix()
iat := now.Add(-2 * time.Minute).Unix()
nbf := now.Add(-2 * time.Minute).Unix()
token, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": exp,
"iat": iat,
"nbf": nbf,
"sub": "test-subject",
"email": "user@example.com",
"nonce": "test-nonce",
"jti": generateRandomString(16),
})
return token
}
// Helper function to create a JWT token
func createTestJWT(privateKey *rsa.PrivateKey, alg, kid string, claims map[string]interface{}) (string, error) {
header := map[string]interface{}{
"alg": alg,
"kid": kid,
"typ": "JWT",
}
headerJSON, err := json.Marshal(header)
if err != nil {
return "", err
}
headerEncoded := base64.RawURLEncoding.EncodeToString(headerJSON)
claimsJSON, err := json.Marshal(claims)
if err != nil {
return "", err
}
claimsEncoded := base64.RawURLEncoding.EncodeToString(claimsJSON)
signedContent := headerEncoded + "." + claimsEncoded
// Select the appropriate hash function based on algorithm
var hashFunc crypto.Hash
switch alg {
case "RS256", "PS256":
hashFunc = crypto.SHA256
case "RS384", "PS384":
hashFunc = crypto.SHA384
case "RS512", "PS512":
hashFunc = crypto.SHA512
default:
return "", fmt.Errorf("unsupported algorithm: %s", alg)
}
hasher := hashFunc.New()
hasher.Write([]byte(signedContent))
hashed := hasher.Sum(nil)
var signatureBytes []byte
// Use appropriate signing method based on algorithm
if strings.HasPrefix(alg, "RS") {
// PKCS1v15 signing for RS* algorithms
signatureBytes, err = rsa.SignPKCS1v15(rand.Reader, privateKey, hashFunc, hashed)
} else if strings.HasPrefix(alg, "PS") {
// PSS signing for PS* algorithms
signatureBytes, err = rsa.SignPSS(rand.Reader, privateKey, hashFunc, hashed, nil)
} else {
return "", fmt.Errorf("unsupported RSA algorithm: %s", alg)
}
if err != nil {
return "", err
}
signatureEncoded := base64.RawURLEncoding.EncodeToString(signatureBytes)
token := signedContent + "." + signatureEncoded
return token, nil
}
func bigIntToBytes(i *big.Int) []byte {
return i.Bytes()
}
// TestVerifyToken tests the VerifyToken method
func TestVerifyToken(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
tests := []struct {
name string
token string
blacklist bool
rateLimit bool
cacheToken bool
expectedError bool
}{
{
name: "Valid token",
token: ts.token,
expectedError: false,
},
{
name: "Invalid token signature",
token: ts.token + "invalid",
expectedError: true,
},
{
name: "Blacklisted token",
token: ts.token,
blacklist: true,
expectedError: true,
},
{
name: "Rate limit exceeded",
token: ts.token,
rateLimit: true,
expectedError: true,
},
{
name: "Token in cache",
token: ts.token,
cacheToken: true,
expectedError: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Reset token blacklist and cache for each test
ts.tOidc.tokenBlacklist = NewCache() // Use generic cache for blacklist
// Clear the token cache instead of creating a new one (it's a singleton)
ts.tOidc.tokenCache = NewTokenCache()
ts.tOidc.tokenCache.Clear()
ts.tOidc.limiter = rate.NewLimiter(rate.Every(time.Second), 10)
// Set up the test case
if tc.blacklist {
// Use Set with a duration. Value 'true' is arbitrary.
ts.tOidc.tokenBlacklist.Set(tc.token, true, 1*time.Hour)
}
if tc.rateLimit {
// Exceed rate limit
ts.tOidc.limiter = rate.NewLimiter(rate.Every(time.Hour), 0)
}
if tc.cacheToken {
// Use more realistic claims for cached token
ts.tOidc.tokenCache.Set(tc.token, map[string]interface{}{
"iss": "https://test-issuer.com",
"sub": "test-subject",
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
"jti": generateRandomString(16), // Add a JTI claim to prevent replay detection
}, time.Minute)
// Verify the token is actually in the cache
if claims, exists := ts.tOidc.tokenCache.Get(tc.token); exists {
t.Logf("Token found in cache with claims: %v", claims)
} else {
t.Logf("Token NOT found in cache despite cacheToken=true")
}
}
err := ts.tOidc.VerifyToken(tc.token)
if tc.expectedError && err == nil {
t.Errorf("Test %s: expected error but got nil", tc.name)
}
if !tc.expectedError && err != nil {
t.Errorf("Test %s: expected no error but got %v", tc.name, err)
}
})
}
}
// TestServeHTTP tests the ServeHTTP method
func TestServeHTTP(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
})
ts.tOidc.next = nextHandler
ts.tOidc.name = "test"
// Helper to create an expired token
createExpiredToken := func() string {
exp := time.Now().Add(-1 * time.Hour).Unix() // Expired 1 hour ago
iat := time.Now().Add(-2 * time.Hour).Unix()
nbf := time.Now().Add(-2 * time.Hour).Unix()
expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": exp,
"iat": iat,
"nbf": nbf,
"sub": "test-subject",
"email": "user@example.com",
"nonce": "test-nonce-expired", // Different nonce for clarity
"jti": generateRandomString(16),
})
return expiredToken
}
tests := []struct {
sessionValues map[interface{}]interface{}
setupSession func(*SessionData)
mockRefreshTokenFunc func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error)
assertSessionAfterRequest func(t *testing.T, rr *httptest.ResponseRecorder, req *http.Request, sessionManager *SessionManager)
requestHeaders map[string]string
name string
requestPath string
expectedBody string
expectedStatus int
}{
{
name: "Excluded URL",
requestPath: "/favicon.ico",
expectedStatus: http.StatusOK,
expectedBody: "OK",
},
{
name: "Unauthenticated request (no refresh token) to protected URL",
requestPath: "/protected",
setupSession: func(session *SessionData) {
// Ensure no tokens are set
session.SetAuthenticated(false)
session.SetAccessToken("")
session.SetRefreshToken("")
},
expectedStatus: http.StatusFound, // Expect redirect to OIDC as there's no refresh token
},
{
name: "Unauthenticated request (with refresh token) to protected URL - Expect Refresh Attempt",
requestPath: "/protected",
setupSession: func(session *SessionData) {
session.SetAuthenticated(false) // Not authenticated
session.SetAccessToken("") // No access token
session.SetRefreshToken("valid-refresh-token-for-unauth-test") // BUT has refresh token
},
mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) {
return func(refreshToken string) (*TokenResponse, error) {
if refreshToken != "valid-refresh-token-for-unauth-test" {
return nil, fmt.Errorf("mock error: unexpected refresh token '%s'", refreshToken)
}
// Simulate successful refresh
newToken := ts.createNewValidToken() // Use helper from TestServeHTTP
return &TokenResponse{IDToken: newToken, AccessToken: newToken, RefreshToken: "new-refresh-token-unauth", ExpiresIn: 3600}, nil
}
},
expectedStatus: http.StatusOK, // Expect OK after successful refresh
expectedBody: "OK",
},
{
name: "Unauthenticated request (with refresh token) to protected URL - Refresh Fails",
requestPath: "/protected",
setupSession: func(session *SessionData) {
session.SetAuthenticated(false) // Not authenticated
session.SetAccessToken("") // No access token
session.SetRefreshToken("invalid-refresh-token-for-unauth-test") // Invalid refresh token
},
mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) {
return func(refreshToken string) (*TokenResponse, error) {
// Simulate failed refresh
return nil, fmt.Errorf("mock error: refresh token invalid")
}
},
expectedStatus: http.StatusFound, // Expect redirect to OIDC after failed refresh
},
{
name: "Authenticated request to protected URL (Valid Token)",
requestPath: "/protected",
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetUserIdentifier("user@example.com")
// Generate a fresh valid token for this test case to avoid replay issues
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
"iat": time.Now().Unix(), "nbf": time.Now().Unix(), "sub": "test-subject", "email": "user@example.com",
"jti": generateRandomString(16), // Unique JTI
})
session.SetAccessToken(freshToken)
session.SetIDToken(freshToken) // Ensure ID token is also set
session.SetRefreshToken("valid-refresh-token")
},
expectedStatus: http.StatusOK,
expectedBody: "OK",
},
// This test case remains valid as the logic should still attempt refresh when expired token + refresh token exist
{
name: "Authenticated request with expired token and successful refresh",
requestPath: "/protected",
setupSession: func(session *SessionData) {
// NOTE: isUserAuthenticated now returns authenticated=false if access token is expired,
// even if session.SetAuthenticated(true) was called.
// We rely on needsRefresh=true and the presence of the refresh token to trigger the refresh attempt.
session.SetAuthenticated(true) // Set flag initially, though isUserAuthenticated will override based on token
session.SetUserIdentifier("user@example.com")
// Create an expired token for this test
expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(-1 * time.Hour).Unix(),
"iat": time.Now().Add(-2 * time.Hour).Unix(), "nbf": time.Now().Add(-2 * time.Hour).Unix(),
"sub": "test-subject", "email": "test@example.com", "jti": generateRandomString(16),
})
session.SetAccessToken(expiredToken) // Set expired token
session.SetRefreshToken("valid-refresh-token") // Set valid refresh token
},
mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) {
return func(refreshToken string) (*TokenResponse, error) {
if refreshToken != "valid-refresh-token" {
return nil, fmt.Errorf("mock error: expected 'valid-refresh-token', got '%s'", refreshToken)
}
// Simulate successful refresh
newToken := ts.createNewValidToken()
return &TokenResponse{
IDToken: newToken, // Return new valid token
AccessToken: newToken, // Often the same as ID token in tests
RefreshToken: "new-refresh-token",
ExpiresIn: 3600,
}, nil
}
},
expectedStatus: http.StatusOK, // Expect success after refresh
expectedBody: "OK",
assertSessionAfterRequest: func(t *testing.T, rr *httptest.ResponseRecorder, req *http.Request, sessionManager *SessionManager) {
// Create a new request to read the cookies set by the response recorder
reqForCookieRead := httptest.NewRequest("GET", "/protected", nil)
for _, cookie := range rr.Result().Cookies() {
reqForCookieRead.AddCookie(cookie)
}
// Get session based on response cookies
session, err := sessionManager.GetSession(reqForCookieRead)
if err != nil {
t.Fatalf("Failed to get session after request: %v", err)
}
// Assert new tokens are in the session
if session.GetAccessToken() == "" || session.GetAccessToken() == createExpiredToken() {
t.Errorf("Expected access token to be updated in session, but it was empty or still the expired one")
}
if session.GetRefreshToken() != "new-refresh-token" {
t.Errorf("Expected refresh token to be updated to 'new-refresh-token', got '%s'", session.GetRefreshToken())
}
// Also check authenticated flag is now true
if !session.GetAuthenticated() {
t.Errorf("Expected session to be marked authenticated after successful refresh")
}
},
},
// This test case remains valid as the logic should still return 401 for API clients on refresh failure
{
name: "Logout URL",
requestPath: "/callback/logout", // Match the default logout path set in TestSuite.Setup
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetUserIdentifier("user@example.com")
// Generate a fresh valid token for this test case
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
"iat": time.Now().Unix(), "nbf": time.Now().Unix(), "sub": "test-subject", "email": "user@example.com",
"jti": generateRandomString(16), // Unique JTI
})
session.SetAccessToken(freshToken)
},
expectedStatus: http.StatusFound, // Expect redirect after logout
expectedBody: "",
// No specific session assertion needed for logout redirect itself
},
{
name: "Authenticated request with expired token and FAILED refresh (Accept: JSON)",
requestPath: "/protected",
setupSession: func(session *SessionData) {
session.SetAuthenticated(true) // Set flag initially
session.SetUserIdentifier("user@example.com")
// Create an expired token for this test
expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(-1 * time.Hour).Unix(),
"iat": time.Now().Add(-2 * time.Hour).Unix(), "nbf": time.Now().Add(-2 * time.Hour).Unix(),
"sub": "test-subject", "email": "test@example.com", "jti": generateRandomString(16),
})
session.SetAccessToken(expiredToken) // Expired access token
session.SetRefreshToken("valid-refresh-token") // Valid refresh token
},
mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) {
return func(refreshToken string) (*TokenResponse, error) {
// Simulate failed refresh
return nil, fmt.Errorf("mock error: refresh token invalid or provider down")
}
},
requestHeaders: map[string]string{
"Accept": "application/json",
},
expectedStatus: http.StatusUnauthorized, // Expect 401 for API client after failed refresh attempt
expectedBody: `{"error":"Unauthorized","error_description":"Token refresh failed","status_code":401}`,
},
// This test case remains valid as the logic should still redirect browser clients on refresh failure
{
name: "Authenticated request with expired token and FAILED refresh (Accept: HTML)",
requestPath: "/protected",
setupSession: func(session *SessionData) {
session.SetAuthenticated(true) // Set flag initially
session.SetUserIdentifier("user@example.com")
// Create an expired token for this test
expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(-1 * time.Hour).Unix(),
"iat": time.Now().Add(-2 * time.Hour).Unix(), "nbf": time.Now().Add(-2 * time.Hour).Unix(),
"sub": "test-subject", "email": "test@example.com", "jti": generateRandomString(16),
})
session.SetAccessToken(expiredToken) // Expired access token
session.SetRefreshToken("valid-refresh-token") // Valid refresh token
},
mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) {
return func(refreshToken string) (*TokenResponse, error) {
// Simulate failed refresh
return nil, fmt.Errorf("mock error: refresh token invalid or provider down")
}
},
requestHeaders: map[string]string{
"Accept": "text/html", // Browser client
},
expectedStatus: http.StatusFound, // Expect redirect to OIDC for browser client after failed refresh attempt
},
// This test case remains valid as proactive refresh should still be attempted
{
name: "Authenticated request with token nearing expiry (needs refresh)",
requestPath: "/protected",
setupSession: func(session *SessionData) {
// Create token expiring soon (e.g., 30s, within default 60s grace period)
exp := time.Now().Add(30 * time.Second).Unix()
iat := time.Now().Add(-1 * time.Minute).Unix()
nbf := time.Now().Add(-1 * time.Minute).Unix()
nearExpiryToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": exp, "iat": iat, "nbf": nbf,
"sub": "test-subject", "email": "user@example.com", "jti": generateRandomString(16),
})
session.SetAuthenticated(true)
session.SetUserIdentifier("user@example.com")
session.SetAccessToken(nearExpiryToken)
session.SetRefreshToken("valid-refresh-token-for-near-expiry") // Refresh token MUST exist for proactive refresh
},
mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) {
return func(refreshToken string) (*TokenResponse, error) {
if refreshToken != "valid-refresh-token-for-near-expiry" {
return nil, fmt.Errorf("mock error: unexpected refresh token '%s'", refreshToken)
}
// Simulate successful refresh
newToken := ts.createNewValidToken()
return &TokenResponse{IDToken: newToken, AccessToken: newToken, RefreshToken: "new-refresh-token-near-expiry", ExpiresIn: 3600}, nil
}
},
expectedStatus: http.StatusOK, // Expect success after proactive refresh
expectedBody: "OK",
},
// This test case remains valid as no refresh should be attempted
{
name: "Authenticated request with token valid (outside grace period)",
requestPath: "/protected",
setupSession: func(session *SessionData) {
// Create token expiring later (e.g., 10 mins, outside default 60s grace period)
exp := time.Now().Add(10 * time.Minute).Unix()
iat := time.Now().Add(-1 * time.Minute).Unix()
nbf := time.Now().Add(-1 * time.Minute).Unix()
validToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": exp, "iat": iat, "nbf": nbf,
"sub": "test-subject", "email": "user@example.com", "jti": generateRandomString(16),
})
session.SetAuthenticated(true)
session.SetUserIdentifier("user@example.com")
session.SetAccessToken(validToken)
session.SetIDToken(validToken) // Ensure ID token is also set
session.SetRefreshToken("should-not-be-used-refresh-token")
},
mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) {
// This should NOT be called
return func(refreshToken string) (*TokenResponse, error) {
t.Errorf("Refresh token function was called unexpectedly for valid token outside grace period")
return nil, fmt.Errorf("refresh should not have been attempted")
}
},
expectedStatus: http.StatusOK, // Expect success, no refresh needed
expectedBody: "OK",
},
{
name: "Disallowed Domain (Accept: JSON)",
requestPath: "/protected",
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetUserIdentifier("user@disallowed.com") // Use disallowed domain
// Generate a fresh valid token for this test case
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
"iat": time.Now().Unix(), "nbf": time.Now().Unix(), "sub": "test-subject", "email": "user@disallowed.com", // Match email
"jti": generateRandomString(16), // Unique JTI
})
session.SetAccessToken(freshToken)
session.SetIDToken(freshToken) // Ensure ID token is also set
session.SetRefreshToken("valid-refresh-token")
},
requestHeaders: map[string]string{
"Accept": "application/json",
},
expectedStatus: http.StatusForbidden,
expectedBody: `{"error":"Forbidden","error_description":"Access denied: You are not authorized to access this resource. To log out, visit: /callback/logout","status_code":403}`,
},
{
name: "Disallowed Domain (Accept: HTML)",
requestPath: "/protected",
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetUserIdentifier("user@disallowed.com") // Use disallowed domain
// Generate a fresh valid token for this test case
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
"iat": time.Now().Unix(), "nbf": time.Now().Unix(), "sub": "test-subject", "email": "user@disallowed.com", // Match email
"jti": generateRandomString(16), // Unique JTI
})
session.SetAccessToken(freshToken)
session.SetIDToken(freshToken) // Ensure ID token is also set
session.SetRefreshToken("valid-refresh-token")
},
requestHeaders: map[string]string{
"Accept": "text/html",
},
expectedStatus: http.StatusForbidden, // Still Forbidden, but HTML response
expectedBody: "", // Body check is harder for HTML, focus on status and content-type
},
}
// Configure allowed domains for domain restriction tests
// This allows example.com but not disallowed.com
ts.tOidc.allowedUserDomains = map[string]struct{}{
"example.com": {},
}
// Use mock JWK cache to enable proper token verification
ts.tOidc.jwkCache = ts.mockJWKCache
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Reset token blacklist and cache for each test to prevent token replay detection errors
ts.tOidc.tokenBlacklist = NewCache() // Use generic cache for blacklist
ts.tOidc.tokenCache = NewTokenCache()
// Reset the global replayCache to prevent "token replay detected" errors
cleanupReplayCache()
initReplayCache()
// Store original tokenVerifier to restore later
origTokenVerifier := ts.tOidc.tokenVerifier
// Create a mock tokenVerifier that clears the replay cache before verification
// This prevents replay detection when the same token is verified multiple times within a test
mockTokenVerifier := &MockTokenVerifier{
VerifyFunc: func(token string) error {
// Clear replay cache before token verification
cleanupReplayCache()
initReplayCache()
// For test tokens, perform basic validation without JWKS dependency
if isTestToken(token) {
// Parse the token to check basic validity and expiration
claims, err := extractClaims(token)
if err != nil {
return fmt.Errorf("token parsing failed: %v", err)
}
// Check token expiration
if exp, ok := claims["exp"].(float64); ok {
if time.Now().Unix() > int64(exp) {
return fmt.Errorf("token has expired")
}
}
// Token is valid for test purposes - also cache the claims like the real verifier would
if ts.tOidc.tokenCache != nil {
ts.tOidc.tokenCache.Set(token, claims, time.Hour)
}
return nil
}
// For non-test tokens, call the original verifier
if origTokenVerifier != nil {
return origTokenVerifier.VerifyToken(token)
}
return fmt.Errorf("original token verifier is nil")
},
}
// Replace tokenVerifier with our mock
ts.tOidc.tokenVerifier = mockTokenVerifier
// Restore original tokenVerifier after test
defer func() {
ts.tOidc.tokenVerifier = origTokenVerifier
}()
req := httptest.NewRequest("GET", tc.requestPath, nil)
// Set common headers needed by the logic (determineScheme, determineHost)
req.Header.Set("X-Forwarded-Proto", "http") // Or https if testing that
req.Header.Set("X-Forwarded-Host", "testhost.com")
req.Host = "testhost.com" // Also set Host header
// Set request headers from test case
if tc.requestHeaders != nil {
for key, value := range tc.requestHeaders {
req.Header.Set(key, value)
}
}
rr := httptest.NewRecorder()
// Setup session if needed
session, err := ts.tOidc.sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Test %s: Failed to get initial session: %v", tc.name, err)
}
if tc.setupSession != nil {
tc.setupSession(session)
// Save session to recorder to get cookies
saveRecorder := httptest.NewRecorder()
if err := session.Save(req, saveRecorder); err != nil {
t.Fatalf("Test %s: Failed to save initial session: %v", tc.name, err)
}
// Copy cookies from save recorder to the actual request
for _, cookie := range saveRecorder.Result().Cookies() {
req.AddCookie(cookie)
}
}
// Mocking setup for TokenExchanger
originalExchanger := ts.tOidc.tokenExchanger // Store original
mockExchanger, isMock := originalExchanger.(*MockTokenExchanger)
if !isMock {
// This case should ideally not happen if Setup correctly assigns the mock,
// but handle it defensively.
t.Logf("Warning: Default exchanger was not the mock. Creating a temporary mock.")
mockExchanger = &MockTokenExchanger{
ExchangeCodeFunc: originalExchanger.ExchangeCodeForToken,
RefreshTokenFunc: originalExchanger.GetNewTokenWithRefreshToken,
RevokeTokenFunc: originalExchanger.RevokeTokenWithProvider,
}
ts.tOidc.tokenExchanger = mockExchanger // Temporarily assign mock
}
// Override specific mock methods if needed for the test case
originalMockRefreshFunc := mockExchanger.RefreshTokenFunc // Store current mock func
if tc.mockRefreshTokenFunc != nil {
// Assign the test case specific mock function
mockExchanger.RefreshTokenFunc = tc.mockRefreshTokenFunc(originalExchanger.GetNewTokenWithRefreshToken)
}
// Call ServeHTTP
ts.tOidc.ServeHTTP(rr, req)
// Restore original exchanger and mock function state
ts.tOidc.tokenExchanger = originalExchanger
if tc.mockRefreshTokenFunc != nil && mockExchanger != nil {
// Restore the previous mock function if we overrode it
mockExchanger.RefreshTokenFunc = originalMockRefreshFunc
}
// Check response status
if rr.Code != tc.expectedStatus {
t.Errorf("Test %s: Expected status %d, got %d. Body: %s", tc.name, tc.expectedStatus, rr.Code, rr.Body.String())
}
// Check response body if expected
// Check response body if expected (handle JSON vs HTML)
if tc.expectedBody != "" {
// For JSON, compare directly
if strings.Contains(rr.Header().Get("Content-Type"), "application/json") {
if body := strings.TrimSpace(rr.Body.String()); body != tc.expectedBody {
t.Errorf("Test %s: Expected JSON body %q, got %q", tc.name, tc.expectedBody, body)
}
} else if tc.expectedBody == "OK" { // Simple check for the "OK" body from next handler
if body := strings.TrimSpace(rr.Body.String()); body != tc.expectedBody {
t.Errorf("Test %s: Expected body %q, got %q", tc.name, tc.expectedBody, body)
}
}
// Add more sophisticated HTML body checks if needed
}
// Perform post-request session assertions if defined
if tc.assertSessionAfterRequest != nil {
tc.assertSessionAfterRequest(t, rr, req, ts.tOidc.sessionManager)
}
})
}
}
func TestJWKToPEM(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
tests := []struct {
jwk *JWK
name string
errorContains string
expectError bool
}{
{
name: "Unsupported Key Type",
jwk: &JWK{
Kty: "unsupported",
Kid: "test-key-id",
},
expectError: true,
errorContains: "unsupported key type",
},
{
name: "EC Key",
jwk: &JWK{
Kty: "EC",
Kid: "test-ec-key-id",
Crv: "P-256",
X: base64.RawURLEncoding.EncodeToString(ts.ecPrivateKey.PublicKey.X.Bytes()),
Y: base64.RawURLEncoding.EncodeToString(ts.ecPrivateKey.PublicKey.Y.Bytes()),
},
expectError: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
pemBytes, err := jwkToPEM(tc.jwk)
if tc.expectError {
if err == nil {
t.Errorf("Expected error, got nil")
} else if !strings.Contains(err.Error(), tc.errorContains) {
t.Errorf("Expected error containing '%s', got '%v'", tc.errorContains, err)
}
} else {
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if len(pemBytes) == 0 {
t.Error("PEM bytes should not be empty")
}
}
})
}
}
func TestParseJWT(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
tests := []struct {
name string
token string
errorContains string
expectError bool
}{
{
name: "Invalid Format",
token: "invalid.jwt.token",
expectError: true,
errorContains: "invalid JWT format",
},
{
name: "Valid Token",
token: ts.token,
expectError: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
_, err := parseJWT(tc.token)
if tc.expectError {
if err == nil {
t.Errorf("Expected error, got nil")
} else if !strings.Contains(err.Error(), tc.errorContains) {
t.Errorf("Expected error containing '%s', got '%v'", tc.errorContains, err)
}
} else {
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
}
})
}
}
func TestJWTVerify_MissingClaims(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
jwt := &JWT{
Header: map[string]interface{}{
"alg": "RS256",
"kid": "test-key-id",
},
Claims: map[string]interface{}{
// Missing 'iss', 'aud', 'exp', 'iat', 'sub'
},
}
err := jwt.Verify("https://test-issuer.com", "test-client-id")
if err == nil {
t.Error("Expected error for missing claims, got nil")
}
}
func TestHandleCallback(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
redirectURL := "http://example.com/"
tests := []struct {
exchangeCodeForToken func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error)
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
sessionSetupFunc func(*SessionData)
name string
queryParams string
expectedStatus int
}{
{
name: "Success",
queryParams: "?code=test-code&state=test-csrf-token",
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
}, nil
},
extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
return map[string]interface{}{
"email": "user@example.com",
"nonce": "test-nonce",
}, nil
},
sessionSetupFunc: func(session *SessionData) {
session.SetCSRF("test-csrf-token")
session.SetNonce("test-nonce")
},
expectedStatus: http.StatusFound,
},
{
name: "Missing Code",
queryParams: "",
sessionSetupFunc: func(session *SessionData) {
session.SetCSRF("test-csrf-token")
session.SetNonce("test-nonce")
},
expectedStatus: http.StatusBadRequest,
},
{
name: "Exchange Code Error",
queryParams: "?code=test-code&state=test-csrf-token",
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
return nil, fmt.Errorf("exchange code error")
},
sessionSetupFunc: func(session *SessionData) {
session.SetCSRF("test-csrf-token")
session.SetNonce("test-nonce")
},
expectedStatus: http.StatusInternalServerError,
},
{
name: "Missing ID Token",
queryParams: "?code=test-code&state=test-csrf-token",
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
return &TokenResponse{}, nil
},
sessionSetupFunc: func(session *SessionData) {
session.SetCSRF("test-csrf-token")
session.SetNonce("test-nonce")
},
expectedStatus: http.StatusInternalServerError,
},
{
name: "Disallowed Email",
queryParams: "?code=test-code&state=test-csrf-token",
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
// Generate a unique token for this test case to avoid replay issues
// Use claims relevant to this test (disallowed email)
now := time.Now()
exp := now.Add(1 * time.Hour).Unix()
iat := now.Unix()
nbf := now.Unix()
disallowedToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": exp,
"iat": iat,
"nbf": nbf,
"sub": "test-subject-disallowed",
"email": "user@disallowed.com", // The disallowed email for this test
"nonce": "test-nonce", // Match the nonce set in sessionSetupFunc
"jti": generateRandomString(16), // Unique JTI
})
if err != nil {
return nil, fmt.Errorf("failed to create disallowed token for test: %w", err)
}
return &TokenResponse{
IDToken: disallowedToken,
RefreshToken: "test-refresh-token-disallowed",
}, nil
},
sessionSetupFunc: func(session *SessionData) {
session.SetCSRF("test-csrf-token")
session.SetNonce("test-nonce")
},
expectedStatus: http.StatusForbidden,
},
{
name: "Invalid State Parameter",
queryParams: "?code=test-code&state=invalid-csrf-token",
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
}, nil
},
extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
return map[string]interface{}{
"email": "user@example.com",
"nonce": "test-nonce",
}, nil
},
sessionSetupFunc: func(session *SessionData) {
session.SetCSRF("test-csrf-token")
session.SetNonce("test-nonce")
},
expectedStatus: http.StatusBadRequest,
},
{
name: "Nonce Mismatch",
queryParams: "?code=test-code&state=test-csrf-token",
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
}, nil
},
extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
return map[string]interface{}{
"email": "user@example.com",
"nonce": "invalid-nonce",
}, nil
},
sessionSetupFunc: func(session *SessionData) {
session.SetCSRF("test-csrf-token")
session.SetNonce("test-nonce")
},
expectedStatus: http.StatusInternalServerError,
},
{
name: "Missing Nonce in Claims",
queryParams: "?code=test-code&state=test-csrf-token",
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
}, nil
},
extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
return map[string]interface{}{
"email": "user@example.com",
// Missing nonce
}, nil
},
sessionSetupFunc: func(session *SessionData) {
session.SetCSRF("test-csrf-token")
session.SetNonce("test-nonce")
},
expectedStatus: http.StatusInternalServerError,
},
}
for _, tc := range tests {
// Capture range variable
t.Run(tc.name, func(t *testing.T) {
// Clear the global replay cache before each test run
cleanupReplayCache()
initReplayCache()
// Explicitly clear the shared blacklist at the start of each sub-test
// to ensure no state leaks, even though we expect the local one to be used.
// Note: This line might be redundant now that the verifier is local, but keep for safety.
ts.tOidc.tokenBlacklist = NewCache() // Use generic cache for blacklist
logger := NewLogger("info")
sessionManager, _ := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, "", "", 0, logger)
// Create a new instance for each test to avoid state carryover
instanceExtractClaimsFunc := tc.extractClaimsFunc
if instanceExtractClaimsFunc == nil {
instanceExtractClaimsFunc = extractClaims // Default to the real function if not provided by test case
}
tOidc := &TraefikOidc{
allowedUserDomains: map[string]struct{}{"example.com": {}},
logger: logger,
userIdentifierClaim: "email", // Required for claim extraction
// exchangeCodeForTokenFunc: tc.exchangeCodeForToken, // Removed field
extractClaimsFunc: instanceExtractClaimsFunc, // Use the potentially defaulted function
tokenVerifier: nil, // Will be set to self below
jwtVerifier: nil, // Temporarily nil, will be set below
sessionManager: sessionManager,
tokenExchanger: &MockTokenExchanger{ // Create a new mock exchanger for this specific test run
ExchangeCodeFunc: func(ctx context.Context, grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error) {
// Wrap the test case function to match the required signature
if tc.exchangeCodeForToken != nil {
// Only call if the test case provided a function
return tc.exchangeCodeForToken(codeOrToken, redirectURL, codeVerifier)
}
// Provide a default behavior or error if no mock was provided for this test case
return nil, fmt.Errorf("mock ExchangeCodeFunc not implemented for this test case")
},
// Keep other mock funcs nil or provide defaults if needed by other parts of handleCallback
},
tokenCache: NewTokenCache(), // Initialize token cache
limiter: rate.NewLimiter(rate.Inf, 0), // Initialize rate limiter
tokenBlacklist: NewCache(), // Initialize token blacklist cache
// Add potentially missing fields based on New() comparison
clientID: ts.tOidc.clientID,
audience: ts.tOidc.clientID,
issuerURL: ts.tOidc.issuerURL,
jwkCache: ts.tOidc.jwkCache, // Use the mock cache from TestSuite
httpClient: ts.tOidc.httpClient,
initComplete: make(chan struct{}), // Initialize the channel
// Setting other fields like paths, enablePKCE etc. if needed
}
tOidc.tokenVerifier = tOidc // Point tokenVerifier to the local instance NOW
tOidc.jwtVerifier = tOidc // Point jwtVerifier to the local instance NOW
close(tOidc.initComplete) // Mark this test instance as initialized
// Create request and response recorder
req := httptest.NewRequest("GET", "/callback"+tc.queryParams, nil)
rr := httptest.NewRecorder()
// Create session
session, err := sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
if tc.sessionSetupFunc != nil {
tc.sessionSetupFunc(session)
}
if err := session.Save(req, rr); err != nil {
t.Fatalf("Failed to save session: %v", err)
}
// Copy cookies to the new request
for _, cookie := range rr.Result().Cookies() {
req.AddCookie(cookie)
}
// Reset response recorder for the actual test
rr = httptest.NewRecorder()
// Call handleCallback
tOidc.handleCallback(rr, req, redirectURL)
// Check response
if rr.Code != tc.expectedStatus {
t.Errorf("Expected status %d, got %d", tc.expectedStatus, rr.Code)
}
})
}
}
func TestIsAllowedDomain(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
tests := []struct {
allowedDomains map[string]struct{}
allowedUsers map[string]struct{}
name string
email string
expectedLogOutput string
allowed bool
}{
{
name: "Allowed domain",
email: "user@example.com",
allowedDomains: map[string]struct{}{"example.com": {}},
allowedUsers: map[string]struct{}{},
allowed: true,
},
{
name: "Disallowed domain",
email: "user@notallowed.com",
allowedDomains: map[string]struct{}{"example.com": {}},
allowedUsers: map[string]struct{}{},
allowed: false,
},
{
name: "Invalid email",
email: "invalid-email",
allowedDomains: map[string]struct{}{"example.com": {}},
allowedUsers: map[string]struct{}{},
allowed: false,
},
{
name: "Specific user is allowed regardless of domain",
email: "specific.user@otherdomain.com",
allowedDomains: map[string]struct{}{"example.com": {}},
allowedUsers: map[string]struct{}{"specific.user@otherdomain.com": {}},
allowed: true,
},
{
name: "Case-insensitive email matching for specific user",
email: "Specific.User@otherdomain.com", // Mixed case
allowedDomains: map[string]struct{}{"example.com": {}},
allowedUsers: map[string]struct{}{"specific.user@otherdomain.com": {}}, // Lowercase
allowed: true,
},
{
name: "Only allowed users configured (no domains)",
email: "specific.user@otherdomain.com",
allowedDomains: map[string]struct{}{}, // Empty allowed domains
allowedUsers: map[string]struct{}{"specific.user@otherdomain.com": {}},
allowed: true,
},
{
name: "User not in allowed list when only specific users configured",
email: "other.user@otherdomain.com",
allowedDomains: map[string]struct{}{}, // Empty allowed domains
allowedUsers: map[string]struct{}{"specific.user@otherdomain.com": {}},
allowed: false,
},
{
name: "No restrictions (both empty)",
email: "anyone@anydomain.com",
allowedDomains: map[string]struct{}{},
allowedUsers: map[string]struct{}{},
allowed: true,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Configure TraefikOidc instance for this test case
tOidc := ts.tOidc
tOidc.allowedUserDomains = tc.allowedDomains
tOidc.allowedUsers = tc.allowedUsers
allowed := tOidc.isAllowedDomain(tc.email)
if allowed != tc.allowed {
t.Errorf("Expected allowed=%v, got %v", tc.allowed, allowed)
}
})
}
}
func TestIsAllowedUser(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
tests := []struct {
allowedDomains map[string]struct{}
allowedUsers map[string]struct{}
userIdentifierClaim string
name string
userIdentifier string
allowed bool
}{
// Email-based identification (default behavior)
{
name: "Email identifier - allowed domain",
userIdentifier: "user@example.com",
userIdentifierClaim: "email",
allowedDomains: map[string]struct{}{"example.com": {}},
allowedUsers: map[string]struct{}{},
allowed: true,
},
{
name: "Email identifier - disallowed domain",
userIdentifier: "user@notallowed.com",
userIdentifierClaim: "email",
allowedDomains: map[string]struct{}{"example.com": {}},
allowedUsers: map[string]struct{}{},
allowed: false,
},
{
name: "Email identifier - specific user allowed",
userIdentifier: "specific.user@otherdomain.com",
userIdentifierClaim: "email",
allowedDomains: map[string]struct{}{"example.com": {}},
allowedUsers: map[string]struct{}{"specific.user@otherdomain.com": {}},
allowed: true,
},
// Non-email identifier (sub claim - for Azure AD users without email)
{
name: "Sub identifier - allowed in allowedUsers",
userIdentifier: "abc12345-6789-0abc-def0-123456789abc",
userIdentifierClaim: "sub",
allowedDomains: map[string]struct{}{},
allowedUsers: map[string]struct{}{"abc12345-6789-0abc-def0-123456789abc": {}},
allowed: true,
},
{
name: "Sub identifier - not in allowedUsers",
userIdentifier: "xyz-not-allowed-user",
userIdentifierClaim: "sub",
allowedDomains: map[string]struct{}{},
allowedUsers: map[string]struct{}{"abc12345-6789-0abc-def0-123456789abc": {}},
allowed: false,
},
{
name: "Sub identifier - allowedDomains ignored for non-email",
userIdentifier: "user-id-12345",
userIdentifierClaim: "sub",
allowedDomains: map[string]struct{}{"example.com": {}}, // Should be ignored
allowedUsers: map[string]struct{}{"user-id-12345": {}},
allowed: true,
},
{
name: "Sub identifier - no restrictions allows all",
userIdentifier: "any-user-id",
userIdentifierClaim: "sub",
allowedDomains: map[string]struct{}{},
allowedUsers: map[string]struct{}{},
allowed: true,
},
{
name: "Sub identifier - case insensitive matching",
userIdentifier: "ABC12345-6789-0ABC-DEF0-123456789ABC", // Uppercase
userIdentifierClaim: "sub",
allowedDomains: map[string]struct{}{},
allowedUsers: map[string]struct{}{"abc12345-6789-0abc-def0-123456789abc": {}}, // Lowercase
allowed: true,
},
// OID claim (Azure AD object ID)
{
name: "OID identifier - allowed user",
userIdentifier: "oid-12345-67890",
userIdentifierClaim: "oid",
allowedDomains: map[string]struct{}{},
allowedUsers: map[string]struct{}{"oid-12345-67890": {}},
allowed: true,
},
// UPN claim (Azure AD User Principal Name)
{
name: "UPN identifier - allowed user (looks like email but use sub logic)",
userIdentifier: "user@tenant.onmicrosoft.com",
userIdentifierClaim: "upn",
allowedDomains: map[string]struct{}{"example.com": {}}, // Different domain, should be ignored
allowedUsers: map[string]struct{}{"user@tenant.onmicrosoft.com": {}},
allowed: true,
},
// Edge cases
{
name: "Empty identifier - not allowed",
userIdentifier: "",
userIdentifierClaim: "sub",
allowedDomains: map[string]struct{}{},
allowedUsers: map[string]struct{}{"some-user": {}},
allowed: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Configure TraefikOidc instance for this test case
tOidc := ts.tOidc
tOidc.allowedUserDomains = tc.allowedDomains
tOidc.allowedUsers = tc.allowedUsers
tOidc.userIdentifierClaim = tc.userIdentifierClaim
allowed := tOidc.isAllowedUser(tc.userIdentifier)
if allowed != tc.allowed {
t.Errorf("Expected allowed=%v, got %v for userIdentifier=%q with claim=%q",
tc.allowed, allowed, tc.userIdentifier, tc.userIdentifierClaim)
}
})
}
}
func TestUserIdentifierClaimExtraction(t *testing.T) {
// Test that the correct claim is extracted based on userIdentifierClaim config
tests := []struct {
name string
userIdentifierClaim string
claims map[string]interface{}
expectedIdentifier string
shouldFallbackToSub bool
}{
{
name: "Email claim extraction (default)",
userIdentifierClaim: "email",
claims: map[string]interface{}{
"sub": "user-sub-id",
"email": "user@example.com",
},
expectedIdentifier: "user@example.com",
shouldFallbackToSub: false,
},
{
name: "Sub claim extraction",
userIdentifierClaim: "sub",
claims: map[string]interface{}{
"sub": "user-sub-id",
"email": "user@example.com",
},
expectedIdentifier: "user-sub-id",
shouldFallbackToSub: false,
},
{
name: "OID claim extraction (Azure AD)",
userIdentifierClaim: "oid",
claims: map[string]interface{}{
"sub": "user-sub-id",
"email": "user@example.com",
"oid": "azure-object-id",
},
expectedIdentifier: "azure-object-id",
shouldFallbackToSub: false,
},
{
name: "UPN claim extraction (Azure AD)",
userIdentifierClaim: "upn",
claims: map[string]interface{}{
"sub": "user-sub-id",
"upn": "user@tenant.onmicrosoft.com",
},
expectedIdentifier: "user@tenant.onmicrosoft.com",
shouldFallbackToSub: false,
},
{
name: "Fallback to sub when configured claim is missing",
userIdentifierClaim: "email",
claims: map[string]interface{}{
"sub": "fallback-sub-id",
// email is missing
},
expectedIdentifier: "fallback-sub-id",
shouldFallbackToSub: true,
},
{
name: "preferred_username claim extraction",
userIdentifierClaim: "preferred_username",
claims: map[string]interface{}{
"sub": "user-sub-id",
"preferred_username": "jdoe",
},
expectedIdentifier: "jdoe",
shouldFallbackToSub: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Extract user identifier using the same logic as auth_flow.go
userIdentifier, _ := tc.claims[tc.userIdentifierClaim].(string)
usedFallback := false
if userIdentifier == "" && tc.userIdentifierClaim != "sub" {
userIdentifier, _ = tc.claims["sub"].(string)
usedFallback = true
}
if userIdentifier != tc.expectedIdentifier {
t.Errorf("Expected identifier %q, got %q", tc.expectedIdentifier, userIdentifier)
}
if usedFallback != tc.shouldFallbackToSub {
t.Errorf("Expected fallback=%v, got %v", tc.shouldFallbackToSub, usedFallback)
}
})
}
}
func TestOIDCHandler(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
ts.token = "valid.jwt.token"
tests := []struct {
exchangeCodeForToken func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error)
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
sessionSetupFunc func(session *sessions.Session)
name string
queryParams string
expectedStatus int
blacklist bool
rateLimit bool
cacheToken bool
}{
{
name: "Missing Code",
queryParams: "",
sessionSetupFunc: func(session *sessions.Session) {
// Set CSRF and nonce values in session
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
// Simulate token exchange
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
}, nil
},
extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
// Simulate extraction of claims with invalid nonce
return map[string]interface{}{
"email": "user@example.com",
"nonce": "invalid-nonce",
}, nil
},
expectedStatus: http.StatusInternalServerError,
},
{
name: "Missing Nonce in Claims",
queryParams: "?code=test-code&state=test-csrf-token",
sessionSetupFunc: func(session *sessions.Session) {
// Set CSRF and nonce values in session
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
// Simulate token exchange
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
}, nil
},
extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
// Simulate extraction of claims without nonce
return map[string]interface{}{
"email": "user@example.com",
}, nil
},
expectedStatus: http.StatusBadRequest,
},
{
name: "Invalid State Parameter",
queryParams: "?code=test-code&state=invalid-csrf-token",
sessionSetupFunc: func(session *sessions.Session) {
// Set CSRF and nonce values in session
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
// Simulate token exchange
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
}, nil
},
extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
// Simulate extraction of claims
return map[string]interface{}{
"email": "user@example.com",
"nonce": "test-nonce",
}, nil
},
expectedStatus: http.StatusBadRequest,
},
{
name: "Nonce Mismatch",
queryParams: "?code=test-code&state=test-csrf-token",
sessionSetupFunc: func(session *sessions.Session) {
// Set CSRF and nonce values in session
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
// Simulate token exchange
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
}, nil
},
extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
// Simulate extraction of claims with mismatched nonce
return map[string]interface{}{
"email": "user@example.com",
"nonce": "invalid-nonce",
}, nil
},
expectedStatus: http.StatusBadRequest,
},
}
for _, tc := range tests {
// Capture range variable
t.Run(tc.name, func(t *testing.T) {
// Reset token blacklist and cache
ts.tOidc.tokenBlacklist = NewCache() // Use generic cache for blacklist
ts.tOidc.tokenCache = NewTokenCache()
ts.tOidc.limiter = rate.NewLimiter(rate.Every(time.Second), 10)
// Set up the test case
if tc.blacklist {
// Use Set with a duration. Value 'true' is arbitrary.
ts.tOidc.tokenBlacklist.Set(ts.token, true, 1*time.Hour)
}
if tc.rateLimit {
// Exceed rate limit
ts.tOidc.limiter = rate.NewLimiter(rate.Every(time.Hour), 0)
}
if tc.cacheToken {
// Cache the token with dummy claims
ts.tOidc.tokenCache.Set(ts.token, map[string]interface{}{
"empty": "claim",
}, 60)
}
})
}
}
// TestHandleLogout tests the logout functionality
func TestHandleLogout(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
// Create mock revocation endpoint server
mockRevocationServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
t.Errorf("Expected POST request, got %s", r.Method)
}
if err := r.ParseForm(); err != nil {
t.Fatalf("Failed to parse form: %v", err)
}
// Verify the required parameters are present
if r.Form.Get("token") == "" {
t.Error("Missing token parameter")
}
if r.Form.Get("token_type_hint") == "" {
t.Error("Missing token_type_hint parameter")
}
w.WriteHeader(http.StatusOK)
}))
defer mockRevocationServer.Close()
tests := []struct {
setupSession func(*SessionData)
name string
endSessionURL string
expectedURL string
host string
expectedStatus int
}{
{
name: "Successful logout with end session endpoint",
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetAccessToken(ValidAccessToken)
session.SetIDToken(ValidIDToken)
session.SetRefreshToken(ValidRefreshToken)
},
endSessionURL: "https://provider/end-session",
expectedStatus: http.StatusFound,
expectedURL: "https://provider/end-session?id_token_hint=" + url.QueryEscape(ValidIDToken) + "&post_logout_redirect_uri=http%3A%2F%2Fexample.com%2F",
host: "test-host",
},
{
name: "Successful logout without end session endpoint",
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetAccessToken(ValidAccessToken)
session.SetIDToken(ValidIDToken)
session.SetRefreshToken(ValidRefreshToken)
},
endSessionURL: "",
expectedStatus: http.StatusFound,
expectedURL: "http://example.com/",
host: "test-host",
},
{
name: "Logout with empty session",
setupSession: func(session *SessionData) {},
expectedStatus: http.StatusFound,
expectedURL: "http://example.com/",
host: "test-host",
},
{
name: "Logout with invalid end session URL",
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetAccessToken(ValidAccessToken)
session.SetIDToken(ValidIDToken)
session.SetRefreshToken(ValidRefreshToken)
},
endSessionURL: ":\\invalid-url",
expectedStatus: http.StatusInternalServerError,
host: "test-host",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
logger := NewLogger("info")
sessionManager, _ := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, "", "", 0, logger)
tOidc := &TraefikOidc{
revocationURL: mockRevocationServer.URL,
endSessionURL: tc.endSessionURL,
logger: logger,
tokenBlacklist: NewCache(), // Use generic cache for blacklist
httpClient: &http.Client{},
clientID: "test-client-id",
audience: "test-client-id",
clientSecret: "test-client-secret",
tokenCache: NewTokenCache(),
forceHTTPS: false,
sessionManager: sessionManager,
}
// Create request with proper headers
req := httptest.NewRequest("GET", "/logout", nil)
req.Header.Set("Host", tc.host)
// Create a response recorder
rr := httptest.NewRecorder()
// Get a session
session, err := sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
if tc.setupSession != nil {
tc.setupSession(session)
}
if err := session.Save(req, rr); err != nil {
t.Fatalf("Failed to save session: %v", err)
}
// Copy cookies to the new request
for _, cookie := range rr.Result().Cookies() {
req.AddCookie(cookie)
}
// Reset response recorder
rr = httptest.NewRecorder()
// Handle logout
tOidc.handleLogout(rr, req)
// Check response
if rr.Code != tc.expectedStatus {
t.Errorf("Expected status %d, got %d", tc.expectedStatus, rr.Code)
}
if tc.expectedURL != "" {
location := rr.Header().Get("Location")
if location != tc.expectedURL {
t.Errorf("Expected redirect to %q, got %q", tc.expectedURL, location)
}
}
// Verify session is cleared
updatedSession, err := sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get updated session: %v", err)
}
// Verify tokens are cleared
if token := updatedSession.GetAccessToken(); token != "" {
t.Error("Access token not cleared")
}
if token := updatedSession.GetRefreshToken(); token != "" {
t.Error("Refresh token not cleared")
}
if updatedSession.GetAuthenticated() {
t.Error("Session still marked as authenticated")
}
// Check token blacklist
if token := session.GetAccessToken(); token != "" {
if _, exists := tOidc.tokenBlacklist.Get(token); !exists {
t.Error("Access token was not blacklisted in cache")
}
}
if token := session.GetRefreshToken(); token != "" {
if _, exists := tOidc.tokenBlacklist.Get(token); !exists {
t.Error("Refresh token was not blacklisted in cache")
}
}
})
}
}
// TestRevokeTokenWithProvider tests the token revocation with provider
func TestRevokeTokenWithProvider(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
tests := []struct {
name string
token string
tokenType string
statusCode int
expectError bool
}{
{
name: "Successful token revocation",
token: "valid-token",
tokenType: "refresh_token",
statusCode: http.StatusOK,
expectError: false,
},
{
name: "Failed token revocation",
token: "invalid-token",
tokenType: "refresh_token",
statusCode: http.StatusBadRequest,
expectError: true,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Create test server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Verify request method and content type
if r.Method != "POST" {
t.Errorf("Expected POST request, got %s", r.Method)
}
if ct := r.Header.Get("Content-Type"); ct != "application/x-www-form-urlencoded" {
t.Errorf("Expected Content-Type application/x-www-form-urlencoded, got %s", ct)
}
// Verify form values
if err := r.ParseForm(); err != nil {
t.Fatalf("Failed to parse form: %v", err)
}
if got := r.Form.Get("token"); got != tc.token {
t.Errorf("Expected token %s, got %s", tc.token, got)
}
if got := r.Form.Get("token_type_hint"); got != tc.tokenType {
t.Errorf("Expected token_type_hint %s, got %s", tc.tokenType, got)
}
if got := r.Form.Get("client_id"); got != ts.tOidc.clientID {
t.Errorf("Expected client_id %s, got %s", ts.tOidc.clientID, got)
}
if got := r.Form.Get("client_secret"); got != ts.tOidc.clientSecret {
t.Errorf("Expected client_secret %s, got %s", ts.tOidc.clientSecret, got)
}
w.WriteHeader(tc.statusCode)
}))
defer server.Close()
// Set revocation URL to test server
ts.tOidc.revocationURL = server.URL
// Test token revocation
err := ts.tOidc.RevokeTokenWithProvider(tc.token, tc.tokenType)
if tc.expectError && err == nil {
t.Error("Expected error but got nil")
}
if !tc.expectError && err != nil {
t.Errorf("Unexpected error: %v", err)
}
})
}
}
// TestRevokeToken tests the token revocation functionality
func TestRevokeToken(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
token := "test.token.with.claims"
claims := map[string]interface{}{
"exp": float64(time.Now().Add(time.Hour).Unix()),
}
// Test token revocation
t.Run("Token revocation", func(t *testing.T) {
// Create a new instance for this specific test
tOidc := &TraefikOidc{
tokenBlacklist: NewCache(), // Use generic cache for blacklist
tokenCache: NewTokenCache(),
logger: NewLogger("info"), // Initialize the logger
}
// Cache the token
tOidc.tokenCache.Set(token, claims, time.Hour)
// Revoke the token
tOidc.RevokeToken(token)
// Verify token was removed from cache
if _, exists := tOidc.tokenCache.Get(token); exists {
t.Error("Token was not removed from cache")
}
// Verify token was added to blacklist cache
if _, exists := tOidc.tokenBlacklist.Get(token); !exists {
t.Error("Token was not added to blacklist")
}
})
}
// Add this new test function
func TestBuildLogoutURL(t *testing.T) {
tests := []struct {
name string
endSessionURL string
idToken string
postLogoutRedirect string
expectedURL string
expectError bool
}{
{
name: "Valid URL",
endSessionURL: "https://provider/end-session",
idToken: "test.id.token",
postLogoutRedirect: "http://example.com/",
expectedURL: "https://provider/end-session?id_token_hint=test.id.token&post_logout_redirect_uri=http%3A%2F%2Fexample.com%2F",
expectError: false,
},
{
name: "Invalid URL",
endSessionURL: "://invalid-url",
idToken: "test.id.token",
postLogoutRedirect: "http://example.com/",
expectError: true,
},
{
name: "URL with existing query parameters",
endSessionURL: "https://provider/end-session?existing=param",
idToken: "test.id.token",
postLogoutRedirect: "http://example.com/",
expectedURL: "https://provider/end-session?existing=param&id_token_hint=test.id.token&post_logout_redirect_uri=http%3A%2F%2Fexample.com%2F",
expectError: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
url, err := BuildLogoutURL(tc.endSessionURL, tc.idToken, tc.postLogoutRedirect)
if tc.expectError {
if err == nil {
t.Error("Expected error but got nil")
}
} else {
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if url != tc.expectedURL {
t.Errorf("Expected URL %q, got %q", tc.expectedURL, url)
}
}
})
}
}
// Add this new test function
func TestHandleExpiredToken(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
tests := []struct {
name string
setupSession func(*SessionData)
expectedPath string
}{
{
name: "Basic expired token",
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
// Create an expired token for this test
expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(-1 * time.Hour).Unix(),
"iat": time.Now().Add(-2 * time.Hour).Unix(), "nbf": time.Now().Add(-2 * time.Hour).Unix(),
"sub": "test-subject", "email": "test@example.com", "jti": generateRandomString(16),
})
session.SetAccessToken(expiredToken)
session.SetUserIdentifier("test@example.com")
},
expectedPath: "/original/path",
},
{
name: "Session with additional values",
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
// Create an expired token for this test
expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(-1 * time.Hour).Unix(),
"iat": time.Now().Add(-2 * time.Hour).Unix(), "nbf": time.Now().Add(-2 * time.Hour).Unix(),
"sub": "test-subject", "email": "test@example.com", "jti": generateRandomString(16),
})
session.SetAccessToken(expiredToken)
session.mainSession.Values["custom_value"] = "should-be-cleared"
},
expectedPath: "/another/path",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
logger := NewLogger("info")
sessionManager, _ := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, "", "", 0, logger)
tOidc := &TraefikOidc{
sessionManager: sessionManager,
logger: logger,
tokenVerifier: ts.tOidc.tokenVerifier,
jwtVerifier: ts.tOidc.jwtVerifier,
initComplete: make(chan struct{}),
initiateAuthenticationFunc: func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
http.Redirect(rw, req, "/login", http.StatusFound)
},
}
close(tOidc.initComplete)
// Create request
req := httptest.NewRequest("GET", tc.expectedPath, nil)
rr := httptest.NewRecorder()
// Get session
session, err := sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
// Setup session data
tc.setupSession(session)
// Handle expired token
tOidc.handleExpiredToken(rr, req, session, tc.expectedPath)
// Get the updated session to verify changes
updatedSession, err := sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get updated session: %v", err)
}
// Verify main session values
if updatedSession.GetCSRF() == "" {
t.Error("CSRF token not set")
}
if path := updatedSession.GetIncomingPath(); path != tc.expectedPath {
t.Errorf("Expected path %s, got %s", tc.expectedPath, path)
}
if updatedSession.GetNonce() == "" {
t.Error("Nonce not set")
}
// Verify tokens are cleared
if token := updatedSession.GetAccessToken(); token != "" {
t.Error("Access token not cleared")
}
if token := updatedSession.GetRefreshToken(); token != "" {
t.Error("Refresh token not cleared")
}
// Verify redirect status
if rr.Code != http.StatusFound {
t.Errorf("Expected status %d, got %d", http.StatusFound, rr.Code)
}
})
}
}
// Add this new test function
func TestExtractGroupsAndRoles(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
tests := []struct {
name string
claims map[string]interface{}
expectGroups []string
expectRoles []string
expectError bool
}{
{
name: "Valid groups and roles",
claims: map[string]interface{}{
"groups": []interface{}{"group1", "group2"},
"roles": []interface{}{"role1", "role2"},
},
expectGroups: []string{"group1", "group2"},
expectRoles: []string{"role1", "role2"},
expectError: false,
},
{
name: "Empty groups and roles",
claims: map[string]interface{}{
"groups": []interface{}{},
"roles": []interface{}{},
},
expectGroups: []string{},
expectRoles: []string{},
expectError: false,
},
{
name: "Invalid groups format",
claims: map[string]interface{}{
"groups": "not-an-array",
"roles": []interface{}{"role1"},
},
expectError: true,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Create a test token with the claims
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", tc.claims)
if err != nil {
t.Fatalf("Failed to create test token: %v", err)
}
groups, roles, err := ts.tOidc.extractGroupsAndRoles(token)
if tc.expectError {
if err == nil {
t.Error("Expected error but got nil")
}
} else {
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
// Compare groups
if !stringSliceEqual(groups, tc.expectGroups) {
t.Errorf("Expected groups %v, got %v", tc.expectGroups, groups)
}
// Compare roles
if !stringSliceEqual(roles, tc.expectRoles) {
t.Errorf("Expected roles %v, got %v", tc.expectRoles, roles)
}
}
})
}
}
// TestMultipleMiddlewareInstances verifies that multiple middleware instances
// can be created and initialized properly for different routes
func TestMultipleMiddlewareInstances(t *testing.T) {
if testing.Short() {
t.Skip("Skipping test in short mode")
}
// Create mock provider metadata server
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/.well-known/openid-configuration" {
w.WriteHeader(http.StatusNotFound)
return
}
metadata := ProviderMetadata{
Issuer: "https://test-issuer.com",
AuthURL: "https://test-issuer.com/auth",
TokenURL: "https://test-issuer.com/token",
JWKSURL: "https://test-issuer.com/jwks",
RevokeURL: "https://test-issuer.com/revoke",
EndSessionURL: "https://test-issuer.com/end-session",
}
json.NewEncoder(w).Encode(metadata)
}))
defer mockServer.Close()
// Create base config
config := &Config{
ProviderURL: mockServer.URL,
ClientID: "test-client",
ClientSecret: "test-secret",
CallbackURL: "/callback",
SessionEncryptionKey: "test-encryption-key-thats-long-enough",
}
// Create multiple middleware instances
routes := []string{"/api/v1", "/api/v2", "/api/v3"}
var middlewares []*TraefikOidc
for _, route := range routes {
config.CallbackURL = route + "/callback"
middleware, err := New(context.Background(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}), config, "test")
if err != nil {
t.Fatalf("Failed to create middleware for route %s: %v", route, err)
}
// Type assert to access internal fields
if m, ok := middleware.(*TraefikOidc); ok {
middlewares = append(middlewares, m)
} else {
t.Fatalf("Middleware is not of type *TraefikOidc")
}
}
// Clean up all middleware instances to prevent goroutine leaks
defer func() {
for i, m := range middlewares {
if err := m.Close(); err != nil {
t.Errorf("Failed to close middleware instance %d: %v", i, err)
}
}
}()
// Wait for all instances to initialize
for i, m := range middlewares {
select {
case <-m.initComplete:
case <-time.After(5 * time.Second):
t.Fatalf("Middleware instance %d failed to initialize", i)
}
// Verify each instance has its own unique configuration
if m.issuerURL != "https://test-issuer.com" {
t.Errorf("Instance %d: Expected issuer URL %s, got %s", i, "https://test-issuer.com", m.issuerURL)
}
if m.authURL != "https://test-issuer.com/auth" {
t.Errorf("Instance %d: Expected auth URL %s, got %s", i, "https://test-issuer.com/auth", m.authURL)
}
if m.tokenURL != "https://test-issuer.com/token" {
t.Errorf("Instance %d: Expected token URL %s, got %s", i, "https://test-issuer.com/token", m.tokenURL)
}
if m.jwksURL != "https://test-issuer.com/jwks" {
t.Errorf("Instance %d: Expected JWKS URL %s, got %s", i, "https://test-issuer.com/jwks", m.jwksURL)
}
if m.redirURLPath != routes[i]+"/callback" {
t.Errorf("Instance %d: Expected callback URL %s, got %s", i, routes[i]+"/callback", m.redirURLPath)
}
}
// Test that each instance can handle requests independently
for i, m := range middlewares {
req := httptest.NewRequest("GET", routes[i]+"/protected", nil)
rr := httptest.NewRecorder()
m.ServeHTTP(rr, req)
// Should redirect to auth URL since not authenticated
if rr.Code != http.StatusFound {
t.Errorf("Instance %d: Expected redirect status %d, got %d", i, http.StatusFound, rr.Code)
}
location := rr.Header().Get("Location")
if !strings.Contains(location, "https://test-issuer.com/auth") {
t.Errorf("Instance %d: Expected redirect to auth URL, got %s", i, location)
}
}
}
// TestMultiRealmMetadataRefreshIsolation verifies that multiple middleware instances
// with different provider URLs (e.g., different Keycloak realms) get separate
// metadata refresh tasks. This addresses the issue reported in PR #88.
func TestMultiRealmMetadataRefreshIsolation(t *testing.T) {
if testing.Short() {
t.Skip("Skipping test in short mode")
}
// Create two mock provider metadata servers simulating different Keycloak realms
realm1Server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/.well-known/openid-configuration" {
w.WriteHeader(http.StatusNotFound)
return
}
metadata := ProviderMetadata{
Issuer: "https://keycloak.example.com/realms/realm1",
AuthURL: "https://keycloak.example.com/realms/realm1/protocol/openid-connect/auth",
TokenURL: "https://keycloak.example.com/realms/realm1/protocol/openid-connect/token",
JWKSURL: "https://keycloak.example.com/realms/realm1/protocol/openid-connect/certs",
EndSessionURL: "https://keycloak.example.com/realms/realm1/protocol/openid-connect/logout",
}
json.NewEncoder(w).Encode(metadata)
}))
defer realm1Server.Close()
realm2Server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/.well-known/openid-configuration" {
w.WriteHeader(http.StatusNotFound)
return
}
metadata := ProviderMetadata{
Issuer: "https://keycloak.example.com/realms/realm2",
AuthURL: "https://keycloak.example.com/realms/realm2/protocol/openid-connect/auth",
TokenURL: "https://keycloak.example.com/realms/realm2/protocol/openid-connect/token",
JWKSURL: "https://keycloak.example.com/realms/realm2/protocol/openid-connect/certs",
EndSessionURL: "https://keycloak.example.com/realms/realm2/protocol/openid-connect/logout",
}
json.NewEncoder(w).Encode(metadata)
}))
defer realm2Server.Close()
// Config for realm1
config1 := &Config{
ProviderURL: realm1Server.URL,
ClientID: "realm1-client",
ClientSecret: "realm1-secret",
CallbackURL: "/realm1/callback",
SessionEncryptionKey: "test-encryption-key-thats-long-enough",
CookiePrefix: "_oidc_realm1_",
}
// Config for realm2
config2 := &Config{
ProviderURL: realm2Server.URL,
ClientID: "realm2-client",
ClientSecret: "realm2-secret",
CallbackURL: "/realm2/callback",
SessionEncryptionKey: "test-encryption-key-thats-long-enough",
CookiePrefix: "_oidc_realm2_",
}
// Create middleware instances for both realms
middleware1, err := New(context.Background(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}), config1, "realm1-middleware")
if err != nil {
t.Fatalf("Failed to create middleware for realm1: %v", err)
}
middleware2, err := New(context.Background(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}), config2, "realm2-middleware")
if err != nil {
t.Fatalf("Failed to create middleware for realm2: %v", err)
}
m1, ok1 := middleware1.(*TraefikOidc)
m2, ok2 := middleware2.(*TraefikOidc)
if !ok1 || !ok2 {
t.Fatalf("Middleware is not of type *TraefikOidc")
}
// Clean up middleware instances
defer func() {
if err := m1.Close(); err != nil {
t.Errorf("Failed to close realm1 middleware: %v", err)
}
if err := m2.Close(); err != nil {
t.Errorf("Failed to close realm2 middleware: %v", err)
}
}()
// Wait for both instances to initialize
select {
case <-m1.initComplete:
case <-time.After(5 * time.Second):
t.Fatalf("Realm1 middleware failed to initialize")
}
select {
case <-m2.initComplete:
case <-time.After(5 * time.Second):
t.Fatalf("Realm2 middleware failed to initialize")
}
// Verify each instance has the correct issuer URL from their respective realms
if !strings.Contains(m1.issuerURL, "realm1") {
t.Errorf("Realm1 middleware expected issuer with realm1, got %s", m1.issuerURL)
}
if !strings.Contains(m2.issuerURL, "realm2") {
t.Errorf("Realm2 middleware expected issuer with realm2, got %s", m2.issuerURL)
}
// Verify provider URLs are different
if m1.providerURL == m2.providerURL {
t.Errorf("Both middlewares should have different provider URLs, got same: %s", m1.providerURL)
}
// Test that each middleware can handle requests independently
req1 := httptest.NewRequest("GET", "/realm1/protected", nil)
rr1 := httptest.NewRecorder()
m1.ServeHTTP(rr1, req1)
req2 := httptest.NewRequest("GET", "/realm2/protected", nil)
rr2 := httptest.NewRecorder()
m2.ServeHTTP(rr2, req2)
// Both should redirect to their respective auth URLs
if rr1.Code != http.StatusFound {
t.Errorf("Realm1: Expected redirect status %d, got %d", http.StatusFound, rr1.Code)
}
if rr2.Code != http.StatusFound {
t.Errorf("Realm2: Expected redirect status %d, got %d", http.StatusFound, rr2.Code)
}
location1 := rr1.Header().Get("Location")
location2 := rr2.Header().Get("Location")
if !strings.Contains(location1, "realm1") {
t.Errorf("Realm1: Expected redirect to realm1 auth URL, got %s", location1)
}
if !strings.Contains(location2, "realm2") {
t.Errorf("Realm2: Expected redirect to realm2 auth URL, got %s", location2)
}
}
// TestMetadataRecoveryOnProviderFailure verifies that the middleware automatically
// recovers when the OIDC provider becomes available after initial failure.
func TestMetadataRecoveryOnProviderFailure(t *testing.T) {
if testing.Short() {
t.Skip("Skipping test in short mode")
}
// Track whether the provider is "available"
providerAvailable := false
var mu sync.Mutex
// Create mock provider that initially fails, then becomes available
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mu.Lock()
available := providerAvailable
mu.Unlock()
if !available {
w.WriteHeader(http.StatusServiceUnavailable)
return
}
if r.URL.Path == "/.well-known/openid-configuration" {
metadata := ProviderMetadata{
Issuer: "https://test-issuer.com",
AuthURL: "https://test-issuer.com/auth",
TokenURL: "https://test-issuer.com/token",
JWKSURL: "https://test-issuer.com/jwks",
EndSessionURL: "https://test-issuer.com/logout",
}
json.NewEncoder(w).Encode(metadata)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer mockServer.Close()
config := &Config{
ProviderURL: mockServer.URL,
ClientID: "test-client",
ClientSecret: "test-secret",
CallbackURL: "/callback",
SessionEncryptionKey: "test-encryption-key-thats-long-enough",
}
// Create middleware while provider is unavailable
middleware, err := New(context.Background(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}), config, "test-recovery")
if err != nil {
t.Fatalf("Failed to create middleware: %v", err)
}
m, ok := middleware.(*TraefikOidc)
if !ok {
t.Fatalf("Middleware is not of type *TraefikOidc")
}
defer m.Close()
// Wait for initial initialization to complete (it should fail)
select {
case <-m.initComplete:
case <-time.After(15 * time.Second):
t.Fatal("Initialization did not complete in time")
}
// Verify initial state - should be in failed state (no issuerURL)
m.metadataMu.RLock()
initialIssuer := m.issuerURL
m.metadataMu.RUnlock()
if initialIssuer != "" {
t.Errorf("Expected empty issuerURL after failed init, got: %s", initialIssuer)
}
// First request should get 503
req1 := httptest.NewRequest("GET", "/protected", nil)
rr1 := httptest.NewRecorder()
m.ServeHTTP(rr1, req1)
if rr1.Code != http.StatusServiceUnavailable {
t.Errorf("Expected 503 when provider unavailable, got %d", rr1.Code)
}
// Now make the provider available
mu.Lock()
providerAvailable = true
mu.Unlock()
// Reset the retry timer to allow immediate retry
m.metadataRetryMutex.Lock()
m.lastMetadataRetryTime = time.Time{} // Reset to zero time
m.metadataRetryMutex.Unlock()
// Second request should trigger recovery attempt
req2 := httptest.NewRequest("GET", "/protected", nil)
rr2 := httptest.NewRecorder()
m.ServeHTTP(rr2, req2)
// Give the async recovery a moment to complete
time.Sleep(100 * time.Millisecond)
// Check if recovery happened
m.metadataMu.RLock()
recoveredIssuer := m.issuerURL
m.metadataMu.RUnlock()
if recoveredIssuer == "" {
t.Error("Expected issuerURL to be recovered after provider became available")
}
// Third request should succeed (redirect to auth, not 503)
req3 := httptest.NewRequest("GET", "/protected", nil)
rr3 := httptest.NewRecorder()
m.ServeHTTP(rr3, req3)
if rr3.Code == http.StatusServiceUnavailable {
t.Errorf("Expected redirect after recovery, still got 503")
}
t.Logf("Recovery test: initial_issuer=%q, recovered_issuer=%q, final_status=%d",
initialIssuer, recoveredIssuer, rr3.Code)
}
func TestServeHTTPRolesAndGroups(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
// Create consistent timestamps for all test cases
now := time.Now()
exp := now.Add(1 * time.Hour).Unix()
iat := now.Add(-2 * time.Minute).Unix() // Account for clock skew
nbf := now.Add(-2 * time.Minute).Unix() // Account for clock skew
tests := []struct {
allowedRolesAndGroups map[string]struct{}
claims map[string]interface{}
setupSession func(*SessionData)
expectedHeaders map[string]string
name string
expectedStatus int
}{
{
name: "User with allowed role",
allowedRolesAndGroups: map[string]struct{}{
"admin": {},
},
claims: map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": exp,
"iat": iat,
"nbf": nbf,
"sub": "test-subject",
"roles": []interface{}{"admin", "user"},
"groups": []interface{}{"group1"},
"jti": generateRandomString(16),
},
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetUserIdentifier("user@example.com")
},
expectedStatus: http.StatusOK,
expectedHeaders: map[string]string{
"X-User-Roles": "admin,user",
"X-User-Groups": "group1",
},
},
{
name: "User with allowed group",
allowedRolesAndGroups: map[string]struct{}{
"allowed-group": {},
},
claims: map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": exp,
"iat": iat,
"nbf": nbf,
"sub": "test-subject",
"roles": []interface{}{"user"},
"groups": []interface{}{"allowed-group"},
"jti": generateRandomString(16),
},
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetUserIdentifier("user@example.com")
},
expectedStatus: http.StatusOK,
expectedHeaders: map[string]string{
"X-User-Roles": "user",
"X-User-Groups": "allowed-group",
},
},
{
name: "User without allowed roles or groups",
allowedRolesAndGroups: map[string]struct{}{
"admin": {},
"allowed-group": {},
},
claims: map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": exp,
"iat": iat,
"nbf": nbf,
"sub": "test-subject",
"roles": []interface{}{"user"},
"groups": []interface{}{"regular-group"},
"jti": generateRandomString(16),
},
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetUserIdentifier("user@example.com")
},
expectedStatus: http.StatusForbidden,
},
{
name: "No role/group restrictions",
allowedRolesAndGroups: map[string]struct{}{},
claims: map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": exp,
"iat": iat,
"nbf": nbf,
"sub": "test-subject",
"roles": []interface{}{"user"},
"groups": []interface{}{"regular-group"},
"jti": generateRandomString(16),
},
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetUserIdentifier("user@example.com")
},
expectedStatus: http.StatusOK,
expectedHeaders: map[string]string{
"X-User-Roles": "user",
"X-User-Groups": "regular-group",
},
},
{
name: "Claims without roles and groups",
allowedRolesAndGroups: map[string]struct{}{},
claims: map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": exp,
"iat": iat,
"nbf": nbf,
"sub": "test-subject",
"jti": generateRandomString(16),
},
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetUserIdentifier("user@example.com")
},
expectedStatus: http.StatusOK,
expectedHeaders: map[string]string{},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Create token with claims
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", tc.claims)
if err != nil {
t.Fatalf("Failed to create test token: %v", err)
}
// Create test handler
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
// Configure OIDC middleware
tOidc := ts.tOidc
tOidc.next = nextHandler
tOidc.allowedRolesAndGroups = tc.allowedRolesAndGroups
// Create request
req := httptest.NewRequest("GET", "/protected", nil)
rr := httptest.NewRecorder()
// Set up session
session, err := tOidc.sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
tc.setupSession(session)
session.SetAccessToken(token)
if err := session.Save(req, rr); err != nil {
t.Fatalf("Failed to save session: %v", err)
}
// Copy cookies to the new request
for _, cookie := range rr.Result().Cookies() {
req.AddCookie(cookie)
}
// Reset response recorder
rr = httptest.NewRecorder()
// Serve request
tOidc.ServeHTTP(rr, req)
// Check status code
if rr.Code != tc.expectedStatus {
t.Errorf("Expected status %d, got %d", tc.expectedStatus, rr.Code)
}
// Check headers if status is OK
if tc.expectedStatus == http.StatusOK {
for header, expectedValue := range tc.expectedHeaders {
if value := req.Header.Get(header); value != expectedValue {
t.Errorf("Expected header %s to be %s, got %s", header, expectedValue, value)
}
}
}
})
}
}
// Helper function to compare string slices
func stringSliceEqual(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
// TestExchangeTokensWithRedirects tests the token exchange process with redirects
func TestExchangeTokensWithRedirects(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
tests := []struct {
setupServer func() *httptest.Server
name string
errorContains string
expectError bool
}{
{
name: "Successful token exchange with redirects",
setupServer: func() *httptest.Server {
redirectCount := 0
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if redirectCount < 3 {
// Set a cookie before redirecting
http.SetCookie(w, &http.Cookie{
Name: fmt.Sprintf("redirect-cookie-%d", redirectCount),
Value: "test-value",
})
redirectCount++
w.Header().Set("Location", r.URL.String())
w.WriteHeader(http.StatusFound)
return
}
// Verify all cookies from previous redirects are present
cookies := r.Cookies()
if len(cookies) != 3 {
t.Errorf("Expected 3 cookies, got %d", len(cookies))
}
for i := range 3 {
found := false
expectedName := fmt.Sprintf("redirect-cookie-%d", i)
for _, cookie := range cookies {
if cookie.Name == expectedName {
found = true
break
}
}
if !found {
t.Errorf("Cookie %s not found", expectedName)
}
}
// Return successful token response
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(TokenResponse{
IDToken: "test.id.token",
AccessToken: "test-access-token",
TokenType: "Bearer",
ExpiresIn: 3600,
RefreshToken: "test-refresh-token",
})
}))
},
expectError: false,
},
{
name: "Too many redirects",
setupServer: func() *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Location", r.URL.String())
w.WriteHeader(http.StatusFound)
}))
},
expectError: true,
errorContains: "stopped after 50 redirects",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
server := tc.setupServer()
defer server.Close()
// Configure the test instance
tOidc := ts.tOidc
tOidc.tokenURL = server.URL
// Test token exchange
response, err := tOidc.exchangeTokens(context.Background(), "authorization_code", "test-code", "http://callback", "test-code-verifier")
if tc.expectError {
if err == nil {
t.Error("Expected error but got nil")
} else if !strings.Contains(err.Error(), tc.errorContains) {
t.Errorf("Expected error containing %q, got %q", tc.errorContains, err.Error())
}
} else {
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if response == nil {
t.Error("Expected token response but got nil")
} else if response.IDToken != "test.id.token" {
t.Errorf("Expected ID token %q, got %q", "test.id.token", response.IDToken)
}
}
})
}
}
// TestBuildAuthURL tests the buildAuthURL function with various URL scenarios
func TestBuildAuthURL(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
tests := []struct {
name string
authURL string
issuerURL string
redirectURL string
state string
nonce string
codeChallenge string
expectedPrefix string
enablePKCE bool
checkPKCE bool
}{
{
name: "Absolute Auth URL",
authURL: "https://auth.example.com/oauth/authorize",
issuerURL: "https://auth.example.com",
redirectURL: "https://app.example.com/callback",
state: "test-state",
nonce: "test-nonce",
enablePKCE: false,
codeChallenge: "",
expectedPrefix: "https://auth.example.com/oauth/authorize?",
checkPKCE: false,
},
{
name: "Relative Auth URL",
authURL: "/oidc/auth",
issuerURL: "https://logto.example.com",
redirectURL: "https://app.example.com/callback",
state: "test-state",
nonce: "test-nonce",
enablePKCE: false,
codeChallenge: "",
expectedPrefix: "https://logto.example.com/oidc/auth?",
checkPKCE: false,
},
{
name: "Relative Auth URL with Different Issuer",
authURL: "/sign-in",
issuerURL: "https://auth.example.com:8443",
redirectURL: "https://app.example.com/callback",
state: "test-state",
nonce: "test-nonce",
enablePKCE: false,
codeChallenge: "",
expectedPrefix: "https://auth.example.com:8443/sign-in?",
checkPKCE: false,
},
{
name: "With PKCE Enabled",
authURL: "https://auth.example.com/oauth/authorize",
issuerURL: "https://auth.example.com",
redirectURL: "https://app.example.com/callback",
state: "test-state",
nonce: "test-nonce",
enablePKCE: true,
codeChallenge: "test-code-challenge",
expectedPrefix: "https://auth.example.com/oauth/authorize?",
checkPKCE: true,
},
{
name: "With PKCE Enabled but No Challenge",
authURL: "https://auth.example.com/oauth/authorize",
issuerURL: "https://auth.example.com",
redirectURL: "https://app.example.com/callback",
state: "test-state",
nonce: "test-nonce",
enablePKCE: true,
codeChallenge: "",
expectedPrefix: "https://auth.example.com/oauth/authorize?",
checkPKCE: false,
},
{
name: "With PKCE Disabled but Challenge Provided",
authURL: "https://auth.example.com/oauth/authorize",
issuerURL: "https://auth.example.com",
redirectURL: "https://app.example.com/callback",
state: "test-state",
nonce: "test-nonce",
enablePKCE: false,
codeChallenge: "test-code-challenge",
expectedPrefix: "https://auth.example.com/oauth/authorize?",
checkPKCE: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Configure the test instance
tOidc := ts.tOidc
tOidc.authURL = tc.authURL
tOidc.issuerURL = tc.issuerURL
tOidc.enablePKCE = tc.enablePKCE
// Call buildAuthURL with code challenge
result := tOidc.buildAuthURL(tc.redirectURL, tc.state, tc.nonce, tc.codeChallenge)
// Verify the URL starts with the expected prefix
if !strings.HasPrefix(result, tc.expectedPrefix) {
t.Errorf("Expected URL to start with %q, got %q", tc.expectedPrefix, result)
}
// Parse the resulting URL to verify query parameters
parsedURL, err := url.Parse(result)
if err != nil {
t.Fatalf("Failed to parse resulting URL: %v", err)
}
query := parsedURL.Query()
expectedParams := map[string]string{
"client_id": tOidc.clientID,
"response_type": "code",
"redirect_uri": tc.redirectURL,
"state": tc.state,
"nonce": tc.nonce,
}
for key, expected := range expectedParams {
if got := query.Get(key); got != expected {
t.Errorf("Expected %s=%q, got %q", key, expected, got)
}
}
// Verify PKCE parameters
if tc.checkPKCE {
if got := query.Get("code_challenge"); got != tc.codeChallenge {
t.Errorf("Expected code_challenge=%q, got %q", tc.codeChallenge, got)
}
if got := query.Get("code_challenge_method"); got != "S256" {
t.Errorf("Expected code_challenge_method=%q, got %q", "S256", got)
}
} else {
if got := query.Get("code_challenge"); got != "" {
t.Errorf("Expected no code_challenge, but got %q", got)
}
if got := query.Get("code_challenge_method"); got != "" {
t.Errorf("Expected no code_challenge_method, but got %q", got)
}
}
// Verify scopes are present and correct
if len(tOidc.scopes) > 0 {
expectedScopes := strings.Join(tOidc.scopes, " ")
if got := query.Get("scope"); got != expectedScopes {
t.Errorf("Expected scope=%q, got %q", expectedScopes, got)
}
}
})
}
}
// TestExchangeCodeForToken tests the exchangeCodeForToken function with PKCE support
func TestExchangeCodeForToken(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
tests := []struct {
setupMock func(t *testing.T) *httptest.Server
name string
codeVerifier string
enablePKCE bool
}{
{
name: "With PKCE Enabled and Code Verifier",
enablePKCE: true,
codeVerifier: "test-code-verifier",
setupMock: func(t *testing.T) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
t.Fatalf("Failed to parse form: %v", err)
}
// Verify code_verifier is included
if codeVerifier := r.Form.Get("code_verifier"); codeVerifier != "test-code-verifier" {
t.Errorf("Expected code_verifier=test-code-verifier, got %s", codeVerifier)
}
// Return successful token response
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(TokenResponse{
IDToken: "test.id.token",
AccessToken: "test-access-token",
TokenType: "Bearer",
ExpiresIn: 3600,
RefreshToken: "test-refresh-token",
})
}))
},
},
{
name: "With PKCE Disabled but Code Verifier Provided",
enablePKCE: false,
codeVerifier: "test-code-verifier",
setupMock: func(t *testing.T) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
t.Fatalf("Failed to parse form: %v", err)
}
// Verify code_verifier is NOT included
if codeVerifier := r.Form.Get("code_verifier"); codeVerifier != "" {
t.Errorf("Expected no code_verifier, got %s", codeVerifier)
}
// Return successful token response
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(TokenResponse{
IDToken: "test.id.token",
AccessToken: "test-access-token",
TokenType: "Bearer",
ExpiresIn: 3600,
RefreshToken: "test-refresh-token",
})
}))
},
},
{
name: "With PKCE Enabled but No Code Verifier",
enablePKCE: true,
codeVerifier: "",
setupMock: func(t *testing.T) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
t.Fatalf("Failed to parse form: %v", err)
}
// Verify code_verifier is NOT included
if codeVerifier := r.Form.Get("code_verifier"); codeVerifier != "" {
t.Errorf("Expected no code_verifier, got %s", codeVerifier)
}
// Return successful token response
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(TokenResponse{
IDToken: "test.id.token",
AccessToken: "test-access-token",
TokenType: "Bearer",
ExpiresIn: 3600,
RefreshToken: "test-refresh-token",
})
}))
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
server := tc.setupMock(t)
defer server.Close()
// Configure the test instance
tOidc := ts.tOidc
tOidc.tokenURL = server.URL
tOidc.enablePKCE = tc.enablePKCE
// Test exchangeCodeForToken
response, err := tOidc.exchangeCodeForToken("test-code", "http://callback", tc.codeVerifier)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if response == nil {
t.Error("Expected token response but got nil")
} else if response.IDToken != "test.id.token" {
t.Errorf("Expected ID token %q, got %q", "test.id.token", response.IDToken)
}
})
}
}
// TestDefaultInitiateAuthentication_PreservesQueryParameters tests that defaultInitiateAuthentication preserves query parameters in the incoming path.
func TestDefaultInitiateAuthentication_PreservesQueryParameters(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
// Create a request with query parameters
req := httptest.NewRequest("GET", "/protected/resource?param1=value1&param2=value2", nil)
responseRecorder := httptest.NewRecorder()
// Get session
session, err := ts.sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
// Call defaultInitiateAuthentication
redirectURL := "http://example.com/callback"
ts.tOidc.defaultInitiateAuthentication(responseRecorder, req, session, redirectURL)
// Verify that the incoming path includes query parameters
incomingPath := session.GetIncomingPath()
expectedPath := "/protected/resource?param1=value1&param2=value2"
if incomingPath != expectedPath {
t.Errorf("Expected incoming path to be '%s', got '%s'", expectedPath, incomingPath)
}
}
// TestVerifyTimeConstraint tests the time constraint verification logic with separate past/future skew tolerances.
func TestVerifyTimeConstraint(t *testing.T) {
// Define tolerances used in jwt.go (ensure they match)
toleranceFuture := 2 * time.Minute
tolerancePast := 10 * time.Second
now := time.Now()
tests := []struct {
name string
claimTime time.Time
claimName string
futureCheck bool // true for exp, false for iat/nbf
expectError bool
}{
// Expiration (future=true, tolerance=2min)
{
name: "EXP: Valid (expires in 1 min)",
claimTime: now.Add(1 * time.Minute),
claimName: "exp",
futureCheck: true,
expectError: false,
},
{
name: "EXP: Expired (expired 3 min ago)",
claimTime: now.Add(-3 * time.Minute), // Outside 2min tolerance
claimName: "exp",
futureCheck: true,
expectError: true,
},
{
name: "EXP: Valid (expired 1 min ago, within 2min tolerance)",
claimTime: now.Add(-1 * time.Minute), // Inside 2min tolerance
claimName: "exp",
futureCheck: true,
expectError: false, // Should be allowed due to future tolerance
},
// Issued At (future=false, tolerance=10s)
{
name: "IAT: Valid (issued 1 min ago)",
claimTime: now.Add(-1 * time.Minute),
claimName: "iat",
futureCheck: false,
expectError: false,
},
{
name: "IAT: Invalid (issued 15 sec in future)",
claimTime: now.Add(15 * time.Second), // Outside 10s past tolerance
claimName: "iat",
futureCheck: false,
expectError: true, // "token used before issued"
},
{
name: "IAT: Valid (issued 5 sec in future, within 10s tolerance)",
claimTime: now.Add(5 * time.Second), // Inside 10s past tolerance
claimName: "iat",
futureCheck: false,
expectError: false, // Should be allowed due to past tolerance
},
// Not Before (future=false, tolerance=10s)
{
name: "NBF: Valid (active 1 min ago)",
claimTime: now.Add(-1 * time.Minute),
claimName: "nbf",
futureCheck: false,
expectError: false,
},
{
name: "NBF: Invalid (active in 15 sec)",
claimTime: now.Add(15 * time.Second), // Outside 10s past tolerance
claimName: "nbf",
futureCheck: false,
expectError: true, // "token not yet valid"
},
{
name: "NBF: Valid (active in 5 sec, within 10s tolerance)",
claimTime: now.Add(5 * time.Second), // Inside 10s past tolerance
claimName: "nbf",
futureCheck: false,
expectError: false, // Should be allowed due to past tolerance
},
}
// Temporarily adjust global tolerances for test consistency, then restore
originalFutureTolerance := ClockSkewToleranceFuture
originalPastTolerance := ClockSkewTolerancePast
ClockSkewToleranceFuture = toleranceFuture
ClockSkewTolerancePast = tolerancePast
defer func() {
ClockSkewToleranceFuture = originalFutureTolerance
ClockSkewTolerancePast = originalPastTolerance
}()
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Convert claim time to float64 unix timestamp
unixTime := float64(tc.claimTime.Unix()) + float64(tc.claimTime.Nanosecond())/1e9
var err error
// Call the specific verification function which uses verifyTimeConstraint
if tc.claimName == "exp" {
err = verifyExpiration(unixTime)
} else if tc.claimName == "iat" {
err = verifyIssuedAt(unixTime)
} else if tc.claimName == "nbf" {
err = verifyNotBefore(unixTime)
} else {
t.Fatalf("Unknown claim name in test setup: %s", tc.claimName)
}
if tc.expectError {
if err == nil {
t.Errorf("Expected error for claim %s at time %v (now=%v), but got nil", tc.claimName, tc.claimTime, now)
} else {
t.Logf("Got expected error: %v", err) // Log the error for confirmation
}
} else {
if err != nil {
t.Errorf("Expected no error for claim %s at time %v (now=%v), but got: %v", tc.claimName, tc.claimTime, now, err)
}
}
})
}
} // Add missing closing brace for TestVerifyTimeConstraint
// ===== JWT REPLAY DETECTION TESTS =====
// These tests ensure the replay detection fix works correctly and prevents regressions
// TestJWTVerifyWithSkipReplayCheck tests the new skipReplayCheck parameter functionality
func TestJWTVerifyWithSkipReplayCheck(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
// Clear the global replay cache before test
cleanupReplayCache()
initReplayCache()
// Create a test JWT with unique JTI
jti := generateRandomString(16)
now := time.Now()
exp := now.Add(1 * time.Hour).Unix()
iat := now.Unix()
nbf := now.Unix()
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": exp,
"iat": iat,
"nbf": nbf,
"sub": "test-subject",
"email": "user@example.com",
"nonce": "test-nonce",
"jti": jti,
})
if err != nil {
t.Fatalf("Failed to create test JWT: %v", err)
}
jwt, err := parseJWT(token)
if err != nil {
t.Fatalf("Failed to parse JWT: %v", err)
}
tests := []struct {
name string
errorContains string
skipReplayCheck bool
firstCall bool
expectError bool
}{
{
name: "First verification with skipReplayCheck=false should succeed",
skipReplayCheck: false,
firstCall: true,
expectError: false,
},
{
name: "Second verification with skipReplayCheck=false should fail (replay detected)",
skipReplayCheck: false,
firstCall: false,
expectError: true,
errorContains: "token replay detected",
},
{
name: "Verification with skipReplayCheck=true should always succeed",
skipReplayCheck: true,
firstCall: false, // Even on subsequent calls
expectError: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
if tc.firstCall {
// Clear replay cache for first call tests
cleanupReplayCache()
initReplayCache()
}
err := jwt.Verify("https://test-issuer.com", "test-client-id", tc.skipReplayCheck)
if tc.expectError {
if err == nil {
t.Errorf("Expected error containing '%s', but got nil", tc.errorContains)
} else if !strings.Contains(err.Error(), tc.errorContains) {
t.Errorf("Expected error containing '%s', got '%v'", tc.errorContains, err)
}
} else {
if err != nil {
t.Errorf("Expected no error, but got: %v", err)
}
}
})
}
}
// TestJWTVerifyBackwardCompatibility tests that calls without the skipReplayCheck parameter default to replay checking
func TestJWTVerifyBackwardCompatibility(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
// Clear the global replay cache
cleanupReplayCache()
initReplayCache()
// Create a test JWT with unique JTI
jti := generateRandomString(16)
now := time.Now()
exp := now.Add(1 * time.Hour).Unix()
iat := now.Unix()
nbf := now.Unix()
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": exp,
"iat": iat,
"nbf": nbf,
"sub": "test-subject",
"email": "user@example.com",
"nonce": "test-nonce",
"jti": jti,
})
if err != nil {
t.Fatalf("Failed to create test JWT: %v", err)
}
jwt, err := parseJWT(token)
if err != nil {
t.Fatalf("Failed to parse JWT: %v", err)
}
// First call with old signature (no skipReplayCheck parameter) should succeed
err = jwt.Verify("https://test-issuer.com", "test-client-id")
if err != nil {
t.Errorf("First verification should succeed, got: %v", err)
}
// Second call with old signature should fail due to replay detection
err = jwt.Verify("https://test-issuer.com", "test-client-id")
if err == nil {
t.Error("Second verification should fail due to replay detection")
} else if !strings.Contains(err.Error(), "token replay detected") {
t.Errorf("Expected 'token replay detected' error, got: %v", err)
}
}
// TestTokenReplayDetectionFalsePositiveFix tests the specific scenario that was causing false positives
func TestTokenReplayDetectionFalsePositiveFix(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
// Clear the global replay cache
cleanupReplayCache()
initReplayCache()
// Create a test JWT with unique JTI
jti := generateRandomString(16)
now := time.Now()
exp := now.Add(1 * time.Hour).Unix()
iat := now.Unix()
nbf := now.Unix()
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": exp,
"iat": iat,
"nbf": nbf,
"sub": "test-subject",
"email": "user@example.com",
"nonce": "test-nonce",
"jti": jti,
})
if err != nil {
t.Fatalf("Failed to create test JWT: %v", err)
}
// Simulate the authentication flow that was causing false positives:
// 1. Initial authentication adds JTI to cache
// 2. Subsequent request validation should not trigger false positive
// Step 1: Initial authentication (this would add JTI to cache)
jwt1, err := parseJWT(token)
if err != nil {
t.Fatalf("Failed to parse JWT for initial auth: %v", err)
}
err = jwt1.Verify("https://test-issuer.com", "test-client-id", false) // Normal replay check
if err != nil {
t.Fatalf("Initial authentication should succeed: %v", err)
}
// Step 2: Subsequent request validation (this should skip replay check to avoid false positive)
jwt2, err := parseJWT(token)
if err != nil {
t.Fatalf("Failed to parse JWT for subsequent request: %v", err)
}
err = jwt2.Verify("https://test-issuer.com", "test-client-id", true) // Skip replay check
if err != nil {
t.Errorf("Subsequent request validation should succeed with skipReplayCheck=true: %v", err)
}
// Step 3: Verify that actual replay attacks are still detected
jwt3, err := parseJWT(token)
if err != nil {
t.Fatalf("Failed to parse JWT for replay attack test: %v", err)
}
err = jwt3.Verify("https://test-issuer.com", "test-client-id", false) // Normal replay check
if err == nil {
t.Error("Actual replay attack should be detected when skipReplayCheck=false")
} else if !strings.Contains(err.Error(), "token replay detected") {
t.Errorf("Expected 'token replay detected' error, got: %v", err)
}
}
// TestAuthenticationFlowReplayDetection tests the complete authentication flow
func TestAuthenticationFlowReplayDetection(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
// Clear the global replay cache
cleanupReplayCache()
initReplayCache()
// Create a test JWT with unique JTI
jti := generateRandomString(16)
now := time.Now()
exp := now.Add(1 * time.Hour).Unix()
iat := now.Unix()
nbf := now.Unix()
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": exp,
"iat": iat,
"nbf": nbf,
"sub": "test-subject",
"email": "user@example.com",
"nonce": "test-nonce",
"jti": jti,
})
if err != nil {
t.Fatalf("Failed to create test JWT: %v", err)
}
// Test the complete flow:
// 1. Initial authentication (should add JTI to cache)
// 2. Multiple subsequent requests (should not trigger false positives)
// 3. Actual replay attack from different source (should be detected)
// Step 1: Initial authentication
err = ts.tOidc.VerifyToken(token)
if err != nil {
t.Fatalf("Initial authentication should succeed: %v", err)
}
// Verify JTI is in cache (use shardedReplayCache which is the actual cache used)
exists := shardedReplayCache.Exists(jti)
if !exists {
t.Error("JTI should be added to replay cache during initial authentication")
}
// Step 2: Subsequent requests (simulate normal request processing)
// These should use the token cache and skip replay detection
for i := range 3 {
err = ts.tOidc.VerifyToken(token)
if err != nil {
t.Errorf("Subsequent request %d should succeed: %v", i+1, err)
}
}
// Step 3: Simulate actual replay attack by directly calling JWT.Verify with replay check
jwt, err := parseJWT(token)
if err != nil {
t.Fatalf("Failed to parse JWT for replay attack test: %v", err)
}
err = jwt.Verify("https://test-issuer.com", "test-client-id", false) // Force replay check
if err == nil {
t.Error("Actual replay attack should be detected")
} else if !strings.Contains(err.Error(), "token replay detected") {
t.Errorf("Expected 'token replay detected' error, got: %v", err)
}
}
// TestActualReplayAttackDetection ensures real replay attacks are still properly detected
func TestActualReplayAttackDetection(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
// Clear the global replay cache
cleanupReplayCache()
initReplayCache()
// Create a test JWT with unique JTI
jti := generateRandomString(16)
now := time.Now()
exp := now.Add(1 * time.Hour).Unix()
iat := now.Unix()
nbf := now.Unix()
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": exp,
"iat": iat,
"nbf": nbf,
"sub": "test-subject",
"email": "user@example.com",
"nonce": "test-nonce",
"jti": jti,
})
if err != nil {
t.Fatalf("Failed to create test JWT: %v", err)
}
jwt, err := parseJWT(token)
if err != nil {
t.Fatalf("Failed to parse JWT: %v", err)
}
// First verification should succeed
err = jwt.Verify("https://test-issuer.com", "test-client-id", false)
if err != nil {
t.Fatalf("First verification should succeed: %v", err)
}
// Simulate different types of replay attacks
replayTests := []struct {
name string
description string
}{
{
name: "Direct replay attack",
description: "Same token used again with replay checking enabled",
},
{
name: "Replay from different source",
description: "Token intercepted and replayed by attacker",
},
}
for _, rt := range replayTests {
t.Run(rt.name, func(t *testing.T) {
// Parse token again (simulating replay)
replayJWT, err := parseJWT(token)
if err != nil {
t.Fatalf("Failed to parse JWT for replay test: %v", err)
}
// Attempt replay with normal replay checking
err = replayJWT.Verify("https://test-issuer.com", "test-client-id", false)
if err == nil {
t.Errorf("Replay attack should be detected for: %s", rt.description)
} else if !strings.Contains(err.Error(), "token replay detected") {
t.Errorf("Expected 'token replay detected' error for %s, got: %v", rt.description, err)
}
})
}
}
// TestConcurrentTokenValidation tests thread safety of replay detection
func TestConcurrentTokenValidation(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
// Configure rate limiter to allow more requests for concurrent testing
ts.tOidc.limiter = rate.NewLimiter(rate.Limit(1000), 1000) // Allow 1000 requests per second with burst of 1000
// Clear the global replay cache
cleanupReplayCache()
initReplayCache()
// Create multiple tokens with unique JTIs
var tokens []string
var jtis []string
now := time.Now()
exp := now.Add(1 * time.Hour).Unix()
iat := now.Unix()
nbf := now.Unix()
for i := range 10 {
jti := generateRandomString(16)
jtis = append(jtis, jti)
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": exp,
"iat": iat,
"nbf": nbf,
"sub": "test-subject",
"email": "user@example.com",
"nonce": "test-nonce",
"jti": jti,
})
if err != nil {
t.Fatalf("Failed to create test JWT %d: %v", i, err)
}
tokens = append(tokens, token)
}
// Test concurrent validation
const numGoroutines = 20
const numIterations = 5
results := make(chan error, numGoroutines*numIterations)
for g := range numGoroutines {
go func(goroutineID int) {
for i := range numIterations {
tokenIndex := (goroutineID + i) % len(tokens)
token := tokens[tokenIndex]
// First validation should succeed
err := ts.tOidc.VerifyToken(token)
results <- err
// Subsequent validation with same token should also succeed (uses cache)
err = ts.tOidc.VerifyToken(token)
results <- err
}
}(g)
}
// Collect results
var errors []error
for range numGoroutines * numIterations * 2 {
if err := <-results; err != nil {
errors = append(errors, err)
}
}
// All validations should succeed (no race conditions)
if len(errors) > 0 {
t.Errorf("Expected no errors in concurrent validation, got %d errors: %v", len(errors), errors)
}
// Verify all JTIs are in cache (use shardedReplayCache which is the actual cache used)
for i, jti := range jtis {
if !shardedReplayCache.Exists(jti) {
t.Errorf("JTI %d (%s) should be in replay cache", i, jti)
}
}
}
// TestJTIBlacklistBehavior tests the JTI blacklist cache management
func TestJTIBlacklistBehavior(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
// Properly reinitialize the global replay cache
cleanupReplayCache() // Clean up any existing cache and reset sync.Once
initReplayCache() // Initialize new cache through proper channel
// Create a test JWT with unique JTI
jti := generateRandomString(16)
t.Logf("TestJTIBlacklistBehavior - JTI: %s", jti)
now := time.Now()
exp := now.Add(1 * time.Hour).Unix()
iat := now.Unix()
nbf := now.Unix()
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": exp,
"iat": iat,
"nbf": nbf,
"sub": "test-subject",
"email": "user@example.com",
"nonce": "test-nonce",
"jti": jti,
})
if err != nil {
t.Fatalf("Failed to create test JWT: %v", err)
}
// Test JTI blacklist behavior
tests := []struct {
action func() error
name string
description string
expectError bool
}{
{
name: "Initial verification adds JTI to blacklist",
action: func() error {
return ts.tOidc.VerifyToken(token)
},
expectError: false,
description: "First verification should succeed and add JTI to blacklist",
},
{
name: "JTI exists in blacklist after verification",
action: func() error {
// Use shardedReplayCache which is the actual cache used
if !shardedReplayCache.Exists(jti) {
return fmt.Errorf("JTI not found in blacklist cache")
}
return nil
},
expectError: false,
description: "JTI should be present in blacklist cache",
},
{
name: "Subsequent verification uses cache (no replay check)",
action: func() error {
return ts.tOidc.VerifyToken(token)
},
expectError: false,
description: "Subsequent verification should succeed using token cache",
},
{
name: "Direct JWT verification detects replay",
action: func() error {
jwt, err := parseJWT(token)
if err != nil {
return err
}
return jwt.Verify("https://test-issuer.com", "test-client-id", false)
},
expectError: true,
description: "Direct JWT verification should detect replay",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
err := tc.action()
if tc.expectError {
if err == nil {
t.Errorf("Expected error for %s, but got nil", tc.description)
}
} else {
if err != nil {
t.Errorf("Expected no error for %s, but got: %v", tc.description, err)
}
}
})
}
}
// TestSessionBasedTokenRevalidation tests token revalidation in session-based scenarios
func TestSessionBasedTokenRevalidation(t *testing.T) {
if testing.Short() {
t.Skip("Skipping session-based token revalidation test in short mode")
}
ts := NewTestSuite(t)
ts.Setup()
// Clear the global replay cache
cleanupReplayCache()
initReplayCache()
// Create a test JWT with unique JTI
jti := generateRandomString(16)
now := time.Now()
exp := now.Add(1 * time.Hour).Unix()
iat := now.Unix()
nbf := now.Unix()
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": exp,
"iat": iat,
"nbf": nbf,
"sub": "test-subject",
"email": "user@example.com",
"nonce": "test-nonce",
"jti": jti,
})
if err != nil {
t.Fatalf("Failed to create test JWT: %v", err)
}
// Simulate session-based token revalidation scenario
// This tests the specific case that was causing false positives
// Step 1: Initial authentication (callback processing)
err = ts.tOidc.VerifyToken(token)
if err != nil {
t.Fatalf("Initial authentication should succeed: %v", err)
}
// Step 2: Multiple session-based requests (normal request processing)
// These should not trigger replay detection false positives
for i := range 5 {
err = ts.tOidc.VerifyToken(token)
if err != nil {
t.Errorf("Session request %d should succeed: %v", i+1, err)
}
}
// Step 3: Verify token is in both caches appropriately
// Check token cache
if _, exists := ts.tOidc.tokenCache.Get(token); !exists {
t.Error("Token should be in token cache")
}
// Check replay cache
// Use shardedReplayCache which is the actual cache used
inReplayCache := shardedReplayCache.Exists(jti)
if !inReplayCache {
t.Error("JTI should be in replay cache")
}
// Step 4: Verify that clearing token cache still allows validation
ts.tOidc.tokenCache = NewTokenCache() // Clear token cache
err = ts.tOidc.VerifyToken(token)
if err != nil {
t.Errorf("Token validation should succeed even after cache clear: %v", err)
}
}
// TestEdgeCasesWithDifferentTokenTypes tests replay detection with different token types
func TestEdgeCasesWithDifferentTokenTypes(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
// Properly reinitialize the global replay cache
cleanupReplayCache() // Clean up any existing cache and reset sync.Once
initReplayCache() // Initialize new cache through proper channel
now := time.Now()
exp := now.Add(1 * time.Hour).Unix()
iat := now.Unix()
nbf := now.Unix()
tests := []struct {
claims map[string]interface{}
name string
tokenType string
expectError bool
}{
{
name: "ID Token with JTI",
tokenType: "id_token",
claims: map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": exp,
"iat": iat,
"nbf": nbf,
"sub": "test-subject",
"email": "user@example.com",
"nonce": "test-nonce",
"jti": generateRandomString(16),
"token_type": "id_token",
},
expectError: false,
},
{
name: "Access Token with JTI",
tokenType: "access_token",
claims: map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": exp,
"iat": iat,
"nbf": nbf,
"sub": "test-subject",
"scope": "openid profile email",
"jti": generateRandomString(16),
"token_type": "access_token",
},
expectError: false,
},
{
name: "Token without JTI",
tokenType: "no_jti",
claims: map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": exp,
"iat": iat,
"nbf": nbf,
"sub": "test-subject",
"email": "user@example.com",
"nonce": "test-nonce",
// No JTI claim
},
expectError: false, // Should still work, just no replay protection
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Create token with specific claims
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", tc.claims)
if err != nil {
t.Fatalf("Failed to create test JWT: %v", err)
}
// First verification should succeed
err = ts.tOidc.VerifyToken(token)
if tc.expectError {
if err == nil {
t.Errorf("Expected error for token type %s, but got nil", tc.tokenType)
}
} else {
if err != nil {
t.Errorf("Expected no error for token type %s, but got: %v", tc.tokenType, err)
}
}
// Second verification should also succeed (uses cache)
if !tc.expectError {
err = ts.tOidc.VerifyToken(token)
if err != nil {
t.Errorf("Second verification should succeed for token type %s: %v", tc.tokenType, err)
}
}
// Test direct JWT verification for replay detection
if !tc.expectError && tc.claims["jti"] != nil {
jwt, err := parseJWT(token)
if err != nil {
t.Fatalf("Failed to parse JWT: %v", err)
}
// This should detect replay for tokens with JTI
err = jwt.Verify("https://test-issuer.com", "test-client-id", false)
if err == nil {
t.Errorf("Expected replay detection for token type %s with JTI", tc.tokenType)
} else if !strings.Contains(err.Error(), "token replay detected") {
t.Errorf("Expected 'token replay detected' error for token type %s, got: %v", tc.tokenType, err)
}
}
})
}
}
// TestScopeMerging tests the scope append functionality
func TestScopeMerging(t *testing.T) {
// Helper function to compare string slices
equalSlices := func(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i, v := range a {
if v != b[i] {
return false
}
}
return true
}
tests := []struct {
name string
defaultScopes []string
userScopes []string
expectedScopes []string
}{
{
name: "Empty user scopes",
defaultScopes: []string{"openid", "profile", "email"},
userScopes: []string{},
expectedScopes: []string{"openid", "profile", "email"},
},
{
name: "Nil user scopes",
defaultScopes: []string{"openid", "profile", "email"},
userScopes: nil,
expectedScopes: []string{"openid", "profile", "email"},
},
{
name: "New scopes are appended",
defaultScopes: []string{"openid", "profile", "email"},
userScopes: []string{"custom_scope", "another_scope"},
expectedScopes: []string{"openid", "profile", "email", "custom_scope", "another_scope"},
},
{
name: "Deduplication - user scope already in defaults",
defaultScopes: []string{"openid", "profile", "email"},
userScopes: []string{"openid", "custom_scope"},
expectedScopes: []string{"openid", "profile", "email", "custom_scope"},
},
{
name: "Duplicate user scopes are removed",
defaultScopes: []string{"openid", "profile", "email"},
userScopes: []string{"custom_scope", "custom_scope", "another_scope"},
expectedScopes: []string{"openid", "profile", "email", "custom_scope", "another_scope"},
},
{
name: "Multiple overlapping scopes",
defaultScopes: []string{"openid", "profile", "email"},
userScopes: []string{"profile", "custom_scope", "email", "another_scope", "profile"},
expectedScopes: []string{"openid", "profile", "email", "custom_scope", "another_scope"},
},
{
name: "Only custom scopes",
defaultScopes: []string{"openid", "profile", "email"},
userScopes: []string{"read:users", "write:users", "admin"},
expectedScopes: []string{"openid", "profile", "email", "read:users", "write:users", "admin"},
},
{
name: "Empty defaults",
defaultScopes: []string{},
userScopes: []string{"custom1", "custom2"},
expectedScopes: []string{"custom1", "custom2"},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Test the mergeScopes function directly
result := mergeScopes(tc.defaultScopes, tc.userScopes)
if !equalSlices(result, tc.expectedScopes) {
t.Errorf("Expected %v, got %v", tc.expectedScopes, result)
}
})
}
}
// TestScopeMergingEdgeCases tests additional edge cases for scope deduplication
func TestScopeMergingEdgeCases(t *testing.T) {
// Helper function to compare string slices
equalSlices := func(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i, v := range a {
if v != b[i] {
return false
}
}
return true
}
tests := []struct {
name string
description string
defaultScopes []string
userScopes []string
expectedScopes []string
}{
{
name: "Case sensitivity preserved",
defaultScopes: []string{"openid", "profile", "email"},
userScopes: []string{"OpenID", "PROFILE", "custom"},
expectedScopes: []string{"openid", "profile", "email", "OpenID", "PROFILE", "custom"},
description: "OAuth scopes are case-sensitive, so different cases should be preserved",
},
{
name: "Empty strings in user scopes",
defaultScopes: []string{"openid", "profile", "email"},
userScopes: []string{"", "custom", "", "another"},
expectedScopes: []string{"openid", "profile", "email", "", "custom", "another"},
description: "Empty strings should be preserved (though invalid in OAuth)",
},
{
name: "Whitespace scopes",
defaultScopes: []string{"openid", "profile", "email"},
userScopes: []string{" ", "custom", " ", "another"},
expectedScopes: []string{"openid", "profile", "email", " ", "custom", " ", "another"},
description: "Whitespace-only scopes should be preserved as distinct",
},
{
name: "Large number of scopes",
defaultScopes: []string{"openid", "profile", "email"},
userScopes: generateLargeUserScopes(),
expectedScopes: func() []string {
// Manually calculate expected result with proper deduplication
defaults := []string{"openid", "profile", "email"}
userScopes := generateLargeUserScopes()
return mergeScopes(defaults, userScopes)
}(),
description: "Performance test with larger scope lists",
},
{
name: "Complex OAuth scopes with special characters",
defaultScopes: []string{"openid", "profile", "email"},
userScopes: []string{"read:users", "write:users", "admin:*", "scope/with/slashes", "scope-with-dashes"},
expectedScopes: []string{"openid", "profile", "email", "read:users", "write:users", "admin:*", "scope/with/slashes", "scope-with-dashes"},
description: "Real-world OAuth scopes with colons, slashes, and special characters",
},
{
name: "Duplicate defaults in user scopes multiple times",
defaultScopes: []string{"openid", "profile", "email"},
userScopes: []string{"openid", "profile", "openid", "custom", "email", "profile", "custom"},
expectedScopes: []string{"openid", "profile", "email", "custom"},
description: "Multiple duplicates of default scopes should be completely deduplicated",
},
{
name: "All user scopes are duplicates of defaults",
defaultScopes: []string{"openid", "profile", "email"},
userScopes: []string{"email", "openid", "profile", "openid"},
expectedScopes: []string{"openid", "profile", "email"},
description: "When all user scopes duplicate defaults, result should be just defaults",
},
{
name: "Single scope scenarios",
defaultScopes: []string{"openid"},
userScopes: []string{"custom"},
expectedScopes: []string{"openid", "custom"},
description: "Minimal case with single scopes",
},
{
name: "Identical scopes in same order",
defaultScopes: []string{"openid", "profile", "email"},
userScopes: []string{"openid", "profile", "email"},
expectedScopes: []string{"openid", "profile", "email"},
description: "When user scopes exactly match defaults, no duplication",
},
{
name: "Identical scopes in different order",
defaultScopes: []string{"openid", "profile", "email"},
userScopes: []string{"email", "profile", "openid"},
expectedScopes: []string{"openid", "profile", "email"},
description: "Order of defaults is preserved when user scopes are reordered duplicates",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Test the mergeScopes function directly
result := mergeScopes(tc.defaultScopes, tc.userScopes)
if !equalSlices(result, tc.expectedScopes) {
t.Errorf("Expected %v, got %v\nDescription: %s", tc.expectedScopes, result, tc.description)
}
})
}
}
// generateLargeUserScopes creates a large list of user scopes for performance testing
func generateLargeUserScopes() []string {
scopes := make([]string, 100)
for i := range 100 {
scopes[i] = fmt.Sprintf("scope_%d", i)
}
// Add some duplicates to test deduplication performance
scopes = append(scopes, "scope_1", "scope_5", "scope_10", "openid") // Include a default duplicate
return scopes
}
// TestScopeMergingPerformance tests performance with large scope lists
func TestScopeMergingPerformance(t *testing.T) {
// Create large scope lists
defaultScopes := []string{"openid", "profile", "email"}
// Create 1000 user scopes with some duplicates
userScopes := make([]string, 1000)
for i := range 1000 {
if i%10 == 0 {
// Add some duplicates of defaults
userScopes[i] = defaultScopes[i%len(defaultScopes)]
} else if i%7 == 0 {
// Add some internal duplicates
userScopes[i] = fmt.Sprintf("scope_%d", i%50)
} else {
userScopes[i] = fmt.Sprintf("scope_%d", i)
}
}
// Measure performance
start := time.Now()
result := mergeScopes(defaultScopes, userScopes)
duration := time.Since(start)
// Verify result correctness
if len(result) < len(defaultScopes) {
t.Errorf("Result should contain at least the default scopes")
}
// Verify no duplicates exist
seen := make(map[string]bool)
for _, scope := range result {
if seen[scope] {
t.Errorf("Duplicate scope found in result: %s", scope)
}
seen[scope] = true
}
// Performance assertion (should be very fast)
if duration > time.Millisecond*10 {
t.Logf("Performance note: mergeScopes took %v for 1000+ scopes (still acceptable)", duration)
}
t.Logf("Performance: processed %d user scopes in %v, result has %d unique scopes",
len(userScopes), duration, len(result))
}
// TestScopeMergingMemoryEfficiency tests memory efficiency of the mergeScopes function
func TestScopeMergingMemoryEfficiency(t *testing.T) {
defaultScopes := []string{"openid", "profile", "email"}
userScopes := []string{"custom1", "custom2"}
// Test that the function doesn't modify input slices
originalDefaults := make([]string, len(defaultScopes))
copy(originalDefaults, defaultScopes)
originalUser := make([]string, len(userScopes))
copy(originalUser, userScopes)
result := mergeScopes(defaultScopes, userScopes)
// Verify input slices are unchanged
for i, scope := range defaultScopes {
if scope != originalDefaults[i] {
t.Errorf("Default scopes were modified: expected %s, got %s", originalDefaults[i], scope)
}
}
for i, scope := range userScopes {
if scope != originalUser[i] {
t.Errorf("User scopes were modified: expected %s, got %s", originalUser[i], scope)
}
}
// Verify result is independent
result[0] = "modified"
if defaultScopes[0] == "modified" {
t.Error("Modifying result affected input defaults")
}
expectedLength := len(defaultScopes) + len(userScopes)
if len(result) != expectedLength {
t.Errorf("Expected result length %d, got %d", expectedLength, len(result))
}
}
// TestNewWithScopeAppending tests that the New function properly merges scopes
func TestNewWithScopeAppending(t *testing.T) {
if testing.Short() {
t.Skip("Skipping test in short mode")
}
// Create mock provider metadata server
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/.well-known/openid-configuration" {
w.WriteHeader(http.StatusNotFound)
return
}
metadata := ProviderMetadata{
Issuer: "https://test-issuer.com",
AuthURL: "https://test-issuer.com/auth",
TokenURL: "https://test-issuer.com/token",
JWKSURL: "https://test-issuer.com/jwks",
RevokeURL: "https://test-issuer.com/revoke",
EndSessionURL: "https://test-issuer.com/end-session",
}
json.NewEncoder(w).Encode(metadata)
}))
defer mockServer.Close()
tests := []struct {
name string
configScopes []string
expectedScopes []string
}{
{
name: "Default scopes only",
configScopes: []string{},
expectedScopes: []string{"openid", "profile", "email"},
},
{
name: "Custom scopes appended",
configScopes: []string{"custom_scope", "another_scope"},
expectedScopes: []string{"openid", "profile", "email", "custom_scope", "another_scope"},
},
{
name: "Overlapping scopes deduplicated",
configScopes: []string{"openid", "custom_scope"},
expectedScopes: []string{"openid", "profile", "email", "custom_scope"},
},
{
name: "OAuth scopes",
configScopes: []string{"read:users", "write:users", "admin"},
expectedScopes: []string{"openid", "profile", "email", "read:users", "write:users", "admin"},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Create config with test scopes
config := &Config{
ProviderURL: mockServer.URL,
ClientID: "test-client",
ClientSecret: "test-secret",
CallbackURL: "/callback",
SessionEncryptionKey: "test-encryption-key-thats-long-enough",
Scopes: tc.configScopes,
}
// Create middleware instance
middleware, err := New(context.Background(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}), config, "test")
if err != nil {
t.Fatalf("Failed to create middleware: %v", err)
}
// Wait for initialization
if m, ok := middleware.(*TraefikOidc); ok {
// Ensure middleware is properly closed to prevent goroutine leaks
defer func() {
if err := m.Close(); err != nil {
t.Errorf("Failed to close middleware: %v", err)
}
}()
select {
case <-m.initComplete:
case <-time.After(5 * time.Second):
t.Fatalf("Middleware failed to initialize")
}
// Check that scopes were properly merged
if !equalSlices(m.scopes, tc.expectedScopes) {
t.Errorf("Expected scopes %v, got %v", tc.expectedScopes, m.scopes)
}
} else {
t.Fatalf("Middleware is not of type *TraefikOidc")
}
})
}
}
// TestBuildAuthURLWithMergedScopes tests that the auth URL includes the properly merged scopes
func TestBuildAuthURLWithMergedScopes(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
tests := []struct {
name string
expectedScopes string
scopes []string
}{
{
name: "Default scopes only",
scopes: []string{"openid", "profile", "email"},
expectedScopes: "openid profile email offline_access",
},
{
name: "Custom scopes appended",
scopes: []string{"openid", "profile", "email", "custom_scope", "another_scope"},
expectedScopes: "openid profile email custom_scope another_scope offline_access",
},
{
name: "OAuth scopes",
scopes: []string{"openid", "profile", "email", "read:users", "write:users"},
expectedScopes: "openid profile email read:users write:users offline_access",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Configure the test instance with specific scopes
tOidc := ts.tOidc
tOidc.scopes = tc.scopes // These scopes are already deduplicated by New()
tOidc.authURL = "https://auth.example.com/oauth/authorize"
tOidc.issuerURL = "https://auth.example.com"
// Reset overrideScopes for each test case, as it's part of tOidc state
// Default to false, specific tests will set it.
tOidc.overrideScopes = false
// Build auth URL
result := tOidc.buildAuthURL("https://app.example.com/callback", "test-state", "test-nonce", "")
// Parse the resulting URL to verify scopes
parsedURL, err := url.Parse(result)
if err != nil {
t.Fatalf("Failed to parse resulting URL: %v", err)
}
query := parsedURL.Query()
actualScopes := query.Get("scope")
if actualScopes != tc.expectedScopes {
t.Errorf("Expected scopes %q, got %q", tc.expectedScopes, actualScopes)
}
})
}
}
// TestBuildAuthURL_OverrideScopes_And_OfflineAccess tests the offline_access logic in buildAuthURL
// considering the overrideScopes flag.
func TestBuildAuthURL_OverrideScopes_And_OfflineAccess(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup() // Sets up ts.tOidc
tests := []struct {
expectedParams map[string]string
name string
expectedScope string
initialScopes []string
overrideScopes bool
isGoogle bool
isAzure bool
}{
{
name: "Override false, no user scopes, non-Google/Azure",
initialScopes: []string{"openid", "profile", "email"}, // Defaults from New() when config.Scopes is empty
overrideScopes: false,
expectedScope: "openid profile email offline_access",
},
{
name: "Override false, user scopes without offline_access, non-Google/Azure",
initialScopes: []string{"openid", "profile", "email", "custom1"}, // Merged and deduplicated by New()
overrideScopes: false,
expectedScope: "openid profile email custom1 offline_access",
},
{
name: "Override false, user scopes with offline_access, non-Google/Azure",
initialScopes: []string{"openid", "profile", "email", "offline_access", "custom1"},
overrideScopes: false,
expectedScope: "openid profile email offline_access custom1", // Order might vary based on merge, but offline_access present
},
{
name: "Override true, user scopes without offline_access, non-Google/Azure",
initialScopes: []string{"custom1", "custom2"}, // Directly from config.Scopes, deduplicated
overrideScopes: true,
expectedScope: "custom1 custom2", // offline_access NOT added
},
{
name: "Override true, user scopes with offline_access, non-Google/Azure",
initialScopes: []string{"custom1", "offline_access", "custom2"},
overrideScopes: true,
expectedScope: "custom1 offline_access custom2", // User explicitly included it
},
{
name: "Override true, no user scopes (edge case), non-Google/Azure",
initialScopes: []string{}, // config.Scopes was empty
overrideScopes: true,
// In this edge case, buildAuthURL's logic `(t.overrideScopes && len(t.scopes) == 0)`
// will lead to offline_access being added, as it behaves like defaults.
expectedScope: "offline_access",
},
// Google Provider Tests (access_type=offline, prompt=consent)
{
name: "Google, Override false, no user scopes",
initialScopes: []string{"openid", "profile", "email"},
overrideScopes: false,
isGoogle: true,
expectedParams: map[string]string{"access_type": "offline", "prompt": "consent"},
expectedScope: "openid profile email", // No offline_access scope for Google
},
{
name: "Google, Override true, user scopes",
initialScopes: []string{"custom1", "custom2"},
overrideScopes: true,
isGoogle: true,
expectedParams: map[string]string{"access_type": "offline", "prompt": "consent"},
expectedScope: "custom1 custom2", // No offline_access scope for Google
},
// Azure Provider Tests (response_mode=query, offline_access scope added if not present by user)
{
name: "Azure, Override false, no user scopes",
initialScopes: []string{"openid", "profile", "email"},
overrideScopes: false,
isAzure: true,
expectedParams: map[string]string{"response_mode": "query"},
expectedScope: "openid profile email offline_access",
},
{
name: "Azure, Override true, user scopes without offline_access",
initialScopes: []string{"custom1", "custom2"},
overrideScopes: true,
isAzure: true,
expectedParams: map[string]string{"response_mode": "query"},
expectedScope: "custom1 custom2", // offline_access NOT added by default when override is true
},
{
name: "Azure, Override true, user scopes with offline_access",
initialScopes: []string{"custom1", "offline_access"},
overrideScopes: true,
isAzure: true,
expectedParams: map[string]string{"response_mode": "query"},
expectedScope: "custom1 offline_access",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
tOidc := ts.tOidc
tOidc.scopes = tc.initialScopes // Set the scopes as if they came from New()
tOidc.overrideScopes = tc.overrideScopes
// Adjust issuerURL for provider-specific tests
originalIssuerURL := tOidc.issuerURL
if tc.isGoogle {
tOidc.issuerURL = "https://accounts.google.com"
} else if tc.isAzure {
tOidc.issuerURL = "https://login.microsoftonline.com/common"
} else {
tOidc.issuerURL = "https://generic-provider.com" // Non-Google/Azure
}
authURLString := tOidc.buildAuthURL("http://localhost/callback", "state123", "nonce123", "challenge123")
parsedAuthURL, err := url.Parse(authURLString)
if err != nil {
t.Fatalf("Failed to parse auth URL: %v", err)
}
query := parsedAuthURL.Query()
actualScope := query.Get("scope")
if actualScope != tc.expectedScope {
t.Errorf("Expected scope string %q, got %q", tc.expectedScope, actualScope)
}
if tc.expectedParams != nil {
for k, v := range tc.expectedParams {
if query.Get(k) != v {
t.Errorf("Expected param %s=%s, got %s", k, v, query.Get(k))
}
}
}
// Restore original issuerURL for next test
tOidc.issuerURL = originalIssuerURL
})
}
}
// TestBuildAuthURL_SpecificUserCase tests the buildAuthURL function with the specific user-reported scenario.
func TestBuildAuthURL_SpecificUserCase(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup() // Basic setup for tOidc
// Configure the TraefikOidc instance for the specific scenario
tOidc := ts.tOidc
tOidc.scopes = []string{"email", "test3"} // This is what t.scopes should be after New()
tOidc.overrideScopes = true
tOidc.issuerURL = "https://generic-provider.com" // Non-Google/Azure
tOidc.authURL = "https://generic-provider.com/auth" // Dummy auth URL
tOidc.clientID = "test-client-id"
// Expected scope string in the URL
expectedScopeString := "email test3"
// Call buildAuthURL
authURLString := tOidc.buildAuthURL("http://localhost/callback", "test-state", "test-nonce", "")
// Parse the resulting URL
parsedAuthURL, err := url.Parse(authURLString)
if err != nil {
t.Fatalf("Failed to parse generated auth URL %q: %v", authURLString, err)
}
// Get the 'scope' query parameter
actualScopeString := parsedAuthURL.Query().Get("scope")
// Assert that the scope string is as expected
if actualScopeString != expectedScopeString {
t.Errorf("Expected scope parameter to be %q, but got %q. Full URL: %s",
expectedScopeString, actualScopeString, authURLString)
}
// Additionally, ensure 'offline_access' was not added
if strings.Contains(actualScopeString, "offline_access") {
t.Errorf("Scope parameter %q should not contain 'offline_access' when overrideScopes is true and it's not in tOidc.scopes", actualScopeString)
}
}