diff --git a/helpers.go b/helpers.go index 07f12fd..740c154 100644 --- a/helpers.go +++ b/helpers.go @@ -225,3 +225,46 @@ func (tb *TokenBlacklist) Cleanup() { } } } + +type TokenCache struct { + cache map[string]*TokenInfo + mutex sync.RWMutex +} + +type TokenInfo struct { + Token string + ExpiresAt time.Time +} + +func NewTokenCache() *TokenCache { + return &TokenCache{ + cache: make(map[string]*TokenInfo), + } +} + +func (tc *TokenCache) Set(token string, expiresAt time.Time) { + tc.mutex.Lock() + defer tc.mutex.Unlock() + tc.cache[token] = &TokenInfo{Token: token, ExpiresAt: expiresAt} +} + +func (tc *TokenCache) Get(token string) (*TokenInfo, bool) { + tc.mutex.RLock() + defer tc.mutex.RUnlock() + info, exists := tc.cache[token] + if exists && time.Now().Before(info.ExpiresAt) { + return info, true + } + return nil, false +} + +func (tc *TokenCache) Cleanup() { + tc.mutex.Lock() + defer tc.mutex.Unlock() + now := time.Now() + for token, info := range tc.cache { + if now.After(info.ExpiresAt) { + delete(tc.cache, token) + } + } +} diff --git a/main.go b/main.go index cca6707..8327fb3 100644 --- a/main.go +++ b/main.go @@ -38,6 +38,7 @@ type TraefikOidc struct { limiter *rate.Limiter forceHTTPS bool scheme string + tokenCache *TokenCache } type ProviderMetadata struct { @@ -62,7 +63,7 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h return nil, fmt.Errorf("failed to discover provider metadata: %v", err) } - return &TraefikOidc{ + t := &TraefikOidc{ next: next, name: name, store: store, @@ -78,7 +79,11 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h tokenURL: metadata.TokenURL, scopes: config.Scopes, limiter: rate.NewLimiter(rate.Every(time.Second), 100), - }, nil + tokenCache: NewTokenCache(), + } + + t.startTokenCleanup() + return t, nil } func discoverProviderMetadata(providerURL string) (*ProviderMetadata, error) { @@ -218,6 +223,10 @@ func (t *TraefikOidc) verifyToken(token string) error { return errors.New("rate limit exceeded") } + if _, exists := t.tokenCache.Get(token); exists { + return nil // Token is valid and cached + } + jwt, err := parseJWT(token) if err != nil { return err @@ -272,6 +281,9 @@ func (t *TraefikOidc) verifyToken(token string) error { return err } + expirationTime := time.Unix(int64(jwt.Claims["exp"].(float64)), 0) + t.tokenCache.Set(token, expirationTime) + return nil } @@ -288,3 +300,12 @@ func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce string) string { // infoLogger.Printf("Built auth URL: %s", authURL) return authURL } + +func (t *TraefikOidc) startTokenCleanup() { + ticker := time.NewTicker(5 * time.Minute) + go func() { + for range ticker.C { + t.tokenCache.Cleanup() + } + }() +}