Add support for logout URL.

This commit is contained in:
2024-07-25 00:21:39 +01:00
parent 4baf3fbefd
commit 3fe92d38e0
5 changed files with 68 additions and 3 deletions
+1
View File
@@ -13,6 +13,7 @@ testData:
clientID: 1234567890.apps.googleusercontent.com
clientSecret: secret
callbackURL: /oauth2/callback
logoutURL: /oauth2/logout
scopes:
- openid
- email
+27
View File
@@ -58,6 +58,27 @@ func (t *TraefikOidc) exchangeCodeForToken(ctx context.Context, code, redirectUR
return result, nil
}
func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
session, err := t.store.Get(req, cookieName)
if err != nil {
handleError(rw, "Session error", http.StatusInternalServerError, t.logger)
return
}
if idToken, ok := session.Values["id_token"].(string); ok {
t.RevokeToken(idToken)
}
// Clear the session
session.Options.MaxAge = -1
session.Values = make(map[interface{}]interface{})
err = session.Save(req, rw)
if err != nil {
handleError(rw, "Failed to save session", http.StatusInternalServerError, t.logger)
return
}
}
func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request) (bool, string) {
ctx := req.Context()
session, err := t.store.Get(req, cookieName)
@@ -208,6 +229,12 @@ func (tc *TokenCache) Get(token string) (*TokenInfo, bool) {
return nil, false
}
func (tc *TokenCache) Delete(token string) {
tc.mutex.Lock()
defer tc.mutex.Unlock()
delete(tc.cache, token)
}
func (tc *TokenCache) Cleanup() {
tc.mutex.Lock()
defer tc.mutex.Unlock()
+4
View File
@@ -136,6 +136,10 @@ func (t *TraefikOidc) verifyAndCacheToken(token string) error {
return fmt.Errorf("rate limit exceeded")
}
if t.tokenBlacklist.IsBlacklisted(token) {
return fmt.Errorf("token is blacklisted")
}
if _, exists := t.tokenCache.Get(token); exists {
return nil // Token is valid and cached
}
+23 -3
View File
@@ -19,6 +19,7 @@ type TraefikOidc struct {
name string
store sessions.Store
redirURLPath string
logoutURLPath string
issuerURL string
jwkCache *JWKCache
tokenBlacklist *TokenBlacklist
@@ -57,13 +58,13 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
if err != nil {
return nil, fmt.Errorf("failed to discover provider metadata: %w", err)
}
logger := NewLogger(config.LogLevel)
t := &TraefikOidc{
next: next,
name: name,
store: store,
redirURLPath: config.CallbackURL,
logoutURLPath: config.LogoutURL,
issuerURL: metadata.Issuer,
tokenBlacklist: NewTokenBlacklist(),
jwkCache: &JWKCache{},
@@ -77,9 +78,8 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
limiter: rate.NewLimiter(rate.Every(time.Second), 100),
tokenCache: NewTokenCache(),
httpClient: &http.Client{},
logger: logger,
logger: NewLogger(config.LogLevel),
}
t.startTokenCleanup()
return t, nil
}
@@ -108,6 +108,12 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
t.scheme = t.determineScheme(req)
host := t.determineHost(req)
if req.URL.Path == t.logoutURLPath {
t.handleLogout(rw, req)
http.Error(rw, "Logged out", http.StatusOK)
return
}
redirectURL := buildFullURL(t.scheme, host, t.redirURLPath)
t.logger.Infof("Final redirect URL: %s", redirectURL)
@@ -227,3 +233,17 @@ func (t *TraefikOidc) startTokenCleanup() {
}
}()
}
func (t *TraefikOidc) RevokeToken(token string) {
// Remove from cache
t.tokenCache.Delete(token)
// Add to blacklist
claims, err := extractClaims(token)
if err == nil {
if exp, ok := claims["exp"].(float64); ok {
expTime := time.Unix(int64(exp), 0)
t.tokenBlacklist.Add(token, expTime)
}
}
}
+13
View File
@@ -13,6 +13,7 @@ const (
type Config struct {
ProviderURL string `json:"providerURL"`
CallbackURL string `json:"callbackURL"`
LogoutURL string `json:"logoutURL"`
ClientID string `json:"clientID"`
ClientSecret string `json:"clientSecret"`
Scopes []string `json:"scopes"`
@@ -55,12 +56,22 @@ func NewLogger(level string) Logger {
return &defaultLogger{level: level}
}
func (l *defaultLogger) Info(args ...interface{}) {
if l.level == "info" || l.level == "debug" {
fmt.Println(append([]interface{}{"INFO:"}, args...)...)
}
}
func (l *defaultLogger) Infof(format string, args ...interface{}) {
if l.level == "info" || l.level == "debug" {
fmt.Printf("INFO: "+format+"\n", args...)
}
}
func (l *defaultLogger) Error(args ...interface{}) {
fmt.Fprintln(os.Stderr, append([]interface{}{"ERROR:"}, args...)...)
}
func (l *defaultLogger) Errorf(format string, args ...interface{}) {
fmt.Fprintf(os.Stderr, "ERROR: "+format+"\n", args...)
}
@@ -71,7 +82,9 @@ type HTTPClient interface {
}
type Logger interface {
Info(args ...interface{})
Infof(format string, args ...interface{})
Error(args ...interface{})
Errorf(format string, args ...interface{})
}