From a7d42de0a485effe0c7f9bc2e464c21a68de17d4 Mon Sep 17 00:00:00 2001 From: Lukasz Raczylo Date: Tue, 24 Sep 2024 12:35:44 +0100 Subject: [PATCH] Invalidate user session with provider on logout --- helpers.go | 5 +++++ main.go | 57 ++++++++++++++++++++++++++++++++++++---------------- main_test.go | 21 +++++++++++++++++++ settings.go | 1 + 4 files changed, 67 insertions(+), 17 deletions(-) diff --git a/helpers.go b/helpers.go index 16d8def..fb7580f 100644 --- a/helpers.go +++ b/helpers.go @@ -106,6 +106,11 @@ func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) { } if idToken, ok := session.Values["id_token"].(string); ok { + err := t.RevokeTokenWithProvider(idToken) + if err != nil { + handleError(rw, "Failed to revoke token", http.StatusInternalServerError, t.logger) + return + } t.RevokeToken(idToken) } diff --git a/main.go b/main.go index 991f3be..4f1f5b4 100644 --- a/main.go +++ b/main.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "io" "net" "net/http" "net/url" @@ -32,6 +33,7 @@ type TraefikOidc struct { redirURLPath string logoutURLPath string issuerURL string + revocationURL string jwkCache *JWKCache tokenBlacklist *TokenBlacklist jwksURL string @@ -54,10 +56,11 @@ type TraefikOidc struct { } type ProviderMetadata struct { - Issuer string `json:"issuer"` - AuthURL string `json:"authorization_endpoint"` - TokenURL string `json:"token_endpoint"` - JWKSURL string `json:"jwks_uri"` + Issuer string `json:"issuer"` + AuthURL string `json:"authorization_endpoint"` + TokenURL string `json:"token_endpoint"` + JWKSURL string `json:"jwks_uri"` + RevokeURL string `json:"revocation_endpoint"` } var defaultExcludedURLs = map[string]struct{}{ @@ -176,6 +179,7 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h return config.LogoutURL }(), issuerURL: metadata.Issuer, + revocationURL: metadata.RevokeURL, tokenBlacklist: NewTokenBlacklist(), jwkCache: &JWKCache{}, jwksURL: metadata.JWKSURL, @@ -325,7 +329,7 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { if !t.isAllowedDomain(email) { t.logger.Infof("User with email %s is not from an allowed domain", email) - http.Error(rw, fmt.Sprintf("Access denied: Your email domain is not allowed. Log me out", t.logoutURLPath), http.StatusForbidden) + http.Error(rw, fmt.Sprintf("Access denied: Your email domain is not allowed. To log out, visit: %s", t.logoutURLPath), http.StatusForbidden) return } @@ -488,21 +492,40 @@ func (t *TraefikOidc) RevokeToken(token string) { } } -func (t *TraefikOidc) refreshSession(w http.ResponseWriter, r *http.Request) { - session, err := t.store.Get(r, cookieName) - if err != nil { - t.logger.Errorf("Error getting session: %v", err) - return +func (t *TraefikOidc) RevokeTokenWithProvider(token string) error { + t.logger.Debugf("Revoking token with provider") + + data := url.Values{ + "token": {token}, + "token_type_hint": {"access_token", "refresh_token"}, + "client_id": {t.clientID}, + "client_secret": {t.clientSecret}, } - if auth, ok := session.Values["authenticated"].(bool); ok && auth { - // Refresh the session - session.Options.MaxAge = ConstSessionTimeout - err = session.Save(r, w) - if err != nil { - t.logger.Errorf("Error saving session: %v", err) - } + // Create the request + req, err := http.NewRequestWithContext(context.Background(), "POST", t.revocationURL, strings.NewReader(data.Encode())) + if err != nil { + return fmt.Errorf("failed to create token revocation request: %w", err) } + + // Set headers + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + // Send the request + resp, err := t.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to send token revocation request: %w", err) + } + defer resp.Body.Close() + + // Check the response + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("token revocation failed with status %d: %s", resp.StatusCode, string(body)) + } + + t.logger.Debugf("Token successfully revoked") + return nil } func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, session *sessions.Session) bool { diff --git a/main_test.go b/main_test.go index 72e2c2d..1759821 100644 --- a/main_test.go +++ b/main_test.go @@ -114,6 +114,7 @@ func (suite *TraefikOidcTestSuite) SetupTest() { authURL: "https://example.com/auth", tokenURL: "https://example.com/token", jwksURL: "https://example.com/.well-known/jwks.json", + revocationURL: "https://example.com/revoke", tokenVerifier: suite.mockTokenVerifier, jwtVerifier: suite.mockJWTVerifier, } @@ -376,10 +377,20 @@ func (suite *TraefikOidcTestSuite) TestHandleLogout() { suite.mockStore.On("Get", req, cookieName).Return(session, nil) suite.mockStore.On("Save", mock.Anything, mock.Anything, mock.Anything).Return(nil) + // Mock the HTTP client for token revocation + suite.mockHTTPClient.On("RoundTrip", mock.MatchedBy(func(req *http.Request) bool { + return req.URL.String() == suite.oidc.revocationURL + })).Return(&http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(`{"status":"ok"}`)), + }, nil) + suite.oidc.handleLogout(rw, req) suite.Equal(http.StatusForbidden, rw.Code) suite.Equal("Logged out\n", rw.Body.String()) + + suite.mockHTTPClient.AssertExpectations(suite.T()) } func (suite *TraefikOidcTestSuite) TestExtractClaims() { @@ -784,10 +795,20 @@ func (suite *TraefikOidcTestSuite) TestHandleLogout_CustomLogoutURL() { suite.mockStore.On("Get", req, cookieName).Return(session, nil) suite.mockStore.On("Save", mock.Anything, mock.Anything, mock.Anything).Return(nil) + // Mock the HTTP client for token revocation + suite.mockHTTPClient.On("RoundTrip", mock.MatchedBy(func(req *http.Request) bool { + return req.URL.String() == suite.oidc.revocationURL + })).Return(&http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(`{"status":"ok"}`)), + }, nil) + suite.oidc.ServeHTTP(rw, req) suite.Equal(http.StatusForbidden, rw.Code) suite.Equal("Logged out\n", rw.Body.String()) + + suite.mockHTTPClient.AssertExpectations(suite.T()) } func (suite *TraefikOidcTestSuite) TestVerifyToken_RateLimitReached() { diff --git a/settings.go b/settings.go index 9052868..f477d0e 100644 --- a/settings.go +++ b/settings.go @@ -14,6 +14,7 @@ const ( type Config struct { ProviderURL string `json:"providerURL"` + RevocationURL string `json:"revocationURL"` CallbackURL string `json:"callbackURL"` LogoutURL string `json:"logoutURL"` ClientID string `json:"clientID"`