Compare commits

...

6 Commits

Author SHA1 Message Date
lukaszraczylo 33c71fd6fe Enhance test suite. 2025-02-06 23:38:22 +00:00
lukaszraczylo 241cb1c209 Deal with the memory growth issue.
* TokenBlacklist limit is set to 1000
* Increased token cleanup frequency
2025-02-06 23:34:05 +00:00
lukaszraczylo 09daa1025c Follow multiple redirects during the OIDC flow. 2025-02-06 23:31:13 +00:00
lukaszraczylo c09e7a9228 Add additional test cases to cover it. 2025-02-06 21:50:35 +00:00
lukaszraczylo e5da5d4fe9 Fix redirection to the provider when session expires 2025-02-06 21:48:56 +00:00
lukaszraczylo 31db701dda Trigger build and release. 2025-02-05 19:04:44 +00:00
5 changed files with 522 additions and 7 deletions
+1 -1
View File
@@ -1,4 +1,4 @@
## TODO / wishlist
### TODO / wishlist
- [] Improve test coverage
- [x] Improve caching mechanism
+45 -1
View File
@@ -8,6 +8,7 @@ import (
"fmt"
"io"
"net/http"
"net/http/cookiejar"
"net/url"
"strings"
"sync"
@@ -68,13 +69,28 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken
data.Set("refresh_token", codeOrToken)
}
// Create a cookie jar for this request to handle redirects with cookies
jar, _ := cookiejar.New(nil)
client := &http.Client{
Transport: t.httpClient.Transport,
Timeout: t.httpClient.Timeout,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
// Always follow redirects for OIDC endpoints
if len(via) >= 50 {
return fmt.Errorf("stopped after 50 redirects")
}
return nil
},
Jar: jar,
}
req, err := http.NewRequestWithContext(ctx, "POST", t.tokenURL, strings.NewReader(data.Encode()))
if err != nil {
return nil, fmt.Errorf("failed to create token request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := t.httpClient.Do(req)
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to exchange tokens: %w", err)
}
@@ -276,12 +292,16 @@ type TokenBlacklist struct {
// mutex protects concurrent access to the blacklist
mutex sync.RWMutex
// maxSize is the maximum number of tokens in the blacklist
maxSize int
}
// NewTokenBlacklist creates a new TokenBlacklist instance.
func NewTokenBlacklist() *TokenBlacklist {
return &TokenBlacklist{
blacklist: make(map[string]time.Time),
maxSize: 1000, // Limit the size to prevent unbounded growth
}
}
@@ -289,6 +309,30 @@ func NewTokenBlacklist() *TokenBlacklist {
func (tb *TokenBlacklist) Add(tokenID string, expiration time.Time) {
tb.mutex.Lock()
defer tb.mutex.Unlock()
// Clean up expired tokens if we're at capacity
if len(tb.blacklist) >= tb.maxSize {
now := time.Now()
for token, exp := range tb.blacklist {
if now.After(exp) {
delete(tb.blacklist, token)
}
}
// If still at capacity after cleanup, remove oldest token
if len(tb.blacklist) >= tb.maxSize {
var oldestToken string
var oldestTime time.Time
first := true
for token, exp := range tb.blacklist {
if first || exp.Before(oldestTime) {
oldestToken = token
oldestTime = exp
first = false
}
}
delete(tb.blacklist, oldestToken)
}
}
tb.blacklist[tokenID] = expiration
}
+225
View File
@@ -0,0 +1,225 @@
package traefikoidc
import (
"fmt"
"runtime"
"testing"
"time"
)
func TestTokenBlacklistSizeLimit(t *testing.T) {
tb := NewTokenBlacklist()
// Add tokens up to maxSize
for i := 0; i < 1000; i++ {
tb.Add(fmt.Sprintf("token%d", i), time.Now().Add(time.Hour))
}
// Verify size is at max
if len(tb.blacklist) != 1000 {
t.Errorf("Expected blacklist size to be 1000, got %d", len(tb.blacklist))
}
// Add one more token, should trigger cleanup/eviction
tb.Add("newtoken", time.Now().Add(time.Hour))
// Size should still be at max
if len(tb.blacklist) > 1000 {
t.Errorf("Blacklist exceeded max size: %d", len(tb.blacklist))
}
}
func TestTokenBlacklistExpiredCleanup(t *testing.T) {
tb := NewTokenBlacklist()
// Add some expired tokens
for i := 0; i < 500; i++ {
tb.Add(fmt.Sprintf("expired%d", i), time.Now().Add(-time.Hour))
}
// Add some valid tokens
for i := 0; i < 500; i++ {
tb.Add(fmt.Sprintf("valid%d", i), time.Now().Add(time.Hour))
}
// Force cleanup
tb.Cleanup()
// Only valid tokens should remain
if len(tb.blacklist) != 500 {
t.Errorf("Expected 500 valid tokens after cleanup, got %d", len(tb.blacklist))
}
// Verify only valid tokens remain
for token, expiry := range tb.blacklist {
if time.Now().After(expiry) {
t.Errorf("Found expired token after cleanup: %s", token)
}
}
}
func TestTokenBlacklistOldestEviction(t *testing.T) {
tb := NewTokenBlacklist()
// Add tokens at capacity with different expiration times
baseTime := time.Now()
oldestToken := "oldest"
// Add oldest token first
tb.Add(oldestToken, baseTime.Add(time.Hour))
// Fill up to capacity with newer tokens
for i := 0; i < 999; i++ {
tb.Add(fmt.Sprintf("token%d", i), baseTime.Add(time.Hour*2))
}
// Add a new token that should evict the oldest
newToken := "newest"
tb.Add(newToken, baseTime.Add(time.Hour*3))
// Verify oldest token was evicted
if tb.IsBlacklisted(oldestToken) {
t.Error("Oldest token should have been evicted")
}
// Verify newest token is present
if !tb.IsBlacklisted(newToken) {
t.Error("Newest token should be present")
}
}
func TestTokenBlacklistMemoryUsage(t *testing.T) {
tb := NewTokenBlacklist()
iterations := 10000
// Force initial GC
runtime.GC()
// Record initial memory stats
var m1, m2 runtime.MemStats
runtime.ReadMemStats(&m1)
// Simulate heavy usage
for i := 0; i < iterations; i++ {
// Add new token
tb.Add(fmt.Sprintf("token%d", i), time.Now().Add(time.Hour))
// Periodically check blacklisted status
if i%100 == 0 {
tb.IsBlacklisted(fmt.Sprintf("token%d", i-50))
}
// Periodically cleanup
if i%1000 == 0 {
tb.Cleanup()
}
}
// Force GC and wait for it to complete
runtime.GC()
time.Sleep(100 * time.Millisecond)
runtime.ReadMemStats(&m2)
// Check memory growth (using HeapAlloc for more accurate measurement)
memoryGrowth := int64(m2.HeapAlloc - m1.HeapAlloc)
maxAllowedGrowth := int64(2 * 1024 * 1024) // 2MB max growth
if memoryGrowth > maxAllowedGrowth {
t.Logf("Initial HeapAlloc: %d, Final HeapAlloc: %d", m1.HeapAlloc, m2.HeapAlloc)
t.Errorf("Excessive memory growth: %d bytes", memoryGrowth)
}
// Verify size stayed within limits
if len(tb.blacklist) > tb.maxSize {
t.Errorf("Blacklist exceeded max size: %d", len(tb.blacklist))
}
}
func TestConcurrentTokenBlacklistOperations(t *testing.T) {
tb := NewTokenBlacklist()
iterations := 1000
concurrency := 10
done := make(chan bool)
// Start multiple goroutines performing operations
for i := 0; i < concurrency; i++ {
go func(id int) {
for j := 0; j < iterations; j++ {
// Add tokens
token := fmt.Sprintf("token%d-%d", id, j)
tb.Add(token, time.Now().Add(time.Hour))
// Check blacklist status
tb.IsBlacklisted(token)
// Periodic cleanup
if j%100 == 0 {
tb.Cleanup()
}
}
done <- true
}(i)
}
// Wait for all goroutines to complete
for i := 0; i < concurrency; i++ {
<-done
}
// Verify size constraints were maintained
if len(tb.blacklist) > tb.maxSize {
t.Errorf("Blacklist exceeded max size under concurrent operations: %d", len(tb.blacklist))
}
}
func TestTokenCacheMemoryUsage(t *testing.T) {
tc := NewTokenCache()
iterations := 10000
// Force initial GC
runtime.GC()
// Record initial memory stats
var m1, m2 runtime.MemStats
runtime.ReadMemStats(&m1)
// Simulate heavy cache usage
for i := 0; i < iterations; i++ {
claims := map[string]interface{}{
"sub": fmt.Sprintf("user%d", i),
"exp": time.Now().Add(time.Hour).Unix(),
}
// Add to cache
tc.Set(fmt.Sprintf("token%d", i), claims, time.Hour)
// Periodically retrieve
if i%100 == 0 {
tc.Get(fmt.Sprintf("token%d", i-50))
}
// Periodically cleanup
if i%1000 == 0 {
tc.Cleanup()
}
}
// Force GC and wait for it to complete
runtime.GC()
time.Sleep(100 * time.Millisecond)
runtime.ReadMemStats(&m2)
// Check memory growth (using HeapAlloc for more accurate measurement)
memoryGrowth := int64(m2.HeapAlloc - m1.HeapAlloc)
maxAllowedGrowth := int64(2 * 1024 * 1024) // 2MB max growth
if memoryGrowth > maxAllowedGrowth {
t.Logf("Initial HeapAlloc: %d, Final HeapAlloc: %d", m1.HeapAlloc, m2.HeapAlloc)
t.Errorf("Excessive cache memory growth: %d bytes", memoryGrowth)
}
// Verify cache size stayed within limits
if len(tc.cache.items) > tc.cache.maxSize {
t.Errorf("Cache exceeded max size: %d", len(tc.cache.items))
}
}
+55 -5
View File
@@ -225,6 +225,13 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
httpClient = &http.Client{
Timeout: time.Second * 15, // Reduced timeout
Transport: transport,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
// Always follow redirects for OIDC endpoints
if len(via) >= 50 {
return fmt.Errorf("stopped after 50 redirects")
}
return nil
},
}
}
@@ -639,17 +646,60 @@ func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce string) string {
if len(t.scopes) > 0 {
params.Set("scope", strings.Join(t.scopes, " "))
}
// Ensure authURL is absolute
if !strings.HasPrefix(t.authURL, "http://") && !strings.HasPrefix(t.authURL, "https://") {
// Extract issuer base URL
issuerURL, err := url.Parse(t.issuerURL)
if err == nil {
return fmt.Sprintf("%s://%s%s?%s",
issuerURL.Scheme,
issuerURL.Host,
t.authURL,
params.Encode())
}
}
return t.authURL + "?" + params.Encode()
}
// startTokenCleanup starts the token cleanup goroutine
func (t *TraefikOidc) startTokenCleanup() {
ticker := time.NewTicker(1 * time.Minute)
ctx, cancel := context.WithCancel(context.Background())
ticker := time.NewTicker(30 * time.Second) // Increased frequency to prevent memory buildup
go func() {
for range ticker.C {
t.logger.Debug("Cleaning up token cache")
t.tokenCache.Cleanup()
t.tokenBlacklist.Cleanup()
defer ticker.Stop()
defer cancel()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
t.logger.Debug("Starting token cleanup cycle")
// Run cleanup in a separate goroutine with timeout
cleanupCtx, cleanupCancel := context.WithTimeout(ctx, 10*time.Second)
done := make(chan struct{})
go func() {
defer close(done)
t.tokenCache.Cleanup()
t.tokenBlacklist.Cleanup()
}()
// Wait for cleanup to complete or timeout
select {
case <-cleanupCtx.Done():
if cleanupCtx.Err() == context.DeadlineExceeded {
t.logger.Error("Token cleanup cycle timed out")
}
case <-done:
t.logger.Debug("Token cleanup cycle completed successfully")
}
cleanupCancel()
}
}
}()
}
+196
View File
@@ -13,6 +13,7 @@ import (
"math/big"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
@@ -1647,6 +1648,111 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
// Helper function to compare string slices
// TestExchangeTokensWithRedirects tests the token exchange process with redirects
func TestExchangeTokensWithRedirects(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
tests := []struct {
name string
setupServer func() *httptest.Server
expectError bool
errorContains string
}{
{
name: "Successful token exchange with redirects",
setupServer: func() *httptest.Server {
redirectCount := 0
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if redirectCount < 3 {
// Set a cookie before redirecting
http.SetCookie(w, &http.Cookie{
Name: fmt.Sprintf("redirect-cookie-%d", redirectCount),
Value: "test-value",
})
redirectCount++
w.Header().Set("Location", r.URL.String())
w.WriteHeader(http.StatusFound)
return
}
// Verify all cookies from previous redirects are present
cookies := r.Cookies()
if len(cookies) != 3 {
t.Errorf("Expected 3 cookies, got %d", len(cookies))
}
for i := 0; i < 3; i++ {
found := false
expectedName := fmt.Sprintf("redirect-cookie-%d", i)
for _, cookie := range cookies {
if cookie.Name == expectedName {
found = true
break
}
}
if !found {
t.Errorf("Cookie %s not found", expectedName)
}
}
// Return successful token response
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(TokenResponse{
IDToken: "test.id.token",
AccessToken: "test-access-token",
TokenType: "Bearer",
ExpiresIn: 3600,
RefreshToken: "test-refresh-token",
})
}))
},
expectError: false,
},
{
name: "Too many redirects",
setupServer: func() *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Location", r.URL.String())
w.WriteHeader(http.StatusFound)
}))
},
expectError: true,
errorContains: "stopped after 50 redirects",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
server := tc.setupServer()
defer server.Close()
// Configure the test instance
tOidc := ts.tOidc
tOidc.tokenURL = server.URL
// Test token exchange
response, err := tOidc.exchangeTokens(context.Background(), "authorization_code", "test-code", "http://callback")
if tc.expectError {
if err == nil {
t.Error("Expected error but got nil")
} else if !strings.Contains(err.Error(), tc.errorContains) {
t.Errorf("Expected error containing %q, got %q", tc.errorContains, err.Error())
}
} else {
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if response == nil {
t.Error("Expected token response but got nil")
} else if response.IDToken != "test.id.token" {
t.Errorf("Expected ID token %q, got %q", "test.id.token", response.IDToken)
}
}
})
}
}
func stringSliceEqual(a, b []string) bool {
if len(a) != len(b) {
return false
@@ -1659,6 +1765,96 @@ func stringSliceEqual(a, b []string) bool {
return true
}
// TestBuildAuthURL tests the buildAuthURL function with various URL scenarios
func TestBuildAuthURL(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
tests := []struct {
name string
authURL string
issuerURL string
redirectURL string
state string
nonce string
expectedPrefix string
}{
{
name: "Absolute Auth URL",
authURL: "https://auth.example.com/oauth/authorize",
issuerURL: "https://auth.example.com",
redirectURL: "https://app.example.com/callback",
state: "test-state",
nonce: "test-nonce",
expectedPrefix: "https://auth.example.com/oauth/authorize?",
},
{
name: "Relative Auth URL",
authURL: "/oidc/auth",
issuerURL: "https://logto.example.com",
redirectURL: "https://app.example.com/callback",
state: "test-state",
nonce: "test-nonce",
expectedPrefix: "https://logto.example.com/oidc/auth?",
},
{
name: "Relative Auth URL with Different Issuer",
authURL: "/sign-in",
issuerURL: "https://auth.example.com:8443",
redirectURL: "https://app.example.com/callback",
state: "test-state",
nonce: "test-nonce",
expectedPrefix: "https://auth.example.com:8443/sign-in?",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Configure the test instance
tOidc := ts.tOidc
tOidc.authURL = tc.authURL
tOidc.issuerURL = tc.issuerURL
// Call buildAuthURL
result := tOidc.buildAuthURL(tc.redirectURL, tc.state, tc.nonce)
// Verify the URL starts with the expected prefix
if !strings.HasPrefix(result, tc.expectedPrefix) {
t.Errorf("Expected URL to start with %q, got %q", tc.expectedPrefix, result)
}
// Parse the resulting URL to verify query parameters
parsedURL, err := url.Parse(result)
if err != nil {
t.Fatalf("Failed to parse resulting URL: %v", err)
}
query := parsedURL.Query()
expectedParams := map[string]string{
"client_id": tOidc.clientID,
"response_type": "code",
"redirect_uri": tc.redirectURL,
"state": tc.state,
"nonce": tc.nonce,
}
for key, expected := range expectedParams {
if got := query.Get(key); got != expected {
t.Errorf("Expected %s=%q, got %q", key, expected, got)
}
}
// Verify scopes are present and correct
if len(tOidc.scopes) > 0 {
expectedScopes := strings.Join(tOidc.scopes, " ")
if got := query.Get("scope"); got != expectedScopes {
t.Errorf("Expected scope=%q, got %q", expectedScopes, got)
}
}
})
}
}
// TestDefaultInitiateAuthentication_PreservesQueryParameters tests that defaultInitiateAuthentication preserves query parameters in the incoming path.
func TestDefaultInitiateAuthentication_PreservesQueryParameters(t *testing.T) {
ts := &TestSuite{t: t}