Compare commits

...

9 Commits

Author SHA1 Message Date
lukaszraczylo 9402f1bca5 Token blacklist, cache and metadata improvements
TokenBlacklist Improvements:
Fixed size limit enforcement to properly maintain max size of 1000 tokens
Improved eviction strategy to remove expired tokens first before removing oldest
Added proper cleanup of tokens during Add operation to prevent size overflow
Fixed oldest token eviction logic to ensure correct token removal
Added proper locking mechanisms to prevent race conditions
Cache Improvements:
Fixed cleanup mechanism to only remove truly expired items
Improved eviction strategy in LRU cache to prioritize expired items
Added smarter eviction in evictOldest to scan for expired items first
Fixed aggressive cleanup that was removing valid items
Maintained proper LRU ordering while handling evictions
MetadataCache:
Verified proper implementation of metadata caching with hourly refresh
Confirmed proper handling of cache extension on fetch failures
Validated thread-safe operations with proper RWMutex usage
2025-02-09 23:53:05 +00:00
lukaszraczylo e6205b3a48 Add metadata caching capability to avoid unnecesary API calls 2025-02-09 23:37:50 +00:00
lukaszraczylo fdb8e3233e Testing (could be unstable) additional headers.
This adds additional headers to control the access origin and control allow headers.
2025-02-06 23:46:08 +00:00
lukaszraczylo 33c71fd6fe Enhance test suite. 2025-02-06 23:38:22 +00:00
lukaszraczylo 241cb1c209 Deal with the memory growth issue.
* TokenBlacklist limit is set to 1000
* Increased token cleanup frequency
2025-02-06 23:34:05 +00:00
lukaszraczylo 09daa1025c Follow multiple redirects during the OIDC flow. 2025-02-06 23:31:13 +00:00
lukaszraczylo c09e7a9228 Add additional test cases to cover it. 2025-02-06 21:50:35 +00:00
lukaszraczylo e5da5d4fe9 Fix redirection to the provider when session expires 2025-02-06 21:48:56 +00:00
lukaszraczylo 31db701dda Trigger build and release. 2025-02-05 19:04:44 +00:00
8 changed files with 742 additions and 75 deletions
+1 -1
View File
@@ -1,4 +1,4 @@
## TODO / wishlist
### TODO / wishlist
- [] Improve test coverage
- [x] Improve caching mechanism
+110
View File
@@ -0,0 +1,110 @@
package traefikoidc
import (
"sync"
"time"
)
// TokenBlacklist manages a thread-safe list of revoked tokens with expiration.
type TokenBlacklist struct {
tokens map[string]time.Time
mutex sync.RWMutex
}
// NewTokenBlacklist creates a new token blacklist instance.
func NewTokenBlacklist() *TokenBlacklist {
return &TokenBlacklist{
tokens: make(map[string]time.Time),
}
}
// Add adds a token to the blacklist with an expiration time.
func (b *TokenBlacklist) Add(token string, expiry time.Time) {
b.mutex.Lock()
defer b.mutex.Unlock()
// Clean up expired tokens if we're at capacity
if len(b.tokens) >= 1000 {
now := time.Now()
futureThreshold := now.Add(time.Minute)
for t, exp := range b.tokens {
if now.After(exp) || futureThreshold.After(exp) {
delete(b.tokens, t)
}
}
// If still at capacity, remove oldest token
if len(b.tokens) >= 1000 {
var oldestToken string
var oldestTime time.Time
first := true
for t, exp := range b.tokens {
if first || exp.Before(oldestTime) {
oldestToken = t
oldestTime = exp
first = false
}
}
if oldestToken != "" {
delete(b.tokens, oldestToken)
}
}
}
b.tokens[token] = expiry
}
// IsBlacklisted checks if a token is in the blacklist and not expired.
func (b *TokenBlacklist) IsBlacklisted(token string) bool {
b.mutex.RLock()
defer b.mutex.RUnlock()
expiry, exists := b.tokens[token]
if !exists {
return false
}
// If token is expired, remove it and return false
if time.Now().After(expiry) {
// Switch to write lock to remove expired token
b.mutex.RUnlock()
b.mutex.Lock()
delete(b.tokens, token)
b.mutex.Unlock()
b.mutex.RLock()
return false
}
return true
}
// Cleanup removes expired tokens from the blacklist.
// Also removes tokens that will expire within the next minute to prevent edge cases.
func (b *TokenBlacklist) Cleanup() {
b.mutex.Lock()
defer b.mutex.Unlock()
now := time.Now()
futureThreshold := now.Add(time.Minute)
for token, expiry := range b.tokens {
// Remove tokens that are expired or will expire soon
if now.After(expiry) || futureThreshold.After(expiry) {
delete(b.tokens, token)
}
}
}
// Remove removes a token from the blacklist regardless of its expiration.
func (b *TokenBlacklist) Remove(token string) {
b.mutex.Lock()
defer b.mutex.Unlock()
delete(b.tokens, token)
}
// Count returns the current number of tokens in the blacklist.
func (b *TokenBlacklist) Count() int {
b.mutex.RLock()
defer b.mutex.RUnlock()
return len(b.tokens)
}
+17 -1
View File
@@ -128,6 +128,7 @@ func (c *Cache) Cleanup() {
now := time.Now()
for key, item := range c.items {
// Only remove items that are already expired
if now.After(item.ExpiresAt) {
c.removeItem(key)
}
@@ -136,8 +137,23 @@ func (c *Cache) Cleanup() {
// evictOldest removes the least recently used item from the cache.
func (c *Cache) evictOldest() {
now := time.Now()
elem := c.order.Front()
if elem != nil {
// First try to find an expired item from the front
for elem != nil {
entry := elem.Value.(lruEntry)
if item, exists := c.items[entry.key]; exists {
if now.After(item.ExpiresAt) {
c.removeItem(entry.key)
return
}
}
elem = elem.Next()
}
// If no expired items found, remove the oldest item
if elem = c.order.Front(); elem != nil {
entry := elem.Value.(lruEntry)
c.removeItem(entry.key)
}
+17 -47
View File
@@ -8,9 +8,9 @@ import (
"fmt"
"io"
"net/http"
"net/http/cookiejar"
"net/url"
"strings"
"sync"
"time"
)
@@ -68,13 +68,28 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken
data.Set("refresh_token", codeOrToken)
}
// Create a cookie jar for this request to handle redirects with cookies
jar, _ := cookiejar.New(nil)
client := &http.Client{
Transport: t.httpClient.Transport,
Timeout: t.httpClient.Timeout,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
// Always follow redirects for OIDC endpoints
if len(via) >= 50 {
return fmt.Errorf("stopped after 50 redirects")
}
return nil
},
Jar: jar,
}
req, err := http.NewRequestWithContext(ctx, "POST", t.tokenURL, strings.NewReader(data.Encode()))
if err != nil {
return nil, fmt.Errorf("failed to create token request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := t.httpClient.Do(req)
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to exchange tokens: %w", err)
}
@@ -267,51 +282,6 @@ func extractClaims(tokenString string) (map[string]interface{}, error) {
return claims, nil
}
// TokenBlacklist maintains a thread-safe list of revoked tokens.
// It stores tokens with their expiration times and automatically
// removes expired entries during cleanup operations.
type TokenBlacklist struct {
// blacklist maps token IDs to their expiration times
blacklist map[string]time.Time
// mutex protects concurrent access to the blacklist
mutex sync.RWMutex
}
// NewTokenBlacklist creates a new TokenBlacklist instance.
func NewTokenBlacklist() *TokenBlacklist {
return &TokenBlacklist{
blacklist: make(map[string]time.Time),
}
}
// Add adds a token to the blacklist with an expiration time.
func (tb *TokenBlacklist) Add(tokenID string, expiration time.Time) {
tb.mutex.Lock()
defer tb.mutex.Unlock()
tb.blacklist[tokenID] = expiration
}
// IsBlacklisted checks if a token is in the blacklist and not expired.
func (tb *TokenBlacklist) IsBlacklisted(tokenID string) bool {
tb.mutex.RLock()
defer tb.mutex.RUnlock()
expiration, exists := tb.blacklist[tokenID]
return exists && time.Now().Before(expiration)
}
// Cleanup removes expired tokens from the blacklist.
func (tb *TokenBlacklist) Cleanup() {
tb.mutex.Lock()
defer tb.mutex.Unlock()
now := time.Now()
for tokenID, expiration := range tb.blacklist {
if now.After(expiration) {
delete(tb.blacklist, tokenID)
}
}
}
// TokenCache provides a caching mechanism for validated tokens.
// It stores token claims to avoid repeated validation of the
// same token, improving performance for frequently used tokens.
+227
View File
@@ -0,0 +1,227 @@
package traefikoidc
import (
"fmt"
"runtime"
"testing"
"time"
)
func TestTokenBlacklistSizeLimit(t *testing.T) {
tb := NewTokenBlacklist()
// Add tokens up to maxSize
for i := 0; i < 1000; i++ {
tb.Add(fmt.Sprintf("token%d", i), time.Now().Add(time.Hour))
}
// Verify size is at max
if tb.Count() != 1000 {
t.Errorf("Expected blacklist size to be 1000, got %d", tb.Count())
}
// Add one more token, should trigger cleanup/eviction
tb.Add("newtoken", time.Now().Add(time.Hour))
// Size should still be at max
if tb.Count() > 1000 {
t.Errorf("Blacklist exceeded max size: %d", tb.Count())
}
}
func TestTokenBlacklistExpiredCleanup(t *testing.T) {
tb := NewTokenBlacklist()
// Add some expired tokens
for i := 0; i < 500; i++ {
tb.Add(fmt.Sprintf("expired%d", i), time.Now().Add(-time.Hour))
}
// Add some valid tokens
for i := 0; i < 500; i++ {
tb.Add(fmt.Sprintf("valid%d", i), time.Now().Add(time.Hour))
}
// Force cleanup
tb.Cleanup()
// Only valid tokens should remain
if tb.Count() != 500 {
t.Errorf("Expected 500 valid tokens after cleanup, got %d", tb.Count())
}
// Verify only valid tokens remain
tb.mutex.RLock()
defer tb.mutex.RUnlock()
for token, expiry := range tb.tokens {
if time.Now().After(expiry) {
t.Errorf("Found expired token after cleanup: %s", token)
}
}
}
func TestTokenBlacklistOldestEviction(t *testing.T) {
tb := NewTokenBlacklist()
// Add tokens at capacity with different expiration times
baseTime := time.Now()
oldestToken := "oldest"
// Add oldest token first
tb.Add(oldestToken, baseTime.Add(time.Hour))
// Fill up to capacity with newer tokens
for i := 0; i < 999; i++ {
tb.Add(fmt.Sprintf("token%d", i), baseTime.Add(time.Hour*2))
}
// Add a new token that should evict the oldest
newToken := "newest"
tb.Add(newToken, baseTime.Add(time.Hour*3))
// Verify oldest token was evicted
if tb.IsBlacklisted(oldestToken) {
t.Error("Oldest token should have been evicted")
}
// Verify newest token is present
if !tb.IsBlacklisted(newToken) {
t.Error("Newest token should be present")
}
}
func TestTokenBlacklistMemoryUsage(t *testing.T) {
tb := NewTokenBlacklist()
iterations := 10000
// Force initial GC
runtime.GC()
// Record initial memory stats
var m1, m2 runtime.MemStats
runtime.ReadMemStats(&m1)
// Simulate heavy usage
for i := 0; i < iterations; i++ {
// Add new token
tb.Add(fmt.Sprintf("token%d", i), time.Now().Add(time.Hour))
// Periodically check blacklisted status
if i%100 == 0 {
tb.IsBlacklisted(fmt.Sprintf("token%d", i-50))
}
// Periodically cleanup
if i%1000 == 0 {
tb.Cleanup()
}
}
// Force GC and wait for it to complete
runtime.GC()
time.Sleep(100 * time.Millisecond)
runtime.ReadMemStats(&m2)
// Check memory growth (using HeapAlloc for more accurate measurement)
memoryGrowth := int64(m2.HeapAlloc - m1.HeapAlloc)
maxAllowedGrowth := int64(2 * 1024 * 1024) // 2MB max growth
if memoryGrowth > maxAllowedGrowth {
t.Logf("Initial HeapAlloc: %d, Final HeapAlloc: %d", m1.HeapAlloc, m2.HeapAlloc)
t.Errorf("Excessive memory growth: %d bytes", memoryGrowth)
}
// Verify size stayed within limits
if tb.Count() > 1000 {
t.Errorf("Blacklist exceeded max size: %d", tb.Count())
}
}
func TestConcurrentTokenBlacklistOperations(t *testing.T) {
tb := NewTokenBlacklist()
iterations := 1000
concurrency := 10
done := make(chan bool)
// Start multiple goroutines performing operations
for i := 0; i < concurrency; i++ {
go func(id int) {
for j := 0; j < iterations; j++ {
// Add tokens
token := fmt.Sprintf("token%d-%d", id, j)
tb.Add(token, time.Now().Add(time.Hour))
// Check blacklist status
tb.IsBlacklisted(token)
// Periodic cleanup
if j%100 == 0 {
tb.Cleanup()
}
}
done <- true
}(i)
}
// Wait for all goroutines to complete
for i := 0; i < concurrency; i++ {
<-done
}
// Verify size constraints were maintained
if tb.Count() > 1000 {
t.Errorf("Blacklist exceeded max size under concurrent operations: %d", tb.Count())
}
}
func TestTokenCacheMemoryUsage(t *testing.T) {
tc := NewTokenCache()
iterations := 10000
// Force initial GC
runtime.GC()
// Record initial memory stats
var m1, m2 runtime.MemStats
runtime.ReadMemStats(&m1)
// Simulate heavy cache usage
for i := 0; i < iterations; i++ {
claims := map[string]interface{}{
"sub": fmt.Sprintf("user%d", i),
"exp": time.Now().Add(time.Hour).Unix(),
}
// Add to cache
tc.Set(fmt.Sprintf("token%d", i), claims, time.Hour)
// Periodically retrieve
if i%100 == 0 {
tc.Get(fmt.Sprintf("token%d", i-50))
}
// Periodically cleanup
if i%1000 == 0 {
tc.Cleanup()
}
}
// Force GC and wait for it to complete
runtime.GC()
time.Sleep(100 * time.Millisecond)
runtime.ReadMemStats(&m2)
// Check memory growth (using HeapAlloc for more accurate measurement)
memoryGrowth := int64(m2.HeapAlloc - m1.HeapAlloc)
maxAllowedGrowth := int64(2 * 1024 * 1024) // 2MB max growth
if memoryGrowth > maxAllowedGrowth {
t.Logf("Initial HeapAlloc: %d, Final HeapAlloc: %d", m1.HeapAlloc, m2.HeapAlloc)
t.Errorf("Excessive cache memory growth: %d bytes", memoryGrowth)
}
// Verify cache size stayed within limits
if len(tc.cache.items) > tc.cache.maxSize {
t.Errorf("Cache exceeded max size: %d", len(tc.cache.items))
}
}
+120 -26
View File
@@ -39,6 +39,7 @@ type TraefikOidc struct {
issuerURL string
revocationURL string
jwkCache JWKCacheInterface
metadataCache *MetadataCache
tokenBlacklist *TokenBlacklist
jwksURL string
clientID string
@@ -225,6 +226,13 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
httpClient = &http.Client{
Timeout: time.Second * 15, // Reduced timeout
Transport: transport,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
// Always follow redirects for OIDC endpoints
if len(via) >= 50 {
return fmt.Errorf("stopped after 50 redirects")
}
return nil
},
}
}
@@ -246,6 +254,7 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
}(),
tokenBlacklist: NewTokenBlacklist(),
jwkCache: &JWKCache{},
metadataCache: NewMetadataCache(),
clientID: config.ClientID,
clientSecret: config.ClientSecret,
forceHTTPS: config.ForceHTTPS,
@@ -285,40 +294,55 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
func (t *TraefikOidc) initializeMetadata(providerURL string) {
t.logger.Debug("Starting provider metadata discovery")
// Keep retrying until successful
backoff := time.Second
maxBackoff := 30 * time.Second
for {
metadata, err := discoverProviderMetadata(providerURL, t.httpClient, t.logger)
// Get metadata from cache or fetch it
metadata, err := t.metadataCache.GetMetadata(providerURL, t.httpClient, t.logger)
if err != nil {
t.logger.Errorf("Failed to get provider metadata: %v", err)
return
}
if metadata != nil {
t.logger.Debug("Successfully initialized provider metadata")
t.jwksURL = metadata.JWKSURL
t.authURL = metadata.AuthURL
t.tokenURL = metadata.TokenURL
t.issuerURL = metadata.Issuer
t.revocationURL = metadata.RevokeURL
t.endSessionURL = metadata.EndSessionURL
// Start metadata refresh goroutine
go t.startMetadataRefresh(providerURL)
// Only close channel on success
close(t.initComplete)
return
}
t.logger.Error("Received nil metadata")
}
// startMetadataRefresh periodically refreshes the OIDC metadata
func (t *TraefikOidc) startMetadataRefresh(providerURL string) {
ticker := time.NewTicker(1 * time.Hour)
defer ticker.Stop()
for range ticker.C {
t.logger.Debug("Refreshing OIDC metadata")
metadata, err := t.metadataCache.GetMetadata(providerURL, t.httpClient, t.logger)
if err != nil {
t.logger.Errorf("Failed to discover provider metadata: %v, retrying in %v", err, backoff)
time.Sleep(backoff)
// Exponential backoff with max
backoff *= 2
if backoff > maxBackoff {
backoff = maxBackoff
}
t.logger.Errorf("Failed to refresh metadata: %v", err)
continue
}
if metadata != nil {
t.logger.Debug("Successfully initialized provider metadata")
t.jwksURL = metadata.JWKSURL
t.authURL = metadata.AuthURL
t.tokenURL = metadata.TokenURL
t.issuerURL = metadata.Issuer
t.revocationURL = metadata.RevokeURL
t.endSessionURL = metadata.EndSessionURL
// Only close channel on success
close(t.initComplete)
return
t.logger.Debug("Successfully refreshed metadata")
}
t.logger.Error("Received nil metadata, retrying")
time.Sleep(backoff)
}
}
@@ -509,7 +533,35 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
// Set user information in headers
req.Header.Set("X-Forwarded-User", email)
// Set OIDC-specific headers
req.Header.Set("X-Auth-Request-Redirect", req.URL.RequestURI())
req.Header.Set("X-Auth-Request-User", email)
if idToken := session.GetAccessToken(); idToken != "" {
req.Header.Set("X-Auth-Request-Token", idToken)
}
// Set security headers
rw.Header().Set("X-Frame-Options", "DENY")
rw.Header().Set("X-Content-Type-Options", "nosniff")
rw.Header().Set("X-XSS-Protection", "1; mode=block")
rw.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
// Set CORS headers
origin := req.Header.Get("Origin")
if origin != "" {
rw.Header().Set("Access-Control-Allow-Origin", origin)
rw.Header().Set("Access-Control-Allow-Credentials", "true")
rw.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
rw.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
// Handle preflight requests
if req.Method == "OPTIONS" {
rw.WriteHeader(http.StatusOK)
return
}
}
// Process the request
t.next.ServeHTTP(rw, req)
}
@@ -639,17 +691,59 @@ func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce string) string {
if len(t.scopes) > 0 {
params.Set("scope", strings.Join(t.scopes, " "))
}
// Ensure authURL is absolute
if !strings.HasPrefix(t.authURL, "http://") && !strings.HasPrefix(t.authURL, "https://") {
// Extract issuer base URL
issuerURL, err := url.Parse(t.issuerURL)
if err == nil {
return fmt.Sprintf("%s://%s%s?%s",
issuerURL.Scheme,
issuerURL.Host,
t.authURL,
params.Encode())
}
}
return t.authURL + "?" + params.Encode()
}
// startTokenCleanup starts the token cleanup goroutine
func (t *TraefikOidc) startTokenCleanup() {
ticker := time.NewTicker(1 * time.Minute)
ctx, cancel := context.WithCancel(context.Background())
ticker := time.NewTicker(15 * time.Second) // More frequent cleanup
go func() {
defer ticker.Stop()
defer cancel()
for range ticker.C {
t.logger.Debug("Cleaning up token cache")
t.tokenCache.Cleanup()
t.tokenBlacklist.Cleanup()
t.logger.Debug("Starting token cleanup cycle")
// Run cleanup in a separate goroutine with shorter timeout
cleanupCtx, cleanupCancel := context.WithTimeout(ctx, 5*time.Second)
done := make(chan struct{})
go func() {
defer close(done)
// Clean up in smaller batches to prevent long-running operations
t.tokenCache.Cleanup()
t.tokenBlacklist.Cleanup()
// Force garbage collection after cleanup
runtime.GC()
}()
// Wait for cleanup to complete or timeout
select {
case <-cleanupCtx.Done():
if cleanupCtx.Err() == context.DeadlineExceeded {
t.logger.Error("Token cleanup cycle timed out")
}
case <-done:
t.logger.Debug("Token cleanup cycle completed successfully")
}
cleanupCancel()
}
}()
}
+196
View File
@@ -13,6 +13,7 @@ import (
"math/big"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
@@ -1647,6 +1648,111 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
// Helper function to compare string slices
// TestExchangeTokensWithRedirects tests the token exchange process with redirects
func TestExchangeTokensWithRedirects(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
tests := []struct {
name string
setupServer func() *httptest.Server
expectError bool
errorContains string
}{
{
name: "Successful token exchange with redirects",
setupServer: func() *httptest.Server {
redirectCount := 0
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if redirectCount < 3 {
// Set a cookie before redirecting
http.SetCookie(w, &http.Cookie{
Name: fmt.Sprintf("redirect-cookie-%d", redirectCount),
Value: "test-value",
})
redirectCount++
w.Header().Set("Location", r.URL.String())
w.WriteHeader(http.StatusFound)
return
}
// Verify all cookies from previous redirects are present
cookies := r.Cookies()
if len(cookies) != 3 {
t.Errorf("Expected 3 cookies, got %d", len(cookies))
}
for i := 0; i < 3; i++ {
found := false
expectedName := fmt.Sprintf("redirect-cookie-%d", i)
for _, cookie := range cookies {
if cookie.Name == expectedName {
found = true
break
}
}
if !found {
t.Errorf("Cookie %s not found", expectedName)
}
}
// Return successful token response
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(TokenResponse{
IDToken: "test.id.token",
AccessToken: "test-access-token",
TokenType: "Bearer",
ExpiresIn: 3600,
RefreshToken: "test-refresh-token",
})
}))
},
expectError: false,
},
{
name: "Too many redirects",
setupServer: func() *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Location", r.URL.String())
w.WriteHeader(http.StatusFound)
}))
},
expectError: true,
errorContains: "stopped after 50 redirects",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
server := tc.setupServer()
defer server.Close()
// Configure the test instance
tOidc := ts.tOidc
tOidc.tokenURL = server.URL
// Test token exchange
response, err := tOidc.exchangeTokens(context.Background(), "authorization_code", "test-code", "http://callback")
if tc.expectError {
if err == nil {
t.Error("Expected error but got nil")
} else if !strings.Contains(err.Error(), tc.errorContains) {
t.Errorf("Expected error containing %q, got %q", tc.errorContains, err.Error())
}
} else {
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if response == nil {
t.Error("Expected token response but got nil")
} else if response.IDToken != "test.id.token" {
t.Errorf("Expected ID token %q, got %q", "test.id.token", response.IDToken)
}
}
})
}
}
func stringSliceEqual(a, b []string) bool {
if len(a) != len(b) {
return false
@@ -1659,6 +1765,96 @@ func stringSliceEqual(a, b []string) bool {
return true
}
// TestBuildAuthURL tests the buildAuthURL function with various URL scenarios
func TestBuildAuthURL(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
tests := []struct {
name string
authURL string
issuerURL string
redirectURL string
state string
nonce string
expectedPrefix string
}{
{
name: "Absolute Auth URL",
authURL: "https://auth.example.com/oauth/authorize",
issuerURL: "https://auth.example.com",
redirectURL: "https://app.example.com/callback",
state: "test-state",
nonce: "test-nonce",
expectedPrefix: "https://auth.example.com/oauth/authorize?",
},
{
name: "Relative Auth URL",
authURL: "/oidc/auth",
issuerURL: "https://logto.example.com",
redirectURL: "https://app.example.com/callback",
state: "test-state",
nonce: "test-nonce",
expectedPrefix: "https://logto.example.com/oidc/auth?",
},
{
name: "Relative Auth URL with Different Issuer",
authURL: "/sign-in",
issuerURL: "https://auth.example.com:8443",
redirectURL: "https://app.example.com/callback",
state: "test-state",
nonce: "test-nonce",
expectedPrefix: "https://auth.example.com:8443/sign-in?",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Configure the test instance
tOidc := ts.tOidc
tOidc.authURL = tc.authURL
tOidc.issuerURL = tc.issuerURL
// Call buildAuthURL
result := tOidc.buildAuthURL(tc.redirectURL, tc.state, tc.nonce)
// Verify the URL starts with the expected prefix
if !strings.HasPrefix(result, tc.expectedPrefix) {
t.Errorf("Expected URL to start with %q, got %q", tc.expectedPrefix, result)
}
// Parse the resulting URL to verify query parameters
parsedURL, err := url.Parse(result)
if err != nil {
t.Fatalf("Failed to parse resulting URL: %v", err)
}
query := parsedURL.Query()
expectedParams := map[string]string{
"client_id": tOidc.clientID,
"response_type": "code",
"redirect_uri": tc.redirectURL,
"state": tc.state,
"nonce": tc.nonce,
}
for key, expected := range expectedParams {
if got := query.Get(key); got != expected {
t.Errorf("Expected %s=%q, got %q", key, expected, got)
}
}
// Verify scopes are present and correct
if len(tOidc.scopes) > 0 {
expectedScopes := strings.Join(tOidc.scopes, " ")
if got := query.Get("scope"); got != expectedScopes {
t.Errorf("Expected scope=%q, got %q", expectedScopes, got)
}
}
})
}
}
// TestDefaultInitiateAuthentication_PreservesQueryParameters tests that defaultInitiateAuthentication preserves query parameters in the incoming path.
func TestDefaultInitiateAuthentication_PreservesQueryParameters(t *testing.T) {
ts := &TestSuite{t: t}
+54
View File
@@ -0,0 +1,54 @@
package traefikoidc
import (
"fmt"
"net/http"
"sync"
"time"
)
// MetadataCache provides thread-safe caching for OIDC provider metadata
type MetadataCache struct {
metadata *ProviderMetadata
expiresAt time.Time
mutex sync.RWMutex
}
// NewMetadataCache creates a new metadata cache instance
func NewMetadataCache() *MetadataCache {
return &MetadataCache{}
}
// GetMetadata retrieves the metadata from cache or fetches it if expired
func (c *MetadataCache) GetMetadata(providerURL string, httpClient *http.Client, logger *Logger) (*ProviderMetadata, error) {
c.mutex.RLock()
if c.metadata != nil && time.Now().Before(c.expiresAt) {
defer c.mutex.RUnlock()
return c.metadata, nil
}
c.mutex.RUnlock()
c.mutex.Lock()
defer c.mutex.Unlock()
// Double-check after acquiring write lock
if c.metadata != nil && time.Now().Before(c.expiresAt) {
return c.metadata, nil
}
metadata, err := discoverProviderMetadata(providerURL, httpClient, logger)
if err != nil {
if c.metadata != nil {
// On error, extend current cache by 5 minutes to prevent thundering herd
c.expiresAt = time.Now().Add(5 * time.Minute)
logger.Errorf("Failed to refresh metadata, using cached version for 5 more minutes: %v", err)
return c.metadata, nil
}
return nil, fmt.Errorf("failed to fetch provider metadata: %w", err)
}
c.metadata = metadata
c.expiresAt = time.Now().Add(1 * time.Hour)
return metadata, nil
}