Add token cache to speed up the process and reduce the number of requests to the oidc endpoint.

This commit is contained in:
2024-07-24 18:30:51 +01:00
parent 1649c72b9e
commit 6de1ccbd17
2 changed files with 66 additions and 2 deletions
+43
View File
@@ -225,3 +225,46 @@ func (tb *TokenBlacklist) Cleanup() {
}
}
}
type TokenCache struct {
cache map[string]*TokenInfo
mutex sync.RWMutex
}
type TokenInfo struct {
Token string
ExpiresAt time.Time
}
func NewTokenCache() *TokenCache {
return &TokenCache{
cache: make(map[string]*TokenInfo),
}
}
func (tc *TokenCache) Set(token string, expiresAt time.Time) {
tc.mutex.Lock()
defer tc.mutex.Unlock()
tc.cache[token] = &TokenInfo{Token: token, ExpiresAt: expiresAt}
}
func (tc *TokenCache) Get(token string) (*TokenInfo, bool) {
tc.mutex.RLock()
defer tc.mutex.RUnlock()
info, exists := tc.cache[token]
if exists && time.Now().Before(info.ExpiresAt) {
return info, true
}
return nil, false
}
func (tc *TokenCache) Cleanup() {
tc.mutex.Lock()
defer tc.mutex.Unlock()
now := time.Now()
for token, info := range tc.cache {
if now.After(info.ExpiresAt) {
delete(tc.cache, token)
}
}
}
+23 -2
View File
@@ -38,6 +38,7 @@ type TraefikOidc struct {
limiter *rate.Limiter
forceHTTPS bool
scheme string
tokenCache *TokenCache
}
type ProviderMetadata struct {
@@ -62,7 +63,7 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
return nil, fmt.Errorf("failed to discover provider metadata: %v", err)
}
return &TraefikOidc{
t := &TraefikOidc{
next: next,
name: name,
store: store,
@@ -78,7 +79,11 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
tokenURL: metadata.TokenURL,
scopes: config.Scopes,
limiter: rate.NewLimiter(rate.Every(time.Second), 100),
}, nil
tokenCache: NewTokenCache(),
}
t.startTokenCleanup()
return t, nil
}
func discoverProviderMetadata(providerURL string) (*ProviderMetadata, error) {
@@ -218,6 +223,10 @@ func (t *TraefikOidc) verifyToken(token string) error {
return errors.New("rate limit exceeded")
}
if _, exists := t.tokenCache.Get(token); exists {
return nil // Token is valid and cached
}
jwt, err := parseJWT(token)
if err != nil {
return err
@@ -272,6 +281,9 @@ func (t *TraefikOidc) verifyToken(token string) error {
return err
}
expirationTime := time.Unix(int64(jwt.Claims["exp"].(float64)), 0)
t.tokenCache.Set(token, expirationTime)
return nil
}
@@ -288,3 +300,12 @@ func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce string) string {
// infoLogger.Printf("Built auth URL: %s", authURL)
return authURL
}
func (t *TraefikOidc) startTokenCleanup() {
ticker := time.NewTicker(5 * time.Minute)
go func() {
for range ticker.C {
t.tokenCache.Cleanup()
}
}()
}