mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-05 22:44:17 +00:00
99881f5837
- Unbounded Replay Cache: Now bounded to 10,000 entries with automatic cleanup - Session Pool Leaks: Proper object lifecycle prevents accumulation - HTTP Client Leaks: Reusable clients eliminate connection overhead - Goroutine Leaks: Tracked lifecycle with graceful shutdown
2809 lines
91 KiB
Go
2809 lines
91 KiB
Go
package traefikoidc
|
|
|
|
import (
|
|
"context"
|
|
"crypto"
|
|
"crypto/ecdsa"
|
|
"crypto/elliptic"
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"math/big"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/gorilla/sessions"
|
|
"golang.org/x/time/rate"
|
|
)
|
|
|
|
// TestSuite holds common test data and setup
|
|
type TestSuite struct {
|
|
t *testing.T
|
|
rsaPrivateKey *rsa.PrivateKey
|
|
rsaPublicKey *rsa.PublicKey
|
|
ecPrivateKey *ecdsa.PrivateKey
|
|
tOidc *TraefikOidc
|
|
mockJWKCache *MockJWKCache
|
|
token string
|
|
sessionManager *SessionManager
|
|
}
|
|
|
|
// Setup initializes the test suite
|
|
func (ts *TestSuite) Setup() {
|
|
var err error
|
|
ts.rsaPrivateKey, err = rsa.GenerateKey(rand.Reader, 2048)
|
|
if err != nil {
|
|
ts.t.Fatalf("Failed to generate RSA key: %v", err)
|
|
}
|
|
ts.rsaPublicKey = &ts.rsaPrivateKey.PublicKey
|
|
|
|
// Generate EC key for EC key tests
|
|
ts.ecPrivateKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
if err != nil {
|
|
ts.t.Fatalf("Failed to generate EC key: %v", err)
|
|
}
|
|
|
|
// Create a JWK for the RSA public key
|
|
jwk := JWK{
|
|
Kty: "RSA",
|
|
Kid: "test-key-id",
|
|
Alg: "RS256",
|
|
N: base64.RawURLEncoding.EncodeToString(ts.rsaPublicKey.N.Bytes()),
|
|
E: base64.RawURLEncoding.EncodeToString(bigIntToBytes(big.NewInt(int64(ts.rsaPublicKey.E)))),
|
|
}
|
|
jwks := &JWKSet{
|
|
Keys: []JWK{jwk},
|
|
}
|
|
|
|
// Create a mock JWKCache
|
|
ts.mockJWKCache = &MockJWKCache{
|
|
JWKS: jwks,
|
|
Err: nil,
|
|
}
|
|
|
|
// Create a test JWT token signed with the RSA private key
|
|
// Create timestamps with proper clock skew
|
|
now := time.Now()
|
|
exp := now.Add(1 * time.Hour).Unix()
|
|
iat := now.Add(-2 * time.Minute).Unix() // Account for clock skew
|
|
nbf := now.Add(-2 * time.Minute).Unix() // Account for clock skew
|
|
|
|
ts.token, err = createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://test-issuer.com",
|
|
"aud": "test-client-id",
|
|
"exp": exp,
|
|
"iat": iat,
|
|
"nbf": nbf,
|
|
"sub": "test-subject",
|
|
"email": "user@example.com",
|
|
"nonce": "test-nonce",
|
|
"jti": generateRandomString(16),
|
|
})
|
|
if err != nil {
|
|
ts.t.Fatalf("Failed to create test JWT: %v", err)
|
|
}
|
|
|
|
logger := NewLogger("info")
|
|
ts.sessionManager, _ = NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger)
|
|
|
|
// Common TraefikOidc instance
|
|
ts.tOidc = &TraefikOidc{
|
|
issuerURL: "https://test-issuer.com",
|
|
clientID: "test-client-id",
|
|
clientSecret: "test-client-secret",
|
|
jwkCache: ts.mockJWKCache,
|
|
jwksURL: "https://test-jwks-url.com",
|
|
revocationURL: "https://revocation-endpoint.com",
|
|
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
|
tokenBlacklist: NewCache(), // Use generic cache for blacklist
|
|
tokenCache: NewTokenCache(),
|
|
logger: logger,
|
|
allowedUserDomains: map[string]struct{}{"example.com": {}},
|
|
excludedURLs: map[string]struct{}{"/favicon": {}},
|
|
httpClient: &http.Client{},
|
|
// Explicitly set paths as New() is bypassed
|
|
redirURLPath: "/callback", // Assume default callback path for tests
|
|
logoutURLPath: "/callback/logout", // Assume default logout path for tests
|
|
tokenURL: "https://test-issuer.com/token", // Explicitly set for refresh tests
|
|
extractClaimsFunc: extractClaims,
|
|
initComplete: make(chan struct{}),
|
|
sessionManager: ts.sessionManager,
|
|
}
|
|
close(ts.tOidc.initComplete)
|
|
// ts.tOidc.exchangeCodeForTokenFunc = ts.exchangeCodeForTokenFunc // Removed
|
|
ts.tOidc.tokenVerifier = ts.tOidc
|
|
ts.tOidc.jwtVerifier = ts.tOidc
|
|
// Set default mock exchanger
|
|
ts.tOidc.tokenExchanger = &MockTokenExchanger{
|
|
ExchangeCodeFunc: func(ctx context.Context, grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error) {
|
|
// Default mock behavior for code exchange
|
|
return &TokenResponse{
|
|
IDToken: ts.token, // Use the valid token from setup
|
|
AccessToken: ts.token,
|
|
RefreshToken: "default-refresh-token",
|
|
ExpiresIn: 3600,
|
|
}, nil
|
|
},
|
|
RefreshTokenFunc: func(refreshToken string) (*TokenResponse, error) {
|
|
// Default mock behavior for refresh (can be overridden in tests)
|
|
return nil, fmt.Errorf("default mock: refresh not expected")
|
|
},
|
|
RevokeTokenFunc: func(token, tokenType string) error {
|
|
// Default mock behavior for revoke
|
|
return nil
|
|
},
|
|
}
|
|
}
|
|
|
|
// Helper function exchangeCodeForTokenFunc removed as it's unused after refactoring to TokenExchanger interface.
|
|
|
|
// MockJWKCache implements JWKCacheInterface
|
|
type MockJWKCache struct {
|
|
JWKS *JWKSet
|
|
Err error
|
|
}
|
|
|
|
// Close is a no-op for the mock.
|
|
func (m *MockJWKCache) Close() {
|
|
// No operation needed for the mock.
|
|
}
|
|
|
|
func (m *MockJWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error) {
|
|
return m.JWKS, m.Err
|
|
}
|
|
|
|
func (m *MockJWKCache) Cleanup() {
|
|
// Mock cleanup implementation
|
|
m.JWKS = nil
|
|
m.Err = nil
|
|
}
|
|
|
|
// MockTokenVerifier implements TokenVerifier for testing, allowing interception of VerifyToken calls.
|
|
type MockTokenVerifier struct {
|
|
VerifyFunc func(token string) error
|
|
}
|
|
|
|
func (m *MockTokenVerifier) VerifyToken(token string) error {
|
|
if m.VerifyFunc != nil {
|
|
return m.VerifyFunc(token)
|
|
}
|
|
return fmt.Errorf("VerifyFunc not implemented in mock")
|
|
}
|
|
|
|
// MockTokenExchanger implements TokenExchanger for testing
|
|
type MockTokenExchanger struct {
|
|
ExchangeCodeFunc func(ctx context.Context, grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error)
|
|
RefreshTokenFunc func(refreshToken string) (*TokenResponse, error)
|
|
RevokeTokenFunc func(token, tokenType string) error
|
|
}
|
|
|
|
func (m *MockTokenExchanger) ExchangeCodeForToken(ctx context.Context, grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error) {
|
|
if m.ExchangeCodeFunc != nil {
|
|
return m.ExchangeCodeFunc(ctx, grantType, codeOrToken, redirectURL, codeVerifier)
|
|
}
|
|
return nil, fmt.Errorf("ExchangeCodeFunc not implemented in mock")
|
|
}
|
|
|
|
func (m *MockTokenExchanger) GetNewTokenWithRefreshToken(refreshToken string) (*TokenResponse, error) {
|
|
if m.RefreshTokenFunc != nil {
|
|
return m.RefreshTokenFunc(refreshToken)
|
|
}
|
|
return nil, fmt.Errorf("RefreshTokenFunc not implemented in mock")
|
|
}
|
|
|
|
func (m *MockTokenExchanger) RevokeTokenWithProvider(token, tokenType string) error {
|
|
if m.RevokeTokenFunc != nil {
|
|
return m.RevokeTokenFunc(token, tokenType)
|
|
}
|
|
return fmt.Errorf("RevokeTokenFunc not implemented in mock")
|
|
}
|
|
|
|
// Helper function to create a JWT token
|
|
func createTestJWT(privateKey *rsa.PrivateKey, alg, kid string, claims map[string]interface{}) (string, error) {
|
|
header := map[string]interface{}{
|
|
"alg": alg,
|
|
"kid": kid,
|
|
"typ": "JWT",
|
|
}
|
|
headerJSON, err := json.Marshal(header)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
headerEncoded := base64.RawURLEncoding.EncodeToString(headerJSON)
|
|
|
|
claimsJSON, err := json.Marshal(claims)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
claimsEncoded := base64.RawURLEncoding.EncodeToString(claimsJSON)
|
|
|
|
signedContent := headerEncoded + "." + claimsEncoded
|
|
|
|
// Select the appropriate hash function based on algorithm
|
|
var hashFunc crypto.Hash
|
|
switch alg {
|
|
case "RS256", "PS256":
|
|
hashFunc = crypto.SHA256
|
|
case "RS384", "PS384":
|
|
hashFunc = crypto.SHA384
|
|
case "RS512", "PS512":
|
|
hashFunc = crypto.SHA512
|
|
default:
|
|
return "", fmt.Errorf("unsupported algorithm: %s", alg)
|
|
}
|
|
|
|
hasher := hashFunc.New()
|
|
hasher.Write([]byte(signedContent))
|
|
hashed := hasher.Sum(nil)
|
|
|
|
var signatureBytes []byte
|
|
|
|
// Use appropriate signing method based on algorithm
|
|
if strings.HasPrefix(alg, "RS") {
|
|
// PKCS1v15 signing for RS* algorithms
|
|
signatureBytes, err = rsa.SignPKCS1v15(rand.Reader, privateKey, hashFunc, hashed)
|
|
} else if strings.HasPrefix(alg, "PS") {
|
|
// PSS signing for PS* algorithms
|
|
signatureBytes, err = rsa.SignPSS(rand.Reader, privateKey, hashFunc, hashed, nil)
|
|
} else {
|
|
return "", fmt.Errorf("unsupported RSA algorithm: %s", alg)
|
|
}
|
|
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
signatureEncoded := base64.RawURLEncoding.EncodeToString(signatureBytes)
|
|
|
|
token := signedContent + "." + signatureEncoded
|
|
|
|
return token, nil
|
|
}
|
|
|
|
func bigIntToBytes(i *big.Int) []byte {
|
|
return i.Bytes()
|
|
}
|
|
|
|
// TestVerifyToken tests the VerifyToken method
|
|
func TestVerifyToken(t *testing.T) {
|
|
ts := &TestSuite{t: t}
|
|
ts.Setup()
|
|
|
|
tests := []struct {
|
|
name string
|
|
token string
|
|
blacklist bool
|
|
rateLimit bool
|
|
cacheToken bool
|
|
expectedError bool
|
|
}{
|
|
{
|
|
name: "Valid token",
|
|
token: ts.token,
|
|
expectedError: false,
|
|
},
|
|
{
|
|
name: "Invalid token signature",
|
|
token: ts.token + "invalid",
|
|
expectedError: true,
|
|
},
|
|
{
|
|
name: "Blacklisted token",
|
|
token: ts.token,
|
|
blacklist: true,
|
|
expectedError: true,
|
|
},
|
|
{
|
|
name: "Rate limit exceeded",
|
|
token: ts.token,
|
|
rateLimit: true,
|
|
expectedError: true,
|
|
},
|
|
{
|
|
name: "Token in cache",
|
|
token: ts.token,
|
|
cacheToken: true,
|
|
expectedError: false,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
// Reset token blacklist and cache for each test
|
|
ts.tOidc.tokenBlacklist = NewCache() // Use generic cache for blacklist
|
|
ts.tOidc.tokenCache = NewTokenCache()
|
|
ts.tOidc.limiter = rate.NewLimiter(rate.Every(time.Second), 10)
|
|
|
|
// Set up the test case
|
|
if tc.blacklist {
|
|
// Use Set with a duration. Value 'true' is arbitrary.
|
|
ts.tOidc.tokenBlacklist.Set(tc.token, true, 1*time.Hour)
|
|
}
|
|
|
|
if tc.rateLimit {
|
|
// Exceed rate limit
|
|
ts.tOidc.limiter = rate.NewLimiter(rate.Every(time.Hour), 0)
|
|
}
|
|
|
|
if tc.cacheToken {
|
|
// Use more realistic claims for cached token
|
|
ts.tOidc.tokenCache.Set(tc.token, map[string]interface{}{
|
|
"iss": "https://test-issuer.com",
|
|
"sub": "test-subject",
|
|
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
|
"jti": generateRandomString(16), // Add a JTI claim to prevent replay detection
|
|
}, time.Minute)
|
|
|
|
// Verify the token is actually in the cache
|
|
if claims, exists := ts.tOidc.tokenCache.Get(tc.token); exists {
|
|
t.Logf("Token found in cache with claims: %v", claims)
|
|
} else {
|
|
t.Logf("Token NOT found in cache despite cacheToken=true")
|
|
}
|
|
}
|
|
|
|
err := ts.tOidc.VerifyToken(tc.token)
|
|
if tc.expectedError && err == nil {
|
|
t.Errorf("Test %s: expected error but got nil", tc.name)
|
|
}
|
|
if !tc.expectedError && err != nil {
|
|
t.Errorf("Test %s: expected no error but got %v", tc.name, err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestServeHTTP tests the ServeHTTP method
|
|
func TestServeHTTP(t *testing.T) {
|
|
ts := &TestSuite{t: t}
|
|
ts.Setup()
|
|
|
|
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte("OK"))
|
|
})
|
|
ts.tOidc.next = nextHandler
|
|
ts.tOidc.name = "test"
|
|
|
|
// Helper to create an expired token
|
|
createExpiredToken := func() string {
|
|
exp := time.Now().Add(-1 * time.Hour).Unix() // Expired 1 hour ago
|
|
iat := time.Now().Add(-2 * time.Hour).Unix()
|
|
nbf := time.Now().Add(-2 * time.Hour).Unix()
|
|
expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://test-issuer.com",
|
|
"aud": "test-client-id",
|
|
"exp": exp,
|
|
"iat": iat,
|
|
"nbf": nbf,
|
|
"sub": "test-subject",
|
|
"email": "user@example.com",
|
|
"nonce": "test-nonce-expired", // Different nonce for clarity
|
|
"jti": generateRandomString(16),
|
|
})
|
|
return expiredToken
|
|
}
|
|
|
|
// Helper to create a new valid token (simulating refresh)
|
|
createNewValidToken := func() string {
|
|
exp := time.Now().Add(1 * time.Hour).Unix() // Valid for 1 hour
|
|
iat := time.Now().Unix()
|
|
nbf := time.Now().Unix()
|
|
newToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://test-issuer.com",
|
|
"aud": "test-client-id",
|
|
"exp": exp,
|
|
"iat": iat,
|
|
"nbf": nbf,
|
|
"sub": "test-subject",
|
|
"email": "user@example.com",
|
|
// "nonce": "test-nonce-new", // Nonce is typically not included/validated in refreshed tokens
|
|
"jti": generateRandomString(16),
|
|
})
|
|
return newToken
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
requestPath string
|
|
sessionValues map[interface{}]interface{}
|
|
expectedStatus int
|
|
expectedBody string
|
|
setupSession func(*SessionData)
|
|
mockRefreshTokenFunc func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error)
|
|
assertSessionAfterRequest func(t *testing.T, rr *httptest.ResponseRecorder, req *http.Request, sessionManager *SessionManager) // Added for post-request checks
|
|
requestHeaders map[string]string // Added for setting headers like Accept
|
|
}{
|
|
{
|
|
name: "Excluded URL",
|
|
requestPath: "/favicon.ico",
|
|
expectedStatus: http.StatusOK,
|
|
expectedBody: "OK",
|
|
},
|
|
{
|
|
name: "Unauthenticated request (no refresh token) to protected URL",
|
|
requestPath: "/protected",
|
|
setupSession: func(session *SessionData) {
|
|
// Ensure no tokens are set
|
|
session.SetAuthenticated(false)
|
|
session.SetAccessToken("")
|
|
session.SetRefreshToken("")
|
|
},
|
|
expectedStatus: http.StatusFound, // Expect redirect to OIDC as there's no refresh token
|
|
},
|
|
{
|
|
name: "Unauthenticated request (with refresh token) to protected URL - Expect Refresh Attempt",
|
|
requestPath: "/protected",
|
|
setupSession: func(session *SessionData) {
|
|
session.SetAuthenticated(false) // Not authenticated
|
|
session.SetAccessToken("") // No access token
|
|
session.SetRefreshToken("valid-refresh-token-for-unauth-test") // BUT has refresh token
|
|
},
|
|
mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) {
|
|
return func(refreshToken string) (*TokenResponse, error) {
|
|
if refreshToken != "valid-refresh-token-for-unauth-test" {
|
|
return nil, fmt.Errorf("mock error: unexpected refresh token '%s'", refreshToken)
|
|
}
|
|
// Simulate successful refresh
|
|
newToken := createNewValidToken() // Use helper from TestServeHTTP
|
|
return &TokenResponse{IDToken: newToken, AccessToken: newToken, RefreshToken: "new-refresh-token-unauth", ExpiresIn: 3600}, nil
|
|
}
|
|
},
|
|
expectedStatus: http.StatusOK, // Expect OK after successful refresh
|
|
expectedBody: "OK",
|
|
},
|
|
{
|
|
name: "Unauthenticated request (with refresh token) to protected URL - Refresh Fails",
|
|
requestPath: "/protected",
|
|
setupSession: func(session *SessionData) {
|
|
session.SetAuthenticated(false) // Not authenticated
|
|
session.SetAccessToken("") // No access token
|
|
session.SetRefreshToken("invalid-refresh-token-for-unauth-test") // Invalid refresh token
|
|
},
|
|
mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) {
|
|
return func(refreshToken string) (*TokenResponse, error) {
|
|
// Simulate failed refresh
|
|
return nil, fmt.Errorf("mock error: refresh token invalid")
|
|
}
|
|
},
|
|
expectedStatus: http.StatusFound, // Expect redirect to OIDC after failed refresh
|
|
},
|
|
{
|
|
name: "Authenticated request to protected URL (Valid Token)",
|
|
requestPath: "/protected",
|
|
setupSession: func(session *SessionData) {
|
|
session.SetAuthenticated(true)
|
|
session.SetEmail("user@example.com")
|
|
// Generate a fresh valid token for this test case to avoid replay issues
|
|
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
|
|
"iat": time.Now().Unix(), "nbf": time.Now().Unix(), "sub": "test-subject", "email": "user@example.com",
|
|
"jti": generateRandomString(16), // Unique JTI
|
|
})
|
|
session.SetAccessToken(freshToken)
|
|
session.SetIDToken(freshToken) // Ensure ID token is also set
|
|
session.SetRefreshToken("valid-refresh-token")
|
|
},
|
|
expectedStatus: http.StatusOK,
|
|
expectedBody: "OK",
|
|
},
|
|
// This test case remains valid as the logic should still attempt refresh when expired token + refresh token exist
|
|
{
|
|
name: "Authenticated request with expired token and successful refresh",
|
|
requestPath: "/protected",
|
|
setupSession: func(session *SessionData) {
|
|
// NOTE: isUserAuthenticated now returns authenticated=false if access token is expired,
|
|
// even if session.SetAuthenticated(true) was called.
|
|
// We rely on needsRefresh=true and the presence of the refresh token to trigger the refresh attempt.
|
|
session.SetAuthenticated(true) // Set flag initially, though isUserAuthenticated will override based on token
|
|
session.SetEmail("user@example.com")
|
|
session.SetAccessToken(createExpiredToken()) // Set expired token
|
|
session.SetRefreshToken("valid-refresh-token") // Set valid refresh token
|
|
},
|
|
mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) {
|
|
return func(refreshToken string) (*TokenResponse, error) {
|
|
if refreshToken != "valid-refresh-token" {
|
|
return nil, fmt.Errorf("mock error: expected 'valid-refresh-token', got '%s'", refreshToken)
|
|
}
|
|
// Simulate successful refresh
|
|
newToken := createNewValidToken()
|
|
return &TokenResponse{
|
|
IDToken: newToken, // Return new valid token
|
|
AccessToken: newToken, // Often the same as ID token in tests
|
|
RefreshToken: "new-refresh-token",
|
|
ExpiresIn: 3600,
|
|
}, nil
|
|
}
|
|
},
|
|
expectedStatus: http.StatusOK, // Expect success after refresh
|
|
expectedBody: "OK",
|
|
assertSessionAfterRequest: func(t *testing.T, rr *httptest.ResponseRecorder, req *http.Request, sessionManager *SessionManager) {
|
|
// Create a new request to read the cookies set by the response recorder
|
|
reqForCookieRead := httptest.NewRequest("GET", "/protected", nil)
|
|
for _, cookie := range rr.Result().Cookies() {
|
|
reqForCookieRead.AddCookie(cookie)
|
|
}
|
|
// Get session based on response cookies
|
|
session, err := sessionManager.GetSession(reqForCookieRead)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get session after request: %v", err)
|
|
}
|
|
// Assert new tokens are in the session
|
|
if session.GetAccessToken() == "" || session.GetAccessToken() == createExpiredToken() {
|
|
t.Errorf("Expected access token to be updated in session, but it was empty or still the expired one")
|
|
}
|
|
if session.GetRefreshToken() != "new-refresh-token" {
|
|
t.Errorf("Expected refresh token to be updated to 'new-refresh-token', got '%s'", session.GetRefreshToken())
|
|
}
|
|
// Also check authenticated flag is now true
|
|
if !session.GetAuthenticated() {
|
|
t.Errorf("Expected session to be marked authenticated after successful refresh")
|
|
}
|
|
},
|
|
},
|
|
// This test case remains valid as the logic should still return 401 for API clients on refresh failure
|
|
{
|
|
name: "Logout URL",
|
|
requestPath: "/callback/logout", // Match the default logout path set in TestSuite.Setup
|
|
setupSession: func(session *SessionData) {
|
|
session.SetAuthenticated(true)
|
|
session.SetEmail("user@example.com")
|
|
// Generate a fresh valid token for this test case
|
|
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
|
|
"iat": time.Now().Unix(), "nbf": time.Now().Unix(), "sub": "test-subject", "email": "user@example.com",
|
|
"jti": generateRandomString(16), // Unique JTI
|
|
})
|
|
session.SetAccessToken(freshToken)
|
|
},
|
|
expectedStatus: http.StatusFound, // Expect redirect after logout
|
|
expectedBody: "",
|
|
// No specific session assertion needed for logout redirect itself
|
|
},
|
|
{
|
|
name: "Authenticated request with expired token and FAILED refresh (Accept: JSON)",
|
|
requestPath: "/protected",
|
|
setupSession: func(session *SessionData) {
|
|
session.SetAuthenticated(true) // Set flag initially
|
|
session.SetEmail("user@example.com")
|
|
session.SetAccessToken(createExpiredToken()) // Expired access token
|
|
session.SetRefreshToken("valid-refresh-token") // Valid refresh token
|
|
},
|
|
mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) {
|
|
return func(refreshToken string) (*TokenResponse, error) {
|
|
// Simulate failed refresh
|
|
return nil, fmt.Errorf("mock error: refresh token invalid or provider down")
|
|
}
|
|
},
|
|
requestHeaders: map[string]string{
|
|
"Accept": "application/json",
|
|
},
|
|
expectedStatus: http.StatusUnauthorized, // Expect 401 for API client after failed refresh attempt
|
|
expectedBody: `{"error":"unauthorized","message":"Token refresh failed"}`,
|
|
},
|
|
// This test case remains valid as the logic should still redirect browser clients on refresh failure
|
|
{
|
|
name: "Authenticated request with expired token and FAILED refresh (Accept: HTML)",
|
|
requestPath: "/protected",
|
|
setupSession: func(session *SessionData) {
|
|
session.SetAuthenticated(true) // Set flag initially
|
|
session.SetEmail("user@example.com")
|
|
session.SetAccessToken(createExpiredToken()) // Expired access token
|
|
session.SetRefreshToken("valid-refresh-token") // Valid refresh token
|
|
},
|
|
mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) {
|
|
return func(refreshToken string) (*TokenResponse, error) {
|
|
// Simulate failed refresh
|
|
return nil, fmt.Errorf("mock error: refresh token invalid or provider down")
|
|
}
|
|
},
|
|
requestHeaders: map[string]string{
|
|
"Accept": "text/html", // Browser client
|
|
},
|
|
expectedStatus: http.StatusFound, // Expect redirect to OIDC for browser client after failed refresh attempt
|
|
},
|
|
// This test case remains valid as proactive refresh should still be attempted
|
|
{
|
|
name: "Authenticated request with token nearing expiry (needs refresh)",
|
|
requestPath: "/protected",
|
|
setupSession: func(session *SessionData) {
|
|
// Create token expiring soon (e.g., 30s, within default 60s grace period)
|
|
exp := time.Now().Add(30 * time.Second).Unix()
|
|
iat := time.Now().Add(-1 * time.Minute).Unix()
|
|
nbf := time.Now().Add(-1 * time.Minute).Unix()
|
|
nearExpiryToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": exp, "iat": iat, "nbf": nbf,
|
|
"sub": "test-subject", "email": "user@example.com", "jti": generateRandomString(16),
|
|
})
|
|
session.SetAuthenticated(true)
|
|
session.SetEmail("user@example.com")
|
|
session.SetAccessToken(nearExpiryToken)
|
|
session.SetRefreshToken("valid-refresh-token-for-near-expiry") // Refresh token MUST exist for proactive refresh
|
|
},
|
|
mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) {
|
|
return func(refreshToken string) (*TokenResponse, error) {
|
|
if refreshToken != "valid-refresh-token-for-near-expiry" {
|
|
return nil, fmt.Errorf("mock error: unexpected refresh token '%s'", refreshToken)
|
|
}
|
|
// Simulate successful refresh
|
|
newToken := createNewValidToken()
|
|
return &TokenResponse{IDToken: newToken, AccessToken: newToken, RefreshToken: "new-refresh-token-near-expiry", ExpiresIn: 3600}, nil
|
|
}
|
|
},
|
|
expectedStatus: http.StatusOK, // Expect success after proactive refresh
|
|
expectedBody: "OK",
|
|
},
|
|
// This test case remains valid as no refresh should be attempted
|
|
{
|
|
name: "Authenticated request with token valid (outside grace period)",
|
|
requestPath: "/protected",
|
|
setupSession: func(session *SessionData) {
|
|
// Create token expiring later (e.g., 10 mins, outside default 60s grace period)
|
|
exp := time.Now().Add(10 * time.Minute).Unix()
|
|
iat := time.Now().Add(-1 * time.Minute).Unix()
|
|
nbf := time.Now().Add(-1 * time.Minute).Unix()
|
|
validToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": exp, "iat": iat, "nbf": nbf,
|
|
"sub": "test-subject", "email": "user@example.com", "jti": generateRandomString(16),
|
|
})
|
|
session.SetAuthenticated(true)
|
|
session.SetEmail("user@example.com")
|
|
session.SetAccessToken(validToken)
|
|
session.SetIDToken(validToken) // Ensure ID token is also set
|
|
session.SetRefreshToken("should-not-be-used-refresh-token")
|
|
},
|
|
mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) {
|
|
// This should NOT be called
|
|
return func(refreshToken string) (*TokenResponse, error) {
|
|
t.Errorf("Refresh token function was called unexpectedly for valid token outside grace period")
|
|
return nil, fmt.Errorf("refresh should not have been attempted")
|
|
}
|
|
},
|
|
expectedStatus: http.StatusOK, // Expect success, no refresh needed
|
|
expectedBody: "OK",
|
|
},
|
|
{
|
|
name: "Disallowed Domain (Accept: JSON)",
|
|
requestPath: "/protected",
|
|
setupSession: func(session *SessionData) {
|
|
session.SetAuthenticated(true)
|
|
session.SetEmail("user@disallowed.com") // Use disallowed domain
|
|
// Generate a fresh valid token for this test case
|
|
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
|
|
"iat": time.Now().Unix(), "nbf": time.Now().Unix(), "sub": "test-subject", "email": "user@disallowed.com", // Match email
|
|
"jti": generateRandomString(16), // Unique JTI
|
|
})
|
|
session.SetAccessToken(freshToken)
|
|
session.SetIDToken(freshToken) // Ensure ID token is also set
|
|
session.SetRefreshToken("valid-refresh-token")
|
|
},
|
|
requestHeaders: map[string]string{
|
|
"Accept": "application/json",
|
|
},
|
|
expectedStatus: http.StatusForbidden,
|
|
expectedBody: `{"error":"Forbidden","error_description":"Access denied: Your email domain is not allowed. To log out, visit: /callback/logout","status_code":403}`,
|
|
},
|
|
{
|
|
name: "Disallowed Domain (Accept: HTML)",
|
|
requestPath: "/protected",
|
|
setupSession: func(session *SessionData) {
|
|
session.SetAuthenticated(true)
|
|
session.SetEmail("user@disallowed.com") // Use disallowed domain
|
|
// Generate a fresh valid token for this test case
|
|
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
|
|
"iat": time.Now().Unix(), "nbf": time.Now().Unix(), "sub": "test-subject", "email": "user@disallowed.com", // Match email
|
|
"jti": generateRandomString(16), // Unique JTI
|
|
})
|
|
session.SetAccessToken(freshToken)
|
|
session.SetIDToken(freshToken) // Ensure ID token is also set
|
|
session.SetRefreshToken("valid-refresh-token")
|
|
},
|
|
requestHeaders: map[string]string{
|
|
"Accept": "text/html",
|
|
},
|
|
expectedStatus: http.StatusForbidden, // Still Forbidden, but HTML response
|
|
expectedBody: "", // Body check is harder for HTML, focus on status and content-type
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
// Reset token blacklist and cache for each test to prevent token replay detection errors
|
|
ts.tOidc.tokenBlacklist = NewCache() // Use generic cache for blacklist
|
|
ts.tOidc.tokenCache = NewTokenCache()
|
|
|
|
// Reset the global replayCache to prevent "token replay detected" errors
|
|
replayCacheMu.Lock()
|
|
replayCache = NewCache()
|
|
replayCache.SetMaxSize(10000)
|
|
replayCacheMu.Unlock()
|
|
|
|
// Store original tokenVerifier to restore later
|
|
origTokenVerifier := ts.tOidc.tokenVerifier
|
|
|
|
// Create a mock tokenVerifier that clears the replay cache before verification
|
|
// This prevents replay detection when the same token is verified multiple times within a test
|
|
mockTokenVerifier := &MockTokenVerifier{
|
|
VerifyFunc: func(token string) error {
|
|
// Clear replay cache before token verification
|
|
replayCacheMu.Lock()
|
|
replayCache = NewCache()
|
|
replayCache.SetMaxSize(10000)
|
|
replayCacheMu.Unlock()
|
|
|
|
// Call the original verifier's VerifyToken method
|
|
// Ensure origTokenVerifier is not nil and is the correct type if necessary,
|
|
// though in this context it should be the *TraefikOidc instance.
|
|
if origTokenVerifier != nil {
|
|
return origTokenVerifier.VerifyToken(token)
|
|
}
|
|
return fmt.Errorf("original token verifier is nil")
|
|
},
|
|
}
|
|
|
|
// Replace tokenVerifier with our mock
|
|
ts.tOidc.tokenVerifier = mockTokenVerifier
|
|
|
|
// Restore original tokenVerifier after test
|
|
defer func() {
|
|
ts.tOidc.tokenVerifier = origTokenVerifier
|
|
}()
|
|
|
|
req := httptest.NewRequest("GET", tc.requestPath, nil)
|
|
// Set common headers needed by the logic (determineScheme, determineHost)
|
|
req.Header.Set("X-Forwarded-Proto", "http") // Or https if testing that
|
|
req.Header.Set("X-Forwarded-Host", "testhost.com")
|
|
req.Host = "testhost.com" // Also set Host header
|
|
// Set request headers from test case
|
|
if tc.requestHeaders != nil {
|
|
for key, value := range tc.requestHeaders {
|
|
req.Header.Set(key, value)
|
|
}
|
|
}
|
|
|
|
rr := httptest.NewRecorder()
|
|
|
|
// Setup session if needed
|
|
session, err := ts.tOidc.sessionManager.GetSession(req)
|
|
if err != nil {
|
|
t.Fatalf("Test %s: Failed to get initial session: %v", tc.name, err)
|
|
}
|
|
if tc.setupSession != nil {
|
|
tc.setupSession(session)
|
|
// Save session to recorder to get cookies
|
|
saveRecorder := httptest.NewRecorder()
|
|
if err := session.Save(req, saveRecorder); err != nil {
|
|
t.Fatalf("Test %s: Failed to save initial session: %v", tc.name, err)
|
|
}
|
|
// Copy cookies from save recorder to the actual request
|
|
for _, cookie := range saveRecorder.Result().Cookies() {
|
|
req.AddCookie(cookie)
|
|
}
|
|
}
|
|
|
|
// Mocking setup for TokenExchanger
|
|
originalExchanger := ts.tOidc.tokenExchanger // Store original
|
|
mockExchanger, isMock := originalExchanger.(*MockTokenExchanger)
|
|
if !isMock {
|
|
// This case should ideally not happen if Setup correctly assigns the mock,
|
|
// but handle it defensively.
|
|
t.Logf("Warning: Default exchanger was not the mock. Creating a temporary mock.")
|
|
mockExchanger = &MockTokenExchanger{
|
|
ExchangeCodeFunc: originalExchanger.ExchangeCodeForToken,
|
|
RefreshTokenFunc: originalExchanger.GetNewTokenWithRefreshToken,
|
|
RevokeTokenFunc: originalExchanger.RevokeTokenWithProvider,
|
|
}
|
|
ts.tOidc.tokenExchanger = mockExchanger // Temporarily assign mock
|
|
}
|
|
|
|
// Override specific mock methods if needed for the test case
|
|
originalMockRefreshFunc := mockExchanger.RefreshTokenFunc // Store current mock func
|
|
if tc.mockRefreshTokenFunc != nil {
|
|
// Assign the test case specific mock function
|
|
mockExchanger.RefreshTokenFunc = tc.mockRefreshTokenFunc(originalExchanger.GetNewTokenWithRefreshToken)
|
|
}
|
|
|
|
// Call ServeHTTP
|
|
ts.tOidc.ServeHTTP(rr, req)
|
|
|
|
// Restore original exchanger and mock function state
|
|
ts.tOidc.tokenExchanger = originalExchanger
|
|
if tc.mockRefreshTokenFunc != nil && mockExchanger != nil {
|
|
// Restore the previous mock function if we overrode it
|
|
mockExchanger.RefreshTokenFunc = originalMockRefreshFunc
|
|
}
|
|
|
|
// Check response status
|
|
if rr.Code != tc.expectedStatus {
|
|
t.Errorf("Test %s: Expected status %d, got %d. Body: %s", tc.name, tc.expectedStatus, rr.Code, rr.Body.String())
|
|
}
|
|
|
|
// Check response body if expected
|
|
// Check response body if expected (handle JSON vs HTML)
|
|
if tc.expectedBody != "" {
|
|
// For JSON, compare directly
|
|
if strings.Contains(rr.Header().Get("Content-Type"), "application/json") {
|
|
if body := strings.TrimSpace(rr.Body.String()); body != tc.expectedBody {
|
|
t.Errorf("Test %s: Expected JSON body %q, got %q", tc.name, tc.expectedBody, body)
|
|
}
|
|
} else if tc.expectedBody == "OK" { // Simple check for the "OK" body from next handler
|
|
if body := strings.TrimSpace(rr.Body.String()); body != tc.expectedBody {
|
|
t.Errorf("Test %s: Expected body %q, got %q", tc.name, tc.expectedBody, body)
|
|
}
|
|
}
|
|
// Add more sophisticated HTML body checks if needed
|
|
}
|
|
|
|
// Perform post-request session assertions if defined
|
|
if tc.assertSessionAfterRequest != nil {
|
|
tc.assertSessionAfterRequest(t, rr, req, ts.tOidc.sessionManager)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestJWKToPEM(t *testing.T) {
|
|
ts := &TestSuite{t: t}
|
|
ts.Setup()
|
|
|
|
tests := []struct {
|
|
name string
|
|
jwk *JWK
|
|
expectError bool
|
|
errorContains string
|
|
}{
|
|
{
|
|
name: "Unsupported Key Type",
|
|
jwk: &JWK{
|
|
Kty: "unsupported",
|
|
Kid: "test-key-id",
|
|
},
|
|
expectError: true,
|
|
errorContains: "unsupported key type",
|
|
},
|
|
{
|
|
name: "EC Key",
|
|
jwk: &JWK{
|
|
Kty: "EC",
|
|
Kid: "test-ec-key-id",
|
|
Crv: "P-256",
|
|
X: base64.RawURLEncoding.EncodeToString(ts.ecPrivateKey.PublicKey.X.Bytes()),
|
|
Y: base64.RawURLEncoding.EncodeToString(ts.ecPrivateKey.PublicKey.Y.Bytes()),
|
|
},
|
|
expectError: false,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
pemBytes, err := jwkToPEM(tc.jwk)
|
|
if tc.expectError {
|
|
if err == nil {
|
|
t.Errorf("Expected error, got nil")
|
|
} else if !strings.Contains(err.Error(), tc.errorContains) {
|
|
t.Errorf("Expected error containing '%s', got '%v'", tc.errorContains, err)
|
|
}
|
|
} else {
|
|
if err != nil {
|
|
t.Errorf("Unexpected error: %v", err)
|
|
}
|
|
if len(pemBytes) == 0 {
|
|
t.Error("PEM bytes should not be empty")
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestParseJWT(t *testing.T) {
|
|
ts := &TestSuite{t: t}
|
|
ts.Setup()
|
|
|
|
tests := []struct {
|
|
name string
|
|
token string
|
|
expectError bool
|
|
errorContains string
|
|
}{
|
|
{
|
|
name: "Invalid Format",
|
|
token: "invalid.jwt.token",
|
|
expectError: true,
|
|
errorContains: "invalid JWT format",
|
|
},
|
|
{
|
|
name: "Valid Token",
|
|
token: ts.token,
|
|
expectError: false,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
_, err := parseJWT(tc.token)
|
|
if tc.expectError {
|
|
if err == nil {
|
|
t.Errorf("Expected error, got nil")
|
|
} else if !strings.Contains(err.Error(), tc.errorContains) {
|
|
t.Errorf("Expected error containing '%s', got '%v'", tc.errorContains, err)
|
|
}
|
|
} else {
|
|
if err != nil {
|
|
t.Errorf("Unexpected error: %v", err)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestJWTVerify_MissingClaims(t *testing.T) {
|
|
ts := &TestSuite{t: t}
|
|
ts.Setup()
|
|
|
|
jwt := &JWT{
|
|
Header: map[string]interface{}{
|
|
"alg": "RS256",
|
|
"kid": "test-key-id",
|
|
},
|
|
Claims: map[string]interface{}{
|
|
// Missing 'iss', 'aud', 'exp', 'iat', 'sub'
|
|
},
|
|
}
|
|
|
|
err := jwt.Verify("https://test-issuer.com", "test-client-id")
|
|
if err == nil {
|
|
t.Error("Expected error for missing claims, got nil")
|
|
}
|
|
}
|
|
|
|
func TestHandleCallback(t *testing.T) {
|
|
ts := &TestSuite{t: t}
|
|
ts.Setup()
|
|
|
|
redirectURL := "http://example.com/"
|
|
|
|
tests := []struct {
|
|
name string
|
|
queryParams string
|
|
exchangeCodeForToken func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error)
|
|
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
|
|
sessionSetupFunc func(*SessionData)
|
|
expectedStatus int
|
|
}{
|
|
{
|
|
name: "Success",
|
|
queryParams: "?code=test-code&state=test-csrf-token",
|
|
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
|
return &TokenResponse{
|
|
IDToken: ts.token,
|
|
RefreshToken: "test-refresh-token",
|
|
}, nil
|
|
},
|
|
extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
|
|
return map[string]interface{}{
|
|
"email": "user@example.com",
|
|
"nonce": "test-nonce",
|
|
}, nil
|
|
},
|
|
sessionSetupFunc: func(session *SessionData) {
|
|
session.SetCSRF("test-csrf-token")
|
|
session.SetNonce("test-nonce")
|
|
},
|
|
expectedStatus: http.StatusFound,
|
|
},
|
|
{
|
|
name: "Missing Code",
|
|
queryParams: "",
|
|
sessionSetupFunc: func(session *SessionData) {
|
|
session.SetCSRF("test-csrf-token")
|
|
session.SetNonce("test-nonce")
|
|
},
|
|
expectedStatus: http.StatusBadRequest,
|
|
},
|
|
{
|
|
name: "Exchange Code Error",
|
|
queryParams: "?code=test-code&state=test-csrf-token",
|
|
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
|
return nil, fmt.Errorf("exchange code error")
|
|
},
|
|
sessionSetupFunc: func(session *SessionData) {
|
|
session.SetCSRF("test-csrf-token")
|
|
session.SetNonce("test-nonce")
|
|
},
|
|
expectedStatus: http.StatusInternalServerError,
|
|
},
|
|
{
|
|
name: "Missing ID Token",
|
|
queryParams: "?code=test-code&state=test-csrf-token",
|
|
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
|
return &TokenResponse{}, nil
|
|
},
|
|
sessionSetupFunc: func(session *SessionData) {
|
|
session.SetCSRF("test-csrf-token")
|
|
session.SetNonce("test-nonce")
|
|
},
|
|
expectedStatus: http.StatusInternalServerError,
|
|
},
|
|
{
|
|
name: "Disallowed Email",
|
|
queryParams: "?code=test-code&state=test-csrf-token",
|
|
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
|
// Generate a unique token for this test case to avoid replay issues
|
|
// Use claims relevant to this test (disallowed email)
|
|
now := time.Now()
|
|
exp := now.Add(1 * time.Hour).Unix()
|
|
iat := now.Unix()
|
|
nbf := now.Unix()
|
|
disallowedToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://test-issuer.com",
|
|
"aud": "test-client-id",
|
|
"exp": exp,
|
|
"iat": iat,
|
|
"nbf": nbf,
|
|
"sub": "test-subject-disallowed",
|
|
"email": "user@disallowed.com", // The disallowed email for this test
|
|
"nonce": "test-nonce", // Match the nonce set in sessionSetupFunc
|
|
"jti": generateRandomString(16), // Unique JTI
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create disallowed token for test: %w", err)
|
|
}
|
|
return &TokenResponse{
|
|
IDToken: disallowedToken,
|
|
RefreshToken: "test-refresh-token-disallowed",
|
|
}, nil
|
|
},
|
|
// Remove mock extractClaimsFunc - let the real one parse the disallowedToken
|
|
// The test should still fail correctly on the email check later.
|
|
// extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
|
|
// return map[string]interface{}{
|
|
// "email": "user@disallowed.com",
|
|
// "nonce": "test-nonce",
|
|
// }, nil
|
|
// },
|
|
sessionSetupFunc: func(session *SessionData) {
|
|
session.SetCSRF("test-csrf-token")
|
|
session.SetNonce("test-nonce")
|
|
},
|
|
expectedStatus: http.StatusForbidden,
|
|
},
|
|
{
|
|
name: "Invalid State Parameter",
|
|
queryParams: "?code=test-code&state=invalid-csrf-token",
|
|
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
|
return &TokenResponse{
|
|
IDToken: ts.token,
|
|
RefreshToken: "test-refresh-token",
|
|
}, nil
|
|
},
|
|
extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
|
|
return map[string]interface{}{
|
|
"email": "user@example.com",
|
|
"nonce": "test-nonce",
|
|
}, nil
|
|
},
|
|
sessionSetupFunc: func(session *SessionData) {
|
|
session.SetCSRF("test-csrf-token")
|
|
session.SetNonce("test-nonce")
|
|
},
|
|
expectedStatus: http.StatusBadRequest,
|
|
},
|
|
{
|
|
name: "Nonce Mismatch",
|
|
queryParams: "?code=test-code&state=test-csrf-token",
|
|
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
|
return &TokenResponse{
|
|
IDToken: ts.token,
|
|
RefreshToken: "test-refresh-token",
|
|
}, nil
|
|
},
|
|
extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
|
|
return map[string]interface{}{
|
|
"email": "user@example.com",
|
|
"nonce": "invalid-nonce",
|
|
}, nil
|
|
},
|
|
sessionSetupFunc: func(session *SessionData) {
|
|
session.SetCSRF("test-csrf-token")
|
|
session.SetNonce("test-nonce")
|
|
},
|
|
expectedStatus: http.StatusInternalServerError,
|
|
},
|
|
{
|
|
name: "Missing Nonce in Claims",
|
|
queryParams: "?code=test-code&state=test-csrf-token",
|
|
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
|
return &TokenResponse{
|
|
IDToken: ts.token,
|
|
RefreshToken: "test-refresh-token",
|
|
}, nil
|
|
},
|
|
extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
|
|
return map[string]interface{}{
|
|
"email": "user@example.com",
|
|
// Missing nonce
|
|
}, nil
|
|
},
|
|
sessionSetupFunc: func(session *SessionData) {
|
|
session.SetCSRF("test-csrf-token")
|
|
session.SetNonce("test-nonce")
|
|
},
|
|
expectedStatus: http.StatusInternalServerError,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
tc := tc // Capture range variable
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
// Clear the global replay cache before each test run
|
|
replayCacheMu.Lock()
|
|
replayCache = NewCache()
|
|
replayCache.SetMaxSize(10000)
|
|
replayCacheMu.Unlock()
|
|
|
|
// Explicitly clear the shared blacklist at the start of each sub-test
|
|
// to ensure no state leaks, even though we expect the local one to be used.
|
|
// Note: This line might be redundant now that the verifier is local, but keep for safety.
|
|
ts.tOidc.tokenBlacklist = NewCache() // Use generic cache for blacklist
|
|
|
|
logger := NewLogger("info")
|
|
sessionManager, _ := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger)
|
|
|
|
// Create a new instance for each test to avoid state carryover
|
|
instanceExtractClaimsFunc := tc.extractClaimsFunc
|
|
if instanceExtractClaimsFunc == nil {
|
|
instanceExtractClaimsFunc = extractClaims // Default to the real function if not provided by test case
|
|
}
|
|
tOidc := &TraefikOidc{
|
|
allowedUserDomains: map[string]struct{}{"example.com": {}},
|
|
logger: logger,
|
|
// exchangeCodeForTokenFunc: tc.exchangeCodeForToken, // Removed field
|
|
extractClaimsFunc: instanceExtractClaimsFunc, // Use the potentially defaulted function
|
|
tokenVerifier: nil, // Will be set to self below
|
|
jwtVerifier: nil, // Temporarily nil, will be set below
|
|
sessionManager: sessionManager,
|
|
tokenExchanger: &MockTokenExchanger{ // Create a new mock exchanger for this specific test run
|
|
ExchangeCodeFunc: func(ctx context.Context, grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error) {
|
|
// Wrap the test case function to match the required signature
|
|
if tc.exchangeCodeForToken != nil {
|
|
// Only call if the test case provided a function
|
|
return tc.exchangeCodeForToken(codeOrToken, redirectURL, codeVerifier)
|
|
}
|
|
// Provide a default behavior or error if no mock was provided for this test case
|
|
return nil, fmt.Errorf("mock ExchangeCodeFunc not implemented for this test case")
|
|
},
|
|
// Keep other mock funcs nil or provide defaults if needed by other parts of handleCallback
|
|
},
|
|
tokenCache: NewTokenCache(), // Initialize token cache
|
|
limiter: rate.NewLimiter(rate.Inf, 0), // Initialize rate limiter
|
|
tokenBlacklist: NewCache(), // Initialize token blacklist cache
|
|
|
|
// Add potentially missing fields based on New() comparison
|
|
clientID: ts.tOidc.clientID,
|
|
issuerURL: ts.tOidc.issuerURL,
|
|
jwkCache: ts.tOidc.jwkCache, // Use the mock cache from TestSuite
|
|
httpClient: ts.tOidc.httpClient,
|
|
initComplete: make(chan struct{}), // Initialize the channel
|
|
// Setting other fields like paths, enablePKCE etc. if needed
|
|
}
|
|
tOidc.tokenVerifier = tOidc // Point tokenVerifier to the local instance NOW
|
|
tOidc.jwtVerifier = tOidc // Point jwtVerifier to the local instance NOW
|
|
close(tOidc.initComplete) // Mark this test instance as initialized
|
|
|
|
// Create request and response recorder
|
|
req := httptest.NewRequest("GET", "/callback"+tc.queryParams, nil)
|
|
rr := httptest.NewRecorder()
|
|
|
|
// Create session
|
|
session, err := sessionManager.GetSession(req)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get session: %v", err)
|
|
}
|
|
if tc.sessionSetupFunc != nil {
|
|
tc.sessionSetupFunc(session)
|
|
}
|
|
if err := session.Save(req, rr); err != nil {
|
|
t.Fatalf("Failed to save session: %v", err)
|
|
}
|
|
|
|
// Copy cookies to the new request
|
|
for _, cookie := range rr.Result().Cookies() {
|
|
req.AddCookie(cookie)
|
|
}
|
|
|
|
// Reset response recorder for the actual test
|
|
rr = httptest.NewRecorder()
|
|
|
|
// Call handleCallback
|
|
tOidc.handleCallback(rr, req, redirectURL)
|
|
|
|
// Check response
|
|
if rr.Code != tc.expectedStatus {
|
|
t.Errorf("Expected status %d, got %d", tc.expectedStatus, rr.Code)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestIsAllowedDomain(t *testing.T) {
|
|
ts := &TestSuite{t: t}
|
|
ts.Setup()
|
|
|
|
tests := []struct {
|
|
name string
|
|
email string
|
|
allowedDomains map[string]struct{}
|
|
allowedUsers map[string]struct{}
|
|
allowed bool
|
|
expectedLogOutput string // For testing log messages
|
|
}{
|
|
{
|
|
name: "Allowed domain",
|
|
email: "user@example.com",
|
|
allowedDomains: map[string]struct{}{"example.com": {}},
|
|
allowedUsers: map[string]struct{}{},
|
|
allowed: true,
|
|
},
|
|
{
|
|
name: "Disallowed domain",
|
|
email: "user@notallowed.com",
|
|
allowedDomains: map[string]struct{}{"example.com": {}},
|
|
allowedUsers: map[string]struct{}{},
|
|
allowed: false,
|
|
},
|
|
{
|
|
name: "Invalid email",
|
|
email: "invalid-email",
|
|
allowedDomains: map[string]struct{}{"example.com": {}},
|
|
allowedUsers: map[string]struct{}{},
|
|
allowed: false,
|
|
},
|
|
{
|
|
name: "Specific user is allowed regardless of domain",
|
|
email: "specific.user@otherdomain.com",
|
|
allowedDomains: map[string]struct{}{"example.com": {}},
|
|
allowedUsers: map[string]struct{}{"specific.user@otherdomain.com": {}},
|
|
allowed: true,
|
|
},
|
|
{
|
|
name: "Case-insensitive email matching for specific user",
|
|
email: "Specific.User@otherdomain.com", // Mixed case
|
|
allowedDomains: map[string]struct{}{"example.com": {}},
|
|
allowedUsers: map[string]struct{}{"specific.user@otherdomain.com": {}}, // Lowercase
|
|
allowed: true,
|
|
},
|
|
{
|
|
name: "Only allowed users configured (no domains)",
|
|
email: "specific.user@otherdomain.com",
|
|
allowedDomains: map[string]struct{}{}, // Empty allowed domains
|
|
allowedUsers: map[string]struct{}{"specific.user@otherdomain.com": {}},
|
|
allowed: true,
|
|
},
|
|
{
|
|
name: "User not in allowed list when only specific users configured",
|
|
email: "other.user@otherdomain.com",
|
|
allowedDomains: map[string]struct{}{}, // Empty allowed domains
|
|
allowedUsers: map[string]struct{}{"specific.user@otherdomain.com": {}},
|
|
allowed: false,
|
|
},
|
|
{
|
|
name: "No restrictions (both empty)",
|
|
email: "anyone@anydomain.com",
|
|
allowedDomains: map[string]struct{}{},
|
|
allowedUsers: map[string]struct{}{},
|
|
allowed: true,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
// Configure TraefikOidc instance for this test case
|
|
tOidc := ts.tOidc
|
|
tOidc.allowedUserDomains = tc.allowedDomains
|
|
tOidc.allowedUsers = tc.allowedUsers
|
|
|
|
allowed := tOidc.isAllowedDomain(tc.email)
|
|
if allowed != tc.allowed {
|
|
t.Errorf("Expected allowed=%v, got %v", tc.allowed, allowed)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestOIDCHandler(t *testing.T) {
|
|
ts := &TestSuite{t: t}
|
|
ts.Setup()
|
|
|
|
ts.token = "valid.jwt.token"
|
|
|
|
tests := []struct {
|
|
name string
|
|
queryParams string
|
|
exchangeCodeForToken func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error)
|
|
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
|
|
sessionSetupFunc func(session *sessions.Session)
|
|
expectedStatus int
|
|
blacklist bool
|
|
rateLimit bool
|
|
cacheToken bool
|
|
}{
|
|
{
|
|
name: "Missing Code",
|
|
queryParams: "",
|
|
sessionSetupFunc: func(session *sessions.Session) {
|
|
// Set CSRF and nonce values in session
|
|
session.Values["csrf"] = "test-csrf-token"
|
|
session.Values["nonce"] = "test-nonce"
|
|
},
|
|
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
|
// Simulate token exchange
|
|
return &TokenResponse{
|
|
IDToken: ts.token,
|
|
RefreshToken: "test-refresh-token",
|
|
}, nil
|
|
},
|
|
extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
|
|
// Simulate extraction of claims with invalid nonce
|
|
return map[string]interface{}{
|
|
"email": "user@example.com",
|
|
"nonce": "invalid-nonce",
|
|
}, nil
|
|
},
|
|
expectedStatus: http.StatusInternalServerError,
|
|
},
|
|
{
|
|
name: "Missing Nonce in Claims",
|
|
queryParams: "?code=test-code&state=test-csrf-token",
|
|
sessionSetupFunc: func(session *sessions.Session) {
|
|
// Set CSRF and nonce values in session
|
|
session.Values["csrf"] = "test-csrf-token"
|
|
session.Values["nonce"] = "test-nonce"
|
|
},
|
|
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
|
// Simulate token exchange
|
|
return &TokenResponse{
|
|
IDToken: ts.token,
|
|
RefreshToken: "test-refresh-token",
|
|
}, nil
|
|
},
|
|
extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
|
|
// Simulate extraction of claims without nonce
|
|
return map[string]interface{}{
|
|
"email": "user@example.com",
|
|
}, nil
|
|
},
|
|
expectedStatus: http.StatusBadRequest,
|
|
},
|
|
{
|
|
name: "Invalid State Parameter",
|
|
queryParams: "?code=test-code&state=invalid-csrf-token",
|
|
sessionSetupFunc: func(session *sessions.Session) {
|
|
// Set CSRF and nonce values in session
|
|
session.Values["csrf"] = "test-csrf-token"
|
|
session.Values["nonce"] = "test-nonce"
|
|
},
|
|
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
|
// Simulate token exchange
|
|
return &TokenResponse{
|
|
IDToken: ts.token,
|
|
RefreshToken: "test-refresh-token",
|
|
}, nil
|
|
},
|
|
extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
|
|
// Simulate extraction of claims
|
|
return map[string]interface{}{
|
|
"email": "user@example.com",
|
|
"nonce": "test-nonce",
|
|
}, nil
|
|
},
|
|
expectedStatus: http.StatusBadRequest,
|
|
},
|
|
{
|
|
name: "Nonce Mismatch",
|
|
queryParams: "?code=test-code&state=test-csrf-token",
|
|
sessionSetupFunc: func(session *sessions.Session) {
|
|
// Set CSRF and nonce values in session
|
|
session.Values["csrf"] = "test-csrf-token"
|
|
session.Values["nonce"] = "test-nonce"
|
|
},
|
|
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
|
// Simulate token exchange
|
|
return &TokenResponse{
|
|
IDToken: ts.token,
|
|
RefreshToken: "test-refresh-token",
|
|
}, nil
|
|
},
|
|
extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
|
|
// Simulate extraction of claims with mismatched nonce
|
|
return map[string]interface{}{
|
|
"email": "user@example.com",
|
|
"nonce": "invalid-nonce",
|
|
}, nil
|
|
},
|
|
expectedStatus: http.StatusBadRequest,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
tc := tc // Capture range variable
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
// Reset token blacklist and cache
|
|
ts.tOidc.tokenBlacklist = NewCache() // Use generic cache for blacklist
|
|
ts.tOidc.tokenCache = NewTokenCache()
|
|
ts.tOidc.limiter = rate.NewLimiter(rate.Every(time.Second), 10)
|
|
|
|
// Set up the test case
|
|
if tc.blacklist {
|
|
// Use Set with a duration. Value 'true' is arbitrary.
|
|
ts.tOidc.tokenBlacklist.Set(ts.token, true, 1*time.Hour)
|
|
}
|
|
|
|
if tc.rateLimit {
|
|
// Exceed rate limit
|
|
ts.tOidc.limiter = rate.NewLimiter(rate.Every(time.Hour), 0)
|
|
}
|
|
|
|
if tc.cacheToken {
|
|
// Cache the token with dummy claims
|
|
ts.tOidc.tokenCache.Set(ts.token, map[string]interface{}{
|
|
"empty": "claim",
|
|
}, 60)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestHandleLogout tests the logout functionality
|
|
func TestHandleLogout(t *testing.T) {
|
|
ts := &TestSuite{t: t}
|
|
ts.Setup()
|
|
|
|
// Create mock revocation endpoint server
|
|
mockRevocationServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != "POST" {
|
|
t.Errorf("Expected POST request, got %s", r.Method)
|
|
}
|
|
if err := r.ParseForm(); err != nil {
|
|
t.Fatalf("Failed to parse form: %v", err)
|
|
}
|
|
// Verify the required parameters are present
|
|
if r.Form.Get("token") == "" {
|
|
t.Error("Missing token parameter")
|
|
}
|
|
if r.Form.Get("token_type_hint") == "" {
|
|
t.Error("Missing token_type_hint parameter")
|
|
}
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
defer mockRevocationServer.Close()
|
|
|
|
tests := []struct {
|
|
name string
|
|
setupSession func(*SessionData)
|
|
endSessionURL string
|
|
expectedStatus int
|
|
expectedURL string
|
|
host string
|
|
}{
|
|
{
|
|
name: "Successful logout with end session endpoint",
|
|
setupSession: func(session *SessionData) {
|
|
session.SetAuthenticated(true)
|
|
session.SetAccessToken("test.id.token")
|
|
session.SetRefreshToken("test-refresh-token")
|
|
},
|
|
endSessionURL: "https://provider/end-session",
|
|
expectedStatus: http.StatusFound,
|
|
expectedURL: "https://provider/end-session?id_token_hint=test.id.token&post_logout_redirect_uri=http%3A%2F%2Fexample.com%2F",
|
|
host: "test-host",
|
|
},
|
|
{
|
|
name: "Successful logout without end session endpoint",
|
|
setupSession: func(session *SessionData) {
|
|
session.SetAuthenticated(true)
|
|
session.SetAccessToken("test.id.token")
|
|
session.SetRefreshToken("test-refresh-token")
|
|
},
|
|
endSessionURL: "",
|
|
expectedStatus: http.StatusFound,
|
|
expectedURL: "http://example.com/",
|
|
host: "test-host",
|
|
},
|
|
{
|
|
name: "Logout with empty session",
|
|
setupSession: func(session *SessionData) {},
|
|
expectedStatus: http.StatusFound,
|
|
expectedURL: "http://example.com/",
|
|
host: "test-host",
|
|
},
|
|
{
|
|
name: "Logout with invalid end session URL",
|
|
setupSession: func(session *SessionData) {
|
|
session.SetAuthenticated(true)
|
|
session.SetAccessToken("test.id.token")
|
|
session.SetRefreshToken("test-refresh-token")
|
|
},
|
|
endSessionURL: ":\\invalid-url",
|
|
expectedStatus: http.StatusInternalServerError,
|
|
host: "test-host",
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
logger := NewLogger("info")
|
|
sessionManager, _ := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger)
|
|
tOidc := &TraefikOidc{
|
|
revocationURL: mockRevocationServer.URL,
|
|
endSessionURL: tc.endSessionURL,
|
|
scheme: "http",
|
|
logger: logger,
|
|
tokenBlacklist: NewCache(), // Use generic cache for blacklist
|
|
httpClient: &http.Client{},
|
|
clientID: "test-client-id",
|
|
clientSecret: "test-client-secret",
|
|
tokenCache: NewTokenCache(),
|
|
forceHTTPS: false,
|
|
sessionManager: sessionManager,
|
|
}
|
|
|
|
// Create request with proper headers
|
|
req := httptest.NewRequest("GET", "/logout", nil)
|
|
req.Header.Set("Host", tc.host)
|
|
|
|
// Create a response recorder
|
|
rr := httptest.NewRecorder()
|
|
|
|
// Get a session
|
|
session, err := sessionManager.GetSession(req)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get session: %v", err)
|
|
}
|
|
if tc.setupSession != nil {
|
|
tc.setupSession(session)
|
|
}
|
|
if err := session.Save(req, rr); err != nil {
|
|
t.Fatalf("Failed to save session: %v", err)
|
|
}
|
|
|
|
// Copy cookies to the new request
|
|
for _, cookie := range rr.Result().Cookies() {
|
|
req.AddCookie(cookie)
|
|
}
|
|
|
|
// Reset response recorder
|
|
rr = httptest.NewRecorder()
|
|
|
|
// Handle logout
|
|
tOidc.handleLogout(rr, req)
|
|
|
|
// Check response
|
|
if rr.Code != tc.expectedStatus {
|
|
t.Errorf("Expected status %d, got %d", tc.expectedStatus, rr.Code)
|
|
}
|
|
|
|
if tc.expectedURL != "" {
|
|
location := rr.Header().Get("Location")
|
|
if location != tc.expectedURL {
|
|
t.Errorf("Expected redirect to %q, got %q", tc.expectedURL, location)
|
|
}
|
|
}
|
|
|
|
// Verify session is cleared
|
|
updatedSession, err := sessionManager.GetSession(req)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get updated session: %v", err)
|
|
}
|
|
|
|
// Verify tokens are cleared
|
|
if token := updatedSession.GetAccessToken(); token != "" {
|
|
t.Error("Access token not cleared")
|
|
}
|
|
if token := updatedSession.GetRefreshToken(); token != "" {
|
|
t.Error("Refresh token not cleared")
|
|
}
|
|
if updatedSession.GetAuthenticated() {
|
|
t.Error("Session still marked as authenticated")
|
|
}
|
|
|
|
// Check token blacklist
|
|
if token := session.GetAccessToken(); token != "" {
|
|
if _, exists := tOidc.tokenBlacklist.Get(token); !exists {
|
|
t.Error("Access token was not blacklisted in cache")
|
|
}
|
|
}
|
|
if token := session.GetRefreshToken(); token != "" {
|
|
if _, exists := tOidc.tokenBlacklist.Get(token); !exists {
|
|
t.Error("Refresh token was not blacklisted in cache")
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestRevokeTokenWithProvider tests the token revocation with provider
|
|
func TestRevokeTokenWithProvider(t *testing.T) {
|
|
ts := &TestSuite{t: t}
|
|
ts.Setup()
|
|
|
|
tests := []struct {
|
|
name string
|
|
token string
|
|
tokenType string
|
|
statusCode int
|
|
expectError bool
|
|
}{
|
|
{
|
|
name: "Successful token revocation",
|
|
token: "valid-token",
|
|
tokenType: "refresh_token",
|
|
statusCode: http.StatusOK,
|
|
expectError: false,
|
|
},
|
|
{
|
|
name: "Failed token revocation",
|
|
token: "invalid-token",
|
|
tokenType: "refresh_token",
|
|
statusCode: http.StatusBadRequest,
|
|
expectError: true,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
// Create test server
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Verify request method and content type
|
|
if r.Method != "POST" {
|
|
t.Errorf("Expected POST request, got %s", r.Method)
|
|
}
|
|
if ct := r.Header.Get("Content-Type"); ct != "application/x-www-form-urlencoded" {
|
|
t.Errorf("Expected Content-Type application/x-www-form-urlencoded, got %s", ct)
|
|
}
|
|
|
|
// Verify form values
|
|
if err := r.ParseForm(); err != nil {
|
|
t.Fatalf("Failed to parse form: %v", err)
|
|
}
|
|
if got := r.Form.Get("token"); got != tc.token {
|
|
t.Errorf("Expected token %s, got %s", tc.token, got)
|
|
}
|
|
if got := r.Form.Get("token_type_hint"); got != tc.tokenType {
|
|
t.Errorf("Expected token_type_hint %s, got %s", tc.tokenType, got)
|
|
}
|
|
if got := r.Form.Get("client_id"); got != ts.tOidc.clientID {
|
|
t.Errorf("Expected client_id %s, got %s", ts.tOidc.clientID, got)
|
|
}
|
|
if got := r.Form.Get("client_secret"); got != ts.tOidc.clientSecret {
|
|
t.Errorf("Expected client_secret %s, got %s", ts.tOidc.clientSecret, got)
|
|
}
|
|
|
|
w.WriteHeader(tc.statusCode)
|
|
}))
|
|
defer server.Close()
|
|
|
|
// Set revocation URL to test server
|
|
ts.tOidc.revocationURL = server.URL
|
|
|
|
// Test token revocation
|
|
err := ts.tOidc.RevokeTokenWithProvider(tc.token, tc.tokenType)
|
|
if tc.expectError && err == nil {
|
|
t.Error("Expected error but got nil")
|
|
}
|
|
if !tc.expectError && err != nil {
|
|
t.Errorf("Unexpected error: %v", err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestRevokeToken tests the token revocation functionality
|
|
func TestRevokeToken(t *testing.T) {
|
|
ts := &TestSuite{t: t}
|
|
ts.Setup()
|
|
|
|
token := "test.token.with.claims"
|
|
claims := map[string]interface{}{
|
|
"exp": float64(time.Now().Add(time.Hour).Unix()),
|
|
}
|
|
|
|
// Test token revocation
|
|
t.Run("Token revocation", func(t *testing.T) {
|
|
// Create a new instance for this specific test
|
|
tOidc := &TraefikOidc{
|
|
tokenBlacklist: NewCache(), // Use generic cache for blacklist
|
|
tokenCache: NewTokenCache(),
|
|
logger: NewLogger("info"), // Initialize the logger
|
|
}
|
|
|
|
// Cache the token
|
|
tOidc.tokenCache.Set(token, claims, time.Hour)
|
|
|
|
// Revoke the token
|
|
tOidc.RevokeToken(token)
|
|
|
|
// Verify token was removed from cache
|
|
if _, exists := tOidc.tokenCache.Get(token); exists {
|
|
t.Error("Token was not removed from cache")
|
|
}
|
|
|
|
// Verify token was added to blacklist cache
|
|
if _, exists := tOidc.tokenBlacklist.Get(token); !exists {
|
|
t.Error("Token was not added to blacklist")
|
|
}
|
|
})
|
|
}
|
|
|
|
// Add this new test function
|
|
func TestBuildLogoutURL(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
endSessionURL string
|
|
idToken string
|
|
postLogoutRedirect string
|
|
expectedURL string
|
|
expectError bool
|
|
}{
|
|
{
|
|
name: "Valid URL",
|
|
endSessionURL: "https://provider/end-session",
|
|
idToken: "test.id.token",
|
|
postLogoutRedirect: "http://example.com/",
|
|
expectedURL: "https://provider/end-session?id_token_hint=test.id.token&post_logout_redirect_uri=http%3A%2F%2Fexample.com%2F",
|
|
expectError: false,
|
|
},
|
|
{
|
|
name: "Invalid URL",
|
|
endSessionURL: "://invalid-url",
|
|
idToken: "test.id.token",
|
|
postLogoutRedirect: "http://example.com/",
|
|
expectError: true,
|
|
},
|
|
{
|
|
name: "URL with existing query parameters",
|
|
endSessionURL: "https://provider/end-session?existing=param",
|
|
idToken: "test.id.token",
|
|
postLogoutRedirect: "http://example.com/",
|
|
expectedURL: "https://provider/end-session?existing=param&id_token_hint=test.id.token&post_logout_redirect_uri=http%3A%2F%2Fexample.com%2F",
|
|
expectError: false,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
url, err := BuildLogoutURL(tc.endSessionURL, tc.idToken, tc.postLogoutRedirect)
|
|
|
|
if tc.expectError {
|
|
if err == nil {
|
|
t.Error("Expected error but got nil")
|
|
}
|
|
} else {
|
|
if err != nil {
|
|
t.Errorf("Unexpected error: %v", err)
|
|
}
|
|
if url != tc.expectedURL {
|
|
t.Errorf("Expected URL %q, got %q", tc.expectedURL, url)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// Add this new test function
|
|
func TestHandleExpiredToken(t *testing.T) {
|
|
ts := &TestSuite{t: t}
|
|
ts.Setup()
|
|
|
|
tests := []struct {
|
|
name string
|
|
setupSession func(*SessionData)
|
|
expectedPath string
|
|
}{
|
|
{
|
|
name: "Basic expired token",
|
|
setupSession: func(session *SessionData) {
|
|
session.SetAuthenticated(true)
|
|
session.SetAccessToken("expired.token")
|
|
session.SetEmail("test@example.com")
|
|
},
|
|
expectedPath: "/original/path",
|
|
},
|
|
{
|
|
name: "Session with additional values",
|
|
setupSession: func(session *SessionData) {
|
|
session.SetAuthenticated(true)
|
|
session.SetAccessToken("expired.token")
|
|
session.mainSession.Values["custom_value"] = "should-be-cleared"
|
|
},
|
|
expectedPath: "/another/path",
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
logger := NewLogger("info")
|
|
sessionManager, _ := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger)
|
|
|
|
tOidc := &TraefikOidc{
|
|
sessionManager: sessionManager,
|
|
logger: logger,
|
|
tokenVerifier: ts.tOidc.tokenVerifier,
|
|
jwtVerifier: ts.tOidc.jwtVerifier,
|
|
initComplete: make(chan struct{}),
|
|
initiateAuthenticationFunc: func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
|
|
http.Redirect(rw, req, "/login", http.StatusFound)
|
|
},
|
|
}
|
|
close(tOidc.initComplete)
|
|
|
|
// Create request
|
|
req := httptest.NewRequest("GET", tc.expectedPath, nil)
|
|
rr := httptest.NewRecorder()
|
|
|
|
// Get session
|
|
session, err := sessionManager.GetSession(req)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get session: %v", err)
|
|
}
|
|
|
|
// Setup session data
|
|
tc.setupSession(session)
|
|
|
|
// Handle expired token
|
|
tOidc.handleExpiredToken(rr, req, session, tc.expectedPath)
|
|
|
|
// Get the updated session to verify changes
|
|
updatedSession, err := sessionManager.GetSession(req)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get updated session: %v", err)
|
|
}
|
|
|
|
// Verify main session values
|
|
if updatedSession.GetCSRF() == "" {
|
|
t.Error("CSRF token not set")
|
|
}
|
|
if path := updatedSession.GetIncomingPath(); path != tc.expectedPath {
|
|
t.Errorf("Expected path %s, got %s", tc.expectedPath, path)
|
|
}
|
|
if updatedSession.GetNonce() == "" {
|
|
t.Error("Nonce not set")
|
|
}
|
|
|
|
// Verify tokens are cleared
|
|
if token := updatedSession.GetAccessToken(); token != "" {
|
|
t.Error("Access token not cleared")
|
|
}
|
|
if token := updatedSession.GetRefreshToken(); token != "" {
|
|
t.Error("Refresh token not cleared")
|
|
}
|
|
|
|
// Verify redirect status
|
|
if rr.Code != http.StatusFound {
|
|
t.Errorf("Expected status %d, got %d", http.StatusFound, rr.Code)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// Add this new test function
|
|
func TestExtractGroupsAndRoles(t *testing.T) {
|
|
ts := &TestSuite{t: t}
|
|
ts.Setup()
|
|
|
|
tests := []struct {
|
|
name string
|
|
claims map[string]interface{}
|
|
expectGroups []string
|
|
expectRoles []string
|
|
expectError bool
|
|
}{
|
|
{
|
|
name: "Valid groups and roles",
|
|
claims: map[string]interface{}{
|
|
"groups": []interface{}{"group1", "group2"},
|
|
"roles": []interface{}{"role1", "role2"},
|
|
},
|
|
expectGroups: []string{"group1", "group2"},
|
|
expectRoles: []string{"role1", "role2"},
|
|
expectError: false,
|
|
},
|
|
{
|
|
name: "Empty groups and roles",
|
|
claims: map[string]interface{}{
|
|
"groups": []interface{}{},
|
|
"roles": []interface{}{},
|
|
},
|
|
expectGroups: []string{},
|
|
expectRoles: []string{},
|
|
expectError: false,
|
|
},
|
|
{
|
|
name: "Invalid groups format",
|
|
claims: map[string]interface{}{
|
|
"groups": "not-an-array",
|
|
"roles": []interface{}{"role1"},
|
|
},
|
|
expectError: true,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
// Create a test token with the claims
|
|
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", tc.claims)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create test token: %v", err)
|
|
}
|
|
|
|
groups, roles, err := ts.tOidc.extractGroupsAndRoles(token)
|
|
|
|
if tc.expectError {
|
|
if err == nil {
|
|
t.Error("Expected error but got nil")
|
|
}
|
|
} else {
|
|
if err != nil {
|
|
t.Errorf("Unexpected error: %v", err)
|
|
}
|
|
|
|
// Compare groups
|
|
if !stringSliceEqual(groups, tc.expectGroups) {
|
|
t.Errorf("Expected groups %v, got %v", tc.expectGroups, groups)
|
|
}
|
|
|
|
// Compare roles
|
|
if !stringSliceEqual(roles, tc.expectRoles) {
|
|
t.Errorf("Expected roles %v, got %v", tc.expectRoles, roles)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestMultipleMiddlewareInstances verifies that multiple middleware instances
|
|
// can be created and initialized properly for different routes
|
|
func TestMultipleMiddlewareInstances(t *testing.T) {
|
|
// Create mock provider metadata server
|
|
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
metadata := ProviderMetadata{
|
|
Issuer: "https://test-issuer.com",
|
|
AuthURL: "https://test-issuer.com/auth",
|
|
TokenURL: "https://test-issuer.com/token",
|
|
JWKSURL: "https://test-issuer.com/jwks",
|
|
RevokeURL: "https://test-issuer.com/revoke",
|
|
EndSessionURL: "https://test-issuer.com/end-session",
|
|
}
|
|
json.NewEncoder(w).Encode(metadata)
|
|
}))
|
|
defer mockServer.Close()
|
|
|
|
// Create base config
|
|
config := &Config{
|
|
ProviderURL: mockServer.URL,
|
|
ClientID: "test-client",
|
|
ClientSecret: "test-secret",
|
|
CallbackURL: "/callback",
|
|
SessionEncryptionKey: "test-encryption-key-thats-long-enough",
|
|
}
|
|
|
|
// Create multiple middleware instances
|
|
routes := []string{"/api/v1", "/api/v2", "/api/v3"}
|
|
var middlewares []*TraefikOidc
|
|
|
|
for _, route := range routes {
|
|
config.CallbackURL = route + "/callback"
|
|
middleware, err := New(context.Background(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}), config, "test")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create middleware for route %s: %v", route, err)
|
|
}
|
|
|
|
// Type assert to access internal fields
|
|
if m, ok := middleware.(*TraefikOidc); ok {
|
|
middlewares = append(middlewares, m)
|
|
} else {
|
|
t.Fatalf("Middleware is not of type *TraefikOidc")
|
|
}
|
|
}
|
|
|
|
// Wait for all instances to initialize
|
|
for i, m := range middlewares {
|
|
select {
|
|
case <-m.initComplete:
|
|
case <-time.After(5 * time.Second):
|
|
t.Fatalf("Middleware instance %d failed to initialize", i)
|
|
}
|
|
|
|
// Verify each instance has its own unique configuration
|
|
if m.issuerURL != "https://test-issuer.com" {
|
|
t.Errorf("Instance %d: Expected issuer URL %s, got %s", i, "https://test-issuer.com", m.issuerURL)
|
|
}
|
|
if m.authURL != "https://test-issuer.com/auth" {
|
|
t.Errorf("Instance %d: Expected auth URL %s, got %s", i, "https://test-issuer.com/auth", m.authURL)
|
|
}
|
|
if m.tokenURL != "https://test-issuer.com/token" {
|
|
t.Errorf("Instance %d: Expected token URL %s, got %s", i, "https://test-issuer.com/token", m.tokenURL)
|
|
}
|
|
if m.jwksURL != "https://test-issuer.com/jwks" {
|
|
t.Errorf("Instance %d: Expected JWKS URL %s, got %s", i, "https://test-issuer.com/jwks", m.jwksURL)
|
|
}
|
|
if m.redirURLPath != routes[i]+"/callback" {
|
|
t.Errorf("Instance %d: Expected callback URL %s, got %s", i, routes[i]+"/callback", m.redirURLPath)
|
|
}
|
|
}
|
|
|
|
// Test that each instance can handle requests independently
|
|
for i, m := range middlewares {
|
|
req := httptest.NewRequest("GET", routes[i]+"/protected", nil)
|
|
rr := httptest.NewRecorder()
|
|
|
|
m.ServeHTTP(rr, req)
|
|
|
|
// Should redirect to auth URL since not authenticated
|
|
if rr.Code != http.StatusFound {
|
|
t.Errorf("Instance %d: Expected redirect status %d, got %d", i, http.StatusFound, rr.Code)
|
|
}
|
|
|
|
location := rr.Header().Get("Location")
|
|
if !strings.Contains(location, "https://test-issuer.com/auth") {
|
|
t.Errorf("Instance %d: Expected redirect to auth URL, got %s", i, location)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestServeHTTPRolesAndGroups(t *testing.T) {
|
|
ts := &TestSuite{t: t}
|
|
ts.Setup()
|
|
|
|
// Create consistent timestamps for all test cases
|
|
now := time.Now()
|
|
exp := now.Add(1 * time.Hour).Unix()
|
|
iat := now.Add(-2 * time.Minute).Unix() // Account for clock skew
|
|
nbf := now.Add(-2 * time.Minute).Unix() // Account for clock skew
|
|
|
|
tests := []struct {
|
|
name string
|
|
allowedRolesAndGroups map[string]struct{}
|
|
claims map[string]interface{}
|
|
setupSession func(*SessionData)
|
|
expectedStatus int
|
|
expectedHeaders map[string]string
|
|
}{
|
|
{
|
|
name: "User with allowed role",
|
|
allowedRolesAndGroups: map[string]struct{}{
|
|
"admin": {},
|
|
},
|
|
claims: map[string]interface{}{
|
|
"iss": "https://test-issuer.com",
|
|
"aud": "test-client-id",
|
|
"exp": exp,
|
|
"iat": iat,
|
|
"nbf": nbf,
|
|
"sub": "test-subject",
|
|
"roles": []interface{}{"admin", "user"},
|
|
"groups": []interface{}{"group1"},
|
|
"jti": generateRandomString(16),
|
|
},
|
|
setupSession: func(session *SessionData) {
|
|
session.SetAuthenticated(true)
|
|
session.SetEmail("user@example.com")
|
|
},
|
|
expectedStatus: http.StatusOK,
|
|
expectedHeaders: map[string]string{
|
|
"X-User-Roles": "admin,user",
|
|
"X-User-Groups": "group1",
|
|
},
|
|
},
|
|
{
|
|
name: "User with allowed group",
|
|
allowedRolesAndGroups: map[string]struct{}{
|
|
"allowed-group": {},
|
|
},
|
|
claims: map[string]interface{}{
|
|
"iss": "https://test-issuer.com",
|
|
"aud": "test-client-id",
|
|
"exp": exp,
|
|
"iat": iat,
|
|
"nbf": nbf,
|
|
"sub": "test-subject",
|
|
"roles": []interface{}{"user"},
|
|
"groups": []interface{}{"allowed-group"},
|
|
"jti": generateRandomString(16),
|
|
},
|
|
setupSession: func(session *SessionData) {
|
|
session.SetAuthenticated(true)
|
|
session.SetEmail("user@example.com")
|
|
},
|
|
expectedStatus: http.StatusOK,
|
|
expectedHeaders: map[string]string{
|
|
"X-User-Roles": "user",
|
|
"X-User-Groups": "allowed-group",
|
|
},
|
|
},
|
|
{
|
|
name: "User without allowed roles or groups",
|
|
allowedRolesAndGroups: map[string]struct{}{
|
|
"admin": {},
|
|
"allowed-group": {},
|
|
},
|
|
claims: map[string]interface{}{
|
|
"iss": "https://test-issuer.com",
|
|
"aud": "test-client-id",
|
|
"exp": exp,
|
|
"iat": iat,
|
|
"nbf": nbf,
|
|
"sub": "test-subject",
|
|
"roles": []interface{}{"user"},
|
|
"groups": []interface{}{"regular-group"},
|
|
"jti": generateRandomString(16),
|
|
},
|
|
setupSession: func(session *SessionData) {
|
|
session.SetAuthenticated(true)
|
|
session.SetEmail("user@example.com")
|
|
},
|
|
expectedStatus: http.StatusForbidden,
|
|
},
|
|
{
|
|
name: "No role/group restrictions",
|
|
allowedRolesAndGroups: map[string]struct{}{},
|
|
claims: map[string]interface{}{
|
|
"iss": "https://test-issuer.com",
|
|
"aud": "test-client-id",
|
|
"exp": exp,
|
|
"iat": iat,
|
|
"nbf": nbf,
|
|
"sub": "test-subject",
|
|
"roles": []interface{}{"user"},
|
|
"groups": []interface{}{"regular-group"},
|
|
"jti": generateRandomString(16),
|
|
},
|
|
setupSession: func(session *SessionData) {
|
|
session.SetAuthenticated(true)
|
|
session.SetEmail("user@example.com")
|
|
},
|
|
expectedStatus: http.StatusOK,
|
|
expectedHeaders: map[string]string{
|
|
"X-User-Roles": "user",
|
|
"X-User-Groups": "regular-group",
|
|
},
|
|
},
|
|
{
|
|
name: "Claims without roles and groups",
|
|
allowedRolesAndGroups: map[string]struct{}{},
|
|
claims: map[string]interface{}{
|
|
"iss": "https://test-issuer.com",
|
|
"aud": "test-client-id",
|
|
"exp": exp,
|
|
"iat": iat,
|
|
"nbf": nbf,
|
|
"sub": "test-subject",
|
|
"jti": generateRandomString(16),
|
|
},
|
|
setupSession: func(session *SessionData) {
|
|
session.SetAuthenticated(true)
|
|
session.SetEmail("user@example.com")
|
|
},
|
|
expectedStatus: http.StatusOK,
|
|
expectedHeaders: map[string]string{},
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
// Create token with claims
|
|
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", tc.claims)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create test token: %v", err)
|
|
}
|
|
|
|
// Create test handler
|
|
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
// Configure OIDC middleware
|
|
tOidc := ts.tOidc
|
|
tOidc.next = nextHandler
|
|
tOidc.allowedRolesAndGroups = tc.allowedRolesAndGroups
|
|
|
|
// Create request
|
|
req := httptest.NewRequest("GET", "/protected", nil)
|
|
rr := httptest.NewRecorder()
|
|
|
|
// Set up session
|
|
session, err := tOidc.sessionManager.GetSession(req)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get session: %v", err)
|
|
}
|
|
|
|
tc.setupSession(session)
|
|
session.SetAccessToken(token)
|
|
|
|
if err := session.Save(req, rr); err != nil {
|
|
t.Fatalf("Failed to save session: %v", err)
|
|
}
|
|
|
|
// Copy cookies to the new request
|
|
for _, cookie := range rr.Result().Cookies() {
|
|
req.AddCookie(cookie)
|
|
}
|
|
|
|
// Reset response recorder
|
|
rr = httptest.NewRecorder()
|
|
|
|
// Serve request
|
|
tOidc.ServeHTTP(rr, req)
|
|
|
|
// Check status code
|
|
if rr.Code != tc.expectedStatus {
|
|
t.Errorf("Expected status %d, got %d", tc.expectedStatus, rr.Code)
|
|
}
|
|
|
|
// Check headers if status is OK
|
|
if tc.expectedStatus == http.StatusOK {
|
|
for header, expectedValue := range tc.expectedHeaders {
|
|
if value := req.Header.Get(header); value != expectedValue {
|
|
t.Errorf("Expected header %s to be %s, got %s", header, expectedValue, value)
|
|
}
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// Helper function to compare string slices
|
|
func stringSliceEqual(a, b []string) bool {
|
|
if len(a) != len(b) {
|
|
return false
|
|
}
|
|
for i := range a {
|
|
if a[i] != b[i] {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
// TestExchangeTokensWithRedirects tests the token exchange process with redirects
|
|
func TestExchangeTokensWithRedirects(t *testing.T) {
|
|
ts := &TestSuite{t: t}
|
|
ts.Setup()
|
|
|
|
tests := []struct {
|
|
name string
|
|
setupServer func() *httptest.Server
|
|
expectError bool
|
|
errorContains string
|
|
}{
|
|
{
|
|
name: "Successful token exchange with redirects",
|
|
setupServer: func() *httptest.Server {
|
|
redirectCount := 0
|
|
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if redirectCount < 3 {
|
|
// Set a cookie before redirecting
|
|
http.SetCookie(w, &http.Cookie{
|
|
Name: fmt.Sprintf("redirect-cookie-%d", redirectCount),
|
|
Value: "test-value",
|
|
})
|
|
redirectCount++
|
|
w.Header().Set("Location", r.URL.String())
|
|
w.WriteHeader(http.StatusFound)
|
|
return
|
|
}
|
|
|
|
// Verify all cookies from previous redirects are present
|
|
cookies := r.Cookies()
|
|
if len(cookies) != 3 {
|
|
t.Errorf("Expected 3 cookies, got %d", len(cookies))
|
|
}
|
|
for i := 0; i < 3; i++ {
|
|
found := false
|
|
expectedName := fmt.Sprintf("redirect-cookie-%d", i)
|
|
for _, cookie := range cookies {
|
|
if cookie.Name == expectedName {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
t.Errorf("Cookie %s not found", expectedName)
|
|
}
|
|
}
|
|
|
|
// Return successful token response
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(TokenResponse{
|
|
IDToken: "test.id.token",
|
|
AccessToken: "test-access-token",
|
|
TokenType: "Bearer",
|
|
ExpiresIn: 3600,
|
|
RefreshToken: "test-refresh-token",
|
|
})
|
|
}))
|
|
},
|
|
expectError: false,
|
|
},
|
|
{
|
|
name: "Too many redirects",
|
|
setupServer: func() *httptest.Server {
|
|
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Location", r.URL.String())
|
|
w.WriteHeader(http.StatusFound)
|
|
}))
|
|
},
|
|
expectError: true,
|
|
errorContains: "stopped after 50 redirects",
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
server := tc.setupServer()
|
|
defer server.Close()
|
|
|
|
// Configure the test instance
|
|
tOidc := ts.tOidc
|
|
tOidc.tokenURL = server.URL
|
|
|
|
// Test token exchange
|
|
response, err := tOidc.exchangeTokens(context.Background(), "authorization_code", "test-code", "http://callback", "test-code-verifier")
|
|
|
|
if tc.expectError {
|
|
if err == nil {
|
|
t.Error("Expected error but got nil")
|
|
} else if !strings.Contains(err.Error(), tc.errorContains) {
|
|
t.Errorf("Expected error containing %q, got %q", tc.errorContains, err.Error())
|
|
}
|
|
} else {
|
|
if err != nil {
|
|
t.Errorf("Unexpected error: %v", err)
|
|
}
|
|
if response == nil {
|
|
t.Error("Expected token response but got nil")
|
|
} else if response.IDToken != "test.id.token" {
|
|
t.Errorf("Expected ID token %q, got %q", "test.id.token", response.IDToken)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestBuildAuthURL tests the buildAuthURL function with various URL scenarios
|
|
func TestBuildAuthURL(t *testing.T) {
|
|
ts := &TestSuite{t: t}
|
|
ts.Setup()
|
|
|
|
tests := []struct {
|
|
name string
|
|
authURL string
|
|
issuerURL string
|
|
redirectURL string
|
|
state string
|
|
nonce string
|
|
enablePKCE bool
|
|
codeChallenge string
|
|
expectedPrefix string
|
|
checkPKCE bool
|
|
}{
|
|
{
|
|
name: "Absolute Auth URL",
|
|
authURL: "https://auth.example.com/oauth/authorize",
|
|
issuerURL: "https://auth.example.com",
|
|
redirectURL: "https://app.example.com/callback",
|
|
state: "test-state",
|
|
nonce: "test-nonce",
|
|
enablePKCE: false,
|
|
codeChallenge: "",
|
|
expectedPrefix: "https://auth.example.com/oauth/authorize?",
|
|
checkPKCE: false,
|
|
},
|
|
{
|
|
name: "Relative Auth URL",
|
|
authURL: "/oidc/auth",
|
|
issuerURL: "https://logto.example.com",
|
|
redirectURL: "https://app.example.com/callback",
|
|
state: "test-state",
|
|
nonce: "test-nonce",
|
|
enablePKCE: false,
|
|
codeChallenge: "",
|
|
expectedPrefix: "https://logto.example.com/oidc/auth?",
|
|
checkPKCE: false,
|
|
},
|
|
{
|
|
name: "Relative Auth URL with Different Issuer",
|
|
authURL: "/sign-in",
|
|
issuerURL: "https://auth.example.com:8443",
|
|
redirectURL: "https://app.example.com/callback",
|
|
state: "test-state",
|
|
nonce: "test-nonce",
|
|
enablePKCE: false,
|
|
codeChallenge: "",
|
|
expectedPrefix: "https://auth.example.com:8443/sign-in?",
|
|
checkPKCE: false,
|
|
},
|
|
{
|
|
name: "With PKCE Enabled",
|
|
authURL: "https://auth.example.com/oauth/authorize",
|
|
issuerURL: "https://auth.example.com",
|
|
redirectURL: "https://app.example.com/callback",
|
|
state: "test-state",
|
|
nonce: "test-nonce",
|
|
enablePKCE: true,
|
|
codeChallenge: "test-code-challenge",
|
|
expectedPrefix: "https://auth.example.com/oauth/authorize?",
|
|
checkPKCE: true,
|
|
},
|
|
{
|
|
name: "With PKCE Enabled but No Challenge",
|
|
authURL: "https://auth.example.com/oauth/authorize",
|
|
issuerURL: "https://auth.example.com",
|
|
redirectURL: "https://app.example.com/callback",
|
|
state: "test-state",
|
|
nonce: "test-nonce",
|
|
enablePKCE: true,
|
|
codeChallenge: "",
|
|
expectedPrefix: "https://auth.example.com/oauth/authorize?",
|
|
checkPKCE: false,
|
|
},
|
|
{
|
|
name: "With PKCE Disabled but Challenge Provided",
|
|
authURL: "https://auth.example.com/oauth/authorize",
|
|
issuerURL: "https://auth.example.com",
|
|
redirectURL: "https://app.example.com/callback",
|
|
state: "test-state",
|
|
nonce: "test-nonce",
|
|
enablePKCE: false,
|
|
codeChallenge: "test-code-challenge",
|
|
expectedPrefix: "https://auth.example.com/oauth/authorize?",
|
|
checkPKCE: false,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
// Configure the test instance
|
|
tOidc := ts.tOidc
|
|
tOidc.authURL = tc.authURL
|
|
tOidc.issuerURL = tc.issuerURL
|
|
tOidc.enablePKCE = tc.enablePKCE
|
|
|
|
// Call buildAuthURL with code challenge
|
|
result := tOidc.buildAuthURL(tc.redirectURL, tc.state, tc.nonce, tc.codeChallenge)
|
|
|
|
// Verify the URL starts with the expected prefix
|
|
if !strings.HasPrefix(result, tc.expectedPrefix) {
|
|
t.Errorf("Expected URL to start with %q, got %q", tc.expectedPrefix, result)
|
|
}
|
|
|
|
// Parse the resulting URL to verify query parameters
|
|
parsedURL, err := url.Parse(result)
|
|
if err != nil {
|
|
t.Fatalf("Failed to parse resulting URL: %v", err)
|
|
}
|
|
|
|
query := parsedURL.Query()
|
|
expectedParams := map[string]string{
|
|
"client_id": tOidc.clientID,
|
|
"response_type": "code",
|
|
"redirect_uri": tc.redirectURL,
|
|
"state": tc.state,
|
|
"nonce": tc.nonce,
|
|
}
|
|
|
|
for key, expected := range expectedParams {
|
|
if got := query.Get(key); got != expected {
|
|
t.Errorf("Expected %s=%q, got %q", key, expected, got)
|
|
}
|
|
}
|
|
|
|
// Verify PKCE parameters
|
|
if tc.checkPKCE {
|
|
if got := query.Get("code_challenge"); got != tc.codeChallenge {
|
|
t.Errorf("Expected code_challenge=%q, got %q", tc.codeChallenge, got)
|
|
}
|
|
if got := query.Get("code_challenge_method"); got != "S256" {
|
|
t.Errorf("Expected code_challenge_method=%q, got %q", "S256", got)
|
|
}
|
|
} else {
|
|
if got := query.Get("code_challenge"); got != "" {
|
|
t.Errorf("Expected no code_challenge, but got %q", got)
|
|
}
|
|
if got := query.Get("code_challenge_method"); got != "" {
|
|
t.Errorf("Expected no code_challenge_method, but got %q", got)
|
|
}
|
|
}
|
|
|
|
// Verify scopes are present and correct
|
|
if len(tOidc.scopes) > 0 {
|
|
expectedScopes := strings.Join(tOidc.scopes, " ")
|
|
if got := query.Get("scope"); got != expectedScopes {
|
|
t.Errorf("Expected scope=%q, got %q", expectedScopes, got)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestExchangeCodeForToken tests the exchangeCodeForToken function with PKCE support
|
|
func TestExchangeCodeForToken(t *testing.T) {
|
|
ts := &TestSuite{t: t}
|
|
ts.Setup()
|
|
|
|
tests := []struct {
|
|
name string
|
|
enablePKCE bool
|
|
codeVerifier string
|
|
setupMock func(t *testing.T) *httptest.Server
|
|
}{
|
|
{
|
|
name: "With PKCE Enabled and Code Verifier",
|
|
enablePKCE: true,
|
|
codeVerifier: "test-code-verifier",
|
|
setupMock: func(t *testing.T) *httptest.Server {
|
|
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if err := r.ParseForm(); err != nil {
|
|
t.Fatalf("Failed to parse form: %v", err)
|
|
}
|
|
|
|
// Verify code_verifier is included
|
|
if codeVerifier := r.Form.Get("code_verifier"); codeVerifier != "test-code-verifier" {
|
|
t.Errorf("Expected code_verifier=test-code-verifier, got %s", codeVerifier)
|
|
}
|
|
|
|
// Return successful token response
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(TokenResponse{
|
|
IDToken: "test.id.token",
|
|
AccessToken: "test-access-token",
|
|
TokenType: "Bearer",
|
|
ExpiresIn: 3600,
|
|
RefreshToken: "test-refresh-token",
|
|
})
|
|
}))
|
|
},
|
|
},
|
|
{
|
|
name: "With PKCE Disabled but Code Verifier Provided",
|
|
enablePKCE: false,
|
|
codeVerifier: "test-code-verifier",
|
|
setupMock: func(t *testing.T) *httptest.Server {
|
|
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if err := r.ParseForm(); err != nil {
|
|
t.Fatalf("Failed to parse form: %v", err)
|
|
}
|
|
|
|
// Verify code_verifier is NOT included
|
|
if codeVerifier := r.Form.Get("code_verifier"); codeVerifier != "" {
|
|
t.Errorf("Expected no code_verifier, got %s", codeVerifier)
|
|
}
|
|
|
|
// Return successful token response
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(TokenResponse{
|
|
IDToken: "test.id.token",
|
|
AccessToken: "test-access-token",
|
|
TokenType: "Bearer",
|
|
ExpiresIn: 3600,
|
|
RefreshToken: "test-refresh-token",
|
|
})
|
|
}))
|
|
},
|
|
},
|
|
{
|
|
name: "With PKCE Enabled but No Code Verifier",
|
|
enablePKCE: true,
|
|
codeVerifier: "",
|
|
setupMock: func(t *testing.T) *httptest.Server {
|
|
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if err := r.ParseForm(); err != nil {
|
|
t.Fatalf("Failed to parse form: %v", err)
|
|
}
|
|
|
|
// Verify code_verifier is NOT included
|
|
if codeVerifier := r.Form.Get("code_verifier"); codeVerifier != "" {
|
|
t.Errorf("Expected no code_verifier, got %s", codeVerifier)
|
|
}
|
|
|
|
// Return successful token response
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(TokenResponse{
|
|
IDToken: "test.id.token",
|
|
AccessToken: "test-access-token",
|
|
TokenType: "Bearer",
|
|
ExpiresIn: 3600,
|
|
RefreshToken: "test-refresh-token",
|
|
})
|
|
}))
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
server := tc.setupMock(t)
|
|
defer server.Close()
|
|
|
|
// Configure the test instance
|
|
tOidc := ts.tOidc
|
|
tOidc.tokenURL = server.URL
|
|
tOidc.enablePKCE = tc.enablePKCE
|
|
|
|
// Test exchangeCodeForToken
|
|
response, err := tOidc.exchangeCodeForToken("test-code", "http://callback", tc.codeVerifier)
|
|
if err != nil {
|
|
t.Errorf("Unexpected error: %v", err)
|
|
}
|
|
if response == nil {
|
|
t.Error("Expected token response but got nil")
|
|
} else if response.IDToken != "test.id.token" {
|
|
t.Errorf("Expected ID token %q, got %q", "test.id.token", response.IDToken)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestDefaultInitiateAuthentication_PreservesQueryParameters tests that defaultInitiateAuthentication preserves query parameters in the incoming path.
|
|
func TestDefaultInitiateAuthentication_PreservesQueryParameters(t *testing.T) {
|
|
ts := &TestSuite{t: t}
|
|
ts.Setup()
|
|
|
|
// Create a request with query parameters
|
|
req := httptest.NewRequest("GET", "/protected/resource?param1=value1¶m2=value2", nil)
|
|
responseRecorder := httptest.NewRecorder()
|
|
|
|
// Get session
|
|
session, err := ts.sessionManager.GetSession(req)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get session: %v", err)
|
|
}
|
|
|
|
// Call defaultInitiateAuthentication
|
|
redirectURL := "http://example.com/callback"
|
|
ts.tOidc.defaultInitiateAuthentication(responseRecorder, req, session, redirectURL)
|
|
|
|
// Verify that the incoming path includes query parameters
|
|
incomingPath := session.GetIncomingPath()
|
|
expectedPath := "/protected/resource?param1=value1¶m2=value2"
|
|
if incomingPath != expectedPath {
|
|
t.Errorf("Expected incoming path to be '%s', got '%s'", expectedPath, incomingPath)
|
|
}
|
|
}
|
|
|
|
// TestVerifyTimeConstraint tests the time constraint verification logic with separate past/future skew tolerances.
|
|
func TestVerifyTimeConstraint(t *testing.T) {
|
|
// Define tolerances used in jwt.go (ensure they match)
|
|
toleranceFuture := 2 * time.Minute
|
|
tolerancePast := 10 * time.Second
|
|
|
|
now := time.Now()
|
|
|
|
tests := []struct {
|
|
name string
|
|
claimTime time.Time
|
|
claimName string
|
|
futureCheck bool // true for exp, false for iat/nbf
|
|
expectError bool
|
|
}{
|
|
// Expiration (future=true, tolerance=2min)
|
|
{
|
|
name: "EXP: Valid (expires in 1 min)",
|
|
claimTime: now.Add(1 * time.Minute),
|
|
claimName: "exp",
|
|
futureCheck: true,
|
|
expectError: false,
|
|
},
|
|
{
|
|
name: "EXP: Expired (expired 3 min ago)",
|
|
claimTime: now.Add(-3 * time.Minute), // Outside 2min tolerance
|
|
claimName: "exp",
|
|
futureCheck: true,
|
|
expectError: true,
|
|
},
|
|
{
|
|
name: "EXP: Valid (expired 1 min ago, within 2min tolerance)",
|
|
claimTime: now.Add(-1 * time.Minute), // Inside 2min tolerance
|
|
claimName: "exp",
|
|
futureCheck: true,
|
|
expectError: false, // Should be allowed due to future tolerance
|
|
},
|
|
|
|
// Issued At (future=false, tolerance=10s)
|
|
{
|
|
name: "IAT: Valid (issued 1 min ago)",
|
|
claimTime: now.Add(-1 * time.Minute),
|
|
claimName: "iat",
|
|
futureCheck: false,
|
|
expectError: false,
|
|
},
|
|
{
|
|
name: "IAT: Invalid (issued 15 sec in future)",
|
|
claimTime: now.Add(15 * time.Second), // Outside 10s past tolerance
|
|
claimName: "iat",
|
|
futureCheck: false,
|
|
expectError: true, // "token used before issued"
|
|
},
|
|
{
|
|
name: "IAT: Valid (issued 5 sec in future, within 10s tolerance)",
|
|
claimTime: now.Add(5 * time.Second), // Inside 10s past tolerance
|
|
claimName: "iat",
|
|
futureCheck: false,
|
|
expectError: false, // Should be allowed due to past tolerance
|
|
},
|
|
|
|
// Not Before (future=false, tolerance=10s)
|
|
{
|
|
name: "NBF: Valid (active 1 min ago)",
|
|
claimTime: now.Add(-1 * time.Minute),
|
|
claimName: "nbf",
|
|
futureCheck: false,
|
|
expectError: false,
|
|
},
|
|
{
|
|
name: "NBF: Invalid (active in 15 sec)",
|
|
claimTime: now.Add(15 * time.Second), // Outside 10s past tolerance
|
|
claimName: "nbf",
|
|
futureCheck: false,
|
|
expectError: true, // "token not yet valid"
|
|
},
|
|
{
|
|
name: "NBF: Valid (active in 5 sec, within 10s tolerance)",
|
|
claimTime: now.Add(5 * time.Second), // Inside 10s past tolerance
|
|
claimName: "nbf",
|
|
futureCheck: false,
|
|
expectError: false, // Should be allowed due to past tolerance
|
|
},
|
|
}
|
|
|
|
// Temporarily adjust global tolerances for test consistency, then restore
|
|
originalFutureTolerance := ClockSkewToleranceFuture
|
|
originalPastTolerance := ClockSkewTolerancePast
|
|
ClockSkewToleranceFuture = toleranceFuture
|
|
ClockSkewTolerancePast = tolerancePast
|
|
defer func() {
|
|
ClockSkewToleranceFuture = originalFutureTolerance
|
|
ClockSkewTolerancePast = originalPastTolerance
|
|
}()
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
// Convert claim time to float64 unix timestamp
|
|
unixTime := float64(tc.claimTime.Unix()) + float64(tc.claimTime.Nanosecond())/1e9
|
|
|
|
var err error
|
|
// Call the specific verification function which uses verifyTimeConstraint
|
|
if tc.claimName == "exp" {
|
|
err = verifyExpiration(unixTime)
|
|
} else if tc.claimName == "iat" {
|
|
err = verifyIssuedAt(unixTime)
|
|
} else if tc.claimName == "nbf" {
|
|
err = verifyNotBefore(unixTime)
|
|
} else {
|
|
t.Fatalf("Unknown claim name in test setup: %s", tc.claimName)
|
|
}
|
|
|
|
if tc.expectError {
|
|
if err == nil {
|
|
t.Errorf("Expected error for claim %s at time %v (now=%v), but got nil", tc.claimName, tc.claimTime, now)
|
|
} else {
|
|
t.Logf("Got expected error: %v", err) // Log the error for confirmation
|
|
}
|
|
} else {
|
|
if err != nil {
|
|
t.Errorf("Expected no error for claim %s at time %v (now=%v), but got: %v", tc.claimName, tc.claimTime, now, err)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
} // Add missing closing brace for TestVerifyTimeConstraint
|