diff --git a/cache.go b/cache.go new file mode 100644 index 0000000..14038f1 --- /dev/null +++ b/cache.go @@ -0,0 +1,62 @@ +package traefikoidc + +import ( + "sync" + "time" +) + +type CacheItem struct { + Value interface{} + ExpiresAt time.Time +} + +type Cache struct { + items map[string]CacheItem + mutex sync.RWMutex +} + +func NewCache() *Cache { + return &Cache{ + items: make(map[string]CacheItem), + } +} + +func (c *Cache) Set(key string, value interface{}, expiration time.Duration) { + c.mutex.Lock() + defer c.mutex.Unlock() + c.items[key] = CacheItem{ + Value: value, + ExpiresAt: time.Now().Add(expiration), + } +} + +func (c *Cache) Get(key string) (interface{}, bool) { + c.mutex.RLock() + defer c.mutex.RUnlock() + item, found := c.items[key] + if !found { + return nil, false + } + if time.Now().After(item.ExpiresAt) { + delete(c.items, key) + return nil, false + } + return item.Value, true +} + +func (c *Cache) Delete(key string) { + c.mutex.Lock() + defer c.mutex.Unlock() + delete(c.items, key) +} + +func (c *Cache) Cleanup() { + c.mutex.Lock() + defer c.mutex.Unlock() + now := time.Now() + for key, item := range c.items { + if now.After(item.ExpiresAt) { + delete(c.items, key) + } + } +} diff --git a/helpers.go b/helpers.go index 5cd9eb6..abb88b7 100644 --- a/helpers.go +++ b/helpers.go @@ -248,11 +248,6 @@ func extractClaims(tokenString string) (map[string]interface{}, error) { return claims, nil } -type UsedTokens struct { - tokens map[string]bool - mutex sync.RWMutex -} - type TokenBlacklist struct { blacklist map[string]time.Time mutex sync.RWMutex @@ -289,8 +284,7 @@ func (tb *TokenBlacklist) Cleanup() { } type TokenCache struct { - cache map[string]*TokenInfo - mutex sync.RWMutex + cache *Cache } type TokenInfo struct { @@ -300,41 +294,32 @@ type TokenInfo struct { func NewTokenCache() *TokenCache { return &TokenCache{ - cache: make(map[string]*TokenInfo), + cache: NewCache(), } } -func (tc *TokenCache) Set(token string, expiresAt time.Time) { - tc.mutex.Lock() - defer tc.mutex.Unlock() - tc.cache[token] = &TokenInfo{Token: token, ExpiresAt: expiresAt} +func (tc *TokenCache) Set(token string, claims map[string]interface{}, expiration time.Duration) { + token = "t-" + token + tc.cache.Set(token, claims, expiration) } -func (tc *TokenCache) Get(token string) (*TokenInfo, bool) { - tc.mutex.RLock() - defer tc.mutex.RUnlock() - info, exists := tc.cache[token] - if exists && time.Now().Before(info.ExpiresAt) { - return info, true +func (tc *TokenCache) Get(token string) (map[string]interface{}, bool) { + token = "t-" + token + value, found := tc.cache.Get(token) + if !found { + return nil, false } - return nil, false + claims, ok := value.(map[string]interface{}) + return claims, ok } func (tc *TokenCache) Delete(token string) { - tc.mutex.Lock() - defer tc.mutex.Unlock() - delete(tc.cache, token) + token = "t-" + token + tc.cache.Delete(token) } func (tc *TokenCache) Cleanup() { - tc.mutex.Lock() - defer tc.mutex.Unlock() - now := time.Now() - for token, info := range tc.cache { - if now.After(info.ExpiresAt) { - delete(tc.cache, token) - } - } + tc.cache.Cleanup() } func (t *TraefikOidc) exchangeCodeForToken(code string) (map[string]interface{}, error) { diff --git a/main.go b/main.go index 1672aa1..f67c71f 100644 --- a/main.go +++ b/main.go @@ -100,7 +100,9 @@ func (t *TraefikOidc) VerifyToken(token string) error { } expirationTime := time.Unix(int64(jwt.Claims["exp"].(float64)), 0) - t.tokenCache.Set(token, expirationTime) + now := time.Now() + duration := expirationTime.Sub(now) + t.tokenCache.Set(token, jwt.Claims, duration) return nil } @@ -184,11 +186,11 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h return dialer.DialContext(ctx, network, addr) }, ForceAttemptHTTP2: true, - MaxIdleConns: 100, - IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, - MaxIdleConnsPerHost: 10, + MaxIdleConns: 100, + MaxIdleConnsPerHost: 100, + IdleConnTimeout: 90 * time.Second, } var httpClient *http.Client diff --git a/main_test.go b/main_test.go index fa1bcdc..7923bac 100644 --- a/main_test.go +++ b/main_test.go @@ -220,7 +220,9 @@ func TestVerifyToken(t *testing.T) { } if tc.cacheToken { - ts.tOidc.tokenCache.Set(tc.token, time.Now().Add(1*time.Hour)) + ts.tOidc.tokenCache.Set(tc.token, map[string]interface{}{ + "empty": "claim", + }, 60) } err := ts.tOidc.VerifyToken(tc.token)