From 38433dfff8275164da2fe23dbfea4d315516ca13 Mon Sep 17 00:00:00 2001 From: Lukasz Raczylo Date: Mon, 16 Sep 2024 22:43:06 +0100 Subject: [PATCH] Improve handling of expired sessions --- helpers.go | 2 ++ main.go | 39 +++++++++++++++++++++++++++++---------- main_test.go | 9 ++++++++- 3 files changed, 39 insertions(+), 11 deletions(-) diff --git a/helpers.go b/helpers.go index 5aa9ea0..28473a2 100644 --- a/helpers.go +++ b/helpers.go @@ -129,6 +129,8 @@ func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Reque if err != nil { t.logger.Errorf("Failed to clear session: %v", err) } + + // Initiate a new authentication flow t.initiateAuthentication(rw, req, session, t.redirectURL) } diff --git a/main.go b/main.go index 7c23e4f..a62b328 100644 --- a/main.go +++ b/main.go @@ -273,7 +273,26 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { return } - authenticated, needsRefresh := t.isUserAuthenticated(session) + authenticated, needsRefresh, expired := t.isUserAuthenticated(session) + + if expired { + t.handleExpiredToken(rw, req, session) + return + } + + if !authenticated { + t.initiateAuthentication(rw, req, session, t.redirectURL) + return + } + + if needsRefresh { + refreshed := t.refreshToken(rw, req, session) + if !refreshed { + t.handleExpiredToken(rw, req, session) + return + } + } + if authenticated { if needsRefresh { // Attempt to refresh the token silently @@ -338,48 +357,48 @@ func (t *TraefikOidc) determineHost(req *http.Request) string { return req.Host } -func (t *TraefikOidc) isUserAuthenticated(session *sessions.Session) (bool, bool) { +func (t *TraefikOidc) isUserAuthenticated(session *sessions.Session) (bool, bool, bool) { authenticated, _ := session.Values["authenticated"].(bool) if !authenticated { - return false, false + return false, false, false } idToken, ok := session.Values["id_token"].(string) if !ok || idToken == "" { - return false, false + return false, false, true // Session is invalid, consider it expired } // Verify the token if err := t.verifyToken(idToken); err != nil { t.logger.Errorf("Token verification failed: %v", err) - return false, false + return false, false, true // Token is invalid, consider it expired } claims, err := extractClaims(idToken) if err != nil { t.logger.Errorf("Failed to extract claims: %v", err) - return false, false + return false, false, true // Can't read claims, consider it expired } exp, ok := claims["exp"].(float64) if !ok { t.logger.Errorf("Failed to get expiration time from claims") - return false, false + return false, false, true // No expiration, consider it expired } now := time.Now().Unix() expTime := int64(exp) if now > expTime { - return false, false // Token has expired + return false, false, true // Token has expired } gracePeriod := time.Minute * 5 if time.Now().Add(gracePeriod).Unix() > expTime { - return true, true // Token will expire soon, needs refresh + return true, true, false // Token will expire soon, needs refresh } - return true, false + return true, false, false // Token is valid and not expiring soon } func (t *TraefikOidc) initiateAuthentication(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string) { diff --git a/main_test.go b/main_test.go index a23cb50..85ce27c 100644 --- a/main_test.go +++ b/main_test.go @@ -410,6 +410,7 @@ func (suite *TraefikOidcTestSuite) TestIsUserAuthenticated() { setupSession func() *sessions.Session expectedAuth bool expectedRefresh bool + expectedExpired bool }{ { name: "Valid Token", @@ -421,6 +422,7 @@ func (suite *TraefikOidcTestSuite) TestIsUserAuthenticated() { }, expectedAuth: true, expectedRefresh: false, + expectedExpired: false, }, { name: "Expired Token", @@ -432,6 +434,7 @@ func (suite *TraefikOidcTestSuite) TestIsUserAuthenticated() { }, expectedAuth: false, expectedRefresh: false, + expectedExpired: true, }, { name: "Token Needs Refresh", @@ -446,6 +449,7 @@ func (suite *TraefikOidcTestSuite) TestIsUserAuthenticated() { }, expectedAuth: true, expectedRefresh: true, + expectedExpired: false, }, { name: "Not Authenticated", @@ -456,6 +460,7 @@ func (suite *TraefikOidcTestSuite) TestIsUserAuthenticated() { }, expectedAuth: false, expectedRefresh: false, + expectedExpired: false, }, } @@ -463,12 +468,14 @@ func (suite *TraefikOidcTestSuite) TestIsUserAuthenticated() { suite.Run(tc.name, func() { session := tc.setupSession() suite.mockTokenVerifier.On("VerifyToken", mock.AnythingOfType("string")).Return(nil).Maybe() - authenticated, needsRefresh := suite.oidc.isUserAuthenticated(session) + authenticated, needsRefresh, expired := suite.oidc.isUserAuthenticated(session) suite.Equal(tc.expectedAuth, authenticated) suite.Equal(tc.expectedRefresh, needsRefresh) + suite.Equal(tc.expectedExpired, expired) }) } } + func (suite *TraefikOidcTestSuite) TestInitiateAuthentication() { req := httptest.NewRequest("GET", "http://example.com", nil) rw := httptest.NewRecorder()