diff --git a/main.go b/main.go index ec18114..f516de1 100644 --- a/main.go +++ b/main.go @@ -400,13 +400,24 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { return } - // Get session - session, err := t.sessionManager.GetSession(req) - if err != nil { - t.logger.Errorf("Error getting session: %v", err) - http.Error(rw, "Session error", http.StatusInternalServerError) - return - } +// Get session +session, err := t.sessionManager.GetSession(req) +if err != nil { + t.logger.Errorf("Error getting session: %v", err) + + // Obtain a new session and clear any residual session cookies + session, _ = t.sessionManager.GetSession(req) + session.Clear(req, rw) + + // Build redirect URL + scheme := t.determineScheme(req) + host := t.determineHost(req) + redirectURL := buildFullURL(scheme, host, t.redirURLPath) + + // Initiate authentication + t.defaultInitiateAuthentication(rw, req, session, redirectURL) + return +} // Build redirect URL scheme := t.determineScheme(req) @@ -589,7 +600,7 @@ func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req // Set session values session.SetCSRF(csrfToken) session.SetNonce(nonce) - session.SetIncomingPath(req.URL.Path) + session.SetIncomingPath(req.URL.RequestURI()) // Save the session if err := session.Save(req, rw); err != nil { diff --git a/main_test.go b/main_test.go index 871fb90..80e6c47 100644 --- a/main_test.go +++ b/main_test.go @@ -1646,6 +1646,7 @@ func TestServeHTTPRolesAndGroups(t *testing.T) { } // Helper function to compare string slices + func stringSliceEqual(a, b []string) bool { if len(a) != len(b) { return false @@ -1657,3 +1658,30 @@ func stringSliceEqual(a, b []string) bool { } return true } + +// TestDefaultInitiateAuthentication_PreservesQueryParameters tests that defaultInitiateAuthentication preserves query parameters in the incoming path. +func TestDefaultInitiateAuthentication_PreservesQueryParameters(t *testing.T) { + ts := &TestSuite{t: t} + ts.Setup() + + // Create a request with query parameters + req := httptest.NewRequest("GET", "/protected/resource?param1=value1¶m2=value2", nil) + rw := httptest.NewRecorder() + + // Get session + session, err := ts.sessionManager.GetSession(req) + if err != nil { + t.Fatalf("Failed to get session: %v", err) + } + + // Call defaultInitiateAuthentication + redirectURL := "http://example.com/callback" + ts.tOidc.defaultInitiateAuthentication(rw, req, session, redirectURL) + + // Verify that the incoming path includes query parameters + incomingPath := session.GetIncomingPath() + expectedPath := "/protected/resource?param1=value1¶m2=value2" + if incomingPath != expectedPath { + t.Errorf("Expected incoming path to be '%s', got '%s'", expectedPath, incomingPath) + } +} diff --git a/session.go b/session.go index cc8e0ea..bb6d8a8 100644 --- a/session.go +++ b/session.go @@ -172,11 +172,11 @@ func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) { // Check for absolute session timeout if createdAt, ok := sessionData.mainSession.Values["created_at"].(int64); ok { - if time.Since(time.Unix(createdAt, 0)) > absoluteSessionTimeout { - sessionData.Clear(r, nil) // Clear expired session - sm.sessionPool.Put(sessionData) - return nil, fmt.Errorf("session expired") - } +if time.Since(time.Unix(createdAt, 0)) > absoluteSessionTimeout { + // Session has expired + sm.sessionPool.Put(sessionData) + return nil, fmt.Errorf("session expired") +} } sessionData.accessSession, err = sm.store.Get(r, accessTokenCookie)