diff --git a/helpers.go b/helpers.go index 2a3765b..cd1e5a4 100644 --- a/helpers.go +++ b/helpers.go @@ -13,17 +13,16 @@ import ( "sync" "time" - "github.com/google/uuid" "github.com/gorilla/sessions" ) func newSessionOptions(isSecure bool) *sessions.Options { return &sessions.Options{ - HttpOnly: true, - Secure: isSecure, - SameSite: http.SameSiteLaxMode, - MaxAge: ConstSessionTimeout, - Path: "/", + HttpOnly: true, + Secure: isSecure, + SameSite: http.SameSiteLaxMode, + MaxAge: ConstSessionTimeout, + Path: "/", } } @@ -100,33 +99,21 @@ func (t *TraefikOidc) getNewTokenWithRefreshToken(refreshToken string) (*TokenRe } // handleExpiredToken handles the case when a token has expired -func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string) { +func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) { // Clear the existing session - session.Options.MaxAge = -1 - for k := range session.Values { - delete(session.Values, k) - } - - // Set new values - session.Values["csrf"] = uuid.New().String() - session.Values["incoming_path"] = req.URL.Path - session.Values["nonce"], _ = generateNonce() - session.Options = newSessionOptions(t.determineScheme(req) == "https") - - // Save the session before initiating authentication - if err := session.Save(req, rw); err != nil { - t.logger.Errorf("Failed to save session: %v", err) + if err := session.Clear(req, rw); err != nil { + t.logger.Errorf("Failed to clear session: %v", err) http.Error(rw, "Internal Server Error", http.StatusInternalServerError) return } - // Initiate a new authentication flow - t.initiateAuthenticationFunc(rw, req, session, redirectURL) + // Initialize new authentication + t.defaultInitiateAuthentication(rw, req, session, redirectURL) } // handleCallback handles the callback from the OIDC provider func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, redirectURL string) { - session, err := t.store.Get(req, cookieName) + session, err := t.sessionManager.GetSession(req) if err != nil { t.logger.Errorf("Session error: %v", err) http.Error(rw, "Session error", http.StatusInternalServerError) @@ -143,26 +130,28 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, return } - // Validate the state parameter matches the session's CSRF token + // Validate state parameter matches the session's CSRF token state := req.URL.Query().Get("state") if state == "" { t.logger.Error("No state in callback") http.Error(rw, "State parameter missing in callback", http.StatusBadRequest) return } - csrfToken, ok := session.Values["csrf"].(string) - if !ok || csrfToken == "" { + + csrfToken := session.GetCSRF() + if csrfToken == "" { t.logger.Error("CSRF token missing in session") http.Error(rw, "CSRF token missing", http.StatusBadRequest) return } + if state != csrfToken { t.logger.Error("State parameter does not match CSRF token in session") http.Error(rw, "Invalid state parameter", http.StatusBadRequest) return } - // Proceed to exchange the code for tokens + // Exchange code for tokens code := req.URL.Query().Get("code") if code == "" { t.logger.Error("No code in callback") @@ -177,49 +166,42 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, return } - // Extract id_token - idToken := tokenResponse.IDToken - if idToken == "" { - t.logger.Error("No id_token in token response") - http.Error(rw, "Authentication failed", http.StatusInternalServerError) - return - } - - // Verify the id_token - if err := t.verifyToken(idToken); err != nil { + // Verify and process tokens + if err := t.verifyToken(tokenResponse.IDToken); err != nil { t.logger.Errorf("Failed to verify id_token: %v", err) http.Error(rw, "Authentication failed", http.StatusInternalServerError) return } - // Extract claims from id_token - claims, err := t.extractClaimsFunc(idToken) + claims, err := t.extractClaimsFunc(tokenResponse.IDToken) if err != nil { t.logger.Errorf("Failed to extract claims: %v", err) http.Error(rw, "Authentication failed", http.StatusInternalServerError) return } - // Verify the nonce claim matches the one stored in session + // Verify nonce nonceClaim, ok := claims["nonce"].(string) if !ok || nonceClaim == "" { t.logger.Error("Nonce claim missing in id_token") http.Error(rw, "Authentication failed", http.StatusInternalServerError) return } - sessionNonce, ok := session.Values["nonce"].(string) - if !ok || sessionNonce == "" { + + sessionNonce := session.GetNonce() + if sessionNonce == "" { t.logger.Error("Nonce not found in session") http.Error(rw, "Authentication failed", http.StatusInternalServerError) return } + if nonceClaim != sessionNonce { t.logger.Error("Nonce claim does not match session nonce") http.Error(rw, "Authentication failed", http.StatusInternalServerError) return } - // Get the email from claims + // Process email email, _ := claims["email"].(string) if email == "" || !t.isAllowedDomain(email) { t.logger.Errorf("Invalid or disallowed email: %s", email) @@ -227,31 +209,25 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, return } - // Store tokens and authentication status in session - session.Values["authenticated"] = true - session.Values["email"] = email - session.Values["id_token"] = idToken - session.Values["refresh_token"] = tokenResponse.RefreshToken - session.Options = newSessionOptions(t.determineScheme(req) == "https") - - // Remove CSRF and nonce from session - delete(session.Values, "csrf") - delete(session.Values, "nonce") + // Update session with new values + session.SetAuthenticated(true) + session.SetEmail(email) + session.SetAccessToken(tokenResponse.IDToken) + session.SetRefreshToken(tokenResponse.RefreshToken) + // Save session if err := session.Save(req, rw); err != nil { t.logger.Errorf("Failed to save session: %v", err) http.Error(rw, "Failed to save session", http.StatusInternalServerError) return } - t.logger.Debugf("Authentication successful. User email: %s", email) - - // Redirect to the original requested path or default to root + // Redirect to original path or root redirectPath := "/" - if path, ok := session.Values["incoming_path"].(string); ok && path != t.redirURLPath { - t.logger.Debugf("Redirecting to incoming path from original request: %s", path) - redirectPath = path + if incomingPath := session.GetIncomingPath(); incomingPath != "" && incomingPath != t.redirURLPath { + redirectPath = incomingPath } + http.Redirect(rw, req, redirectPath, http.StatusFound) } @@ -376,21 +352,19 @@ func createStringMap(keys []string) map[string]struct{} { // handleLogout handles the logout request func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) { - session, err := t.store.Get(req, cookieName) + session, err := t.sessionManager.GetSession(req) if err != nil { t.logger.Errorf("Error getting session: %v", err) http.Error(rw, "Session error", http.StatusInternalServerError) return } - // Get the id_token before clearing the session - idToken, _ := session.Values["id_token"].(string) + // Get the access token before clearing session + accessToken := session.GetAccessToken() - // Clear and expire the session - session.Values = make(map[interface{}]interface{}) - session.Options.MaxAge = -1 - if err := session.Save(req, rw); err != nil { - t.logger.Errorf("Error saving session: %v", err) + // Clear all session data + if err := session.Clear(req, rw); err != nil { + t.logger.Errorf("Error clearing session: %v", err) http.Error(rw, "Session error", http.StatusInternalServerError) return } @@ -401,34 +375,26 @@ func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) { baseURL := fmt.Sprintf("%s://%s", scheme, host) // Determine post logout redirect URI - var postLogoutRedirectURI string - if t.postLogoutRedirectURI != "" { - // Use explicitly configured postLogoutRedirectURI - if strings.HasPrefix(t.postLogoutRedirectURI, "http://") || strings.HasPrefix(t.postLogoutRedirectURI, "https://") { - postLogoutRedirectURI = t.postLogoutRedirectURI - } else { - postLogoutRedirectURI = fmt.Sprintf("%s%s", baseURL, t.postLogoutRedirectURI) - } - } else { - postLogoutRedirectURI = fmt.Sprintf("%s%s", baseURL, "/") + postLogoutRedirectURI := t.postLogoutRedirectURI + if postLogoutRedirectURI == "" { + postLogoutRedirectURI = fmt.Sprintf("%s/", baseURL) + } else if !strings.HasPrefix(postLogoutRedirectURI, "http") { + postLogoutRedirectURI = fmt.Sprintf("%s%s", baseURL, postLogoutRedirectURI) } - t.logger.Debugf("Using post logout redirect URI: %s", postLogoutRedirectURI) - - // If we have an end session endpoint and an ID token, use OIDC end session - if t.endSessionURL != "" && idToken != "" { - logoutURL, err := BuildLogoutURL(t.endSessionURL, idToken, postLogoutRedirectURI) + // If we have an end session endpoint and an access token, use OIDC end session + if t.endSessionURL != "" && accessToken != "" { + logoutURL, err := BuildLogoutURL(t.endSessionURL, accessToken, postLogoutRedirectURI) if err != nil { - handleError(rw, fmt.Sprintf("Failed to build logout URL: %v", err), http.StatusInternalServerError, t.logger) + t.logger.Errorf("Failed to build logout URL: %v", err) + http.Error(rw, "Logout error", http.StatusInternalServerError) return } - t.logger.Debugf("Redirecting to end session URL: %s", logoutURL) http.Redirect(rw, req, logoutURL, http.StatusFound) return } - // If no end session endpoint or no ID token, just redirect to the post logout URI - t.logger.Debugf("Redirecting to post logout URI: %s", postLogoutRedirectURI) + // Otherwise, redirect to post logout URI http.Redirect(rw, req, postLogoutRedirectURI, http.StatusFound) } diff --git a/main.go b/main.go index 3d5e9a9..cbe0628 100644 --- a/main.go +++ b/main.go @@ -14,7 +14,6 @@ import ( "time" "github.com/google/uuid" - "github.com/gorilla/sessions" "golang.org/x/time/rate" ) @@ -34,7 +33,6 @@ type JWTVerifier interface { type TraefikOidc struct { next http.Handler name string - store sessions.Store redirURLPath string logoutURLPath string issuerURL string @@ -58,13 +56,14 @@ type TraefikOidc struct { excludedURLs map[string]struct{} allowedUserDomains map[string]struct{} allowedRolesAndGroups map[string]struct{} - initiateAuthenticationFunc func(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string) + initiateAuthenticationFunc func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) exchangeCodeForTokenFunc func(code string, redirectURL string) (*TokenResponse, error) extractClaimsFunc func(tokenString string) (map[string]interface{}, error) initComplete chan struct{} endSessionURL string baseURL string postLogoutRedirectURI string + sessionManager *SessionManager } // ProviderMetadata holds OIDC provider metadata @@ -185,11 +184,6 @@ func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error // New creates a new instance of the OIDC middleware func New(ctx context.Context, next http.Handler, config *Config, name string) (http.Handler, error) { - store := sessions.NewCookieStore([]byte(config.SessionEncryptionKey)) - store.Options = newSessionOptions(func() bool { - return config.ForceHTTPS - }()) - // Setup HTTP client transport := &http.Transport{ Proxy: http.ProxyFromEnvironment, @@ -221,7 +215,6 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h t := &TraefikOidc{ next: next, name: name, - store: store, redirURLPath: config.CallbackURL, logoutURLPath: func() string { if config.LogoutURL == "" { @@ -251,9 +244,10 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h initComplete: make(chan struct{}), } + t.sessionManager = NewSessionManager(config.SessionEncryptionKey, config.ForceHTTPS, t.logger) t.extractClaimsFunc = extractClaims t.exchangeCodeForTokenFunc = t.exchangeCodeForToken - t.initiateAuthenticationFunc = func(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string) { + t.initiateAuthenticationFunc = func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) { t.defaultInitiateAuthentication(rw, req, session, redirectURL) } @@ -365,53 +359,43 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { http.Error(rw, "OIDC middleware not yet initialized", http.StatusInternalServerError) return } - // Process the request as normal case <-req.Context().Done(): t.logger.Debug("Request cancelled") http.Error(rw, "Request cancelled", http.StatusServiceUnavailable) return } - // Check if the URL is excluded from authentication + // Check if URL is excluded if t.determineExcludedURL(req.URL.Path) { t.next.ServeHTTP(rw, req) return } - // Determine the scheme (http/https) and host - scheme := t.determineScheme(req) - host := t.determineHost(req) - redirectURL := buildFullURL(scheme, host, t.redirURLPath) - // Build the redirect URL if not already set - if redirectURL == "" { - redirectURL = buildFullURL(t.scheme, host, t.redirURLPath) - t.logger.Debugf("Redirect URL updated to: %s", redirectURL) - } - - // Get the session - session, err := t.store.Get(req, cookieName) + // Get session + session, err := t.sessionManager.GetSession(req) if err != nil { t.logger.Errorf("Error getting session: %v", err) http.Error(rw, "Session error", http.StatusInternalServerError) return } - session.Options = newSessionOptions(scheme == "https") - t.logger.Debugf("Session contents at start: %+v", session.Values) + // Build redirect URL + scheme := t.determineScheme(req) + host := t.determineHost(req) + redirectURL := buildFullURL(scheme, host, t.redirURLPath) - // Handle logout URL + // Handle special URLs if req.URL.Path == t.logoutURLPath { t.handleLogout(rw, req) return } - // Handle callback URL if req.URL.Path == t.redirURLPath { t.handleCallback(rw, req, redirectURL) return } - // Check if the user is authenticated + // Check authentication status authenticated, needsRefresh, expired := t.isUserAuthenticated(session) if expired { @@ -432,24 +416,10 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { } } - // At this point, the user is authenticated - idToken, ok := session.Values["id_token"].(string) - if !ok || idToken == "" { - t.logger.Errorf("No id_token found in session") - t.defaultInitiateAuthentication(rw, req, session, redirectURL) - return - } - - claims, err := extractClaims(idToken) - if err != nil { - t.logger.Errorf("Failed to extract claims: %v", err) - t.defaultInitiateAuthentication(rw, req, session, redirectURL) - return - } - - email, _ := claims["email"].(string) + // Process authenticated request + email := session.GetEmail() if email == "" { - t.logger.Debugf("No email found in token claims") + t.logger.Debug("No email found in session") t.defaultInitiateAuthentication(rw, req, session, redirectURL) return } @@ -460,36 +430,10 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { return } - groups, roles, err := t.extractGroupsAndRoles(idToken) - if err != nil { - t.logger.Errorf("Failed to extract groups and roles: %v", err) - } else { - // Set headers for groups and roles - if len(groups) > 0 { - req.Header.Set("X-User-Groups", strings.Join(groups, ",")) - } - if len(roles) > 0 { - req.Header.Set("X-User-Roles", strings.Join(roles, ",")) - } - } - - if len(t.allowedRolesAndGroups) > 0 { - allowed := false - for _, roleOrGroup := range append(groups, roles...) { - if _, ok := t.allowedRolesAndGroups[roleOrGroup]; ok { - allowed = true - break - } - } - if !allowed { - t.logger.Infof("User with email %s does not have any allowed roles or groups", email) - http.Error(rw, fmt.Sprintf("Access denied: You do not have any allowed roles or groups. To log out, visit: %s", t.logoutURLPath), http.StatusForbidden) - return - } - } - + // Set user information in headers req.Header.Set("X-Forwarded-User", email) + // Process the request t.next.ServeHTTP(rw, req) } @@ -528,37 +472,34 @@ func (t *TraefikOidc) determineHost(req *http.Request) string { } // isUserAuthenticated checks if the user is authenticated -func (t *TraefikOidc) isUserAuthenticated(session *sessions.Session) (bool, bool, bool) { - authenticated, _ := session.Values["authenticated"].(bool) - t.logger.Debugf("Session authenticated value: %v", authenticated) - - if !authenticated { +func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, bool) { + if !session.GetAuthenticated() { t.logger.Debug("User is not authenticated according to session") return false, false, false } - idToken, ok := session.Values["id_token"].(string) - if !ok || idToken == "" { - t.logger.Debug("No id_token found in session") + accessToken := session.GetAccessToken() + if accessToken == "" { + t.logger.Debug("No access token found in session") return false, false, true // Session is invalid, consider it expired } // Verify the token - if err := t.verifyToken(idToken); err != nil { + if err := t.verifyToken(accessToken); err != nil { t.logger.Errorf("Token verification failed: %v", err) return false, false, true // Token is invalid, consider it expired } - claims, err := extractClaims(idToken) + claims, err := extractClaims(accessToken) if err != nil { t.logger.Errorf("Failed to extract claims: %v", err) - return false, false, true // Can't read claims, consider it expired + return false, false, true } expClaim, ok := claims["exp"].(float64) if !ok { - t.logger.Errorf("Failed to get expiration time from claims") - return false, false, true // No expiration, consider it expired + t.logger.Error("Failed to get expiration time from claims") + return false, false, true } now := time.Now().Unix() @@ -566,7 +507,7 @@ func (t *TraefikOidc) isUserAuthenticated(session *sessions.Session) (bool, bool if now > expTime { t.logger.Debug("Token has expired") - return false, false, true // Token has expired + return false, false, true } gracePeriod := time.Minute * 5 @@ -575,26 +516,23 @@ func (t *TraefikOidc) isUserAuthenticated(session *sessions.Session) (bool, bool return true, true, false // Token will expire soon, needs refresh } - return true, false, false // Token is valid and not expiring soon + return true, false, false } // defaultInitiateAuthentication initiates the authentication process -func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string) { - // Generate CSRF token +func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) { + // Generate CSRF token and nonce csrfToken := uuid.New().String() - session.Values["csrf"] = csrfToken - session.Values["incoming_path"] = req.URL.Path - session.Options = newSessionOptions(t.determineScheme(req) == "https") - t.logger.Debugf("Setting CSRF token: %s", csrfToken) - - // Generate nonce nonce, err := generateNonce() if err != nil { http.Error(rw, "Failed to generate nonce", http.StatusInternalServerError) return } - session.Values["nonce"] = nonce - t.logger.Debugf("Setting nonce: %s", nonce) + + // Set session values + session.SetCSRF(csrfToken) + session.SetNonce(nonce) + session.SetIncomingPath(req.URL.Path) // Save the session if err := session.Save(req, rw); err != nil { @@ -603,7 +541,7 @@ func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req return } - // Build the authentication URL + // Build and redirect to auth URL authURL := t.buildAuthURL(redirectURL, csrfToken, nonce) http.Redirect(rw, req, authURL, http.StatusFound) } @@ -687,10 +625,10 @@ func (t *TraefikOidc) RevokeTokenWithProvider(token, tokenType string) error { } // refreshToken refreshes the user's token -func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, session *sessions.Session) bool { +func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, session *SessionData) bool { t.logger.Debug("Refreshing token") - refreshToken, ok := session.Values["refresh_token"].(string) - if !ok || refreshToken == "" { + refreshToken := session.GetRefreshToken() + if refreshToken == "" { t.logger.Debug("No refresh token found in session") return false } @@ -701,16 +639,17 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se return false } - // Verify the new id_token + // Verify the new access token if err := t.verifyToken(newToken.IDToken); err != nil { - t.logger.Errorf("Failed to verify new id_token: %v", err) + t.logger.Errorf("Failed to verify new access token: %v", err) return false } // Update session with new tokens - session.Values["id_token"] = newToken.IDToken - session.Values["refresh_token"] = newToken.RefreshToken - session.Options = newSessionOptions(t.determineScheme(req) == "https") + session.SetAccessToken(newToken.IDToken) + session.SetRefreshToken(newToken.RefreshToken) + + // Save the session if err := session.Save(req, rw); err != nil { t.logger.Errorf("Failed to save refreshed session: %v", err) return false diff --git a/main_test.go b/main_test.go index 7f73da4..b41eea8 100644 --- a/main_test.go +++ b/main_test.go @@ -22,13 +22,14 @@ import ( // TestSuite holds common test data and setup type TestSuite struct { - t *testing.T - rsaPrivateKey *rsa.PrivateKey - rsaPublicKey *rsa.PublicKey - ecPrivateKey *ecdsa.PrivateKey - tOidc *TraefikOidc - mockJWKCache *MockJWKCache - token string + t *testing.T + rsaPrivateKey *rsa.PrivateKey + rsaPublicKey *rsa.PublicKey + ecPrivateKey *ecdsa.PrivateKey + tOidc *TraefikOidc + mockJWKCache *MockJWKCache + token string + sessionManager *SessionManager } // Setup initializes the test suite @@ -78,6 +79,9 @@ func (ts *TestSuite) Setup() { ts.t.Fatalf("Failed to create test JWT: %v", err) } + logger := NewLogger("info") + ts.sessionManager = NewSessionManager("test-secret-key", false, logger) + // Common TraefikOidc instance ts.tOidc = &TraefikOidc{ issuerURL: "https://test-issuer.com", @@ -89,13 +93,13 @@ func (ts *TestSuite) Setup() { limiter: rate.NewLimiter(rate.Every(time.Second), 10), tokenBlacklist: NewTokenBlacklist(), tokenCache: NewTokenCache(), - logger: NewLogger("info"), - store: sessions.NewCookieStore([]byte("test-secret-key")), + logger: logger, allowedUserDomains: map[string]struct{}{"example.com": {}}, excludedURLs: map[string]struct{}{"/favicon": {}}, httpClient: &http.Client{}, extractClaimsFunc: extractClaims, initComplete: make(chan struct{}), + sessionManager: ts.sessionManager, } close(ts.tOidc.initComplete) ts.tOidc.exchangeCodeForTokenFunc = ts.exchangeCodeForTokenFunc @@ -257,6 +261,7 @@ func TestServeHTTP(t *testing.T) { sessionValues map[interface{}]interface{} expectedStatus int expectedBody string + setupSession func(*SessionData) }{ { name: "Excluded URL", @@ -272,10 +277,10 @@ func TestServeHTTP(t *testing.T) { { name: "Authenticated request to protected URL", requestPath: "/protected", - sessionValues: map[interface{}]interface{}{ - "authenticated": true, - "email": "user@example.com", - "id_token": ts.token, + setupSession: func(session *SessionData) { + session.SetAuthenticated(true) + session.SetEmail("user@example.com") + session.SetAccessToken(ts.token) }, expectedStatus: http.StatusOK, expectedBody: "OK", @@ -283,52 +288,52 @@ func TestServeHTTP(t *testing.T) { { name: "Logout URL", requestPath: "/logout", - sessionValues: map[interface{}]interface{}{ - "authenticated": true, - "email": "user@example.com", - "id_token": ts.token, + setupSession: func(session *SessionData) { + session.SetAuthenticated(true) + session.SetEmail("user@example.com") + session.SetAccessToken(ts.token) }, expectedStatus: http.StatusOK, - expectedBody: "Logged out\n", + expectedBody: "", }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - // Create a request req := httptest.NewRequest("GET", tc.requestPath, nil) req.Header.Set("X-Forwarded-Proto", "http") req.Header.Set("X-Forwarded-Host", "localhost") - - // Create a temporary response recorder to save the session - rrSession := httptest.NewRecorder() - - // Create a session - session, _ := ts.tOidc.store.New(req, cookieName) - if tc.sessionValues != nil { - for k, v := range tc.sessionValues { - session.Values[k] = v - } - session.Save(req, rrSession) - } - - // Copy session cookie from rrSession to request - for _, cookie := range rrSession.Result().Cookies() { - req.AddCookie(cookie) - } - - // Create a response recorder for ServeHTTP rr := httptest.NewRecorder() + // Setup session if needed + session, err := ts.tOidc.sessionManager.GetSession(req) + if err != nil { + t.Fatalf("Failed to get session: %v", err) + } + if tc.setupSession != nil { + tc.setupSession(session) + if err := session.Save(req, rr); err != nil { + t.Fatalf("Failed to save session: %v", err) + } + + // Copy cookies to the new request + for _, cookie := range rr.Result().Cookies() { + req.AddCookie(cookie) + } + rr = httptest.NewRecorder() + } + // Call ServeHTTP ts.tOidc.ServeHTTP(rr, req) - // Check the response + // Check response if rr.Code != tc.expectedStatus { - t.Errorf("Test %s: expected status %d, got %d", tc.name, tc.expectedStatus, rr.Code) + t.Errorf("Expected status %d, got %d", tc.expectedStatus, rr.Code) } - if tc.expectedBody != "" && strings.TrimSpace(rr.Body.String()) != strings.TrimSpace(rr.Body.String()) { - t.Errorf("Test %s: expected body '%s', got '%s'", tc.name, tc.expectedBody, rr.Body.String()) + if tc.expectedBody != "" { + if body := strings.TrimSpace(rr.Body.String()); body != tc.expectedBody { + t.Errorf("Expected body %q, got %q", tc.expectedBody, body) + } } }) } @@ -459,7 +464,7 @@ func TestHandleCallback(t *testing.T) { queryParams string exchangeCodeForToken func(code string, redirectURL string) (*TokenResponse, error) extractClaimsFunc func(tokenString string) (map[string]interface{}, error) - sessionSetupFunc func(session *sessions.Session) + sessionSetupFunc func(*SessionData) expectedStatus int }{ { @@ -477,18 +482,18 @@ func TestHandleCallback(t *testing.T) { "nonce": "test-nonce", }, nil }, - sessionSetupFunc: func(session *sessions.Session) { - session.Values["csrf"] = "test-csrf-token" - session.Values["nonce"] = "test-nonce" + sessionSetupFunc: func(session *SessionData) { + session.SetCSRF("test-csrf-token") + session.SetNonce("test-nonce") }, expectedStatus: http.StatusFound, }, { name: "Missing Code", queryParams: "", - sessionSetupFunc: func(session *sessions.Session) { - session.Values["csrf"] = "test-csrf-token" - session.Values["nonce"] = "test-nonce" + sessionSetupFunc: func(session *SessionData) { + session.SetCSRF("test-csrf-token") + session.SetNonce("test-nonce") }, expectedStatus: http.StatusBadRequest, }, @@ -498,9 +503,9 @@ func TestHandleCallback(t *testing.T) { exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) { return nil, fmt.Errorf("exchange code error") }, - sessionSetupFunc: func(session *sessions.Session) { - session.Values["csrf"] = "test-csrf-token" - session.Values["nonce"] = "test-nonce" + sessionSetupFunc: func(session *SessionData) { + session.SetCSRF("test-csrf-token") + session.SetNonce("test-nonce") }, expectedStatus: http.StatusInternalServerError, }, @@ -510,9 +515,9 @@ func TestHandleCallback(t *testing.T) { exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) { return &TokenResponse{}, nil }, - sessionSetupFunc: func(session *sessions.Session) { - session.Values["csrf"] = "test-csrf-token" - session.Values["nonce"] = "test-nonce" + sessionSetupFunc: func(session *SessionData) { + session.SetCSRF("test-csrf-token") + session.SetNonce("test-nonce") }, expectedStatus: http.StatusInternalServerError, }, @@ -531,9 +536,9 @@ func TestHandleCallback(t *testing.T) { "nonce": "test-nonce", }, nil }, - sessionSetupFunc: func(session *sessions.Session) { - session.Values["csrf"] = "test-csrf-token" - session.Values["nonce"] = "test-nonce" + sessionSetupFunc: func(session *SessionData) { + session.SetCSRF("test-csrf-token") + session.SetNonce("test-nonce") }, expectedStatus: http.StatusForbidden, }, @@ -552,9 +557,9 @@ func TestHandleCallback(t *testing.T) { "nonce": "test-nonce", }, nil }, - sessionSetupFunc: func(session *sessions.Session) { - session.Values["csrf"] = "test-csrf-token" - session.Values["nonce"] = "test-nonce" + sessionSetupFunc: func(session *SessionData) { + session.SetCSRF("test-csrf-token") + session.SetNonce("test-nonce") }, expectedStatus: http.StatusBadRequest, }, @@ -573,9 +578,9 @@ func TestHandleCallback(t *testing.T) { "nonce": "invalid-nonce", }, nil }, - sessionSetupFunc: func(session *sessions.Session) { - session.Values["csrf"] = "test-csrf-token" - session.Values["nonce"] = "test-nonce" + sessionSetupFunc: func(session *SessionData) { + session.SetCSRF("test-csrf-token") + session.SetNonce("test-nonce") }, expectedStatus: http.StatusInternalServerError, }, @@ -594,9 +599,9 @@ func TestHandleCallback(t *testing.T) { // Missing nonce }, nil }, - sessionSetupFunc: func(session *sessions.Session) { - session.Values["csrf"] = "test-csrf-token" - session.Values["nonce"] = "test-nonce" + sessionSetupFunc: func(session *SessionData) { + session.SetCSRF("test-csrf-token") + session.SetNonce("test-nonce") }, expectedStatus: http.StatusInternalServerError, }, @@ -604,15 +609,18 @@ func TestHandleCallback(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { + logger := NewLogger("info") + sessionManager := NewSessionManager("test-secret-key", false, logger) + // Create a new instance for each test to avoid state carryover tOidc := &TraefikOidc{ - store: sessions.NewCookieStore([]byte("test-secret-key")), allowedUserDomains: map[string]struct{}{"example.com": {}}, - logger: NewLogger("info"), + logger: logger, exchangeCodeForTokenFunc: tc.exchangeCodeForToken, extractClaimsFunc: tc.extractClaimsFunc, tokenVerifier: ts.tOidc.tokenVerifier, jwtVerifier: ts.tOidc.jwtVerifier, + sessionManager: sessionManager, } // Create request and response recorder @@ -620,18 +628,23 @@ func TestHandleCallback(t *testing.T) { rr := httptest.NewRecorder() // Create session - session, _ := tOidc.store.New(req, cookieName) + session, err := sessionManager.GetSession(req) + if err != nil { + t.Fatalf("Failed to get session: %v", err) + } if tc.sessionSetupFunc != nil { tc.sessionSetupFunc(session) } - session.Save(req, rr) + if err := session.Save(req, rr); err != nil { + t.Fatalf("Failed to save session: %v", err) + } - // Copy session cookie to request + // Copy cookies to the new request for _, cookie := range rr.Result().Cookies() { req.AddCookie(cookie) } - // Reset rr for the actual test + // Reset response recorder for the actual test rr = httptest.NewRecorder() // Call handleCallback @@ -849,7 +862,7 @@ func TestHandleLogout(t *testing.T) { tests := []struct { name string - setupSession func(*sessions.Session) + setupSession func(*SessionData) endSessionURL string expectedStatus int expectedURL string @@ -857,11 +870,10 @@ func TestHandleLogout(t *testing.T) { }{ { name: "Successful logout with end session endpoint", - setupSession: func(session *sessions.Session) { - session.Values["authenticated"] = true - session.Values["id_token"] = "test.id.token" - session.Values["refresh_token"] = "test-refresh-token" - session.Values["access_token"] = "test-access-token" + setupSession: func(session *SessionData) { + session.SetAuthenticated(true) + session.SetAccessToken("test.id.token") + session.SetRefreshToken("test-refresh-token") }, endSessionURL: "https://provider/end-session", expectedStatus: http.StatusFound, @@ -870,11 +882,10 @@ func TestHandleLogout(t *testing.T) { }, { name: "Successful logout without end session endpoint", - setupSession: func(session *sessions.Session) { - session.Values["authenticated"] = true - session.Values["id_token"] = "test.id.token" - session.Values["refresh_token"] = "test-refresh-token" - session.Values["access_token"] = "test-access-token" + setupSession: func(session *SessionData) { + session.SetAuthenticated(true) + session.SetAccessToken("test.id.token") + session.SetRefreshToken("test-refresh-token") }, endSessionURL: "", expectedStatus: http.StatusFound, @@ -883,16 +894,17 @@ func TestHandleLogout(t *testing.T) { }, { name: "Logout with empty session", - setupSession: func(session *sessions.Session) {}, + setupSession: func(session *SessionData) {}, expectedStatus: http.StatusFound, expectedURL: "http://example.com/", host: "test-host", }, { name: "Logout with invalid end session URL", - setupSession: func(session *sessions.Session) { - session.Values["authenticated"] = true - session.Values["id_token"] = "test.id.token" + setupSession: func(session *SessionData) { + session.SetAuthenticated(true) + session.SetAccessToken("test.id.token") + session.SetRefreshToken("test-refresh-token") }, endSessionURL: ":\\invalid-url", expectedStatus: http.StatusInternalServerError, @@ -902,19 +914,20 @@ func TestHandleLogout(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - // Create a new TraefikOidc instance for each test + logger := NewLogger("info") + sessionManager := NewSessionManager("test-secret-key", false, logger) tOidc := &TraefikOidc{ - store: sessions.NewCookieStore([]byte("test-secret-key")), revocationURL: mockRevocationServer.URL, endSessionURL: tc.endSessionURL, scheme: "http", - logger: NewLogger("info"), + logger: logger, tokenBlacklist: NewTokenBlacklist(), httpClient: &http.Client{}, clientID: "test-client-id", clientSecret: "test-client-secret", tokenCache: NewTokenCache(), forceHTTPS: false, + sessionManager: sessionManager, } // Create request with proper headers @@ -925,16 +938,18 @@ func TestHandleLogout(t *testing.T) { rr := httptest.NewRecorder() // Get a session - session, err := tOidc.store.Get(req, cookieName) + session, err := sessionManager.GetSession(req) if err != nil { t.Fatalf("Failed to get session: %v", err) } + if tc.setupSession != nil { + tc.setupSession(session) + } + if err := session.Save(req, rr); err != nil { + t.Fatalf("Failed to save session: %v", err) + } - // Setup session - tc.setupSession(session) - session.Save(req, rr) - - // Copy session cookie to request + // Copy cookies to the new request for _, cookie := range rr.Result().Cookies() { req.AddCookie(cookie) } @@ -950,7 +965,6 @@ func TestHandleLogout(t *testing.T) { t.Errorf("Expected status %d, got %d", tc.expectedStatus, rr.Code) } - // Check redirect URL if expected if tc.expectedURL != "" { location := rr.Header().Get("Location") if location != tc.expectedURL { @@ -959,23 +973,31 @@ func TestHandleLogout(t *testing.T) { } // Verify session is cleared - newSession, _ := tOidc.store.Get(req, cookieName) - if len(newSession.Values) > 0 { - t.Error("Session was not cleared") + updatedSession, err := sessionManager.GetSession(req) + if err != nil { + t.Fatalf("Failed to get updated session: %v", err) } - if newSession.Options.MaxAge != -1 { - t.Error("Session MaxAge was not set to -1") + + // Verify tokens are cleared + if token := updatedSession.GetAccessToken(); token != "" { + t.Error("Access token not cleared") + } + if token := updatedSession.GetRefreshToken(); token != "" { + t.Error("Refresh token not cleared") + } + if updatedSession.GetAuthenticated() { + t.Error("Session still marked as authenticated") } // Check token blacklist - if refreshToken, ok := session.Values["refresh_token"].(string); ok && refreshToken != "" { - if !tOidc.tokenBlacklist.IsBlacklisted(refreshToken) { - t.Error("Refresh token was not blacklisted") + if token := session.GetAccessToken(); token != "" { + if !tOidc.tokenBlacklist.IsBlacklisted(token) { + t.Error("Access token was not blacklisted") } } - if accessToken, ok := session.Values["access_token"].(string); ok && accessToken != "" { - if !tOidc.tokenBlacklist.IsBlacklisted(accessToken) { - t.Error("Access token was not blacklisted") + if token := session.GetRefreshToken(); token != "" { + if !tOidc.tokenBlacklist.IsBlacklisted(token) { + t.Error("Refresh token was not blacklisted") } } }) @@ -1156,24 +1178,24 @@ func TestHandleExpiredToken(t *testing.T) { tests := []struct { name string - setupSession func(*sessions.Session) + setupSession func(*SessionData) expectedPath string }{ { name: "Basic expired token", - setupSession: func(session *sessions.Session) { - session.Values["authenticated"] = true - session.Values["id_token"] = "expired.token" - session.Values["email"] = "test@example.com" + setupSession: func(session *SessionData) { + session.SetAuthenticated(true) + session.SetAccessToken("expired.token") + session.SetEmail("test@example.com") }, expectedPath: "/original/path", }, { name: "Session with additional values", - setupSession: func(session *sessions.Session) { - session.Values["authenticated"] = true - session.Values["id_token"] = "expired.token" - session.Values["custom_value"] = "should-be-cleared" + setupSession: func(session *SessionData) { + session.SetAuthenticated(true) + session.SetAccessToken("expired.token") + session.mainSession.Values["custom_value"] = "should-be-cleared" }, expectedPath: "/another/path", }, @@ -1181,16 +1203,16 @@ func TestHandleExpiredToken(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - // Create a new TraefikOidc instance for each test + logger := NewLogger("info") + sessionManager := NewSessionManager("test-secret-key", false, logger) + tOidc := &TraefikOidc{ - store: sessions.NewCookieStore([]byte("test-secret-key")), - logger: NewLogger("info"), - tokenVerifier: ts.tOidc.tokenVerifier, - jwtVerifier: ts.tOidc.jwtVerifier, - initComplete: make(chan struct{}), - // Add this initialization of initiateAuthenticationFunc - initiateAuthenticationFunc: func(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string) { - // Mock implementation for test + sessionManager: sessionManager, + logger: logger, + tokenVerifier: ts.tOidc.tokenVerifier, + jwtVerifier: ts.tOidc.jwtVerifier, + initComplete: make(chan struct{}), + initiateAuthenticationFunc: func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) { http.Redirect(rw, req, "/login", http.StatusFound) }, } @@ -1201,33 +1223,40 @@ func TestHandleExpiredToken(t *testing.T) { rr := httptest.NewRecorder() // Get session - session, _ := tOidc.store.New(req, cookieName) + session, err := sessionManager.GetSession(req) + if err != nil { + t.Fatalf("Failed to get session: %v", err) + } + + // Setup session data tc.setupSession(session) // Handle expired token tOidc.handleExpiredToken(rr, req, session, tc.expectedPath) - // Verify session is cleaned - if len(session.Values) != 3 { // Should only have csrf, incoming_path, and nonce - t.Errorf("Expected 3 session values, got %d", len(session.Values)) + // Get the updated session to verify changes + updatedSession, err := sessionManager.GetSession(req) + if err != nil { + t.Fatalf("Failed to get updated session: %v", err) } - // Verify required values are set - if _, ok := session.Values["csrf"].(string); !ok { + // Verify main session values + if updatedSession.GetCSRF() == "" { t.Error("CSRF token not set") } - if path, ok := session.Values["incoming_path"].(string); !ok || path != tc.expectedPath { + if path := updatedSession.GetIncomingPath(); path != tc.expectedPath { t.Errorf("Expected path %s, got %s", tc.expectedPath, path) } - if _, ok := session.Values["nonce"].(string); !ok { + if updatedSession.GetNonce() == "" { t.Error("Nonce not set") } - defaultSessionOptions := newSessionOptions(tOidc.determineScheme(req) == "https") - - // Verify session options - if session.Options.MaxAge != defaultSessionOptions.MaxAge { - t.Error("Session MaxAge not set correctly") + // Verify tokens are cleared + if token := updatedSession.GetAccessToken(); token != "" { + t.Error("Access token not cleared") + } + if token := updatedSession.GetRefreshToken(); token != "" { + t.Error("Refresh token not cleared") } // Verify redirect status diff --git a/session.go b/session.go new file mode 100644 index 0000000..257717b --- /dev/null +++ b/session.go @@ -0,0 +1,196 @@ +package traefikoidc + +import ( + "fmt" + "net/http" + "strings" + + "github.com/gorilla/sessions" +) + +const ( + mainCookieName = "_raczylo_oidc" // Main session cookie + accessTokenCookie = "_raczylo_oidc_access" // Access token cookie + refreshTokenCookie = "_raczylo_oidc_refresh" // Refresh token cookie +) + +// SessionManager handles multiple session cookies +type SessionManager struct { + store sessions.Store + forceHTTPS bool + logger *Logger +} + +// NewSessionManager creates a new session manager +func NewSessionManager(encryptionKey string, forceHTTPS bool, logger *Logger) *SessionManager { + return &SessionManager{ + store: sessions.NewCookieStore([]byte(encryptionKey)), + forceHTTPS: forceHTTPS, + logger: logger, + } +} + +// getSessionOptions returns session options based on scheme +func (sm *SessionManager) getSessionOptions(isSecure bool) *sessions.Options { + return &sessions.Options{ + HttpOnly: true, + Secure: isSecure || sm.forceHTTPS, + SameSite: http.SameSiteLaxMode, + MaxAge: ConstSessionTimeout, + Path: "/", + } +} + +// GetSession retrieves all session data +func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) { + mainSession, err := sm.store.Get(r, mainCookieName) + if err != nil { + return nil, fmt.Errorf("failed to get main session: %w", err) + } + + accessSession, err := sm.store.Get(r, accessTokenCookie) + if err != nil { + return nil, fmt.Errorf("failed to get access token session: %w", err) + } + + refreshSession, err := sm.store.Get(r, refreshTokenCookie) + if err != nil { + return nil, fmt.Errorf("failed to get refresh token session: %w", err) + } + + sessionData := &SessionData{ + manager: sm, + mainSession: mainSession, + accessSession: accessSession, + refreshSession: refreshSession, + } + + return sessionData, nil +} + +// SessionData holds all session information +type SessionData struct { + manager *SessionManager + mainSession *sessions.Session + accessSession *sessions.Session + refreshSession *sessions.Session +} + +// Save saves all session data +func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error { + isSecure := strings.HasPrefix(r.URL.Scheme, "https") || sd.manager.forceHTTPS + + // Set options for all sessions + sd.mainSession.Options = sd.manager.getSessionOptions(isSecure) + sd.accessSession.Options = sd.manager.getSessionOptions(isSecure) + sd.refreshSession.Options = sd.manager.getSessionOptions(isSecure) + + if err := sd.mainSession.Save(r, w); err != nil { + return fmt.Errorf("failed to save main session: %w", err) + } + if err := sd.accessSession.Save(r, w); err != nil { + return fmt.Errorf("failed to save access token session: %w", err) + } + if err := sd.refreshSession.Save(r, w); err != nil { + return fmt.Errorf("failed to save refresh token session: %w", err) + } + + return nil +} + +// Clear clears all session data +func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error { + // Clear and expire all sessions + sd.mainSession.Options.MaxAge = -1 + sd.accessSession.Options.MaxAge = -1 + sd.refreshSession.Options.MaxAge = -1 + + for k := range sd.mainSession.Values { + delete(sd.mainSession.Values, k) + } + for k := range sd.accessSession.Values { + delete(sd.accessSession.Values, k) + } + for k := range sd.refreshSession.Values { + delete(sd.refreshSession.Values, k) + } + + return sd.Save(r, w) +} + +// GetAuthenticated returns authentication status +func (sd *SessionData) GetAuthenticated() bool { + auth, _ := sd.mainSession.Values["authenticated"].(bool) + return auth +} + +// SetAuthenticated sets authentication status +func (sd *SessionData) SetAuthenticated(value bool) { + sd.mainSession.Values["authenticated"] = value +} + +// GetAccessToken returns the access token +func (sd *SessionData) GetAccessToken() string { + token, _ := sd.accessSession.Values["token"].(string) + return token +} + +// SetAccessToken sets the access token +func (sd *SessionData) SetAccessToken(token string) { + sd.accessSession.Values["token"] = token +} + +// GetRefreshToken returns the refresh token +func (sd *SessionData) GetRefreshToken() string { + token, _ := sd.refreshSession.Values["token"].(string) + return token +} + +// SetRefreshToken sets the refresh token +func (sd *SessionData) SetRefreshToken(token string) { + sd.refreshSession.Values["token"] = token +} + +// GetCSRF returns the CSRF token +func (sd *SessionData) GetCSRF() string { + csrf, _ := sd.mainSession.Values["csrf"].(string) + return csrf +} + +// SetCSRF sets the CSRF token +func (sd *SessionData) SetCSRF(token string) { + sd.mainSession.Values["csrf"] = token +} + +// GetNonce returns the nonce +func (sd *SessionData) GetNonce() string { + nonce, _ := sd.mainSession.Values["nonce"].(string) + return nonce +} + +// SetNonce sets the nonce +func (sd *SessionData) SetNonce(nonce string) { + sd.mainSession.Values["nonce"] = nonce +} + +// GetEmail returns the user's email +func (sd *SessionData) GetEmail() string { + email, _ := sd.mainSession.Values["email"].(string) + return email +} + +// SetEmail sets the user's email +func (sd *SessionData) SetEmail(email string) { + sd.mainSession.Values["email"] = email +} + +// GetIncomingPath returns the original incoming path +func (sd *SessionData) GetIncomingPath() string { + path, _ := sd.mainSession.Values["incoming_path"].(string) + return path +} + +// SetIncomingPath sets the original incoming path +func (sd *SessionData) SetIncomingPath(path string) { + sd.mainSession.Values["incoming_path"] = path +} diff --git a/session_test.go b/session_test.go new file mode 100644 index 0000000..c4faee0 --- /dev/null +++ b/session_test.go @@ -0,0 +1,60 @@ +package traefikoidc + +import ( + "net/http/httptest" + "testing" +) + +func TestSessionManager(t *testing.T) { + logger := NewLogger("info") + manager := NewSessionManager("test-secret-key", false, logger) + + req := httptest.NewRequest("GET", "/test", nil) + rr := httptest.NewRecorder() + + session, err := manager.GetSession(req) + if err != nil { + t.Fatalf("Failed to get session: %v", err) + } + + // Test setting and getting values + session.SetAuthenticated(true) + session.SetEmail("test@example.com") + session.SetAccessToken("test.access.token") + session.SetRefreshToken("test.refresh.token") + + if err := session.Save(req, rr); err != nil { + t.Fatalf("Failed to save session: %v", err) + } + + // Verify cookies are set + cookies := rr.Result().Cookies() + if len(cookies) != 3 { + t.Errorf("Expected 3 cookies, got %d", len(cookies)) + } + + // Create a new request with the cookies + newReq := httptest.NewRequest("GET", "/test", nil) + for _, cookie := range cookies { + newReq.AddCookie(cookie) + } + + // Get the session again and verify values + newSession, err := manager.GetSession(newReq) + if err != nil { + t.Fatalf("Failed to get new session: %v", err) + } + + if !newSession.GetAuthenticated() { + t.Error("Authentication status not preserved") + } + if email := newSession.GetEmail(); email != "test@example.com" { + t.Errorf("Expected email test@example.com, got %s", email) + } + if token := newSession.GetAccessToken(); token != "test.access.token" { + t.Errorf("Expected access token test.access.token, got %s", token) + } + if token := newSession.GetRefreshToken(); token != "test.refresh.token" { + t.Errorf("Expected refresh token test.refresh.token, got %s", token) + } +}