diff --git a/helpers.go b/helpers.go index 24126bb..aabd568 100644 --- a/helpers.go +++ b/helpers.go @@ -11,6 +11,8 @@ import ( "strings" "sync" "time" + + "github.com/gorilla/sessions" ) func generateNonce() (string, error) { @@ -82,6 +84,17 @@ func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) { http.Error(rw, "Logged out", http.StatusForbidden) } +func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *sessions.Session) { + // Clear the existing session + session.Options.MaxAge = -1 + session.Values = make(map[interface{}]interface{}) + err := session.Save(req, rw) + if err != nil { + t.logger.Errorf("Failed to clear session: %v", err) + } + t.initiateAuthentication(rw, req, session, t.redirectURL) +} + func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request) (bool, string) { ctx := req.Context() session, err := t.store.Get(req, cookieName) diff --git a/main.go b/main.go index 12067db..2628efe 100644 --- a/main.go +++ b/main.go @@ -214,7 +214,7 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { if req.URL.Path == t.logoutURLPath { t.handleLogout(rw, req) - return // Remove the http.Error call here + return } if t.redirectURL == "" { @@ -240,13 +240,20 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { return } - if t.isUserAuthenticated(session) { + authenticated, tokenExpired := t.isUserAuthenticated(session) + if authenticated { t.logger.Debugf("User is authenticated, serving content") t.next.ServeHTTP(rw, req) return } - // User is not authenticated or session has expired, start the auth process + if tokenExpired { + t.logger.Debugf("Token has expired, initiating reauthentication") + t.handleExpiredToken(rw, req, session) + return + } + + // User is not authenticated, start the auth process t.initiateAuthentication(rw, req, session, t.redirectURL) } @@ -270,35 +277,34 @@ func (t *TraefikOidc) determineHost(req *http.Request) string { return req.Host } -func (t *TraefikOidc) isUserAuthenticated(session *sessions.Session) bool { +func (t *TraefikOidc) isUserAuthenticated(session *sessions.Session) (bool, bool) { authenticated, _ := session.Values["authenticated"].(bool) if authenticated { idToken, ok := session.Values["id_token"].(string) if !ok || idToken == "" { - return false + return false, false } - // Check if the token has expired claims, err := extractClaims(idToken) if err != nil { t.logger.Errorf("Failed to extract claims: %v", err) - return false + return false, false } exp, ok := claims["exp"].(float64) if !ok { t.logger.Errorf("Failed to get expiration time from claims") - return false + return false, false } if time.Now().Unix() > int64(exp) { t.logger.Debugf("Session has expired") - return false + return false, true // Token expired } - return t.verifyToken(idToken) == nil + return t.verifyToken(idToken) == nil, false } - return false + return false, false } func (t *TraefikOidc) initiateAuthentication(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string) { diff --git a/main_test.go b/main_test.go index c8175ea..6e35156 100644 --- a/main_test.go +++ b/main_test.go @@ -375,8 +375,9 @@ func (suite *TraefikOidcTestSuite) TestIsUserAuthenticated() { suite.mockTokenVerifier.On("VerifyToken", "valid.eyJleHAiOjk5OTk5OTk5OTl9.signature").Return(nil) - authenticated := suite.oidc.isUserAuthenticated(session) + authenticated, tokenExpired := suite.oidc.isUserAuthenticated(session) suite.True(authenticated) + suite.False(tokenExpired) } func (suite *TraefikOidcTestSuite) TestInitiateAuthentication() {