Compare commits

..

2 Commits

13 changed files with 484 additions and 3097 deletions
+1 -1
View File
@@ -22,7 +22,7 @@ testData:
- raczylo.com
allowedRolesAndGroups:
- guest-endpoints
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
sessionEncryptionKey: potato-secret
forceHTTPS: false
logLevel: debug # debug, info, warn, error
rateLimit: 100 # Simple rate limiter to prevent brute force attacks
-2
View File
@@ -19,8 +19,6 @@ Middleware currently supports following scenarios:
#### How to configure...
* `sessionEncryptionKey` should be at least 32 bytes long.
##### Keeping secrets secret
This works ONLY in kubernetes environments. Don't forget to create secret traefik-middleware-oidc with fields ISSUER, CLIENT_ID and SECRET keys.
+13 -116
View File
@@ -5,168 +5,65 @@ import (
"time"
)
// CacheItem represents an item stored in the cache with its associated metadata.
// CacheItem represents an item in the cache
type CacheItem struct {
// Value is the cached data of any type
Value interface{}
// ExpiresAt is the timestamp when this item should be considered expired
// and removed from the cache during cleanup operations
Value interface{}
ExpiresAt time.Time
}
// Cache provides a thread-safe in-memory caching mechanism with expiration support.
// It uses a read-write mutex to ensure safe concurrent access to the cached items.
// Cache is a simple in-memory cache
type Cache struct {
// items stores the cached data with string keys
items map[string]CacheItem
// mutex protects concurrent access to the items map
// Use RLock/RUnlock for reads and Lock/Unlock for writes
mutex sync.RWMutex
// maxSize is the maximum number of items allowed in the cache
maxSize int
// accessList maintains the order of item access for eviction
accessList []string
}
// DefaultMaxSize is the default maximum number of items in the cache
const DefaultMaxSize = 1000
// NewCache creates a new empty cache instance.
// The cache is immediately ready for use and is thread-safe.
// NewCache creates a new Cache
func NewCache() *Cache {
return &Cache{
items: make(map[string]CacheItem),
maxSize: DefaultMaxSize,
accessList: make([]string, 0, DefaultMaxSize),
items: make(map[string]CacheItem),
}
}
// Set adds or updates an item in the cache with the specified expiration duration.
// Parameters:
// - key: Unique identifier for the cached item
// - value: The data to cache (can be of any type)
// - expiration: How long the item should remain in the cache
//
// Thread-safe: Uses write locking to ensure safe concurrent access.
// Set adds an item to the cache
func (c *Cache) Set(key string, value interface{}, expiration time.Duration) {
c.mutex.Lock()
defer c.mutex.Unlock()
// If key exists, update it
if _, exists := c.items[key]; exists {
c.items[key] = CacheItem{
Value: value,
ExpiresAt: time.Now().Add(expiration),
}
return
}
// If cache is full, remove oldest item
if len(c.items) >= c.maxSize {
c.evictOldest()
}
// Add new item
c.items[key] = CacheItem{
Value: value,
ExpiresAt: time.Now().Add(expiration),
}
c.accessList = append(c.accessList, key)
}
// Get retrieves an item from the cache if it exists and hasn't expired.
// Parameters:
// - key: The identifier of the item to retrieve
//
// Returns:
// - value: The cached data (nil if not found or expired)
// - found: true if the item was found and is valid, false otherwise
//
// Thread-safe: Uses read locking to ensure safe concurrent access.
// Get retrieves an item from the cache
func (c *Cache) Get(key string) (interface{}, bool) {
c.mutex.RLock()
defer c.mutex.RUnlock()
item, found := c.items[key]
c.mutex.RUnlock()
if !found {
return nil, false
}
if time.Now().After(item.ExpiresAt) {
c.mutex.Lock()
c.removeItem(key)
c.mutex.Unlock()
delete(c.items, key)
return nil, false
}
// Update access order
c.mutex.Lock()
c.updateAccessOrder(key)
c.mutex.Unlock()
return item.Value, true
}
// Delete removes an item from the cache if it exists.
// If the item doesn't exist, this operation is a no-op.
// Thread-safe: Uses write locking to ensure safe concurrent access.
// Delete removes an item from the cache
func (c *Cache) Delete(key string) {
c.mutex.Lock()
defer c.mutex.Unlock()
delete(c.items, key)
}
// Cleanup removes all expired items from the cache.
// This should be called periodically to prevent memory leaks from
// expired items that haven't been accessed (and thus not removed during Get operations).
// Thread-safe: Uses write locking to ensure safe concurrent access.
// Cleanup removes expired items from the cache
func (c *Cache) Cleanup() {
c.mutex.Lock()
defer c.mutex.Unlock()
now := time.Now()
var newAccessList []string
for _, key := range c.accessList {
if item, exists := c.items[key]; exists && !now.After(item.ExpiresAt) {
newAccessList = append(newAccessList, key)
} else {
for key, item := range c.items {
if now.After(item.ExpiresAt) {
delete(c.items, key)
}
}
c.accessList = newAccessList
}
// evictOldest removes the least recently used item from the cache
func (c *Cache) evictOldest() {
if len(c.accessList) > 0 {
oldest := c.accessList[0]
c.removeItem(oldest)
}
}
// removeItem removes an item from both the cache and access list
func (c *Cache) removeItem(key string) {
delete(c.items, key)
for i, k := range c.accessList {
if k == key {
c.accessList = append(c.accessList[:i], c.accessList[i+1:]...)
break
}
}
}
// updateAccessOrder moves the accessed key to the end of the access list
func (c *Cache) updateAccessOrder(key string) {
for i, k := range c.accessList {
if k == key {
c.accessList = append(append(c.accessList[:i], c.accessList[i+1:]...), key)
break
}
}
}
-306
View File
@@ -1,306 +0,0 @@
package traefikoidc
import (
"reflect"
"testing"
"time"
)
func TestCache(t *testing.T) {
t.Run("Basic Set and Get", func(t *testing.T) {
cache := NewCache()
key := "test-key"
value := "test-value"
expiration := 1 * time.Second
// Test Set
cache.Set(key, value, expiration)
// Test Get
got, found := cache.Get(key)
if !found {
t.Error("Expected to find key in cache")
}
if got != value {
t.Errorf("Expected value %v, got %v", value, got)
}
})
t.Run("Expiration", func(t *testing.T) {
cache := NewCache()
key := "test-key"
value := "test-value"
expiration := 10 * time.Millisecond
// Set with short expiration
cache.Set(key, value, expiration)
// Wait for expiration
time.Sleep(20 * time.Millisecond)
// Should not find expired key
_, found := cache.Get(key)
if found {
t.Error("Expected key to be expired")
}
})
t.Run("Delete", func(t *testing.T) {
cache := NewCache()
key := "test-key"
value := "test-value"
expiration := 1 * time.Second
// Set and then delete
cache.Set(key, value, expiration)
cache.Delete(key)
// Should not find deleted key
_, found := cache.Get(key)
if found {
t.Error("Expected key to be deleted")
}
})
t.Run("Cleanup", func(t *testing.T) {
cache := NewCache()
// Add multiple items with different expirations
cache.Set("expired1", "value1", 10*time.Millisecond)
cache.Set("expired2", "value2", 10*time.Millisecond)
cache.Set("valid", "value3", 1*time.Second)
// Wait for some items to expire
time.Sleep(20 * time.Millisecond)
// Run cleanup
cache.Cleanup()
// Check expired items are removed
_, found1 := cache.Get("expired1")
_, found2 := cache.Get("expired2")
_, found3 := cache.Get("valid")
if found1 {
t.Error("Expected expired1 to be cleaned up")
}
if found2 {
t.Error("Expected expired2 to be cleaned up")
}
if !found3 {
t.Error("Expected valid item to remain in cache")
}
})
t.Run("Concurrent Access", func(t *testing.T) {
cache := NewCache()
done := make(chan bool)
// Start multiple goroutines to access cache concurrently
for i := 0; i < 10; i++ {
go func(id int) {
key := "key"
value := "value"
expiration := 1 * time.Second
// Perform multiple operations
cache.Set(key, value, expiration)
cache.Get(key)
cache.Delete(key)
cache.Cleanup()
done <- true
}(i)
}
// Wait for all goroutines to complete
for i := 0; i < 10; i++ {
<-done
}
})
t.Run("Zero Expiration", func(t *testing.T) {
cache := NewCache()
key := "test-key"
value := "test-value"
// Set with zero expiration
cache.Set(key, value, 0)
// Should not find the key
_, found := cache.Get(key)
if found {
t.Error("Expected key with zero expiration to be immediately expired")
}
})
t.Run("Negative Expiration", func(t *testing.T) {
cache := NewCache()
key := "test-key"
value := "test-value"
// Set with negative expiration
cache.Set(key, value, -1*time.Second)
// Should not find the key
_, found := cache.Get(key)
if found {
t.Error("Expected key with negative expiration to be immediately expired")
}
})
t.Run("Update Existing Key", func(t *testing.T) {
cache := NewCache()
key := "test-key"
value1 := "value1"
value2 := "value2"
expiration := 1 * time.Second
// Set initial value
cache.Set(key, value1, expiration)
// Update value
cache.Set(key, value2, expiration)
// Check updated value
got, found := cache.Get(key)
if !found {
t.Error("Expected to find key in cache")
}
if got != value2 {
t.Errorf("Expected updated value %v, got %v", value2, got)
}
})
t.Run("Different Value Types", func(t *testing.T) {
cache := NewCache()
expiration := 1 * time.Second
// Test with different value types
testCases := []struct {
key string
value interface{}
}{
{"string", "test"},
{"int", 42},
{"float", 3.14},
{"bool", true},
{"slice", []string{"a", "b", "c"}},
{"map", map[string]int{"a": 1, "b": 2}},
{"struct", struct{ Name string }{"test"}},
}
for _, tc := range testCases {
t.Run(tc.key, func(t *testing.T) {
cache.Set(tc.key, tc.value, expiration)
got, found := cache.Get(tc.key)
if !found {
t.Error("Expected to find key in cache")
}
// Use reflect.DeepEqual for comparing complex types like slices and maps
if !reflect.DeepEqual(got, tc.value) {
t.Errorf("Expected value %v, got %v", tc.value, got)
}
})
}
})
}
func TestTokenCache(t *testing.T) {
t.Run("Basic Operations", func(t *testing.T) {
tc := NewTokenCache()
token := "test-token"
claims := map[string]interface{}{
"sub": "1234567890",
"name": "John Doe",
"admin": true,
}
expiration := 1 * time.Second
// Test Set and Get
tc.Set(token, claims, expiration)
gotClaims, found := tc.Get(token)
if !found {
t.Error("Expected to find token in cache")
}
if len(gotClaims) != len(claims) {
t.Errorf("Expected %d claims, got %d", len(claims), len(gotClaims))
}
for k, v := range claims {
if gotClaims[k] != v {
t.Errorf("Expected claim %s to be %v, got %v", k, v, gotClaims[k])
}
}
// Test Delete
tc.Delete(token)
_, found = tc.Get(token)
if found {
t.Error("Expected token to be deleted")
}
})
t.Run("Expiration", func(t *testing.T) {
tc := NewTokenCache()
token := "test-token"
claims := map[string]interface{}{"sub": "1234567890"}
expiration := 10 * time.Millisecond
// Set with short expiration
tc.Set(token, claims, expiration)
// Wait for expiration
time.Sleep(20 * time.Millisecond)
// Should not find expired token
_, found := tc.Get(token)
if found {
t.Error("Expected token to be expired")
}
})
t.Run("Cleanup", func(t *testing.T) {
tc := NewTokenCache()
// Add multiple tokens with different expirations
tc.Set("expired1", map[string]interface{}{"sub": "1"}, 10*time.Millisecond)
tc.Set("expired2", map[string]interface{}{"sub": "2"}, 10*time.Millisecond)
tc.Set("valid", map[string]interface{}{"sub": "3"}, 1*time.Second)
// Wait for some tokens to expire
time.Sleep(20 * time.Millisecond)
// Run cleanup
tc.Cleanup()
// Check expired tokens are removed
_, found1 := tc.Get("expired1")
_, found2 := tc.Get("expired2")
_, found3 := tc.Get("valid")
if found1 {
t.Error("Expected expired1 to be cleaned up")
}
if found2 {
t.Error("Expected expired2 to be cleaned up")
}
if !found3 {
t.Error("Expected valid token to remain in cache")
}
})
t.Run("Token Prefix", func(t *testing.T) {
tc := NewTokenCache()
token := "test-token"
claims := map[string]interface{}{"sub": "1234567890"}
expiration := 1 * time.Second
// Set token
tc.Set(token, claims, expiration)
// Verify internal storage uses prefix
_, found := tc.cache.Get("t-" + token)
if !found {
t.Error("Expected to find prefixed token in underlying cache")
}
})
}
+118 -139
View File
@@ -13,31 +13,11 @@ import (
"sync"
"time"
"github.com/google/uuid"
"github.com/gorilla/sessions"
)
// newSessionOptions creates secure session cookie options.
// Parameters:
// - isSecure: Whether to set the Secure flag on cookies
//
// Returns session options configured for security with:
// - HttpOnly flag to prevent JavaScript access
// - SameSite=Lax for CSRF protection
// - Appropriate timeout and path settings
func newSessionOptions(isSecure bool) *sessions.Options {
return &sessions.Options{
HttpOnly: true,
Secure: isSecure,
SameSite: http.SameSiteLaxMode,
MaxAge: ConstSessionTimeout,
Path: "/",
}
}
// generateNonce creates a cryptographically secure random nonce
// for use in the OIDC authentication flow. The nonce is used to
// prevent replay attacks by ensuring the token received matches
// the authentication request.
// generateNonce generates a random nonce
func generateNonce() (string, error) {
nonceBytes := make([]byte, 32)
_, err := rand.Read(nonceBytes)
@@ -47,33 +27,7 @@ func generateNonce() (string, error) {
return base64.URLEncoding.EncodeToString(nonceBytes), nil
}
// TokenResponse represents the response from the OIDC token endpoint.
// It contains the various tokens and metadata returned after successful
// code exchange or token refresh operations.
type TokenResponse struct {
// IDToken is the OIDC ID token containing user claims
IDToken string `json:"id_token"`
// AccessToken is the OAuth 2.0 access token for API access
AccessToken string `json:"access_token"`
// RefreshToken is the OAuth 2.0 refresh token for obtaining new tokens
RefreshToken string `json:"refresh_token"`
// ExpiresIn is the lifetime in seconds of the access token
ExpiresIn int `json:"expires_in"`
// TokenType is the type of token, typically "Bearer"
TokenType string `json:"token_type"`
}
// exchangeTokens performs the OAuth 2.0 token exchange with the OIDC provider.
// It supports both authorization code and refresh token grant types.
// Parameters:
// - ctx: Context for the HTTP request
// - grantType: The OAuth 2.0 grant type ("authorization_code" or "refresh_token")
// - codeOrToken: Either the authorization code or refresh token
// - redirectURL: The callback URL for authorization code grant
// exchangeTokens exchanges a code or refresh token for tokens
func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken, redirectURL string) (*TokenResponse, error) {
data := url.Values{
"grant_type": {grantType},
@@ -113,8 +67,16 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken
return &tokenResponse, nil
}
// getNewTokenWithRefreshToken obtains new tokens using a refresh token.
// This is used to refresh access tokens before they expire.
// TokenResponse represents the response from the token endpoint
type TokenResponse struct {
IDToken string `json:"id_token"`
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int `json:"expires_in"`
TokenType string `json:"token_type"`
}
// getNewTokenWithRefreshToken refreshes the token using the refresh token
func (t *TraefikOidc) getNewTokenWithRefreshToken(refreshToken string) (*TokenResponse, error) {
ctx := context.Background()
tokenResponse, err := t.exchangeTokens(ctx, "refresh_token", refreshToken, "")
@@ -123,33 +85,38 @@ func (t *TraefikOidc) getNewTokenWithRefreshToken(refreshToken string) (*TokenRe
}
t.logger.Debugf("Token response: %+v", tokenResponse)
return tokenResponse, nil
}
// handleExpiredToken manages token expiration by clearing the session
// and initiating a new authentication flow.
func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
// Clear authentication data but preserve CSRF state
session.SetAuthenticated(false)
session.SetAccessToken("")
session.SetRefreshToken("")
session.SetEmail("")
// handleExpiredToken handles the case when a token has expired
func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string) {
// Clear the existing session
session.Options.MaxAge = -1
for k := range session.Values {
delete(session.Values, k)
}
// Save the cleared session state
// Set new values
session.Values["csrf"] = uuid.New().String()
session.Values["incoming_path"] = req.URL.Path
session.Values["nonce"], _ = generateNonce()
session.Options = defaultSessionOptions
// Save the session before initiating authentication
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to save cleared session: %v", err)
t.logger.Errorf("Failed to save session: %v", err)
http.Error(rw, "Internal Server Error", http.StatusInternalServerError)
return
}
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
// Initiate a new authentication flow
t.initiateAuthenticationFunc(rw, req, session, redirectURL)
}
// handleCallback processes the authentication callback from the OIDC provider.
// It validates the callback parameters, exchanges the authorization code for
// tokens, verifies the tokens, and establishes the user's session.
// handleCallback handles the callback from the OIDC provider
func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, redirectURL string) {
session, err := t.sessionManager.GetSession(req)
session, err := t.store.Get(req, cookieName)
if err != nil {
t.logger.Errorf("Session error: %v", err)
http.Error(rw, "Session error", http.StatusInternalServerError)
@@ -158,7 +125,7 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
t.logger.Debugf("Handling callback, URL: %s", req.URL.String())
// Check for errors in the callback
// Check for errors in the query parameters
if req.URL.Query().Get("error") != "" {
errorDescription := req.URL.Query().Get("error_description")
t.logger.Errorf("Authentication error: %s - %s", req.URL.Query().Get("error"), errorDescription)
@@ -166,28 +133,26 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
return
}
// Validate CSRF state
// Validate the state parameter matches the session's CSRF token
state := req.URL.Query().Get("state")
if state == "" {
t.logger.Error("No state in callback")
http.Error(rw, "State parameter missing in callback", http.StatusBadRequest)
return
}
csrfToken := session.GetCSRF()
if csrfToken == "" {
csrfToken, ok := session.Values["csrf"].(string)
if !ok || csrfToken == "" {
t.logger.Error("CSRF token missing in session")
http.Error(rw, "CSRF token missing", http.StatusBadRequest)
return
}
if state != csrfToken {
t.logger.Error("State parameter does not match CSRF token in session")
http.Error(rw, "Invalid state parameter", http.StatusBadRequest)
return
}
// Exchange code for tokens
// Proceed to exchange the code for tokens
code := req.URL.Query().Get("code")
if code == "" {
t.logger.Error("No code in callback")
@@ -202,42 +167,49 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
return
}
// Verify tokens and claims
if err := t.verifyToken(tokenResponse.IDToken); err != nil {
// Extract id_token
idToken := tokenResponse.IDToken
if idToken == "" {
t.logger.Error("No id_token in token response")
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
// Verify the id_token
if err := t.verifyToken(idToken); err != nil {
t.logger.Errorf("Failed to verify id_token: %v", err)
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
claims, err := t.extractClaimsFunc(tokenResponse.IDToken)
// Extract claims from id_token
claims, err := t.extractClaimsFunc(idToken)
if err != nil {
t.logger.Errorf("Failed to extract claims: %v", err)
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
// Verify nonce to prevent replay attacks
// Verify the nonce claim matches the one stored in session
nonceClaim, ok := claims["nonce"].(string)
if !ok || nonceClaim == "" {
t.logger.Error("Nonce claim missing in id_token")
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
sessionNonce := session.GetNonce()
if sessionNonce == "" {
sessionNonce, ok := session.Values["nonce"].(string)
if !ok || sessionNonce == "" {
t.logger.Error("Nonce not found in session")
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
if nonceClaim != sessionNonce {
t.logger.Error("Nonce claim does not match session nonce")
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
// Validate user's email domain
// Get the email from claims
email, _ := claims["email"].(string)
if email == "" || !t.isAllowedDomain(email) {
t.logger.Errorf("Invalid or disallowed email: %s", email)
@@ -245,11 +217,16 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
return
}
// Update session with authentication data
session.SetAuthenticated(true)
session.SetEmail(email)
session.SetAccessToken(tokenResponse.IDToken)
session.SetRefreshToken(tokenResponse.RefreshToken)
// Store tokens and authentication status in session
session.Values["authenticated"] = true
session.Values["email"] = email
session.Values["id_token"] = idToken
session.Values["refresh_token"] = tokenResponse.RefreshToken
session.Options = defaultSessionOptions
// Remove CSRF and nonce from session
delete(session.Values, "csrf")
delete(session.Values, "nonce")
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to save session: %v", err)
@@ -257,17 +234,18 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
return
}
// Redirect to original path or root
redirectPath := "/"
if incomingPath := session.GetIncomingPath(); incomingPath != "" && incomingPath != t.redirURLPath {
redirectPath = incomingPath
}
t.logger.Debugf("Authentication successful. User email: %s", email)
// Redirect to the original requested path or default to root
redirectPath := "/"
if path, ok := session.Values["incoming_path"].(string); ok && path != t.redirURLPath {
t.logger.Debugf("Redirecting to incoming path from original request: %s", path)
redirectPath = path
}
http.Redirect(rw, req, redirectPath, http.StatusFound)
}
// extractClaims parses a JWT token and extracts its claims.
// It handles base64url decoding and JSON parsing of the token payload.
// extractClaims extracts claims from a JWT token
func extractClaims(tokenString string) (map[string]interface{}, error) {
parts := strings.Split(tokenString, ".")
if len(parts) != 3 {
@@ -287,32 +265,27 @@ func extractClaims(tokenString string) (map[string]interface{}, error) {
return claims, nil
}
// TokenBlacklist maintains a thread-safe list of revoked tokens.
// It stores tokens with their expiration times and automatically
// removes expired entries during cleanup operations.
// TokenBlacklist maintains a blacklist of tokens
type TokenBlacklist struct {
// blacklist maps token IDs to their expiration times
blacklist map[string]time.Time
// mutex protects concurrent access to the blacklist
mutex sync.RWMutex
mutex sync.RWMutex
}
// NewTokenBlacklist creates a new TokenBlacklist instance.
// NewTokenBlacklist creates a new TokenBlacklist
func NewTokenBlacklist() *TokenBlacklist {
return &TokenBlacklist{
blacklist: make(map[string]time.Time),
}
}
// Add adds a token to the blacklist with an expiration time.
// Add adds a token to the blacklist
func (tb *TokenBlacklist) Add(tokenID string, expiration time.Time) {
tb.mutex.Lock()
defer tb.mutex.Unlock()
tb.blacklist[tokenID] = expiration
}
// IsBlacklisted checks if a token is in the blacklist and not expired.
// IsBlacklisted checks if a token is blacklisted
func (tb *TokenBlacklist) IsBlacklisted(tokenID string) bool {
tb.mutex.RLock()
defer tb.mutex.RUnlock()
@@ -320,7 +293,7 @@ func (tb *TokenBlacklist) IsBlacklisted(tokenID string) bool {
return exists && time.Now().Before(expiration)
}
// Cleanup removes expired tokens from the blacklist.
// Cleanup removes expired tokens from the blacklist
func (tb *TokenBlacklist) Cleanup() {
tb.mutex.Lock()
defer tb.mutex.Unlock()
@@ -332,29 +305,25 @@ func (tb *TokenBlacklist) Cleanup() {
}
}
// TokenCache provides a caching mechanism for validated tokens.
// It stores token claims to avoid repeated validation of the
// same token, improving performance for frequently used tokens.
// TokenCache caches tokens
type TokenCache struct {
// cache is the underlying cache implementation
cache *Cache
}
// NewTokenCache creates a new TokenCache instance.
// NewTokenCache creates a new TokenCache
func NewTokenCache() *TokenCache {
return &TokenCache{
cache: NewCache(),
}
}
// Set stores a token's claims in the cache with an expiration time.
// Set sets a token in the cache
func (tc *TokenCache) Set(token string, claims map[string]interface{}, expiration time.Duration) {
token = "t-" + token
tc.cache.Set(token, claims, expiration)
}
// Get retrieves a token's claims from the cache.
// Returns the claims and a boolean indicating if the token was found.
// Get retrieves a token from the cache
func (tc *TokenCache) Get(token string) (map[string]interface{}, bool) {
token = "t-" + token
value, found := tc.cache.Get(token)
@@ -365,18 +334,18 @@ func (tc *TokenCache) Get(token string) (map[string]interface{}, bool) {
return claims, ok
}
// Delete removes a token from the cache.
// Delete removes a token from the cache
func (tc *TokenCache) Delete(token string) {
token = "t-" + token
tc.cache.Delete(token)
}
// Cleanup removes expired tokens from the cache.
// Cleanup cleans up expired tokens from the cache
func (tc *TokenCache) Cleanup() {
tc.cache.Cleanup()
}
// exchangeCodeForToken exchanges an authorization code for tokens.
// exchangeCodeForToken exchanges the authorization code for tokens
func (t *TraefikOidc) exchangeCodeForToken(code string, redirectURL string) (*TokenResponse, error) {
ctx := context.Background()
tokenResponse, err := t.exchangeTokens(ctx, "authorization_code", code, redirectURL)
@@ -386,8 +355,7 @@ func (t *TraefikOidc) exchangeCodeForToken(code string, redirectURL string) (*To
return tokenResponse, nil
}
// createStringMap creates a map from a slice of strings.
// Used for efficient lookups in allowed domains and roles.
// createStringMap creates a map from a slice of strings
func createStringMap(keys []string) map[string]struct{} {
result := make(map[string]struct{})
for _, key := range keys {
@@ -396,55 +364,65 @@ func createStringMap(keys []string) map[string]struct{} {
return result
}
// handleLogout manages the OIDC logout process.
// It clears the session and redirects either to the OIDC provider's
// end session endpoint (if available) or to the configured post-logout URL.
// handleLogout handles the logout request
func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
session, err := t.sessionManager.GetSession(req)
session, err := t.store.Get(req, cookieName)
if err != nil {
t.logger.Errorf("Error getting session: %v", err)
http.Error(rw, "Session error", http.StatusInternalServerError)
return
}
accessToken := session.GetAccessToken()
// Get the id_token before clearing the session
idToken, _ := session.Values["id_token"].(string)
if err := session.Clear(req, rw); err != nil {
t.logger.Errorf("Error clearing session: %v", err)
// Clear and expire the session
session.Values = make(map[interface{}]interface{})
session.Options.MaxAge = -1
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Error saving session: %v", err)
http.Error(rw, "Session error", http.StatusInternalServerError)
return
}
// Get the base URL for redirects
host := t.determineHost(req)
scheme := t.determineScheme(req)
baseURL := fmt.Sprintf("%s://%s", scheme, host)
postLogoutRedirectURI := t.postLogoutRedirectURI
if postLogoutRedirectURI == "" {
postLogoutRedirectURI = fmt.Sprintf("%s/", baseURL)
} else if !strings.HasPrefix(postLogoutRedirectURI, "http") {
postLogoutRedirectURI = fmt.Sprintf("%s%s", baseURL, postLogoutRedirectURI)
// Determine post logout redirect URI
var postLogoutRedirectURI string
if t.postLogoutRedirectURI != "" {
// Use explicitly configured postLogoutRedirectURI
if strings.HasPrefix(t.postLogoutRedirectURI, "http://") || strings.HasPrefix(t.postLogoutRedirectURI, "https://") {
postLogoutRedirectURI = t.postLogoutRedirectURI
} else {
postLogoutRedirectURI = fmt.Sprintf("%s%s", baseURL, t.postLogoutRedirectURI)
}
} else {
postLogoutRedirectURI = fmt.Sprintf("%s%s", baseURL, "/")
}
if t.endSessionURL != "" && accessToken != "" {
logoutURL, err := BuildLogoutURL(t.endSessionURL, accessToken, postLogoutRedirectURI)
t.logger.Debugf("Using post logout redirect URI: %s", postLogoutRedirectURI)
// If we have an end session endpoint and an ID token, use OIDC end session
if t.endSessionURL != "" && idToken != "" {
logoutURL, err := BuildLogoutURL(t.endSessionURL, idToken, postLogoutRedirectURI)
if err != nil {
t.logger.Errorf("Failed to build logout URL: %v", err)
http.Error(rw, "Logout error", http.StatusInternalServerError)
handleError(rw, fmt.Sprintf("Failed to build logout URL: %v", err), http.StatusInternalServerError, t.logger)
return
}
t.logger.Debugf("Redirecting to end session URL: %s", logoutURL)
http.Redirect(rw, req, logoutURL, http.StatusFound)
return
}
// If no end session endpoint or no ID token, just redirect to the post logout URI
t.logger.Debugf("Redirecting to post logout URI: %s", postLogoutRedirectURI)
http.Redirect(rw, req, postLogoutRedirectURI, http.StatusFound)
}
// BuildLogoutURL constructs the OIDC end session URL with appropriate parameters.
// Parameters:
// - endSessionURL: The OIDC provider's end session endpoint
// - idToken: The ID token to be invalidated
// - postLogoutRedirectURI: Where to redirect after logout completes
// BuildLogoutURL constructs the OIDC end session URL
func BuildLogoutURL(endSessionURL, idToken, postLogoutRedirectURI string) (string, error) {
u, err := url.Parse(endSessionURL)
if err != nil {
@@ -454,6 +432,7 @@ func BuildLogoutURL(endSessionURL, idToken, postLogoutRedirectURI string) (strin
q := u.Query()
q.Set("id_token_hint", idToken)
if postLogoutRedirectURI != "" {
// Ensure postLogoutRedirectURI is properly URL encoded
q.Set("post_logout_redirect_uri", postLogoutRedirectURI)
}
u.RawQuery = q.Encode()
+15 -68
View File
@@ -16,75 +16,37 @@ import (
"time"
)
// JWK represents a JSON Web Key as defined in RFC 7517.
// It contains the cryptographic key information used for token verification.
// JWK represents a JSON Web Key
type JWK struct {
// Kty is the key type (e.g., "RSA", "EC")
Kty string `json:"kty"`
// Kid is the unique key identifier
Kid string `json:"kid"`
// Use specifies the intended use of the key (e.g., "sig" for signature)
Use string `json:"use"`
// N is the modulus for RSA keys
N string `json:"n"`
// E is the exponent for RSA keys
E string `json:"e"`
// Alg is the algorithm intended for use with the key
N string `json:"n"`
E string `json:"e"`
Alg string `json:"alg"`
// Crv is the curve for EC keys (e.g., "P-256", "P-384", "P-521")
Crv string `json:"crv"`
// X is the x-coordinate for EC keys
X string `json:"x"`
// Y is the y-coordinate for EC keys
Y string `json:"y"`
X string `json:"x"`
Y string `json:"y"`
}
// JWKSet represents a set of JSON Web Keys as returned by the JWKS endpoint.
// OIDC providers typically expose multiple keys to support key rotation.
// JWKSet represents a set of JWKs
type JWKSet struct {
// Keys is the array of JSON Web Keys
Keys []JWK `json:"keys"`
}
// JWKCache provides a thread-safe caching mechanism for JWK sets.
// It caches the keys for a configurable duration to reduce load on the OIDC provider
// while ensuring keys are refreshed periodically to handle key rotation.
// JWKCache caches the JWKs
type JWKCache struct {
// jwks holds the cached set of JSON Web Keys
jwks *JWKSet
// expiresAt is the timestamp when the cached keys should be refreshed
jwks *JWKSet
expiresAt time.Time
// mutex protects concurrent access to the cache
mutex sync.RWMutex
mutex sync.RWMutex
}
// JWKCacheInterface defines the interface for JWK caching operations.
// This interface allows for different caching implementations while
// maintaining consistent behavior in the token verification process.
// JWKCacheInterface defines the interface for the JWK cache
type JWKCacheInterface interface {
GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error)
}
// GetJWKS retrieves the JSON Web Key Set, either from cache or by fetching it
// from the OIDC provider. It implements a thread-safe double-checked locking
// pattern to prevent multiple simultaneous fetches of the same keys.
// Parameters:
// - jwksURL: The URL of the JWKS endpoint
// - httpClient: The HTTP client to use for fetching keys
//
// Returns:
// - The JSON Web Key Set
// - An error if the keys cannot be retrieved or parsed
// GetJWKS gets the JWKS, either from cache or by fetching it
func (c *JWKCache) GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) {
c.mutex.RLock()
if c.jwks != nil && time.Now().Before(c.expiresAt) {
@@ -111,15 +73,7 @@ func (c *JWKCache) GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, er
return jwks, nil
}
// fetchJWKS retrieves the JSON Web Key Set from the OIDC provider's JWKS endpoint.
// It handles HTTP communication and JSON parsing of the response.
// Parameters:
// - jwksURL: The URL of the JWKS endpoint
// - httpClient: The HTTP client to use for the request
//
// Returns:
// - The parsed JSON Web Key Set
// - An error if the request fails or the response is invalid
// fetchJWKS fetches the JWKS from the provider
func fetchJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) {
resp, err := httpClient.Get(jwksURL)
if err != nil {
@@ -139,9 +93,7 @@ func fetchJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) {
return &jwks, nil
}
// jwkToPEM converts a JSON Web Key to PEM format for use with standard
// cryptographic functions. It supports both RSA and EC keys, delegating
// to the appropriate converter based on the key type.
// jwkToPEM converts a JWK to PEM format
func jwkToPEM(jwk *JWK) ([]byte, error) {
converter, ok := jwkConverters[jwk.Kty]
if !ok {
@@ -157,9 +109,7 @@ var jwkConverters = map[string]jwkToPEMConverter{
"EC": ecJWKToPEM,
}
// rsaJWKToPEM converts an RSA JSON Web Key to PEM format.
// It handles base64url decoding of the modulus and exponent,
// constructs an RSA public key, and encodes it in PEM format.
// rsaJWKToPEM converts an RSA JWK to PEM
func rsaJWKToPEM(jwk *JWK) ([]byte, error) {
nBytes, err := base64.RawURLEncoding.DecodeString(jwk.N)
if err != nil {
@@ -191,10 +141,7 @@ func rsaJWKToPEM(jwk *JWK) ([]byte, error) {
return pubKeyPEM, nil
}
// ecJWKToPEM converts an EC (Elliptic Curve) JSON Web Key to PEM format.
// It supports the P-256, P-384, and P-521 curves as defined in the
// OIDC specification, decoding the x and y coordinates and encoding
// the resulting public key in PEM format.
// ecJWKToPEM converts an EC JWK to PEM
func ecJWKToPEM(jwk *JWK) ([]byte, error) {
xBytes, err := base64.RawURLEncoding.DecodeString(jwk.X)
if err != nil {
+16 -193
View File
@@ -16,32 +16,15 @@ import (
"time"
)
// JWT represents a JSON Web Token as defined in RFC 7519.
// It contains the three parts of a JWT: header, claims (payload),
// and signature, along with the original token string.
// JWT represents a JSON Web Token
type JWT struct {
// Header contains the token metadata (algorithm, key ID, etc.)
Header map[string]interface{}
// Claims contains the token claims (subject, expiration, etc.)
Claims map[string]interface{}
// Signature contains the raw signature bytes
Header map[string]interface{}
Claims map[string]interface{}
Signature []byte
// Token is the original JWT string
Token string
Token string
}
// parseJWT parses a JWT token string into a JWT struct.
// It validates the token format and decodes the three parts
// (header, claims, signature) using base64url decoding.
// Parameters:
// - tokenString: The raw JWT token string
//
// Returns:
// - A parsed JWT struct
// - An error if the token format is invalid or parsing fails
// parseJWT parses a JWT token string into a JWT struct
func parseJWT(tokenString string) (*JWT, error) {
parts := strings.Split(tokenString, ".")
if len(parts) != 3 {
@@ -80,43 +63,10 @@ func parseJWT(tokenString string) (*JWT, error) {
return jwt, nil
}
// Verify validates the standard JWT claims as defined in RFC 7519.
// It checks:
// - issuer (iss) matches the expected issuer URL
// - audience (aud) includes the client ID
// - expiration time (exp) is in the future (with clock skew tolerance)
// - issued at time (iat) is in the past (with clock skew tolerance)
// - not before time (nbf) is in the past (with clock skew tolerance)
// - subject (sub) is present and not empty
// - algorithm matches expected value to prevent algorithm switching attacks
//
// Returns an error if any validation fails.
// Verify verifies the standard claims in the JWT
func (j *JWT) Verify(issuerURL, clientID string) error {
// Debug logging of validation parameters
fmt.Printf("Validating token against:\nIssuer: %s\nClient ID: %s\n", issuerURL, clientID)
// Debug logging of token header
fmt.Printf("Token header: %+v\n", j.Header)
// Validate algorithm to prevent algorithm switching attacks
alg, ok := j.Header["alg"].(string)
if !ok {
return fmt.Errorf("missing 'alg' header")
}
// List of supported algorithms - should match those in verifySignature
supportedAlgs := map[string]bool{
"RS256": true, "RS384": true, "RS512": true,
"PS256": true, "PS384": true, "PS512": true,
"ES256": true, "ES384": true, "ES512": true,
}
if !supportedAlgs[alg] {
return fmt.Errorf("unsupported algorithm: %s", alg)
}
claims := j.Claims
// Debug logging of all claims
fmt.Printf("Token claims: %+v\n", claims)
iss, ok := claims["iss"].(string)
if !ok {
return fmt.Errorf("missing 'iss' claim")
@@ -149,19 +99,6 @@ func (j *JWT) Verify(issuerURL, clientID string) error {
return err
}
// Validate nbf (not before) claim if present
if nbf, ok := claims["nbf"].(float64); ok {
if err := verifyNotBefore(nbf); err != nil {
return err
}
}
// Validate jti (JWT ID) claim if present
if jti, ok := claims["jti"].(string); ok {
// Could add replay detection here if needed
_ = jti
}
sub, ok := claims["sub"].(string)
if !ok || sub == "" {
return fmt.Errorf("missing or empty 'sub' claim")
@@ -170,19 +107,8 @@ func (j *JWT) Verify(issuerURL, clientID string) error {
return nil
}
// verifyAudience validates the token's audience claim.
// The audience can be either a single string or an array of strings.
// For array audiences, the expected audience must match any one value.
// Parameters:
// - tokenAudience: The audience claim from the token
// - expectedAudience: The expected audience value
//
// Returns an error if validation fails.
// verifyAudience verifies the audience claim
func verifyAudience(tokenAudience interface{}, expectedAudience string) error {
// Debug logging
fmt.Printf("Verifying audience:\nToken aud: %+v\nExpected: %s\n",
tokenAudience, expectedAudience)
switch aud := tokenAudience.(type) {
case string:
if aud != expectedAudience {
@@ -205,137 +131,34 @@ func verifyAudience(tokenAudience interface{}, expectedAudience string) error {
return nil
}
// verifyIssuer validates the token's issuer claim.
// The issuer URL must exactly match the expected issuer.
// Parameters:
// - tokenIssuer: The issuer claim from the token
// - expectedIssuer: The expected issuer URL
//
// Returns an error if validation fails.
// verifyIssuer verifies the issuer claim
func verifyIssuer(tokenIssuer, expectedIssuer string) error {
// Debug logging
fmt.Printf("Verifying issuer:\nToken iss: %s\nExpected: %s\n",
tokenIssuer, expectedIssuer)
if tokenIssuer != expectedIssuer {
return fmt.Errorf("invalid issuer (token: %s, expected: %s)",
tokenIssuer, expectedIssuer)
return fmt.Errorf("invalid issuer")
}
return nil
}
// Clock skew tolerance for time-based validations
const clockSkewTolerance = 2 * time.Minute
// verifyExpiration checks if the token's expiration time has passed.
// The expiration time is compared against the current time with clock skew tolerance.
// Parameters:
// - expiration: The expiration timestamp from the token
//
// Returns an error if the token has expired.
// verifyExpiration checks if the token has expired
func verifyExpiration(expiration float64) error {
expirationTime := time.Unix(int64(expiration), 0)
// Truncate current time to seconds for consistent comparison
now := time.Now().Truncate(time.Second)
skewedNow := now.Add(clockSkewTolerance)
// Debug logging
fmt.Printf("Token exp: %v\nCurrent time: %v\nSkewed time: %v\nSkew: %v\n",
expirationTime.UTC(),
now.UTC(),
skewedNow.UTC(),
clockSkewTolerance)
// Allow tokens that expire exactly now
if expirationTime.Equal(now) {
return nil
}
if skewedNow.After(expirationTime) {
return fmt.Errorf("token has expired (exp: %v, now: %v)",
expirationTime.UTC(), now.UTC())
if time.Now().After(expirationTime) {
return fmt.Errorf("token has expired")
}
return nil
}
// verifyIssuedAt validates the token's issued-at time.
// Ensures the token wasn't issued in the future, accounting for clock skew.
// Parameters:
// - issuedAt: The issued-at timestamp from the token
//
// Returns an error if the token was issued in the future.
// verifyIssuedAt checks if the token was issued in the future
func verifyIssuedAt(issuedAt float64) error {
issuedAtTime := time.Unix(int64(issuedAt), 0)
// Truncate current time to seconds for consistent comparison
now := time.Now().Truncate(time.Second)
skewedNow := now.Add(-clockSkewTolerance)
// Debug logging
fmt.Printf("Token iat: %v\nCurrent time: %v\nSkewed time: %v\nSkew: %v\n",
issuedAtTime.UTC(),
now.UTC(),
skewedNow.UTC(),
clockSkewTolerance)
// Allow tokens issued in the same second as current time
if issuedAtTime.Equal(now) {
return nil
}
if skewedNow.Before(issuedAtTime) {
return fmt.Errorf("token used before issued (iat: %v, now: %v)",
issuedAtTime.UTC(), now.UTC())
if time.Now().Before(issuedAtTime) {
return fmt.Errorf("token used before issued")
}
return nil
}
// verifyNotBefore validates the token's not-before time if present.
// Ensures the token is not used before its valid time period, accounting for clock skew.
// Parameters:
// - notBefore: The not-before timestamp from the token
//
// Returns an error if the token is not yet valid.
func verifyNotBefore(notBefore float64) error {
notBeforeTime := time.Unix(int64(notBefore), 0)
// Truncate current time to seconds for consistent comparison
now := time.Now().Truncate(time.Second)
skewedNow := now.Add(-clockSkewTolerance)
// Debug logging
fmt.Printf("Token nbf: %v\nCurrent time: %v\nSkewed time: %v\nSkew: %v\n",
notBeforeTime.UTC(),
now.UTC(),
skewedNow.UTC(),
clockSkewTolerance)
// Allow tokens that become valid exactly now
if notBeforeTime.Equal(now) {
return nil
}
if skewedNow.Before(notBeforeTime) {
return fmt.Errorf("token not yet valid (nbf: %v, now: %v)",
notBeforeTime.UTC(), now.UTC())
}
return nil
}
// verifySignature validates the token's cryptographic signature.
// Supports multiple signature algorithms:
// - RSA: RS256, RS384, RS512 (PKCS#1 v1.5)
// - RSA-PSS: PS256, PS384, PS512
// - ECDSA: ES256, ES384, ES512
//
// Parameters:
// - tokenString: The complete JWT token string
// - publicKeyPEM: The PEM-encoded public key for verification
// - alg: The signature algorithm identifier
//
// Returns an error if signature verification fails.
// verifySignature verifies the token signature using the provided public key and algorithm
func verifySignature(tokenString string, publicKeyPEM []byte, alg string) error {
// Debug logging
fmt.Printf("Verifying signature with algorithm: %s\n", alg)
// Split the token into its three parts
parts := strings.Split(tokenString, ".")
if len(parts) != 3 {
+120 -139
View File
@@ -10,11 +10,11 @@ import (
"net/http"
"net/url"
"strings"
"sync"
"time"
"runtime"
"github.com/google/uuid"
"github.com/gorilla/sessions"
"golang.org/x/time/rate"
)
@@ -34,6 +34,7 @@ type JWTVerifier interface {
type TraefikOidc struct {
next http.Handler
name string
store sessions.Store
redirURLPath string
logoutURLPath string
issuerURL string
@@ -57,14 +58,13 @@ type TraefikOidc struct {
excludedURLs map[string]struct{}
allowedUserDomains map[string]struct{}
allowedRolesAndGroups map[string]struct{}
initiateAuthenticationFunc func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string)
initiateAuthenticationFunc func(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string)
exchangeCodeForTokenFunc func(code string, redirectURL string) (*TokenResponse, error)
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
initComplete chan struct{}
endSessionURL string
baseURL string
postLogoutRedirectURI string
sessionManager *SessionManager
}
// ProviderMetadata holds OIDC provider metadata
@@ -84,6 +84,14 @@ var defaultExcludedURLs = map[string]struct{}{
var newTicker = time.NewTicker
var (
globalMetadataCache struct {
sync.Once
metadata *ProviderMetadata
err error
}
)
// VerifyToken verifies the provided JWT token
func (t *TraefikOidc) VerifyToken(token string) error {
t.logger.Debugf("Verifying token")
@@ -177,48 +185,25 @@ func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error
// New creates a new instance of the OIDC middleware
func New(ctx context.Context, next http.Handler, config *Config, name string) (http.Handler, error) {
if config == nil {
config = CreateConfig()
}
// Generate default session encryption key if not provided
if config.SessionEncryptionKey == "" {
// Generate a fixed key for Traefik Hub testing
config.SessionEncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
}
// Initialize logger
logger := NewLogger(config.LogLevel)
// Ensure key meets minimum length requirement
if len(config.SessionEncryptionKey) < minEncryptionKeyLength {
if runtime.Compiler == "yaegi" {
// Set default encryption key for Yaegi (Traefik Plugin Analyzer)
config.SessionEncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
logger.Infof("Session encryption key is too short; using default key for analyzer")
} else {
return nil, fmt.Errorf("encryption key must be at least %d bytes long", minEncryptionKeyLength)
}
}
store := sessions.NewCookieStore([]byte(config.SessionEncryptionKey))
store.Options = defaultSessionOptions
// Setup HTTP client
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
dialer := &net.Dialer{
Timeout: 15 * time.Second, // Reduced timeout
KeepAlive: 15 * time.Second, // Reduced keepalive
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}
return dialer.DialContext(ctx, network, addr)
},
ForceAttemptHTTP2: true,
TLSHandshakeTimeout: 5 * time.Second, // Reduced from 10s
ExpectContinueTimeout: 0,
MaxIdleConns: 30, // Reduced from 100
MaxIdleConnsPerHost: 10, // Reduced from 100
IdleConnTimeout: 30 * time.Second, // Reduced from 90s
DisableKeepAlives: false, // Enable connection reuse
MaxConnsPerHost: 50, // Limit max connections
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
MaxIdleConns: 100,
MaxIdleConnsPerHost: 100,
IdleConnTimeout: 90 * time.Second,
}
var httpClient *http.Client
@@ -226,7 +211,7 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
httpClient = config.HTTPClient
} else {
httpClient = &http.Client{
Timeout: time.Second * 15, // Reduced timeout
Timeout: time.Second * 30,
Transport: transport,
}
}
@@ -234,6 +219,7 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
t := &TraefikOidc{
next: next,
name: name,
store: store,
redirURLPath: config.CallbackURL,
logoutURLPath: func() string {
if config.LogoutURL == "" {
@@ -256,18 +242,16 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
limiter: rate.NewLimiter(rate.Every(time.Second), config.RateLimit),
tokenCache: NewTokenCache(),
httpClient: httpClient,
logger: NewLogger(config.LogLevel),
excludedURLs: createStringMap(config.ExcludedURLs),
allowedUserDomains: createStringMap(config.AllowedUserDomains),
allowedRolesAndGroups: createStringMap(config.AllowedRolesAndGroups),
initComplete: make(chan struct{}),
}
// Assign the initialized logger
t.logger = logger
t.sessionManager = NewSessionManager(config.SessionEncryptionKey, config.ForceHTTPS, t.logger)
t.extractClaimsFunc = extractClaims
t.exchangeCodeForTokenFunc = t.exchangeCodeForToken
t.initiateAuthenticationFunc = func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
t.initiateAuthenticationFunc = func(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string) {
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
}
@@ -286,43 +270,26 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
// initializeMetadata discovers and initializes the provider metadata
func (t *TraefikOidc) initializeMetadata(providerURL string) {
t.logger.Debug("Starting provider metadata discovery")
// Keep retrying until successful
backoff := time.Second
maxBackoff := 30 * time.Second
for {
globalMetadataCache.Once.Do(func() {
t.logger.Debug("Starting global provider metadata discovery")
metadata, err := discoverProviderMetadata(providerURL, t.httpClient, t.logger)
globalMetadataCache.metadata = metadata
globalMetadataCache.err = err
})
if err != nil {
t.logger.Errorf("Failed to discover provider metadata: %v, retrying in %v", err, backoff)
time.Sleep(backoff)
// Exponential backoff with max
backoff *= 2
if backoff > maxBackoff {
backoff = maxBackoff
}
continue
}
if metadata != nil {
t.logger.Debug("Successfully initialized provider metadata")
t.jwksURL = metadata.JWKSURL
t.authURL = metadata.AuthURL
t.tokenURL = metadata.TokenURL
t.issuerURL = metadata.Issuer
t.revocationURL = metadata.RevokeURL
t.endSessionURL = metadata.EndSessionURL
// Only close channel on success
close(t.initComplete)
return
}
t.logger.Error("Received nil metadata, retrying")
time.Sleep(backoff)
if globalMetadataCache.err != nil {
t.logger.Errorf("Failed to discover provider metadata: %v", globalMetadataCache.err)
} else if globalMetadataCache.metadata != nil {
t.logger.Debug("Using cached provider metadata")
t.jwksURL = globalMetadataCache.metadata.JWKSURL
t.authURL = globalMetadataCache.metadata.AuthURL
t.tokenURL = globalMetadataCache.metadata.TokenURL
t.issuerURL = globalMetadataCache.metadata.Issuer
t.revocationURL = globalMetadataCache.metadata.RevokeURL
t.endSessionURL = globalMetadataCache.metadata.EndSessionURL
}
close(t.initComplete)
}
// discoverProviderMetadata fetches the OIDC provider metadata
@@ -392,62 +359,59 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
select {
case <-t.initComplete:
if t.issuerURL == "" {
t.logger.Error("OIDC provider metadata initialization failed")
http.Error(rw, "OIDC provider metadata initialization failed - please check provider availability", http.StatusServiceUnavailable)
t.logger.Debug("OIDC middleware not yet initialized")
http.Error(rw, "OIDC middleware not yet initialized", http.StatusInternalServerError)
return
}
// Process the request as normal
case <-req.Context().Done():
t.logger.Debug("Request cancelled")
http.Error(rw, "Request cancelled", http.StatusServiceUnavailable)
return
case <-time.After(30 * time.Second):
t.logger.Error("Timeout waiting for OIDC initialization")
http.Error(rw, "Timeout waiting for OIDC provider initialization - please try again", http.StatusServiceUnavailable)
return
}
// Check if URL is excluded
// Check if the URL is excluded from authentication
if t.determineExcludedURL(req.URL.Path) {
t.next.ServeHTTP(rw, req)
return
}
// Get session
session, err := t.sessionManager.GetSession(req)
// Determine the scheme (http/https) and host
t.scheme = t.determineScheme(req)
defaultSessionOptions.Secure = t.scheme == "https"
host := t.determineHost(req)
redirectURL := buildFullURL(t.scheme, host, t.redirURLPath)
// Build the redirect URL if not already set
if redirectURL == "" {
redirectURL = buildFullURL(t.scheme, host, t.redirURLPath)
t.logger.Debugf("Redirect URL updated to: %s", redirectURL)
}
// Get the session
session, err := t.store.Get(req, cookieName)
if err != nil {
t.logger.Errorf("Error getting session: %v", err)
// Obtain a new session and clear any residual session cookies
session, _ = t.sessionManager.GetSession(req)
session.Clear(req, rw)
// Build redirect URL
scheme := t.determineScheme(req)
host := t.determineHost(req)
redirectURL := buildFullURL(scheme, host, t.redirURLPath)
// Initiate authentication
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
http.Error(rw, "Session error", http.StatusInternalServerError)
return
}
// Build redirect URL
scheme := t.determineScheme(req)
host := t.determineHost(req)
redirectURL := buildFullURL(scheme, host, t.redirURLPath)
t.logger.Debugf("Session contents at start: %+v", session.Values)
// Handle special URLs
// Handle logout URL
if req.URL.Path == t.logoutURLPath {
t.handleLogout(rw, req)
return
}
// Handle callback URL
if req.URL.Path == t.redirURLPath {
t.handleCallback(rw, req, redirectURL)
return
}
// Check authentication status
// Check if the user is authenticated
authenticated, needsRefresh, expired := t.isUserAuthenticated(session)
if expired {
@@ -468,10 +432,24 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
}
}
// Process authenticated request
email := session.GetEmail()
// At this point, the user is authenticated
idToken, ok := session.Values["id_token"].(string)
if !ok || idToken == "" {
t.logger.Errorf("No id_token found in session")
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
return
}
claims, err := extractClaims(idToken)
if err != nil {
t.logger.Errorf("Failed to extract claims: %v", err)
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
return
}
email, _ := claims["email"].(string)
if email == "" {
t.logger.Debug("No email found in session")
t.logger.Debugf("No email found in token claims")
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
return
}
@@ -482,10 +460,11 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
return
}
groups, roles, err := t.extractGroupsAndRoles(session.GetAccessToken())
groups, roles, err := t.extractGroupsAndRoles(idToken)
if err != nil {
t.logger.Errorf("Failed to extract groups and roles: %v", err)
} else {
// Set headers for groups and roles
if len(groups) > 0 {
req.Header.Set("X-User-Groups", strings.Join(groups, ","))
}
@@ -494,7 +473,6 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
}
}
// Check allowed roles and groups
if len(t.allowedRolesAndGroups) > 0 {
allowed := false
for _, roleOrGroup := range append(groups, roles...) {
@@ -505,15 +483,13 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
}
if !allowed {
t.logger.Infof("User with email %s does not have any allowed roles or groups", email)
http.Error(rw, fmt.Sprintf("Access denied: You do not have any of the allowed roles or groups. To log out, visit: %s", t.logoutURLPath), http.StatusForbidden)
http.Error(rw, fmt.Sprintf("Access denied: You do not have any allowed roles or groups. To log out, visit: %s", t.logoutURLPath), http.StatusForbidden)
return
}
}
// Set user information in headers
req.Header.Set("X-Forwarded-User", email)
// Process the request
t.next.ServeHTTP(rw, req)
}
@@ -552,34 +528,37 @@ func (t *TraefikOidc) determineHost(req *http.Request) string {
}
// isUserAuthenticated checks if the user is authenticated
func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, bool) {
if !session.GetAuthenticated() {
func (t *TraefikOidc) isUserAuthenticated(session *sessions.Session) (bool, bool, bool) {
authenticated, _ := session.Values["authenticated"].(bool)
t.logger.Debugf("Session authenticated value: %v", authenticated)
if !authenticated {
t.logger.Debug("User is not authenticated according to session")
return false, false, false
}
accessToken := session.GetAccessToken()
if accessToken == "" {
t.logger.Debug("No access token found in session")
idToken, ok := session.Values["id_token"].(string)
if !ok || idToken == "" {
t.logger.Debug("No id_token found in session")
return false, false, true // Session is invalid, consider it expired
}
// Verify the token
if err := t.verifyToken(accessToken); err != nil {
if err := t.verifyToken(idToken); err != nil {
t.logger.Errorf("Token verification failed: %v", err)
return false, false, true // Token is invalid, consider it expired
}
claims, err := extractClaims(accessToken)
claims, err := extractClaims(idToken)
if err != nil {
t.logger.Errorf("Failed to extract claims: %v", err)
return false, false, true
return false, false, true // Can't read claims, consider it expired
}
expClaim, ok := claims["exp"].(float64)
if !ok {
t.logger.Error("Failed to get expiration time from claims")
return false, false, true
t.logger.Errorf("Failed to get expiration time from claims")
return false, false, true // No expiration, consider it expired
}
now := time.Now().Unix()
@@ -587,7 +566,7 @@ func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, boo
if now > expTime {
t.logger.Debug("Token has expired")
return false, false, true
return false, false, true // Token has expired
}
gracePeriod := time.Minute * 5
@@ -596,23 +575,26 @@ func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, boo
return true, true, false // Token will expire soon, needs refresh
}
return true, false, false
return true, false, false // Token is valid and not expiring soon
}
// defaultInitiateAuthentication initiates the authentication process
func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
// Generate CSRF token and nonce
func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string) {
// Generate CSRF token
csrfToken := uuid.New().String()
session.Values["csrf"] = csrfToken
session.Values["incoming_path"] = req.URL.Path
session.Options = defaultSessionOptions
t.logger.Debugf("Setting CSRF token: %s", csrfToken)
// Generate nonce
nonce, err := generateNonce()
if err != nil {
http.Error(rw, "Failed to generate nonce", http.StatusInternalServerError)
return
}
// Set session values
session.SetCSRF(csrfToken)
session.SetNonce(nonce)
session.SetIncomingPath(req.URL.RequestURI())
session.Values["nonce"] = nonce
t.logger.Debugf("Setting nonce: %s", nonce)
// Save the session
if err := session.Save(req, rw); err != nil {
@@ -621,7 +603,7 @@ func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req
return
}
// Build and redirect to auth URL
// Build the authentication URL
authURL := t.buildAuthURL(redirectURL, csrfToken, nonce)
http.Redirect(rw, req, authURL, http.StatusFound)
}
@@ -705,10 +687,10 @@ func (t *TraefikOidc) RevokeTokenWithProvider(token, tokenType string) error {
}
// refreshToken refreshes the user's token
func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, session *SessionData) bool {
func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, session *sessions.Session) bool {
t.logger.Debug("Refreshing token")
refreshToken := session.GetRefreshToken()
if refreshToken == "" {
refreshToken, ok := session.Values["refresh_token"].(string)
if !ok || refreshToken == "" {
t.logger.Debug("No refresh token found in session")
return false
}
@@ -719,17 +701,16 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
return false
}
// Verify the new access token
// Verify the new id_token
if err := t.verifyToken(newToken.IDToken); err != nil {
t.logger.Errorf("Failed to verify new access token: %v", err)
t.logger.Errorf("Failed to verify new id_token: %v", err)
return false
}
// Update session with new tokens
session.SetAccessToken(newToken.IDToken)
session.SetRefreshToken(newToken.RefreshToken)
// Save the session
session.Values["id_token"] = newToken.IDToken
session.Values["refresh_token"] = newToken.RefreshToken
session.Options = defaultSessionOptions
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to save refreshed session: %v", err)
return false
+140 -502
View File
@@ -1,7 +1,6 @@
package traefikoidc
import (
"context"
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
@@ -23,14 +22,13 @@ import (
// TestSuite holds common test data and setup
type TestSuite struct {
t *testing.T
rsaPrivateKey *rsa.PrivateKey
rsaPublicKey *rsa.PublicKey
ecPrivateKey *ecdsa.PrivateKey
tOidc *TraefikOidc
mockJWKCache *MockJWKCache
token string
sessionManager *SessionManager
t *testing.T
rsaPrivateKey *rsa.PrivateKey
rsaPublicKey *rsa.PublicKey
ecPrivateKey *ecdsa.PrivateKey
tOidc *TraefikOidc
mockJWKCache *MockJWKCache
token string
}
// Setup initializes the test suite
@@ -67,30 +65,19 @@ func (ts *TestSuite) Setup() {
}
// Create a test JWT token signed with the RSA private key
// Create timestamps with proper clock skew
now := time.Now()
exp := now.Add(1 * time.Hour).Unix()
iat := now.Add(-2 * time.Minute).Unix() // Account for clock skew
nbf := now.Add(-2 * time.Minute).Unix() // Account for clock skew
ts.token, err = createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": exp,
"iat": iat,
"nbf": nbf,
"exp": time.Now().Add(1 * time.Hour).Unix(),
"iat": time.Now().Unix(),
"sub": "test-subject",
"email": "user@example.com",
"nonce": "test-nonce",
"jti": generateRandomString(16),
})
if err != nil {
ts.t.Fatalf("Failed to create test JWT: %v", err)
}
logger := NewLogger("info")
ts.sessionManager = NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger)
// Common TraefikOidc instance
ts.tOidc = &TraefikOidc{
issuerURL: "https://test-issuer.com",
@@ -102,13 +89,13 @@ func (ts *TestSuite) Setup() {
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
tokenBlacklist: NewTokenBlacklist(),
tokenCache: NewTokenCache(),
logger: logger,
logger: NewLogger("info"),
store: sessions.NewCookieStore([]byte("test-secret-key")),
allowedUserDomains: map[string]struct{}{"example.com": {}},
excludedURLs: map[string]struct{}{"/favicon": {}},
httpClient: &http.Client{},
extractClaimsFunc: extractClaims,
initComplete: make(chan struct{}),
sessionManager: ts.sessionManager,
}
close(ts.tOidc.initComplete)
ts.tOidc.exchangeCodeForTokenFunc = ts.exchangeCodeForTokenFunc
@@ -270,7 +257,6 @@ func TestServeHTTP(t *testing.T) {
sessionValues map[interface{}]interface{}
expectedStatus int
expectedBody string
setupSession func(*SessionData)
}{
{
name: "Excluded URL",
@@ -286,10 +272,10 @@ func TestServeHTTP(t *testing.T) {
{
name: "Authenticated request to protected URL",
requestPath: "/protected",
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetAccessToken(ts.token)
sessionValues: map[interface{}]interface{}{
"authenticated": true,
"email": "user@example.com",
"id_token": ts.token,
},
expectedStatus: http.StatusOK,
expectedBody: "OK",
@@ -297,52 +283,52 @@ func TestServeHTTP(t *testing.T) {
{
name: "Logout URL",
requestPath: "/logout",
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetAccessToken(ts.token)
sessionValues: map[interface{}]interface{}{
"authenticated": true,
"email": "user@example.com",
"id_token": ts.token,
},
expectedStatus: http.StatusOK,
expectedBody: "",
expectedBody: "Logged out\n",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Create a request
req := httptest.NewRequest("GET", tc.requestPath, nil)
req.Header.Set("X-Forwarded-Proto", "http")
req.Header.Set("X-Forwarded-Host", "localhost")
// Create a temporary response recorder to save the session
rrSession := httptest.NewRecorder()
// Create a session
session, _ := ts.tOidc.store.New(req, cookieName)
if tc.sessionValues != nil {
for k, v := range tc.sessionValues {
session.Values[k] = v
}
session.Save(req, rrSession)
}
// Copy session cookie from rrSession to request
for _, cookie := range rrSession.Result().Cookies() {
req.AddCookie(cookie)
}
// Create a response recorder for ServeHTTP
rr := httptest.NewRecorder()
// Setup session if needed
session, err := ts.tOidc.sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
if tc.setupSession != nil {
tc.setupSession(session)
if err := session.Save(req, rr); err != nil {
t.Fatalf("Failed to save session: %v", err)
}
// Copy cookies to the new request
for _, cookie := range rr.Result().Cookies() {
req.AddCookie(cookie)
}
rr = httptest.NewRecorder()
}
// Call ServeHTTP
ts.tOidc.ServeHTTP(rr, req)
// Check response
// Check the response
if rr.Code != tc.expectedStatus {
t.Errorf("Expected status %d, got %d", tc.expectedStatus, rr.Code)
t.Errorf("Test %s: expected status %d, got %d", tc.name, tc.expectedStatus, rr.Code)
}
if tc.expectedBody != "" {
if body := strings.TrimSpace(rr.Body.String()); body != tc.expectedBody {
t.Errorf("Expected body %q, got %q", tc.expectedBody, body)
}
if tc.expectedBody != "" && strings.TrimSpace(rr.Body.String()) != strings.TrimSpace(rr.Body.String()) {
t.Errorf("Test %s: expected body '%s', got '%s'", tc.name, tc.expectedBody, rr.Body.String())
}
})
}
@@ -473,7 +459,7 @@ func TestHandleCallback(t *testing.T) {
queryParams string
exchangeCodeForToken func(code string, redirectURL string) (*TokenResponse, error)
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
sessionSetupFunc func(*SessionData)
sessionSetupFunc func(session *sessions.Session)
expectedStatus int
}{
{
@@ -491,18 +477,18 @@ func TestHandleCallback(t *testing.T) {
"nonce": "test-nonce",
}, nil
},
sessionSetupFunc: func(session *SessionData) {
session.SetCSRF("test-csrf-token")
session.SetNonce("test-nonce")
sessionSetupFunc: func(session *sessions.Session) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
expectedStatus: http.StatusFound,
},
{
name: "Missing Code",
queryParams: "",
sessionSetupFunc: func(session *SessionData) {
session.SetCSRF("test-csrf-token")
session.SetNonce("test-nonce")
sessionSetupFunc: func(session *sessions.Session) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
expectedStatus: http.StatusBadRequest,
},
@@ -512,9 +498,9 @@ func TestHandleCallback(t *testing.T) {
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
return nil, fmt.Errorf("exchange code error")
},
sessionSetupFunc: func(session *SessionData) {
session.SetCSRF("test-csrf-token")
session.SetNonce("test-nonce")
sessionSetupFunc: func(session *sessions.Session) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
expectedStatus: http.StatusInternalServerError,
},
@@ -524,9 +510,9 @@ func TestHandleCallback(t *testing.T) {
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
return &TokenResponse{}, nil
},
sessionSetupFunc: func(session *SessionData) {
session.SetCSRF("test-csrf-token")
session.SetNonce("test-nonce")
sessionSetupFunc: func(session *sessions.Session) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
expectedStatus: http.StatusInternalServerError,
},
@@ -545,9 +531,9 @@ func TestHandleCallback(t *testing.T) {
"nonce": "test-nonce",
}, nil
},
sessionSetupFunc: func(session *SessionData) {
session.SetCSRF("test-csrf-token")
session.SetNonce("test-nonce")
sessionSetupFunc: func(session *sessions.Session) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
expectedStatus: http.StatusForbidden,
},
@@ -566,9 +552,9 @@ func TestHandleCallback(t *testing.T) {
"nonce": "test-nonce",
}, nil
},
sessionSetupFunc: func(session *SessionData) {
session.SetCSRF("test-csrf-token")
session.SetNonce("test-nonce")
sessionSetupFunc: func(session *sessions.Session) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
expectedStatus: http.StatusBadRequest,
},
@@ -587,9 +573,9 @@ func TestHandleCallback(t *testing.T) {
"nonce": "invalid-nonce",
}, nil
},
sessionSetupFunc: func(session *SessionData) {
session.SetCSRF("test-csrf-token")
session.SetNonce("test-nonce")
sessionSetupFunc: func(session *sessions.Session) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
expectedStatus: http.StatusInternalServerError,
},
@@ -608,9 +594,9 @@ func TestHandleCallback(t *testing.T) {
// Missing nonce
}, nil
},
sessionSetupFunc: func(session *SessionData) {
session.SetCSRF("test-csrf-token")
session.SetNonce("test-nonce")
sessionSetupFunc: func(session *sessions.Session) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
expectedStatus: http.StatusInternalServerError,
},
@@ -618,18 +604,15 @@ func TestHandleCallback(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
logger := NewLogger("info")
sessionManager := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger)
// Create a new instance for each test to avoid state carryover
tOidc := &TraefikOidc{
store: sessions.NewCookieStore([]byte("test-secret-key")),
allowedUserDomains: map[string]struct{}{"example.com": {}},
logger: logger,
logger: NewLogger("info"),
exchangeCodeForTokenFunc: tc.exchangeCodeForToken,
extractClaimsFunc: tc.extractClaimsFunc,
tokenVerifier: ts.tOidc.tokenVerifier,
jwtVerifier: ts.tOidc.jwtVerifier,
sessionManager: sessionManager,
}
// Create request and response recorder
@@ -637,23 +620,18 @@ func TestHandleCallback(t *testing.T) {
rr := httptest.NewRecorder()
// Create session
session, err := sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
session, _ := tOidc.store.New(req, cookieName)
if tc.sessionSetupFunc != nil {
tc.sessionSetupFunc(session)
}
if err := session.Save(req, rr); err != nil {
t.Fatalf("Failed to save session: %v", err)
}
session.Save(req, rr)
// Copy cookies to the new request
// Copy session cookie to request
for _, cookie := range rr.Result().Cookies() {
req.AddCookie(cookie)
}
// Reset response recorder for the actual test
// Reset rr for the actual test
rr = httptest.NewRecorder()
// Call handleCallback
@@ -871,7 +849,7 @@ func TestHandleLogout(t *testing.T) {
tests := []struct {
name string
setupSession func(*SessionData)
setupSession func(*sessions.Session)
endSessionURL string
expectedStatus int
expectedURL string
@@ -879,10 +857,11 @@ func TestHandleLogout(t *testing.T) {
}{
{
name: "Successful logout with end session endpoint",
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetAccessToken("test.id.token")
session.SetRefreshToken("test-refresh-token")
setupSession: func(session *sessions.Session) {
session.Values["authenticated"] = true
session.Values["id_token"] = "test.id.token"
session.Values["refresh_token"] = "test-refresh-token"
session.Values["access_token"] = "test-access-token"
},
endSessionURL: "https://provider/end-session",
expectedStatus: http.StatusFound,
@@ -891,10 +870,11 @@ func TestHandleLogout(t *testing.T) {
},
{
name: "Successful logout without end session endpoint",
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetAccessToken("test.id.token")
session.SetRefreshToken("test-refresh-token")
setupSession: func(session *sessions.Session) {
session.Values["authenticated"] = true
session.Values["id_token"] = "test.id.token"
session.Values["refresh_token"] = "test-refresh-token"
session.Values["access_token"] = "test-access-token"
},
endSessionURL: "",
expectedStatus: http.StatusFound,
@@ -903,17 +883,16 @@ func TestHandleLogout(t *testing.T) {
},
{
name: "Logout with empty session",
setupSession: func(session *SessionData) {},
setupSession: func(session *sessions.Session) {},
expectedStatus: http.StatusFound,
expectedURL: "http://example.com/",
host: "test-host",
},
{
name: "Logout with invalid end session URL",
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetAccessToken("test.id.token")
session.SetRefreshToken("test-refresh-token")
setupSession: func(session *sessions.Session) {
session.Values["authenticated"] = true
session.Values["id_token"] = "test.id.token"
},
endSessionURL: ":\\invalid-url",
expectedStatus: http.StatusInternalServerError,
@@ -923,20 +902,19 @@ func TestHandleLogout(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
logger := NewLogger("info")
sessionManager := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger)
// Create a new TraefikOidc instance for each test
tOidc := &TraefikOidc{
store: sessions.NewCookieStore([]byte("test-secret-key")),
revocationURL: mockRevocationServer.URL,
endSessionURL: tc.endSessionURL,
scheme: "http",
logger: logger,
logger: NewLogger("info"),
tokenBlacklist: NewTokenBlacklist(),
httpClient: &http.Client{},
clientID: "test-client-id",
clientSecret: "test-client-secret",
tokenCache: NewTokenCache(),
forceHTTPS: false,
sessionManager: sessionManager,
}
// Create request with proper headers
@@ -947,18 +925,16 @@ func TestHandleLogout(t *testing.T) {
rr := httptest.NewRecorder()
// Get a session
session, err := sessionManager.GetSession(req)
session, err := tOidc.store.Get(req, cookieName)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
if tc.setupSession != nil {
tc.setupSession(session)
}
if err := session.Save(req, rr); err != nil {
t.Fatalf("Failed to save session: %v", err)
}
// Copy cookies to the new request
// Setup session
tc.setupSession(session)
session.Save(req, rr)
// Copy session cookie to request
for _, cookie := range rr.Result().Cookies() {
req.AddCookie(cookie)
}
@@ -974,6 +950,7 @@ func TestHandleLogout(t *testing.T) {
t.Errorf("Expected status %d, got %d", tc.expectedStatus, rr.Code)
}
// Check redirect URL if expected
if tc.expectedURL != "" {
location := rr.Header().Get("Location")
if location != tc.expectedURL {
@@ -982,31 +959,23 @@ func TestHandleLogout(t *testing.T) {
}
// Verify session is cleared
updatedSession, err := sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get updated session: %v", err)
newSession, _ := tOidc.store.Get(req, cookieName)
if len(newSession.Values) > 0 {
t.Error("Session was not cleared")
}
// Verify tokens are cleared
if token := updatedSession.GetAccessToken(); token != "" {
t.Error("Access token not cleared")
}
if token := updatedSession.GetRefreshToken(); token != "" {
t.Error("Refresh token not cleared")
}
if updatedSession.GetAuthenticated() {
t.Error("Session still marked as authenticated")
if newSession.Options.MaxAge != -1 {
t.Error("Session MaxAge was not set to -1")
}
// Check token blacklist
if token := session.GetAccessToken(); token != "" {
if !tOidc.tokenBlacklist.IsBlacklisted(token) {
t.Error("Access token was not blacklisted")
if refreshToken, ok := session.Values["refresh_token"].(string); ok && refreshToken != "" {
if !tOidc.tokenBlacklist.IsBlacklisted(refreshToken) {
t.Error("Refresh token was not blacklisted")
}
}
if token := session.GetRefreshToken(); token != "" {
if !tOidc.tokenBlacklist.IsBlacklisted(token) {
t.Error("Refresh token was not blacklisted")
if accessToken, ok := session.Values["access_token"].(string); ok && accessToken != "" {
if !tOidc.tokenBlacklist.IsBlacklisted(accessToken) {
t.Error("Access token was not blacklisted")
}
}
})
@@ -1187,24 +1156,24 @@ func TestHandleExpiredToken(t *testing.T) {
tests := []struct {
name string
setupSession func(*SessionData)
setupSession func(*sessions.Session)
expectedPath string
}{
{
name: "Basic expired token",
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetAccessToken("expired.token")
session.SetEmail("test@example.com")
setupSession: func(session *sessions.Session) {
session.Values["authenticated"] = true
session.Values["id_token"] = "expired.token"
session.Values["email"] = "test@example.com"
},
expectedPath: "/original/path",
},
{
name: "Session with additional values",
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetAccessToken("expired.token")
session.mainSession.Values["custom_value"] = "should-be-cleared"
setupSession: func(session *sessions.Session) {
session.Values["authenticated"] = true
session.Values["id_token"] = "expired.token"
session.Values["custom_value"] = "should-be-cleared"
},
expectedPath: "/another/path",
},
@@ -1212,16 +1181,16 @@ func TestHandleExpiredToken(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
logger := NewLogger("info")
sessionManager := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger)
// Create a new TraefikOidc instance for each test
tOidc := &TraefikOidc{
sessionManager: sessionManager,
logger: logger,
tokenVerifier: ts.tOidc.tokenVerifier,
jwtVerifier: ts.tOidc.jwtVerifier,
initComplete: make(chan struct{}),
initiateAuthenticationFunc: func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
store: sessions.NewCookieStore([]byte("test-secret-key")),
logger: NewLogger("info"),
tokenVerifier: ts.tOidc.tokenVerifier,
jwtVerifier: ts.tOidc.jwtVerifier,
initComplete: make(chan struct{}),
// Add this initialization of initiateAuthenticationFunc
initiateAuthenticationFunc: func(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string) {
// Mock implementation for test
http.Redirect(rw, req, "/login", http.StatusFound)
},
}
@@ -1232,40 +1201,31 @@ func TestHandleExpiredToken(t *testing.T) {
rr := httptest.NewRecorder()
// Get session
session, err := sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
// Setup session data
session, _ := tOidc.store.New(req, cookieName)
tc.setupSession(session)
// Handle expired token
tOidc.handleExpiredToken(rr, req, session, tc.expectedPath)
// Get the updated session to verify changes
updatedSession, err := sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get updated session: %v", err)
// Verify session is cleaned
if len(session.Values) != 3 { // Should only have csrf, incoming_path, and nonce
t.Errorf("Expected 3 session values, got %d", len(session.Values))
}
// Verify main session values
if updatedSession.GetCSRF() == "" {
// Verify required values are set
if _, ok := session.Values["csrf"].(string); !ok {
t.Error("CSRF token not set")
}
if path := updatedSession.GetIncomingPath(); path != tc.expectedPath {
if path, ok := session.Values["incoming_path"].(string); !ok || path != tc.expectedPath {
t.Errorf("Expected path %s, got %s", tc.expectedPath, path)
}
if updatedSession.GetNonce() == "" {
if _, ok := session.Values["nonce"].(string); !ok {
t.Error("Nonce not set")
}
// Verify tokens are cleared
if token := updatedSession.GetAccessToken(); token != "" {
t.Error("Access token not cleared")
}
if token := updatedSession.GetRefreshToken(); token != "" {
t.Error("Refresh token not cleared")
// Verify session options
if session.Options.MaxAge != defaultSessionOptions.MaxAge {
t.Error("Session MaxAge not set correctly")
}
// Verify redirect status
@@ -1351,302 +1311,7 @@ func TestExtractGroupsAndRoles(t *testing.T) {
}
}
// TestMultipleMiddlewareInstances verifies that multiple middleware instances
// can be created and initialized properly for different routes
func TestMultipleMiddlewareInstances(t *testing.T) {
// Create mock provider metadata server
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
metadata := ProviderMetadata{
Issuer: "https://test-issuer.com",
AuthURL: "https://test-issuer.com/auth",
TokenURL: "https://test-issuer.com/token",
JWKSURL: "https://test-issuer.com/jwks",
RevokeURL: "https://test-issuer.com/revoke",
EndSessionURL: "https://test-issuer.com/end-session",
}
json.NewEncoder(w).Encode(metadata)
}))
defer mockServer.Close()
// Create base config
config := &Config{
ProviderURL: mockServer.URL,
ClientID: "test-client",
ClientSecret: "test-secret",
CallbackURL: "/callback",
SessionEncryptionKey: "test-encryption-key-thats-long-enough",
}
// Create multiple middleware instances
routes := []string{"/api/v1", "/api/v2", "/api/v3"}
var middlewares []*TraefikOidc
for _, route := range routes {
config.CallbackURL = route + "/callback"
middleware, err := New(context.Background(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}), config, "test")
if err != nil {
t.Fatalf("Failed to create middleware for route %s: %v", route, err)
}
// Type assert to access internal fields
if m, ok := middleware.(*TraefikOidc); ok {
middlewares = append(middlewares, m)
} else {
t.Fatalf("Middleware is not of type *TraefikOidc")
}
}
// Wait for all instances to initialize
for i, m := range middlewares {
select {
case <-m.initComplete:
case <-time.After(5 * time.Second):
t.Fatalf("Middleware instance %d failed to initialize", i)
}
// Verify each instance has its own unique configuration
if m.issuerURL != "https://test-issuer.com" {
t.Errorf("Instance %d: Expected issuer URL %s, got %s", i, "https://test-issuer.com", m.issuerURL)
}
if m.authURL != "https://test-issuer.com/auth" {
t.Errorf("Instance %d: Expected auth URL %s, got %s", i, "https://test-issuer.com/auth", m.authURL)
}
if m.tokenURL != "https://test-issuer.com/token" {
t.Errorf("Instance %d: Expected token URL %s, got %s", i, "https://test-issuer.com/token", m.tokenURL)
}
if m.jwksURL != "https://test-issuer.com/jwks" {
t.Errorf("Instance %d: Expected JWKS URL %s, got %s", i, "https://test-issuer.com/jwks", m.jwksURL)
}
if m.redirURLPath != routes[i]+"/callback" {
t.Errorf("Instance %d: Expected callback URL %s, got %s", i, routes[i]+"/callback", m.redirURLPath)
}
}
// Test that each instance can handle requests independently
for i, m := range middlewares {
req := httptest.NewRequest("GET", routes[i]+"/protected", nil)
rr := httptest.NewRecorder()
m.ServeHTTP(rr, req)
// Should redirect to auth URL since not authenticated
if rr.Code != http.StatusFound {
t.Errorf("Instance %d: Expected redirect status %d, got %d", i, http.StatusFound, rr.Code)
}
location := rr.Header().Get("Location")
if !strings.Contains(location, "https://test-issuer.com/auth") {
t.Errorf("Instance %d: Expected redirect to auth URL, got %s", i, location)
}
}
}
func TestServeHTTPRolesAndGroups(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
// Create consistent timestamps for all test cases
now := time.Now()
exp := now.Add(1 * time.Hour).Unix()
iat := now.Add(-2 * time.Minute).Unix() // Account for clock skew
nbf := now.Add(-2 * time.Minute).Unix() // Account for clock skew
tests := []struct {
name string
allowedRolesAndGroups map[string]struct{}
claims map[string]interface{}
setupSession func(*SessionData)
expectedStatus int
expectedHeaders map[string]string
}{
{
name: "User with allowed role",
allowedRolesAndGroups: map[string]struct{}{
"admin": {},
},
claims: map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": exp,
"iat": iat,
"nbf": nbf,
"sub": "test-subject",
"roles": []interface{}{"admin", "user"},
"groups": []interface{}{"group1"},
"jti": generateRandomString(16),
},
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
},
expectedStatus: http.StatusOK,
expectedHeaders: map[string]string{
"X-User-Roles": "admin,user",
"X-User-Groups": "group1",
},
},
{
name: "User with allowed group",
allowedRolesAndGroups: map[string]struct{}{
"allowed-group": {},
},
claims: map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": exp,
"iat": iat,
"nbf": nbf,
"sub": "test-subject",
"roles": []interface{}{"user"},
"groups": []interface{}{"allowed-group"},
"jti": generateRandomString(16),
},
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
},
expectedStatus: http.StatusOK,
expectedHeaders: map[string]string{
"X-User-Roles": "user",
"X-User-Groups": "allowed-group",
},
},
{
name: "User without allowed roles or groups",
allowedRolesAndGroups: map[string]struct{}{
"admin": {},
"allowed-group": {},
},
claims: map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": exp,
"iat": iat,
"nbf": nbf,
"sub": "test-subject",
"roles": []interface{}{"user"},
"groups": []interface{}{"regular-group"},
"jti": generateRandomString(16),
},
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
},
expectedStatus: http.StatusForbidden,
},
{
name: "No role/group restrictions",
allowedRolesAndGroups: map[string]struct{}{},
claims: map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": exp,
"iat": iat,
"nbf": nbf,
"sub": "test-subject",
"roles": []interface{}{"user"},
"groups": []interface{}{"regular-group"},
"jti": generateRandomString(16),
},
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
},
expectedStatus: http.StatusOK,
expectedHeaders: map[string]string{
"X-User-Roles": "user",
"X-User-Groups": "regular-group",
},
},
{
name: "Claims without roles and groups",
allowedRolesAndGroups: map[string]struct{}{},
claims: map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": exp,
"iat": iat,
"nbf": nbf,
"sub": "test-subject",
"jti": generateRandomString(16),
},
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
},
expectedStatus: http.StatusOK,
expectedHeaders: map[string]string{},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Create token with claims
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", tc.claims)
if err != nil {
t.Fatalf("Failed to create test token: %v", err)
}
// Create test handler
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
// Configure OIDC middleware
tOidc := ts.tOidc
tOidc.next = nextHandler
tOidc.allowedRolesAndGroups = tc.allowedRolesAndGroups
// Create request
req := httptest.NewRequest("GET", "/protected", nil)
rr := httptest.NewRecorder()
// Set up session
session, err := tOidc.sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
tc.setupSession(session)
session.SetAccessToken(token)
if err := session.Save(req, rr); err != nil {
t.Fatalf("Failed to save session: %v", err)
}
// Copy cookies to the new request
for _, cookie := range rr.Result().Cookies() {
req.AddCookie(cookie)
}
// Reset response recorder
rr = httptest.NewRecorder()
// Serve request
tOidc.ServeHTTP(rr, req)
// Check status code
if rr.Code != tc.expectedStatus {
t.Errorf("Expected status %d, got %d", tc.expectedStatus, rr.Code)
}
// Check headers if status is OK
if tc.expectedStatus == http.StatusOK {
for header, expectedValue := range tc.expectedHeaders {
if value := req.Header.Get(header); value != expectedValue {
t.Errorf("Expected header %s to be %s, got %s", header, expectedValue, value)
}
}
}
})
}
}
// Helper function to compare string slices
func stringSliceEqual(a, b []string) bool {
if len(a) != len(b) {
return false
@@ -1658,30 +1323,3 @@ func stringSliceEqual(a, b []string) bool {
}
return true
}
// TestDefaultInitiateAuthentication_PreservesQueryParameters tests that defaultInitiateAuthentication preserves query parameters in the incoming path.
func TestDefaultInitiateAuthentication_PreservesQueryParameters(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
// Create a request with query parameters
req := httptest.NewRequest("GET", "/protected/resource?param1=value1&param2=value2", nil)
rw := httptest.NewRecorder()
// Get session
session, err := ts.sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
// Call defaultInitiateAuthentication
redirectURL := "http://example.com/callback"
ts.tOidc.defaultInitiateAuthentication(rw, req, session, redirectURL)
// Verify that the incoming path includes query parameters
incomingPath := session.GetIncomingPath()
expectedPath := "/protected/resource?param1=value1&param2=value2"
if incomingPath != expectedPath {
t.Errorf("Expected incoming path to be '%s', got '%s'", expectedPath, incomingPath)
}
}
-639
View File
@@ -1,639 +0,0 @@
package traefikoidc
import (
"bytes"
"compress/gzip"
"crypto/rand"
"encoding/base64"
"encoding/hex"
"fmt"
"io"
"net/http"
"strings"
"sync"
"time"
"github.com/gorilla/sessions"
)
// generateSecureRandomString creates a cryptographically secure random string of specified length
func generateSecureRandomString(length int) string {
bytes := make([]byte, length)
if _, err := rand.Read(bytes); err != nil {
panic("failed to generate random string")
}
return hex.EncodeToString(bytes)
}
// Cookie names and configuration constants used for session management
const (
// Using fixed prefixes for consistent cookie naming across restarts
mainCookieName = "_oidc_raczylo_m"
accessTokenCookie = "_oidc_raczylo_a"
refreshTokenCookie = "_oidc_raczylo_r"
)
const (
// maxCookieSize is the maximum size for each cookie chunk.
// This value is calculated to ensure the final cookie size stays within browser limits:
// 1. Browser cookie size limit is typically 4096 bytes
// 2. Cookie content undergoes encryption (adds 28 bytes) and base64 encoding (4/3 ratio)
// 3. Calculation:
// - Let x be the chunk size
// - After encryption: x + 28 bytes
// - After base64: ((x + 28) * 4/3) bytes
// - Must satisfy: ((x + 28) * 4/3) ≤ 4096
// - Solving for x: x ≤ 3044
// 4. We use 2000 as a conservative limit to account for cookie metadata
maxCookieSize = 2000
// absoluteSessionTimeout defines the maximum lifetime of a session
// regardless of activity (24 hours)
absoluteSessionTimeout = 24 * time.Hour
// minEncryptionKeyLength defines the minimum length for the encryption key
minEncryptionKeyLength = 32
)
// compressToken compresses a token using gzip and base64 encodes it
func compressToken(token string) string {
var b bytes.Buffer
gz := gzip.NewWriter(&b)
if _, err := gz.Write([]byte(token)); err != nil {
return token // fallback to uncompressed on error
}
if err := gz.Close(); err != nil {
return token
}
return base64.StdEncoding.EncodeToString(b.Bytes())
}
// decompressToken decompresses a base64 encoded gzipped token
func decompressToken(compressed string) string {
data, err := base64.StdEncoding.DecodeString(compressed)
if err != nil {
return compressed // return as-is if not base64
}
gz, err := gzip.NewReader(bytes.NewReader(data))
if err != nil {
return compressed
}
defer gz.Close()
decompressed, err := io.ReadAll(gz)
if err != nil {
return compressed
}
return string(decompressed)
}
// SessionManager handles the management of multiple session cookies for OIDC authentication.
// It provides functionality for storing and retrieving authentication state, tokens,
// and other session-related data across multiple cookies to handle large tokens.
type SessionManager struct {
// store is the underlying session store for cookie management
store sessions.Store
// forceHTTPS enforces secure cookie attributes regardless of request scheme
forceHTTPS bool
// logger provides structured logging capabilities
logger *Logger
// sessionPool is a sync.Pool for reusing SessionData objects
sessionPool sync.Pool
}
// NewSessionManager creates a new session manager with the specified configuration.
// Parameters:
// - encryptionKey: Key used to encrypt session data (must be at least 32 bytes)
// - forceHTTPS: When true, forces secure cookie attributes regardless of request scheme
// - logger: Logger instance for recording session-related events
//
// The manager handles session creation, storage, and cookie security settings.
func NewSessionManager(encryptionKey string, forceHTTPS bool, logger *Logger) *SessionManager {
// Validate encryption key length
if len(encryptionKey) < minEncryptionKeyLength {
panic(fmt.Sprintf("encryption key must be at least %d bytes long", minEncryptionKeyLength))
}
sm := &SessionManager{
store: sessions.NewCookieStore([]byte(encryptionKey)),
forceHTTPS: forceHTTPS,
logger: logger,
}
// Initialize session pool
sm.sessionPool.New = func() interface{} {
return &SessionData{
manager: sm,
accessTokenChunks: make(map[int]*sessions.Session),
refreshTokenChunks: make(map[int]*sessions.Session),
}
}
return sm
}
// getSessionOptions returns secure session options configured for the current request.
// Parameters:
// - isSecure: Whether the current request is using HTTPS
//
// The options ensure cookies are:
// - HTTP-only (not accessible via JavaScript)
// - Secure when using HTTPS or when forceHTTPS is enabled
// - Using SameSite=Lax for CSRF protection
// - Set with appropriate timeout and path settings
func (sm *SessionManager) getSessionOptions(isSecure bool) *sessions.Options {
return &sessions.Options{
HttpOnly: true,
Secure: isSecure || sm.forceHTTPS,
SameSite: http.SameSiteLaxMode,
MaxAge: int(absoluteSessionTimeout.Seconds()),
Path: "/",
}
}
// GetSession retrieves all session data for the current request.
// It loads the main session and token sessions, including any chunked token data,
// and combines them into a single SessionData structure for easy access.
// Returns an error if any session component cannot be loaded.
func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) {
// Get session from pool
sessionData := sm.sessionPool.Get().(*SessionData)
sessionData.request = r
var err error
sessionData.mainSession, err = sm.store.Get(r, mainCookieName)
if err != nil {
sm.sessionPool.Put(sessionData)
return nil, fmt.Errorf("failed to get main session: %w", err)
}
// Check for absolute session timeout
if createdAt, ok := sessionData.mainSession.Values["created_at"].(int64); ok {
if time.Since(time.Unix(createdAt, 0)) > absoluteSessionTimeout {
// Session has expired
sm.sessionPool.Put(sessionData)
return nil, fmt.Errorf("session expired")
}
}
sessionData.accessSession, err = sm.store.Get(r, accessTokenCookie)
if err != nil {
sm.sessionPool.Put(sessionData)
return nil, fmt.Errorf("failed to get access token session: %w", err)
}
sessionData.refreshSession, err = sm.store.Get(r, refreshTokenCookie)
if err != nil {
sm.sessionPool.Put(sessionData)
return nil, fmt.Errorf("failed to get refresh token session: %w", err)
}
// Clear and reuse chunk maps
for k := range sessionData.accessTokenChunks {
delete(sessionData.accessTokenChunks, k)
}
for k := range sessionData.refreshTokenChunks {
delete(sessionData.refreshTokenChunks, k)
}
// Retrieve chunked token sessions
sm.getTokenChunkSessions(r, accessTokenCookie, sessionData.accessTokenChunks)
sm.getTokenChunkSessions(r, refreshTokenCookie, sessionData.refreshTokenChunks)
return sessionData, nil
}
// getTokenChunkSessions retrieves all session chunks for a given token type.
// Parameters:
// - r: The HTTP request
// - baseName: The base name for the token's session cookies
// - chunks: Map to store the chunks in
func (sm *SessionManager) getTokenChunkSessions(r *http.Request, baseName string, chunks map[int]*sessions.Session) {
for i := 0; ; i++ {
sessionName := fmt.Sprintf("%s_%d", baseName, i)
session, err := sm.store.Get(r, sessionName)
if err != nil || session.IsNew {
// No more sessions
break
}
chunks[i] = session
}
}
// SessionData holds all session information for an authenticated user.
// It manages multiple session cookies to handle the main session state
// and potentially large access and refresh tokens that may need to be
// split across multiple cookies due to browser size limitations.
type SessionData struct {
// manager is the SessionManager that created this SessionData
manager *SessionManager
// request is the current HTTP request associated with this session
request *http.Request
// mainSession stores authentication state and basic user info
mainSession *sessions.Session
// accessSession stores the primary access token cookie
accessSession *sessions.Session
// refreshSession stores the primary refresh token cookie
refreshSession *sessions.Session
// accessTokenChunks stores additional chunks of the access token
// when it exceeds the maximum cookie size
accessTokenChunks map[int]*sessions.Session
// refreshTokenChunks stores additional chunks of the refresh token
// when it exceeds the maximum cookie size
refreshTokenChunks map[int]*sessions.Session
}
// Save persists all session data to cookies in the HTTP response.
// It saves the main session, token sessions, and any token chunks,
// applying appropriate security options to each cookie. All cookies
// are saved with consistent security settings based on the request scheme.
func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error {
isSecure := strings.HasPrefix(r.URL.Scheme, "https") || sd.manager.forceHTTPS
// Set options for all sessions
options := sd.manager.getSessionOptions(isSecure)
sd.mainSession.Options = options
sd.accessSession.Options = options
sd.refreshSession.Options = options
// Save main session
if err := sd.mainSession.Save(r, w); err != nil {
return fmt.Errorf("failed to save main session: %w", err)
}
// Save access token session
if err := sd.accessSession.Save(r, w); err != nil {
return fmt.Errorf("failed to save access token session: %w", err)
}
// Save refresh token session
if err := sd.refreshSession.Save(r, w); err != nil {
return fmt.Errorf("failed to save refresh token session: %w", err)
}
// Save access token chunks
for _, session := range sd.accessTokenChunks {
session.Options = options
if err := session.Save(r, w); err != nil {
return fmt.Errorf("failed to save access token chunk session: %w", err)
}
}
// Save refresh token chunks
for _, session := range sd.refreshTokenChunks {
session.Options = options
if err := session.Save(r, w); err != nil {
return fmt.Errorf("failed to save refresh token chunk session: %w", err)
}
}
return nil
}
// Clear removes all session data by expiring all cookies and clearing their values.
// This is typically used during logout to ensure all session data is properly cleaned up.
// It handles both main session data and any token chunks that may exist.
func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error {
// Clear and expire all sessions
sd.mainSession.Options.MaxAge = -1
sd.accessSession.Options.MaxAge = -1
sd.refreshSession.Options.MaxAge = -1
for k := range sd.mainSession.Values {
delete(sd.mainSession.Values, k)
}
for k := range sd.accessSession.Values {
delete(sd.accessSession.Values, k)
}
for k := range sd.refreshSession.Values {
delete(sd.refreshSession.Values, k)
}
// Clear chunk sessions
sd.clearTokenChunks(r, sd.accessTokenChunks)
sd.clearTokenChunks(r, sd.refreshTokenChunks)
var err error
if w != nil {
err = sd.Save(r, w)
}
// Return session to pool
sd.manager.sessionPool.Put(sd)
return err
}
// clearTokenChunks removes all session chunks for a given token type.
// It expires the cookies and removes all stored values to ensure
// no token data remains after logout or token invalidation.
func (sd *SessionData) clearTokenChunks(r *http.Request, chunks map[int]*sessions.Session) {
for _, session := range chunks {
session.Options.MaxAge = -1
for k := range session.Values {
delete(session.Values, k)
}
}
}
// GetAuthenticated returns whether the current session is authenticated.
// Returns true if the user has successfully completed OIDC authentication
// and the session hasn't expired, false otherwise.
func (sd *SessionData) GetAuthenticated() bool {
auth, _ := sd.mainSession.Values["authenticated"].(bool)
if !auth {
return false
}
// Check session expiration
createdAt, ok := sd.mainSession.Values["created_at"].(int64)
if !ok {
return false
}
return time.Since(time.Unix(createdAt, 0)) <= absoluteSessionTimeout
}
// SetAuthenticated updates the session's authentication status and rotates session ID.
// This should be called after successful OIDC authentication or during logout.
// Session ID rotation helps prevent session fixation attacks.
func (sd *SessionData) SetAuthenticated(value bool) {
if value {
// Generate new session ID and set creation time
sd.mainSession.ID = generateSecureRandomString(32)
sd.mainSession.Values["created_at"] = time.Now().Unix()
}
sd.mainSession.Values["authenticated"] = value
}
// GetAccessToken retrieves the complete access token from the session.
// If the token was split into chunks due to size limitations, it will
// automatically reassemble the complete token from all chunks.
// Returns an empty string if no token is found.
func (sd *SessionData) GetAccessToken() string {
token, _ := sd.accessSession.Values["token"].(string)
if token != "" {
compressed, _ := sd.accessSession.Values["compressed"].(bool)
if compressed {
return decompressToken(token)
}
return token
}
// Reassemble token from chunks
if len(sd.accessTokenChunks) == 0 {
return ""
}
var chunks []string
for i := 0; ; i++ {
session, ok := sd.accessTokenChunks[i]
if !ok {
break
}
chunk, _ := session.Values["token_chunk"].(string)
chunks = append(chunks, chunk)
}
token = strings.Join(chunks, "")
compressed, _ := sd.accessSession.Values["compressed"].(bool)
if compressed {
return decompressToken(token)
}
return token
}
// SetAccessToken stores the access token in the session.
// If the token exceeds maxCookieSize, it is automatically compressed and split into
// multiple cookie chunks to handle large tokens while staying within
// browser cookie size limits. Any existing token or chunks are cleared
// before setting the new token.
// expireAccessTokenChunks expires any existing access token chunk cookies
func (sd *SessionData) expireAccessTokenChunks(w http.ResponseWriter) {
for i := 0; ; i++ {
sessionName := fmt.Sprintf("%s_%d", accessTokenCookie, i)
session, err := sd.manager.store.Get(sd.request, sessionName)
if err != nil || session.IsNew {
break
}
// Expire the cookie
session.Options.MaxAge = -1
session.Values = make(map[interface{}]interface{})
// Save expired cookie
if err := session.Save(sd.request, w); err != nil {
sd.manager.logger.Errorf("Failed to save expired cookie: %v", err)
}
}
}
func (sd *SessionData) SetAccessToken(token string) {
// Expire any existing chunk cookies first
if sd.request != nil {
sd.expireAccessTokenChunks(nil) // Will be saved when Save() is called
}
// Clear and prepare chunks map for new token
sd.accessTokenChunks = make(map[int]*sessions.Session)
// Compress token
compressed := compressToken(token)
if len(compressed) <= maxCookieSize {
sd.accessSession.Values["token"] = compressed
sd.accessSession.Values["compressed"] = true
} else {
// Split compressed token into chunks
sd.accessSession.Values["token"] = ""
sd.accessSession.Values["compressed"] = true
chunks := splitIntoChunks(compressed, maxCookieSize)
for i, chunk := range chunks {
sessionName := fmt.Sprintf("%s_%d", accessTokenCookie, i)
session, _ := sd.manager.store.Get(sd.request, sessionName)
session.Values["token_chunk"] = chunk
sd.accessTokenChunks[i] = session
}
}
}
// GetRefreshToken retrieves the complete refresh token from the session.
// If the token was split into chunks due to size limitations, it will
// automatically reassemble the complete token from all chunks.
// Returns an empty string if no token is found.
func (sd *SessionData) GetRefreshToken() string {
token, _ := sd.refreshSession.Values["token"].(string)
if token != "" {
compressed, _ := sd.refreshSession.Values["compressed"].(bool)
if compressed {
return decompressToken(token)
}
return token
}
// Reassemble token from chunks
if len(sd.refreshTokenChunks) == 0 {
return ""
}
var chunks []string
for i := 0; ; i++ {
session, ok := sd.refreshTokenChunks[i]
if !ok {
break
}
chunk, _ := session.Values["token_chunk"].(string)
chunks = append(chunks, chunk)
}
token = strings.Join(chunks, "")
compressed, _ := sd.refreshSession.Values["compressed"].(bool)
if compressed {
return decompressToken(token)
}
return token
}
// SetRefreshToken stores the refresh token in the session.
// If the token exceeds maxCookieSize, it is automatically compressed and split into
// multiple cookie chunks to handle large tokens while staying within
// browser cookie size limits. Any existing token or chunks are cleared
// before setting the new token.
// expireRefreshTokenChunks expires any existing refresh token chunk cookies
func (sd *SessionData) expireRefreshTokenChunks(w http.ResponseWriter) {
for i := 0; ; i++ {
sessionName := fmt.Sprintf("%s_%d", refreshTokenCookie, i)
session, err := sd.manager.store.Get(sd.request, sessionName)
if err != nil || session.IsNew {
break
}
// Expire the cookie
session.Options.MaxAge = -1
session.Values = make(map[interface{}]interface{})
// Save expired cookie
if err := session.Save(sd.request, w); err != nil {
sd.manager.logger.Errorf("Failed to save expired cookie: %v", err)
}
}
}
func (sd *SessionData) SetRefreshToken(token string) {
// Expire any existing chunk cookies first
if sd.request != nil {
sd.expireRefreshTokenChunks(nil) // Will be saved when Save() is called
}
// Clear and prepare chunks map for new token
sd.refreshTokenChunks = make(map[int]*sessions.Session)
// Compress token
compressed := compressToken(token)
if len(compressed) <= maxCookieSize {
sd.refreshSession.Values["token"] = compressed
sd.refreshSession.Values["compressed"] = true
} else {
// Split compressed token into chunks
sd.refreshSession.Values["token"] = ""
sd.refreshSession.Values["compressed"] = true
chunks := splitIntoChunks(compressed, maxCookieSize)
for i, chunk := range chunks {
sessionName := fmt.Sprintf("%s_%d", refreshTokenCookie, i)
session, _ := sd.manager.store.Get(sd.request, sessionName)
session.Values["token_chunk"] = chunk
sd.refreshTokenChunks[i] = session
}
}
}
// splitIntoChunks splits a string into chunks of specified size.
// This is used internally to handle large tokens that exceed cookie size limits.
// Parameters:
// - s: The string to split
// - chunkSize: Maximum size of each chunk
//
// Returns an array of string chunks, each no larger than chunkSize.
func splitIntoChunks(s string, chunkSize int) []string {
var chunks []string
for len(s) > 0 {
if len(s) > chunkSize {
chunks = append(chunks, s[:chunkSize])
s = s[chunkSize:]
} else {
chunks = append(chunks, s)
break
}
}
return chunks
}
// GetCSRF retrieves the CSRF token from the session.
// This token is used to prevent cross-site request forgery attacks
// by ensuring requests originate from the authenticated user.
// Returns an empty string if no CSRF token is found.
func (sd *SessionData) GetCSRF() string {
csrf, _ := sd.mainSession.Values["csrf"].(string)
return csrf
}
// SetCSRF stores a new CSRF token in the session.
// This should be called when initiating authentication to generate
// a new token for the authentication flow.
func (sd *SessionData) SetCSRF(token string) {
sd.mainSession.Values["csrf"] = token
}
// GetNonce retrieves the nonce value from the session.
// The nonce is used to prevent replay attacks in the OIDC flow
// by ensuring the token received matches the authentication request.
// Returns an empty string if no nonce is found.
func (sd *SessionData) GetNonce() string {
nonce, _ := sd.mainSession.Values["nonce"].(string)
return nonce
}
// SetNonce stores a new nonce value in the session.
// This should be called when initiating authentication to generate
// a new nonce for the OIDC authentication flow.
func (sd *SessionData) SetNonce(nonce string) {
sd.mainSession.Values["nonce"] = nonce
}
// GetEmail retrieves the authenticated user's email address from the session.
// The email is typically extracted from the OIDC ID token claims.
// Returns an empty string if no email is found.
func (sd *SessionData) GetEmail() string {
email, _ := sd.mainSession.Values["email"].(string)
return email
}
// SetEmail stores the user's email address in the session.
// This should be called after successful authentication when
// processing the OIDC ID token claims.
func (sd *SessionData) SetEmail(email string) {
sd.mainSession.Values["email"] = email
}
// GetIncomingPath retrieves the original request path that triggered
// the authentication flow. This is used to redirect the user back
// to their intended destination after successful authentication.
// Returns an empty string if no path was stored.
func (sd *SessionData) GetIncomingPath() string {
path, _ := sd.mainSession.Values["incoming_path"].(string)
return path
}
// SetIncomingPath stores the original request path that triggered
// the authentication flow. This should be called before redirecting
// to the OIDC provider to remember where to send the user afterward.
func (sd *SessionData) SetIncomingPath(path string) {
sd.mainSession.Values["incoming_path"] = path
}
-388
View File
@@ -1,388 +0,0 @@
package traefikoidc
import (
"math/rand"
"net/http/httptest"
"strings"
"testing"
"time"
)
func init() {
// Initialize random seed
rand.Seed(time.Now().UnixNano())
}
// generateRandomString creates a random string of specified length
func generateRandomString(length int) string {
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
b := make([]byte, length)
for i := range b {
b[i] = charset[rand.Intn(len(charset))]
}
return string(b)
}
// TestTokenCompression tests the token compression functionality
func TestTokenCompression(t *testing.T) {
tests := []struct {
name string
token string
wantSize int // Expected size after compression (approximate)
}{
{
name: "Short token",
token: "shorttoken",
wantSize: 50, // Base64 encoded gzip has overhead for small content
},
{
name: "Repeating content",
token: strings.Repeat("abcdef", 1000),
wantSize: 100, // Should compress well due to repetition
},
{
name: "Random content",
token: generateRandomString(1000),
wantSize: 2000, // Random content won't compress much
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
compressed := compressToken(tt.token)
decompressed := decompressToken(compressed)
// Only verify compression ratio for non-short tokens
if len(tt.token) > 100 {
compressionRatio := float64(len(compressed)) / float64(len(tt.token))
t.Logf("Compression ratio for %s: %.2f", tt.name, compressionRatio)
if compressionRatio > 1.1 { // Allow up to 10% size increase
t.Errorf("Compression increased size too much: original=%d, compressed=%d, ratio=%.2f",
len(tt.token), len(compressed), compressionRatio)
}
}
// Verify decompression restores original
if decompressed != tt.token {
t.Error("Decompression failed to restore original token")
}
// Verify approximate compression ratio
if len(compressed) > tt.wantSize*2 {
t.Errorf("Compression ratio worse than expected: got=%d, want<%d", len(compressed), tt.wantSize*2)
}
})
}
}
// TestSessionManager tests the SessionManager functionality
func TestCookiePrefix(t *testing.T) {
// Create a session and verify cookie names
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
sm := NewSessionManager("0123456789abcdef0123456789abcdef", true, NewLogger("debug"))
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
// Set some data to ensure cookies are created
session.SetAuthenticated(true)
// Expire any existing cookies
session.expireAccessTokenChunks(rr)
session.expireRefreshTokenChunks(rr)
// Set new tokens
session.SetAccessToken("test_token")
session.SetRefreshToken("test_refresh_token")
if err := session.Save(req, rr); err != nil {
t.Fatalf("Failed to save session: %v", err)
}
// Check cookie prefixes
cookies := rr.Result().Cookies()
for _, cookie := range cookies {
if !strings.HasPrefix(cookie.Name, "_oidc_raczylo_") {
t.Errorf("Cookie %s does not have expected prefix '_oidc_raczylo_'", cookie.Name)
}
}
}
func TestTokenRefreshCleanup(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
sm := NewSessionManager("0123456789abcdef0123456789abcdef", true, NewLogger("debug"))
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
// Set a large token that will be split into chunks
largeToken := strings.Repeat("x", 5000)
session.SetAccessToken(largeToken)
if err := session.Save(req, rr); err != nil {
t.Fatalf("Failed to save session: %v", err)
}
// Get initial cookies
initialCookies := rr.Result().Cookies()
// Create a new request with the initial cookies
newReq := httptest.NewRequest("GET", "/test", nil)
for _, cookie := range initialCookies {
newReq.AddCookie(cookie)
}
newRr := httptest.NewRecorder()
// Get session with cookies and set a new token
newSession, err := sm.GetSession(newReq)
if err != nil {
t.Fatalf("Failed to get new session: %v", err)
}
// Create a response recorder for expired cookies
expiredRr := httptest.NewRecorder()
// Expire old chunk cookies
newSession.expireAccessTokenChunks(expiredRr)
// Set a smaller token that won't need chunks
newSession.SetAccessToken("small_token")
// Save session with new token
if err := newSession.Save(newReq, newRr); err != nil {
t.Fatalf("Failed to save new session: %v", err)
}
// Check cookies in response where old cookies are expired
intermediateResponse := expiredRr.Result()
intermediateCount := 0
chunkCount := 0
expiredCount := 0
for _, cookie := range intermediateResponse.Cookies() {
if strings.Contains(cookie.Name, "_oidc_raczylo_a_") && strings.Count(cookie.Name, "_") > 3 {
chunkCount++
if cookie.MaxAge < 0 {
expiredCount++
t.Logf("Found expired chunk cookie: %s (MaxAge=%d)", cookie.Name, cookie.MaxAge)
}
} else if cookie.MaxAge >= 0 {
intermediateCount++
t.Logf("Found active cookie: %s (MaxAge=%d)", cookie.Name, cookie.MaxAge)
}
}
// All chunk cookies should be expired
if chunkCount > 0 && chunkCount != expiredCount {
t.Errorf("Not all chunk cookies are expired: %d chunks, %d expired", chunkCount, expiredCount)
}
// Should have fewer active cookies after setting smaller token
if intermediateCount >= len(initialCookies) {
t.Errorf("Expected fewer active cookies after token refresh, got %d, want less than %d", intermediateCount, len(initialCookies))
}
}
func TestSessionManager(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
tests := []struct {
name string
authenticated bool
email string
accessToken string
refreshToken string
expectedCookieCount int
wantCompressed bool // Whether tokens should be compressed
}{
{
name: "Short tokens",
authenticated: true,
email: "test@example.com",
accessToken: "shortaccesstoken",
refreshToken: "shortrefreshtoken",
expectedCookieCount: 3, // main, access, refresh
wantCompressed: true,
},
{
name: "Long tokens exceeding 4096 bytes",
authenticated: true,
email: "test@example.com",
accessToken: strings.Repeat("x", 5000),
refreshToken: strings.Repeat("y", 6000),
expectedCookieCount: calculateExpectedCookieCount(strings.Repeat("x", 5000), strings.Repeat("y", 6000)),
wantCompressed: true,
},
{
name: "REALLY long tokens, exceeding 25000 bytes",
authenticated: true,
email: "test@example.com",
accessToken: strings.Repeat("x", 25000),
refreshToken: strings.Repeat("y", 25000),
expectedCookieCount: calculateExpectedCookieCount(strings.Repeat("x", 25000), strings.Repeat("y", 25000)),
wantCompressed: true,
},
{
name: "Unauthenticated session",
authenticated: false,
email: "",
accessToken: "",
refreshToken: "",
expectedCookieCount: 3, // main, access, refresh
wantCompressed: false,
},
{
name: "Random content tokens",
authenticated: true,
email: "test@example.com",
accessToken: generateRandomString(5000),
refreshToken: generateRandomString(5000),
expectedCookieCount: calculateExpectedCookieCount(generateRandomString(5000), generateRandomString(5000)),
wantCompressed: true,
},
}
for _, tc := range tests {
tc := tc // Capture range variable
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
session, err := ts.sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
// Set session values
session.SetAuthenticated(tc.authenticated)
session.SetEmail(tc.email)
// Expire any existing cookies
session.expireAccessTokenChunks(rr)
session.expireRefreshTokenChunks(rr)
// Set new tokens
session.SetAccessToken(tc.accessToken)
session.SetRefreshToken(tc.refreshToken)
// Save session
if err := session.Save(req, rr); err != nil {
t.Fatalf("Failed to save session: %v", err)
}
// Verify cookies are set and compression is used when appropriate
cookies := rr.Result().Cookies()
if len(cookies) != tc.expectedCookieCount {
t.Errorf("Expected %d cookies, got %d", tc.expectedCookieCount, len(cookies))
}
// Verify compression is working by checking token sizes
for _, cookie := range cookies {
if strings.Contains(cookie.Name, accessTokenCookie) {
// Get original and stored sizes
originalSize := len(tc.accessToken)
storedSize := len(cookie.Value)
if originalSize > 100 && tc.wantCompressed {
// For large tokens, verify some compression occurred
compressionRatio := float64(storedSize) / float64(originalSize)
t.Logf("Access token compression ratio: %.2f (original: %d, stored: %d)",
compressionRatio, originalSize, storedSize)
if compressionRatio > 0.9 { // Allow some overhead, but should see compression
t.Errorf("Expected compression for large token in cookie %s (ratio: %.2f)",
cookie.Name, compressionRatio)
}
}
} else if strings.Contains(cookie.Name, refreshTokenCookie) {
originalSize := len(tc.refreshToken)
storedSize := len(cookie.Value)
if originalSize > 100 && tc.wantCompressed {
compressionRatio := float64(storedSize) / float64(originalSize)
t.Logf("Refresh token compression ratio: %.2f (original: %d, stored: %d)",
compressionRatio, originalSize, storedSize)
if compressionRatio > 0.9 {
t.Errorf("Expected compression for large token in cookie %s (ratio: %.2f)",
cookie.Name, compressionRatio)
}
}
}
}
// Create a new request with the cookies
newReq := httptest.NewRequest("GET", "/test", nil)
for _, cookie := range cookies {
newReq.AddCookie(cookie)
}
// Get the session again and verify values
newSession, err := ts.sessionManager.GetSession(newReq)
if err != nil {
t.Fatalf("Failed to get new session: %v", err)
}
// Verify session values
if newSession.GetAuthenticated() != tc.authenticated {
t.Errorf("Authentication status not preserved")
}
if email := newSession.GetEmail(); email != tc.email {
t.Errorf("Expected email %s, got %s", tc.email, email)
}
if token := newSession.GetAccessToken(); token != tc.accessToken {
t.Errorf("Access token not preserved: got len=%d, want len=%d", len(token), len(tc.accessToken))
}
if token := newSession.GetRefreshToken(); token != tc.refreshToken {
t.Errorf("Refresh token not preserved: got len=%d, want len=%d", len(token), len(tc.refreshToken))
}
// Verify session pooling by checking if the session is reused
session2, _ := ts.sessionManager.GetSession(newReq)
if session2 == newSession {
t.Error("Session not properly pooled")
}
})
}
}
func calculateExpectedCookieCount(accessToken, refreshToken string) int {
count := 3 // main, access, refresh
// Helper to calculate chunks for compressed token
calculateChunks := func(token string) int {
// Compress token (matching the actual implementation)
compressed := compressToken(token)
// If compressed token fits in one cookie, no additional chunks needed
if len(compressed) <= maxCookieSize {
return 0
}
// Calculate chunks needed for compressed token
return len(splitIntoChunks(compressed, maxCookieSize))
}
// Add chunks for access token if needed
accessChunks := calculateChunks(accessToken)
if accessChunks > 0 {
count += accessChunks
}
// Add chunks for refresh token if needed
refreshChunks := calculateChunks(refreshToken)
if refreshChunks > 0 {
count += refreshChunks
}
return count
}
+61 -207
View File
@@ -5,235 +5,103 @@ import (
"io"
"log"
"net/http"
"net/url"
"os"
"strings"
"github.com/gorilla/sessions"
)
// Config holds the configuration for the OIDC middleware.
// It provides all necessary settings to configure OpenID Connect authentication
// with various providers like Auth0, Logto, or any standard OIDC provider.
type Config struct {
// ProviderURL is the base URL of the OIDC provider (required)
// Example: https://accounts.google.com
ProviderURL string `json:"providerURL"`
// RevocationURL is the endpoint for revoking tokens (optional)
// If not provided, it will be discovered from provider metadata
RevocationURL string `json:"revocationURL"`
// CallbackURL is the path where the OIDC provider will redirect after authentication (required)
// Example: /oauth2/callback
CallbackURL string `json:"callbackURL"`
// LogoutURL is the path for handling logout requests (optional)
// If not provided, it will be set to CallbackURL + "/logout"
LogoutURL string `json:"logoutURL"`
// ClientID is the OAuth 2.0 client identifier (required)
ClientID string `json:"clientID"`
// ClientSecret is the OAuth 2.0 client secret (required)
ClientSecret string `json:"clientSecret"`
// Scopes defines the OAuth 2.0 scopes to request (optional)
// Defaults to ["openid", "profile", "email"] if not provided
Scopes []string `json:"scopes"`
// LogLevel sets the logging verbosity (optional)
// Valid values: "debug", "info", "error"
// Default: "info"
LogLevel string `json:"logLevel"`
// SessionEncryptionKey is used to encrypt session data (required)
// Must be a secure random string
SessionEncryptionKey string `json:"sessionEncryptionKey"`
// ForceHTTPS forces the use of HTTPS for all URLs (optional)
// Default: false
ForceHTTPS bool `json:"forceHTTPS"`
// RateLimit sets the maximum number of requests per second (optional)
// Default: 100
RateLimit int `json:"rateLimit"`
// ExcludedURLs lists paths that bypass authentication (optional)
// Example: ["/health", "/metrics"]
ExcludedURLs []string `json:"excludedURLs"`
// AllowedUserDomains restricts access to specific email domains (optional)
// Example: ["company.com", "subsidiary.com"]
AllowedUserDomains []string `json:"allowedUserDomains"`
// AllowedRolesAndGroups restricts access to users with specific roles or groups (optional)
// Example: ["admin", "developer"]
AllowedRolesAndGroups []string `json:"allowedRolesAndGroups"`
// OIDCEndSessionURL is the provider's end session endpoint (optional)
// If not provided, it will be discovered from provider metadata
OIDCEndSessionURL string `json:"oidcEndSessionURL"`
// PostLogoutRedirectURI is the URL to redirect to after logout (optional)
// Default: "/"
PostLogoutRedirectURI string `json:"postLogoutRedirectURI"`
// HTTPClient allows customizing the HTTP client used for OIDC operations (optional)
HTTPClient *http.Client
}
const (
// DefaultRateLimit defines the default rate limit for requests per second
DefaultRateLimit = 100
// MinRateLimit defines the minimum allowed rate limit to prevent DOS
MinRateLimit = 10
// DefaultLogLevel defines the default logging level
DefaultLogLevel = "info"
// MinSessionEncryptionKeyLength defines the minimum length for session encryption key
MinSessionEncryptionKeyLength = 32
cookieName = "_raczylo_oidc"
)
// CreateConfig creates a new Config with secure default values.
// Default values are set for optional fields:
// - Scopes: ["openid", "profile", "email"]
// - LogLevel: "info"
// - LogoutURL: CallbackURL + "/logout"
// - RateLimit: 100 requests per second
// - PostLogoutRedirectURI: "/"
// - ForceHTTPS: true (for security)
// Config holds the configuration for the OIDC middleware
type Config struct {
ProviderURL string `json:"providerURL"`
RevocationURL string `json:"revocationURL"`
CallbackURL string `json:"callbackURL"`
LogoutURL string `json:"logoutURL"`
ClientID string `json:"clientID"`
ClientSecret string `json:"clientSecret"`
Scopes []string `json:"scopes"`
LogLevel string `json:"logLevel"`
SessionEncryptionKey string `json:"sessionEncryptionKey"`
ForceHTTPS bool `json:"forceHTTPS"`
RateLimit int `json:"rateLimit"`
ExcludedURLs []string `json:"excludedURLs"`
AllowedUserDomains []string `json:"allowedUserDomains"`
AllowedRolesAndGroups []string `json:"allowedRolesAndGroups"`
OIDCEndSessionURL string `json:"oidcEndSessionURL"`
PostLogoutRedirectURI string `json:"postLogoutRedirectURI"`
HTTPClient *http.Client
}
var defaultSessionOptions = &sessions.Options{
HttpOnly: true,
Secure: false,
SameSite: http.SameSiteLaxMode,
MaxAge: ConstSessionTimeout,
Path: "/",
}
// CreateConfig creates a new Config with default values
func CreateConfig() *Config {
c := &Config{
Scopes: []string{"openid", "profile", "email"},
LogLevel: DefaultLogLevel,
RateLimit: DefaultRateLimit,
ForceHTTPS: true, // Secure by default
c := &Config{}
if c.Scopes == nil {
c.Scopes = []string{"openid", "profile", "email"}
}
if c.LogLevel == "" {
c.LogLevel = "info"
}
if c.LogoutURL == "" {
c.LogoutURL = c.CallbackURL + "/logout"
}
if c.RateLimit == 0 {
c.RateLimit = 100
}
return c
}
// Validate performs validation checks on the Config.
// It ensures all required fields are set and have valid values.
// Returns an error if any validation check fails.
// Validate validates the Config
func (c *Config) Validate() error {
// Validate provider URL
if c.ProviderURL == "" {
return fmt.Errorf("providerURL is required")
}
if !isValidSecureURL(c.ProviderURL) {
return fmt.Errorf("providerURL must be a valid HTTPS URL")
}
// Validate callback URL
if c.CallbackURL == "" {
return fmt.Errorf("callbackURL is required")
}
if !strings.HasPrefix(c.CallbackURL, "/") {
return fmt.Errorf("callbackURL must start with /")
}
// Validate client credentials
if c.ClientID == "" {
return fmt.Errorf("clientID is required")
}
if c.ClientSecret == "" {
return fmt.Errorf("clientSecret is required")
}
// Validate session encryption key
if c.SessionEncryptionKey == "" {
return fmt.Errorf("sessionEncryptionKey is required")
}
if len(c.SessionEncryptionKey) < MinSessionEncryptionKeyLength {
return fmt.Errorf("sessionEncryptionKey must be at least %d characters long", MinSessionEncryptionKeyLength)
}
// Validate log level
if c.LogLevel != "" && !isValidLogLevel(c.LogLevel) {
return fmt.Errorf("logLevel must be one of: debug, info, error")
}
// Validate excluded URLs
for _, url := range c.ExcludedURLs {
if !strings.HasPrefix(url, "/") {
return fmt.Errorf("excluded URL must start with /: %s", url)
}
if strings.Contains(url, "..") {
return fmt.Errorf("excluded URL must not contain path traversal: %s", url)
}
if strings.Contains(url, "*") {
return fmt.Errorf("excluded URL must not contain wildcards: %s", url)
}
}
// Validate revocation URL if set
if c.RevocationURL != "" && !isValidSecureURL(c.RevocationURL) {
return fmt.Errorf("revocationURL must be a valid HTTPS URL")
}
// Validate end session URL if set
if c.OIDCEndSessionURL != "" && !isValidSecureURL(c.OIDCEndSessionURL) {
return fmt.Errorf("oidcEndSessionURL must be a valid HTTPS URL")
}
// Validate post-logout redirect URI if set
if c.PostLogoutRedirectURI != "" && c.PostLogoutRedirectURI != "/" {
if !isValidSecureURL(c.PostLogoutRedirectURI) && !strings.HasPrefix(c.PostLogoutRedirectURI, "/") {
return fmt.Errorf("postLogoutRedirectURI must be either a valid HTTPS URL or start with /")
}
}
// Validate rate limit
if c.RateLimit < MinRateLimit {
return fmt.Errorf("rateLimit must be at least %d", MinRateLimit)
}
return nil
}
// isValidSecureURL checks if the provided string is a valid HTTPS URL
func isValidSecureURL(s string) bool {
u, err := url.Parse(s)
return err == nil && u.Scheme == "https" && u.Host != ""
}
// isValidLogLevel checks if the provided log level is valid
func isValidLogLevel(level string) bool {
return level == "debug" || level == "info" || level == "error"
}
// Logger provides structured logging capabilities with different severity levels.
// It supports error, info, and debug levels with appropriate output streams
// and formatting for each level.
// Logger is a simple logger with different levels
type Logger struct {
// logError handles error-level messages, writing to stderr
logError *log.Logger
// logInfo handles informational messages, writing to stdout
logInfo *log.Logger
// logDebug handles debug-level messages, writing to stdout when debug is enabled
logInfo *log.Logger
logDebug *log.Logger
}
// NewLogger creates a new Logger with the specified log level.
// The log level determines which messages are output:
// - "debug": Outputs all messages (debug, info, error)
// - "info": Outputs info and error messages
// - "error": Outputs only error messages
//
// Error messages are always written to stderr, while info and debug
// messages are written to stdout when enabled.
// NewLogger creates a new Logger
func NewLogger(logLevel string) *Logger {
logError := log.New(io.Discard, "ERROR: TraefikOidcPlugin: ", log.Ldate|log.Ltime)
logInfo := log.New(io.Discard, "INFO: TraefikOidcPlugin: ", log.Ldate|log.Ltime)
logDebug := log.New(io.Discard, "DEBUG: TraefikOidcPlugin: ", log.Ldate|log.Ltime)
logError.SetOutput(os.Stderr)
logInfo.SetOutput(os.Stdout)
if logLevel == "debug" || logLevel == "info" {
logInfo.SetOutput(os.Stdout)
}
if logLevel == "debug" {
logDebug.SetOutput(os.Stdout)
}
@@ -245,51 +113,37 @@ func NewLogger(logLevel string) *Logger {
}
}
// Info logs an informational message.
// These messages are intended for general operational information
// and are written to stdout.
// Info logs an info message
func (l *Logger) Info(format string, args ...interface{}) {
l.logInfo.Printf(format, args...)
}
// Debug logs a debug message.
// These messages are only output when debug level logging is enabled
// and are intended for detailed troubleshooting information.
// Debug logs a debug message
func (l *Logger) Debug(format string, args ...interface{}) {
l.logDebug.Printf(format, args...)
}
// Error logs an error message.
// These messages indicate problems that need attention and are
// always written to stderr regardless of the log level.
// Error logs an error message
func (l *Logger) Error(format string, args ...interface{}) {
l.logError.Printf(format, args...)
}
// Infof logs an informational message using Printf formatting.
// These messages are intended for general operational information
// and are written to stdout.
// Infof logs an info message
func (l *Logger) Infof(format string, args ...interface{}) {
l.logInfo.Printf(format, args...)
}
// Debugf logs a debug message using Printf formatting.
// These messages are only output when debug level logging is enabled
// and are intended for detailed troubleshooting information.
// Debugf logs a debug message
func (l *Logger) Debugf(format string, args ...interface{}) {
l.logDebug.Printf(format, args...)
}
// Errorf logs an error message using Printf formatting.
// These messages indicate problems that need attention and are
// always written to stderr regardless of the log level.
// Errorf logs an error message
func (l *Logger) Errorf(format string, args ...interface{}) {
l.logError.Printf(format, args...)
}
// handleError writes an error message to both the HTTP response and the error log.
// It ensures consistent error handling across the middleware by logging the error
// and sending an appropriate HTTP response to the client.
// handleError writes an error message to the response and logs it
func handleError(w http.ResponseWriter, message string, code int, logger *Logger) {
logger.Error(message)
http.Error(w, message, code)
-397
View File
@@ -1,397 +0,0 @@
package traefikoidc
import (
"bytes"
"log"
"net/http"
"testing"
)
func TestCreateConfig(t *testing.T) {
t.Run("Default Values", func(t *testing.T) {
config := CreateConfig()
// Check default scopes
expectedScopes := []string{"openid", "profile", "email"}
if len(config.Scopes) != len(expectedScopes) {
t.Errorf("Expected %d default scopes, got %d", len(expectedScopes), len(config.Scopes))
}
for i, scope := range expectedScopes {
if config.Scopes[i] != scope {
t.Errorf("Expected scope %s at position %d, got %s", scope, i, config.Scopes[i])
}
}
// Check default log level
if config.LogLevel != DefaultLogLevel {
t.Errorf("Expected default log level '%s', got '%s'", DefaultLogLevel, config.LogLevel)
}
// Check default rate limit
if config.RateLimit != DefaultRateLimit {
t.Errorf("Expected default rate limit %d, got %d", DefaultRateLimit, config.RateLimit)
}
// Check ForceHTTPS default
if !config.ForceHTTPS {
t.Error("Expected ForceHTTPS to be true by default")
}
})
t.Run("Custom Values Preserved", func(t *testing.T) {
config := CreateConfig()
config.Scopes = []string{"custom_scope"}
config.LogLevel = "debug"
config.RateLimit = 50
config.ForceHTTPS = false
// Verify custom values are not overwritten
if len(config.Scopes) != 1 || config.Scopes[0] != "custom_scope" {
t.Error("Custom scopes were overwritten")
}
if config.LogLevel != "debug" {
t.Error("Custom log level was overwritten")
}
if config.RateLimit != 50 {
t.Error("Custom rate limit was overwritten")
}
if config.ForceHTTPS {
t.Error("Custom ForceHTTPS value was overwritten")
}
})
}
func TestConfigValidate(t *testing.T) {
tests := []struct {
name string
config *Config
expectedError string
}{
{
name: "Empty Config",
config: &Config{},
expectedError: "providerURL is required",
},
{
name: "Missing CallbackURL",
config: &Config{
ProviderURL: "https://provider.com",
},
expectedError: "callbackURL is required",
},
{
name: "Missing ClientID",
config: &Config{
ProviderURL: "https://provider.com",
CallbackURL: "/callback",
},
expectedError: "clientID is required",
},
{
name: "Missing ClientSecret",
config: &Config{
ProviderURL: "https://provider.com",
CallbackURL: "/callback",
ClientID: "client-id",
},
expectedError: "clientSecret is required",
},
{
name: "Missing SessionEncryptionKey",
config: &Config{
ProviderURL: "https://provider.com",
CallbackURL: "/callback",
ClientID: "client-id",
ClientSecret: "client-secret",
},
expectedError: "sessionEncryptionKey is required",
},
{
name: "Non-HTTPS ProviderURL",
config: &Config{
ProviderURL: "http://provider.com",
CallbackURL: "/callback",
ClientID: "client-id",
ClientSecret: "client-secret",
SessionEncryptionKey: "encryption-key",
},
expectedError: "providerURL must be a valid HTTPS URL",
},
{
name: "Invalid CallbackURL",
config: &Config{
ProviderURL: "https://provider.com",
CallbackURL: "callback", // Missing leading slash
ClientID: "client-id",
ClientSecret: "client-secret",
SessionEncryptionKey: "encryption-key",
},
expectedError: "callbackURL must start with /",
},
{
name: "Short SessionEncryptionKey",
config: &Config{
ProviderURL: "https://provider.com",
CallbackURL: "/callback",
ClientID: "client-id",
ClientSecret: "client-secret",
SessionEncryptionKey: "short",
},
expectedError: "sessionEncryptionKey must be at least 32 characters long",
},
{
name: "Low RateLimit",
config: &Config{
ProviderURL: "https://provider.com",
CallbackURL: "/callback",
ClientID: "client-id",
ClientSecret: "client-secret",
SessionEncryptionKey: "this-is-a-long-enough-encryption-key",
RateLimit: 5,
},
expectedError: "rateLimit must be at least 10",
},
{
name: "Invalid LogLevel",
config: &Config{
ProviderURL: "https://provider.com",
CallbackURL: "/callback",
ClientID: "client-id",
ClientSecret: "client-secret",
SessionEncryptionKey: "this-is-a-long-enough-encryption-key",
LogLevel: "invalid",
},
expectedError: "logLevel must be one of: debug, info, error",
},
{
name: "Non-HTTPS RevocationURL",
config: &Config{
ProviderURL: "https://provider.com",
CallbackURL: "/callback",
ClientID: "client-id",
ClientSecret: "client-secret",
SessionEncryptionKey: "this-is-a-long-enough-encryption-key",
RevocationURL: "http://revoke.com",
},
expectedError: "revocationURL must be a valid HTTPS URL",
},
{
name: "Non-HTTPS OIDCEndSessionURL",
config: &Config{
ProviderURL: "https://provider.com",
CallbackURL: "/callback",
ClientID: "client-id",
ClientSecret: "client-secret",
SessionEncryptionKey: "this-is-a-long-enough-encryption-key",
OIDCEndSessionURL: "http://endsession.com",
},
expectedError: "oidcEndSessionURL must be a valid HTTPS URL",
},
{
name: "Valid Config",
config: &Config{
ProviderURL: "https://provider.com",
CallbackURL: "/callback",
ClientID: "client-id",
ClientSecret: "client-secret",
SessionEncryptionKey: "this-is-a-long-enough-encryption-key",
LogLevel: "debug",
RateLimit: 100,
RevocationURL: "https://revoke.com",
OIDCEndSessionURL: "https://endsession.com",
},
expectedError: "",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
err := tc.config.Validate()
if tc.expectedError == "" {
if err != nil {
t.Errorf("Expected no error, got: %v", err)
}
} else {
if err == nil {
t.Errorf("Expected error containing '%s', got nil", tc.expectedError)
} else if err.Error() != tc.expectedError {
t.Errorf("Expected error '%s', got '%s'", tc.expectedError, err.Error())
}
}
})
}
}
func TestLogger(t *testing.T) {
// Capture log output
var debugBuf, infoBuf, errorBuf bytes.Buffer
tests := []struct {
name string
logLevel string
testFunc func(*Logger)
checkFunc func(t *testing.T, debugOut, infoOut, errorOut string)
}{
{
name: "Debug Level",
logLevel: "debug",
testFunc: func(l *Logger) {
l.Debug("debug message")
l.Info("info message")
l.Error("error message")
},
checkFunc: func(t *testing.T, debugOut, infoOut, errorOut string) {
if debugOut == "" {
t.Error("Expected debug message in output")
}
if infoOut == "" {
t.Error("Expected info message in output")
}
if errorOut == "" {
t.Error("Expected error message in output")
}
},
},
{
name: "Info Level",
logLevel: "info",
testFunc: func(l *Logger) {
l.Debug("debug message")
l.Info("info message")
l.Error("error message")
},
checkFunc: func(t *testing.T, debugOut, infoOut, errorOut string) {
if debugOut != "" {
t.Error("Did not expect debug message in output")
}
if infoOut == "" {
t.Error("Expected info message in output")
}
if errorOut == "" {
t.Error("Expected error message in output")
}
},
},
{
name: "Error Level",
logLevel: "error",
testFunc: func(l *Logger) {
l.Debug("debug message")
l.Info("info message")
l.Error("error message")
},
checkFunc: func(t *testing.T, debugOut, infoOut, errorOut string) {
if debugOut != "" {
t.Error("Did not expect debug message in output")
}
if infoOut != "" {
t.Error("Did not expect info message in output")
}
if errorOut == "" {
t.Error("Expected error message in output")
}
},
},
{
name: "Printf Methods",
logLevel: "debug",
testFunc: func(l *Logger) {
l.Debugf("debug %s", "formatted")
l.Infof("info %s", "formatted")
l.Errorf("error %s", "formatted")
},
checkFunc: func(t *testing.T, debugOut, infoOut, errorOut string) {
if !bytes.Contains([]byte(debugOut), []byte("debug formatted")) {
t.Error("Expected formatted debug message")
}
if !bytes.Contains([]byte(infoOut), []byte("info formatted")) {
t.Error("Expected formatted info message")
}
if !bytes.Contains([]byte(errorOut), []byte("error formatted")) {
t.Error("Expected formatted error message")
}
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Reset buffers
debugBuf.Reset()
infoBuf.Reset()
errorBuf.Reset()
// Create logger with test buffers
logger := NewLogger(tc.logLevel)
logger.logError.SetOutput(&errorBuf)
if tc.logLevel == "debug" || tc.logLevel == "info" {
logger.logInfo.SetOutput(&infoBuf)
}
if tc.logLevel == "debug" {
logger.logDebug.SetOutput(&debugBuf)
}
// Run test
tc.testFunc(logger)
// Check results
tc.checkFunc(t, debugBuf.String(), infoBuf.String(), errorBuf.String())
})
}
}
func TestHandleError(t *testing.T) {
// Create a test logger with captured output
var errorBuf bytes.Buffer
logger := &Logger{
logError: log.New(&errorBuf, "ERROR: ", log.Ldate|log.Ltime),
}
logger.logError.SetOutput(&errorBuf)
// Create a test response recorder
rr := &testResponseRecorder{
headers: make(map[string][]string),
}
// Test error handling
message := "test error message"
code := 400
handleError(rr, message, code, logger)
// Check response code
if rr.statusCode != code {
t.Errorf("Expected status code %d, got %d", code, rr.statusCode)
}
// Check response body
expectedBody := message + "\n"
if rr.body != expectedBody {
t.Errorf("Expected body %q, got %q", expectedBody, rr.body)
}
// Check error was logged
if !bytes.Contains(errorBuf.Bytes(), []byte(message)) {
t.Error("Error message was not logged")
}
}
// Test helper types
type testResponseRecorder struct {
statusCode int
body string
headers map[string][]string
}
func (r *testResponseRecorder) Header() http.Header {
return r.headers
}
func (r *testResponseRecorder) Write(b []byte) (int, error) {
r.body = string(b)
return len(b), nil
}
func (r *testResponseRecorder) WriteHeader(code int) {
r.statusCode = code
}