Compare commits

...

5 Commits

8 changed files with 154 additions and 39 deletions
+74
View File
@@ -0,0 +1,74 @@
package traefikoidc
import (
"testing"
"time"
)
func TestTokenBlacklist_Add(t *testing.T) {
blacklist := NewTokenBlacklist()
token := "testToken"
expiry := time.Now().Add(time.Hour)
blacklist.Add(token, expiry)
if !blacklist.IsBlacklisted(token) {
t.Errorf("Expected token to be blacklisted, but it was not")
}
}
func TestTokenBlacklist_IsBlacklisted(t *testing.T) {
blacklist := NewTokenBlacklist()
token := "testToken"
expiry := time.Now().Add(time.Hour)
blacklist.Add(token, expiry)
if !blacklist.IsBlacklisted(token) {
t.Errorf("Expected token to be blacklisted, but it was not")
}
if blacklist.IsBlacklisted("nonExistentToken") {
t.Errorf("Expected non-existent token to not be blacklisted, but it was")
}
}
func TestTokenBlacklist_Cleanup(t *testing.T) {
blacklist := NewTokenBlacklist()
token := "testToken"
expiry := time.Now().Add(-time.Hour) // Expired token
blacklist.Add(token, expiry)
blacklist.Cleanup()
if blacklist.IsBlacklisted(token) {
t.Errorf("Expected expired token to be removed after cleanup, but it was not")
}
}
func TestTokenBlacklist_Remove(t *testing.T) {
blacklist := NewTokenBlacklist()
token := "testToken"
expiry := time.Now().Add(time.Hour)
blacklist.Add(token, expiry)
blacklist.Remove(token)
if blacklist.IsBlacklisted(token) {
t.Errorf("Expected token to be removed, but it was not")
}
}
func TestTokenBlacklist_Count(t *testing.T) {
blacklist := NewTokenBlacklist()
token1 := "token1"
token2 := "token2"
expiry := time.Now().Add(time.Hour)
blacklist.Add(token1, expiry)
blacklist.Add(token2, expiry)
if blacklist.Count() != 2 {
t.Errorf("Expected blacklist count to be 2, but got %d", blacklist.Count())
}
}
+4 -4
View File
@@ -40,7 +40,7 @@ type Cache struct {
}
// DefaultMaxSize is the default maximum number of items in the cache.
const DefaultMaxSize = 1000
const DefaultMaxSize = 500
// NewCache creates a new empty cache instance that is ready for use.
func NewCache() *Cache {
@@ -128,8 +128,8 @@ func (c *Cache) Cleanup() {
now := time.Now()
for key, item := range c.items {
// Only remove items that are already expired
if now.After(item.ExpiresAt) {
// Remove items that are expired or within 10% of expiration
if now.After(item.ExpiresAt) || now.Add(time.Duration(float64(item.ExpiresAt.Sub(now))*0.1)).After(item.ExpiresAt) {
c.removeItem(key)
}
}
@@ -139,7 +139,7 @@ func (c *Cache) Cleanup() {
func (c *Cache) evictOldest() {
now := time.Now()
elem := c.order.Front()
// First try to find an expired item from the front
for elem != nil {
entry := elem.Value.(lruEntry)
+13 -13
View File
@@ -9,7 +9,7 @@ import (
func TestTokenBlacklistSizeLimit(t *testing.T) {
tb := NewTokenBlacklist()
// Add tokens up to maxSize
for i := 0; i < 1000; i++ {
tb.Add(fmt.Sprintf("token%d", i), time.Now().Add(time.Hour))
@@ -31,12 +31,12 @@ func TestTokenBlacklistSizeLimit(t *testing.T) {
func TestTokenBlacklistExpiredCleanup(t *testing.T) {
tb := NewTokenBlacklist()
// Add some expired tokens
for i := 0; i < 500; i++ {
tb.Add(fmt.Sprintf("expired%d", i), time.Now().Add(-time.Hour))
}
// Add some valid tokens
for i := 0; i < 500; i++ {
tb.Add(fmt.Sprintf("valid%d", i), time.Now().Add(time.Hour))
@@ -62,14 +62,14 @@ func TestTokenBlacklistExpiredCleanup(t *testing.T) {
func TestTokenBlacklistOldestEviction(t *testing.T) {
tb := NewTokenBlacklist()
// Add tokens at capacity with different expiration times
baseTime := time.Now()
oldestToken := "oldest"
// Add oldest token first
tb.Add(oldestToken, baseTime.Add(time.Hour))
// Fill up to capacity with newer tokens
for i := 0; i < 999; i++ {
tb.Add(fmt.Sprintf("token%d", i), baseTime.Add(time.Hour*2))
@@ -96,7 +96,7 @@ func TestTokenBlacklistMemoryUsage(t *testing.T) {
// Force initial GC
runtime.GC()
// Record initial memory stats
var m1, m2 runtime.MemStats
runtime.ReadMemStats(&m1)
@@ -105,12 +105,12 @@ func TestTokenBlacklistMemoryUsage(t *testing.T) {
for i := 0; i < iterations; i++ {
// Add new token
tb.Add(fmt.Sprintf("token%d", i), time.Now().Add(time.Hour))
// Periodically check blacklisted status
if i%100 == 0 {
tb.IsBlacklisted(fmt.Sprintf("token%d", i-50))
}
// Periodically cleanup
if i%1000 == 0 {
tb.Cleanup()
@@ -180,7 +180,7 @@ func TestTokenCacheMemoryUsage(t *testing.T) {
// Force initial GC
runtime.GC()
// Record initial memory stats
var m1, m2 runtime.MemStats
runtime.ReadMemStats(&m1)
@@ -191,15 +191,15 @@ func TestTokenCacheMemoryUsage(t *testing.T) {
"sub": fmt.Sprintf("user%d", i),
"exp": time.Now().Add(time.Hour).Unix(),
}
// Add to cache
tc.Set(fmt.Sprintf("token%d", i), claims, time.Hour)
// Periodically retrieve
if i%100 == 0 {
tc.Get(fmt.Sprintf("token%d", i-50))
}
// Periodically cleanup
if i%1000 == 0 {
tc.Cleanup()
+12
View File
@@ -73,6 +73,7 @@ type JWKCache struct {
// maintaining consistent behavior in the token verification process.
type JWKCacheInterface interface {
GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error)
Cleanup() // Add Cleanup method to the interface
}
// GetJWKS retrieves the JSON Web Key Set, either from cache or by fetching it
@@ -111,6 +112,17 @@ func (c *JWKCache) GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, er
return jwks, nil
}
// Cleanup removes expired JWKs from the cache.
func (c *JWKCache) Cleanup() {
c.mutex.Lock()
defer c.mutex.Unlock()
now := time.Now()
if c.jwks != nil && now.After(c.expiresAt) {
c.jwks = nil
}
}
// fetchJWKS retrieves the JSON Web Key Set from the OIDC provider's JWKS endpoint.
// It handles HTTP communication and JSON parsing of the response.
// Parameters:
+19 -18
View File
@@ -533,20 +533,20 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
// Set user information in headers
req.Header.Set("X-Forwarded-User", email)
// Set OIDC-specific headers
req.Header.Set("X-Auth-Request-Redirect", req.URL.RequestURI())
req.Header.Set("X-Auth-Request-User", email)
if idToken := session.GetAccessToken(); idToken != "" {
req.Header.Set("X-Auth-Request-Token", idToken)
}
// Set security headers
rw.Header().Set("X-Frame-Options", "DENY")
rw.Header().Set("X-Content-Type-Options", "nosniff")
rw.Header().Set("X-XSS-Protection", "1; mode=block")
rw.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
// Set CORS headers
origin := req.Header.Get("Origin")
if origin != "" {
@@ -554,14 +554,14 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set("Access-Control-Allow-Credentials", "true")
rw.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
rw.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
// Handle preflight requests
if req.Method == "OPTIONS" {
rw.WriteHeader(http.StatusOK)
return
}
}
// Process the request
t.next.ServeHTTP(rw, req)
}
@@ -697,9 +697,9 @@ func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce string) string {
// Extract issuer base URL
issuerURL, err := url.Parse(t.issuerURL)
if err == nil {
return fmt.Sprintf("%s://%s%s?%s",
issuerURL.Scheme,
issuerURL.Host,
return fmt.Sprintf("%s://%s%s?%s",
issuerURL.Scheme,
issuerURL.Host,
t.authURL,
params.Encode())
}
@@ -709,16 +709,17 @@ func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce string) string {
// startTokenCleanup starts the token cleanup goroutine
func (t *TraefikOidc) startTokenCleanup() {
ticker := time.NewTicker(1 * time.Minute) // Run cleanup every minute
go func() {
defer ticker.Stop()
for range ticker.C {
t.logger.Debug("Starting token cleanup cycle")
t.tokenCache.Cleanup()
t.tokenBlacklist.Cleanup()
// Removed runtime.GC() call
}
}()
ticker := time.NewTicker(1 * time.Minute) // Run cleanup every minute
go func() {
defer ticker.Stop()
for range ticker.C {
t.logger.Debug("Starting token cleanup cycle")
t.tokenCache.Cleanup()
t.tokenBlacklist.Cleanup()
t.jwkCache.Cleanup() // Assuming jwkCache is the cache from cache.go
// Removed runtime.GC() call
}
}()
}
// RevokeToken adds the token to the blacklist
+10 -4
View File
@@ -135,6 +135,12 @@ func (m *MockJWKCache) GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet
return m.JWKS, m.Err
}
func (m *MockJWKCache) Cleanup() {
// Mock cleanup implementation
m.JWKS = nil
m.Err = nil
}
// Helper function to create a JWT token
func createTestJWT(privateKey *rsa.PrivateKey, alg, kid string, claims map[string]interface{}) (string, error) {
header := map[string]interface{}{
@@ -1776,7 +1782,7 @@ func TestBuildAuthURL(t *testing.T) {
issuerURL string
redirectURL string
state string
nonce string
nonce string
expectedPrefix string
}{
{
@@ -1785,7 +1791,7 @@ func TestBuildAuthURL(t *testing.T) {
issuerURL: "https://auth.example.com",
redirectURL: "https://app.example.com/callback",
state: "test-state",
nonce: "test-nonce",
nonce: "test-nonce",
expectedPrefix: "https://auth.example.com/oauth/authorize?",
},
{
@@ -1794,7 +1800,7 @@ func TestBuildAuthURL(t *testing.T) {
issuerURL: "https://logto.example.com",
redirectURL: "https://app.example.com/callback",
state: "test-state",
nonce: "test-nonce",
nonce: "test-nonce",
expectedPrefix: "https://logto.example.com/oidc/auth?",
},
{
@@ -1803,7 +1809,7 @@ func TestBuildAuthURL(t *testing.T) {
issuerURL: "https://auth.example.com:8443",
redirectURL: "https://app.example.com/callback",
state: "test-state",
nonce: "test-nonce",
nonce: "test-nonce",
expectedPrefix: "https://auth.example.com:8443/sign-in?",
},
}
+19
View File
@@ -19,6 +19,17 @@ func NewMetadataCache() *MetadataCache {
return &MetadataCache{}
}
// Cleanup removes expired metadata from the cache.
func (c *MetadataCache) Cleanup() {
c.mutex.Lock()
defer c.mutex.Unlock()
now := time.Now()
if c.metadata != nil && now.After(c.expiresAt) {
c.metadata = nil
}
}
// GetMetadata retrieves the metadata from cache or fetches it if expired
func (c *MetadataCache) GetMetadata(providerURL string, httpClient *http.Client, logger *Logger) (*ProviderMetadata, error) {
c.mutex.RLock()
@@ -48,7 +59,15 @@ func (c *MetadataCache) GetMetadata(providerURL string, httpClient *http.Client,
}
c.metadata = metadata
// Calculate expiration time based on usage patterns
usageCount := 0 // This should be replaced with actual usage tracking logic
if usageCount < 10 {
c.expiresAt = time.Now().Add(30 * time.Minute)
} else if usageCount < 50 {
c.expiresAt = time.Now().Add(1 * time.Hour)
} else {
c.expiresAt = time.Now().Add(2 * time.Hour)
}
return metadata, nil
}
+3
View File
@@ -326,6 +326,9 @@ func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error {
err = sd.Save(r, w)
}
// Clear transient per-request fields.
sd.request = nil
// Return session to pool.
sd.manager.sessionPool.Put(sd)