From 3fe92d38e04c4b3c71f847cea9cbf9a09dd626da Mon Sep 17 00:00:00 2001 From: Lukasz Raczylo Date: Thu, 25 Jul 2024 00:21:39 +0100 Subject: [PATCH] Add support for logout URL. --- README.md | 1 + helpers.go | 27 +++++++++++++++++++++++++++ jwt.go | 4 ++++ main.go | 26 +++++++++++++++++++++++--- settings.go | 13 +++++++++++++ 5 files changed, 68 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 2a83c30..c70d095 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,7 @@ testData: clientID: 1234567890.apps.googleusercontent.com clientSecret: secret callbackURL: /oauth2/callback + logoutURL: /oauth2/logout scopes: - openid - email diff --git a/helpers.go b/helpers.go index eadc23f..9b5f3d6 100644 --- a/helpers.go +++ b/helpers.go @@ -58,6 +58,27 @@ func (t *TraefikOidc) exchangeCodeForToken(ctx context.Context, code, redirectUR return result, nil } +func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) { + session, err := t.store.Get(req, cookieName) + if err != nil { + handleError(rw, "Session error", http.StatusInternalServerError, t.logger) + return + } + + if idToken, ok := session.Values["id_token"].(string); ok { + t.RevokeToken(idToken) + } + + // Clear the session + session.Options.MaxAge = -1 + session.Values = make(map[interface{}]interface{}) + err = session.Save(req, rw) + if err != nil { + handleError(rw, "Failed to save session", http.StatusInternalServerError, t.logger) + return + } +} + func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request) (bool, string) { ctx := req.Context() session, err := t.store.Get(req, cookieName) @@ -208,6 +229,12 @@ func (tc *TokenCache) Get(token string) (*TokenInfo, bool) { return nil, false } +func (tc *TokenCache) Delete(token string) { + tc.mutex.Lock() + defer tc.mutex.Unlock() + delete(tc.cache, token) +} + func (tc *TokenCache) Cleanup() { tc.mutex.Lock() defer tc.mutex.Unlock() diff --git a/jwt.go b/jwt.go index 9279140..2e9a23a 100644 --- a/jwt.go +++ b/jwt.go @@ -136,6 +136,10 @@ func (t *TraefikOidc) verifyAndCacheToken(token string) error { return fmt.Errorf("rate limit exceeded") } + if t.tokenBlacklist.IsBlacklisted(token) { + return fmt.Errorf("token is blacklisted") + } + if _, exists := t.tokenCache.Get(token); exists { return nil // Token is valid and cached } diff --git a/main.go b/main.go index e8a684a..f84e452 100644 --- a/main.go +++ b/main.go @@ -19,6 +19,7 @@ type TraefikOidc struct { name string store sessions.Store redirURLPath string + logoutURLPath string issuerURL string jwkCache *JWKCache tokenBlacklist *TokenBlacklist @@ -57,13 +58,13 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h if err != nil { return nil, fmt.Errorf("failed to discover provider metadata: %w", err) } - logger := NewLogger(config.LogLevel) t := &TraefikOidc{ next: next, name: name, store: store, redirURLPath: config.CallbackURL, + logoutURLPath: config.LogoutURL, issuerURL: metadata.Issuer, tokenBlacklist: NewTokenBlacklist(), jwkCache: &JWKCache{}, @@ -77,9 +78,8 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h limiter: rate.NewLimiter(rate.Every(time.Second), 100), tokenCache: NewTokenCache(), httpClient: &http.Client{}, - logger: logger, + logger: NewLogger(config.LogLevel), } - t.startTokenCleanup() return t, nil } @@ -108,6 +108,12 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { t.scheme = t.determineScheme(req) host := t.determineHost(req) + if req.URL.Path == t.logoutURLPath { + t.handleLogout(rw, req) + http.Error(rw, "Logged out", http.StatusOK) + return + } + redirectURL := buildFullURL(t.scheme, host, t.redirURLPath) t.logger.Infof("Final redirect URL: %s", redirectURL) @@ -227,3 +233,17 @@ func (t *TraefikOidc) startTokenCleanup() { } }() } + +func (t *TraefikOidc) RevokeToken(token string) { + // Remove from cache + t.tokenCache.Delete(token) + + // Add to blacklist + claims, err := extractClaims(token) + if err == nil { + if exp, ok := claims["exp"].(float64); ok { + expTime := time.Unix(int64(exp), 0) + t.tokenBlacklist.Add(token, expTime) + } + } +} diff --git a/settings.go b/settings.go index c8db43b..8572450 100644 --- a/settings.go +++ b/settings.go @@ -13,6 +13,7 @@ const ( type Config struct { ProviderURL string `json:"providerURL"` CallbackURL string `json:"callbackURL"` + LogoutURL string `json:"logoutURL"` ClientID string `json:"clientID"` ClientSecret string `json:"clientSecret"` Scopes []string `json:"scopes"` @@ -55,12 +56,22 @@ func NewLogger(level string) Logger { return &defaultLogger{level: level} } +func (l *defaultLogger) Info(args ...interface{}) { + if l.level == "info" || l.level == "debug" { + fmt.Println(append([]interface{}{"INFO:"}, args...)...) + } +} + func (l *defaultLogger) Infof(format string, args ...interface{}) { if l.level == "info" || l.level == "debug" { fmt.Printf("INFO: "+format+"\n", args...) } } +func (l *defaultLogger) Error(args ...interface{}) { + fmt.Fprintln(os.Stderr, append([]interface{}{"ERROR:"}, args...)...) +} + func (l *defaultLogger) Errorf(format string, args ...interface{}) { fmt.Fprintf(os.Stderr, "ERROR: "+format+"\n", args...) } @@ -71,7 +82,9 @@ type HTTPClient interface { } type Logger interface { + Info(args ...interface{}) Infof(format string, args ...interface{}) + Error(args ...interface{}) Errorf(format string, args ...interface{}) }