From 8ca669105b98848aaafcb50472ee43c328a0bc17 Mon Sep 17 00:00:00 2001 From: Lukasz Raczylo Date: Wed, 6 Nov 2024 10:28:35 +0000 Subject: [PATCH] Fix OIDC logout issue, improve test coverage, load provider once. --- helpers.go | 83 ++++++--- main.go | 99 +++++----- main_test.go | 503 +++++++++++++++++++++++++++++++++++++++++++++++++++ settings.go | 1 + 4 files changed, 616 insertions(+), 70 deletions(-) diff --git a/helpers.go b/helpers.go index 61d97d2..8460f8e 100644 --- a/helpers.go +++ b/helpers.go @@ -97,47 +97,74 @@ func (t *TraefikOidc) getNewTokenWithRefreshToken(refreshToken string) (*TokenRe return tokenResponse, nil } -// handleLogout handles the user logout -func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) { - session, err := t.store.Get(req, cookieName) - t.logger.Debugf("Logging out user") +// handleLogout handles the logout process +func (t *TraefikOidc) handleLogout(w http.ResponseWriter, r *http.Request) { + session, err := t.store.Get(r, cookieName) if err != nil { - handleError(rw, "Session error", http.StatusInternalServerError, t.logger) + handleError(w, fmt.Sprintf("Error getting session: %v", err), http.StatusInternalServerError, t.logger) return } - // Revoke tokens if available - if refreshToken, ok := session.Values["refresh_token"].(string); ok && refreshToken != "" { - if err := t.RevokeTokenWithProvider(refreshToken, "refresh_token"); err != nil { - t.logger.Errorf("Failed to revoke refresh token: %v", err) - } + // Get tokens from session + idToken, _ := session.Values["id_token"].(string) + refreshToken, _ := session.Values["refresh_token"].(string) + accessToken, _ := session.Values["access_token"].(string) + + // Revoke tokens if they exist + if refreshToken != "" { + t.RevokeTokenWithProvider(refreshToken, "refresh_token") t.RevokeToken(refreshToken) } - if accessToken, ok := session.Values["access_token"].(string); ok && accessToken != "" { - if err := t.RevokeTokenWithProvider(accessToken, "access_token"); err != nil { - t.logger.Errorf("Failed to revoke access token: %v", err) - } + if accessToken != "" { + t.RevokeTokenWithProvider(accessToken, "access_token") t.RevokeToken(accessToken) } - // Remove tokens from session - delete(session.Values, "id_token") - delete(session.Values, "refresh_token") - delete(session.Values, "access_token") - delete(session.Values, "authenticated") - - // Set session options to delete the session - session.Options = defaultSessionOptions + // Clear session session.Options.MaxAge = -1 - - if err := session.Save(req, rw); err != nil { - handleError(rw, "Failed to save session", http.StatusInternalServerError, t.logger) + session.Values = make(map[interface{}]interface{}) + if err := session.Save(r, w); err != nil { + handleError(w, fmt.Sprintf("Error saving session: %v", err), http.StatusInternalServerError, t.logger) return } - // Redirect or display logout message - rw.WriteHeader(http.StatusOK) - rw.Write([]byte("Logged out successfully")) + // Determine redirect URL + host := r.Header.Get("X-Forwarded-Host") + if host == "" { + host = r.Host + } + scheme := "http" + if r.Header.Get("X-Forwarded-Proto") == "https" || t.forceHTTPS { + scheme = "https" + } + baseURL := fmt.Sprintf("%s://%s/", scheme, host) + + if t.endSessionURL != "" && idToken != "" { + logoutURL, err := BuildLogoutURL(t.endSessionURL, idToken, baseURL) + if err != nil { + handleError(w, fmt.Sprintf("Invalid end session URL: %v", err), http.StatusInternalServerError, t.logger) + return + } + http.Redirect(w, r, logoutURL, http.StatusFound) + return + } + + http.Redirect(w, r, baseURL, http.StatusFound) +} + +// BuildLogoutURL constructs the logout URL with proper encoding +func BuildLogoutURL(endSessionURL, idToken, postLogoutRedirectURI string) (string, error) { + u, err := url.Parse(endSessionURL) + if err != nil { + return "", fmt.Errorf("invalid end session URL: %v", err) + } + + q := u.Query() + q.Set("id_token_hint", idToken) + q.Set("post_logout_redirect_uri", postLogoutRedirectURI) + u.RawQuery = q.Encode() + + return u.String(), nil } // handleExpiredToken handles the case when a token has expired diff --git a/main.go b/main.go index 62e5561..a581dbf 100644 --- a/main.go +++ b/main.go @@ -62,17 +62,19 @@ type TraefikOidc struct { initiateAuthenticationFunc func(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string) exchangeCodeForTokenFunc func(code string) (*TokenResponse, error) extractClaimsFunc func(tokenString string) (map[string]interface{}, error) - initOnce sync.Once initComplete chan struct{} + endSessionURL string + baseURL string } // ProviderMetadata holds OIDC provider metadata type ProviderMetadata struct { - Issuer string `json:"issuer"` - AuthURL string `json:"authorization_endpoint"` - TokenURL string `json:"token_endpoint"` - JWKSURL string `json:"jwks_uri"` - RevokeURL string `json:"revocation_endpoint"` + Issuer string `json:"issuer"` + AuthURL string `json:"authorization_endpoint"` + TokenURL string `json:"token_endpoint"` + JWKSURL string `json:"jwks_uri"` + RevokeURL string `json:"revocation_endpoint"` + EndSessionURL string `json:"end_session_endpoint"` } // defaultExcludedURLs are the paths that are excluded from authentication @@ -82,6 +84,14 @@ var defaultExcludedURLs = map[string]struct{}{ var newTicker = time.NewTicker +var ( + globalMetadataCache struct { + sync.Once + metadata *ProviderMetadata + err error + } +) + // VerifyToken verifies the provided JWT token func (t *TraefikOidc) VerifyToken(token string) error { t.logger.Debugf("Verifying token") @@ -254,20 +264,26 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h // initializeMetadata discovers and initializes the provider metadata func (t *TraefikOidc) initializeMetadata(providerURL string) { - t.initOnce.Do(func() { + globalMetadataCache.Once.Do(func() { + t.logger.Debug("Starting global provider metadata discovery") metadata, err := discoverProviderMetadata(providerURL, t.httpClient, t.logger) - if err != nil { - t.logger.Errorf("Failed to discover provider metadata: %v", err) - } else { - t.logger.Debug("Provider metadata discovered successfully") - t.jwksURL = metadata.JWKSURL - t.authURL = metadata.AuthURL - t.tokenURL = metadata.TokenURL - t.issuerURL = metadata.Issuer - t.revocationURL = metadata.RevokeURL - } - close(t.initComplete) + globalMetadataCache.metadata = metadata + globalMetadataCache.err = err }) + + if globalMetadataCache.err != nil { + t.logger.Errorf("Failed to discover provider metadata: %v", globalMetadataCache.err) + } else if globalMetadataCache.metadata != nil { + t.logger.Debug("Using cached provider metadata") + t.jwksURL = globalMetadataCache.metadata.JWKSURL + t.authURL = globalMetadataCache.metadata.AuthURL + t.tokenURL = globalMetadataCache.metadata.TokenURL + t.issuerURL = globalMetadataCache.metadata.Issuer + t.revocationURL = globalMetadataCache.metadata.RevokeURL + t.endSessionURL = globalMetadataCache.metadata.EndSessionURL + } + + close(t.initComplete) } // discoverProviderMetadata fetches the OIDC provider metadata @@ -620,14 +636,9 @@ func (t *TraefikOidc) RevokeToken(token string) { // Remove from cache t.tokenCache.Delete(token) - // Add to blacklist - claims, err := extractClaims(token) - if err == nil { - if exp, ok := claims["exp"].(float64); ok { - expTime := time.Unix(int64(exp), 0) - t.tokenBlacklist.Add(token, expTime) - } - } + // Add to blacklist with default expiration + expiry := time.Now().Add(24 * time.Hour) // or other appropriate duration + t.tokenBlacklist.Add(token, expiry) } // RevokeTokenWithProvider revokes the token with the provider @@ -726,26 +737,30 @@ func (t *TraefikOidc) extractGroupsAndRoles(idToken string) ([]string, []string, var groups []string var roles []string - // Check for groups claim - if groupsClaim, ok := claims["groups"]; ok { - if groupsSlice, ok := groupsClaim.([]interface{}); ok { - for _, group := range groupsSlice { - if groupStr, ok := group.(string); ok { - t.logger.Debugf("Found group: %s", groupStr) - groups = append(groups, groupStr) - } + // Extract groups with type checking + if groupsClaim, exists := claims["groups"]; exists { + groupsSlice, ok := groupsClaim.([]interface{}) + if !ok { + return nil, nil, fmt.Errorf("groups claim is not an array") + } + for _, group := range groupsSlice { + if groupStr, ok := group.(string); ok { + t.logger.Debugf("Found group: %s", groupStr) + groups = append(groups, groupStr) } } } - // Check for roles claim - if rolesClaim, ok := claims["roles"]; ok { - if rolesSlice, ok := rolesClaim.([]interface{}); ok { - for _, role := range rolesSlice { - if roleStr, ok := role.(string); ok { - t.logger.Debugf("Found role: %s", roleStr) - roles = append(roles, roleStr) - } + // Extract roles with type checking + if rolesClaim, exists := claims["roles"]; exists { + rolesSlice, ok := rolesClaim.([]interface{}) + if !ok { + return nil, nil, fmt.Errorf("roles claim is not an array") + } + for _, role := range rolesSlice { + if roleStr, ok := role.(string); ok { + t.logger.Debugf("Found role: %s", roleStr) + roles = append(roles, roleStr) } } } diff --git a/main_test.go b/main_test.go index 08b3ecf..eeef592 100644 --- a/main_test.go +++ b/main_test.go @@ -820,3 +820,506 @@ func TestOIDCHandler(t *testing.T) { }) } } + +// TestHandleLogout tests the logout functionality +func TestHandleLogout(t *testing.T) { + ts := &TestSuite{t: t} + ts.Setup() + + // Create mock revocation endpoint server + mockRevocationServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + t.Errorf("Expected POST request, got %s", r.Method) + } + if err := r.ParseForm(); err != nil { + t.Fatalf("Failed to parse form: %v", err) + } + // Verify the required parameters are present + if r.Form.Get("token") == "" { + t.Error("Missing token parameter") + } + if r.Form.Get("token_type_hint") == "" { + t.Error("Missing token_type_hint parameter") + } + w.WriteHeader(http.StatusOK) + })) + defer mockRevocationServer.Close() + + tests := []struct { + name string + setupSession func(*sessions.Session) + endSessionURL string + expectedStatus int + expectedURL string + host string + }{ + { + name: "Successful logout with end session endpoint", + setupSession: func(session *sessions.Session) { + session.Values["authenticated"] = true + session.Values["id_token"] = "test.id.token" + session.Values["refresh_token"] = "test-refresh-token" + session.Values["access_token"] = "test-access-token" + }, + endSessionURL: "https://provider/end-session", + expectedStatus: http.StatusFound, + // Fix: The entire URL should be URL-encoded + expectedURL: "https://provider/end-session?id_token_hint=test.id.token&post_logout_redirect_uri=http%3A%2F%2Fexample.com%2F", + host: "test-host", + }, + { + name: "Successful logout without end session endpoint", + setupSession: func(session *sessions.Session) { + session.Values["authenticated"] = true + session.Values["id_token"] = "test.id.token" + session.Values["refresh_token"] = "test-refresh-token" + session.Values["access_token"] = "test-access-token" + }, + endSessionURL: "", + expectedStatus: http.StatusFound, + expectedURL: "http://example.com/", + host: "test-host", + }, + { + name: "Logout with empty session", + setupSession: func(session *sessions.Session) {}, + expectedStatus: http.StatusFound, + expectedURL: "http://example.com/", + host: "test-host", + }, + { + name: "Logout with invalid end session URL", + setupSession: func(session *sessions.Session) { + session.Values["authenticated"] = true + session.Values["id_token"] = "test.id.token" + }, + endSessionURL: ":\\invalid-url", + expectedStatus: http.StatusInternalServerError, + host: "test-host", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Create a new TraefikOidc instance for each test + tOidc := &TraefikOidc{ + store: sessions.NewCookieStore([]byte("test-secret-key")), + revocationURL: mockRevocationServer.URL, + endSessionURL: tc.endSessionURL, + scheme: "http", + logger: NewLogger("info"), + tokenBlacklist: NewTokenBlacklist(), + httpClient: &http.Client{}, + clientID: "test-client-id", + clientSecret: "test-client-secret", + tokenCache: NewTokenCache(), + forceHTTPS: false, + } + + // Create request with proper headers + req := httptest.NewRequest("GET", "/logout", nil) + req.Header.Set("Host", tc.host) + + // Create a response recorder + rr := httptest.NewRecorder() + + // Get a session + session, err := tOidc.store.Get(req, cookieName) + if err != nil { + t.Fatalf("Failed to get session: %v", err) + } + + // Setup session + tc.setupSession(session) + session.Save(req, rr) + + // Copy session cookie to request + for _, cookie := range rr.Result().Cookies() { + req.AddCookie(cookie) + } + + // Reset response recorder + rr = httptest.NewRecorder() + + // Handle logout + tOidc.handleLogout(rr, req) + + // Check response + if rr.Code != tc.expectedStatus { + t.Errorf("Expected status %d, got %d", tc.expectedStatus, rr.Code) + } + + // Check redirect URL if expected + if tc.expectedURL != "" { + location := rr.Header().Get("Location") + if location != tc.expectedURL { + t.Errorf("Expected redirect to %q, got %q", tc.expectedURL, location) + } + } + + // Verify session is cleared + newSession, _ := tOidc.store.Get(req, cookieName) + if len(newSession.Values) > 0 { + t.Error("Session was not cleared") + } + if newSession.Options.MaxAge != -1 { + t.Error("Session MaxAge was not set to -1") + } + + // Check token blacklist + if refreshToken, ok := session.Values["refresh_token"].(string); ok && refreshToken != "" { + if !tOidc.tokenBlacklist.IsBlacklisted(refreshToken) { + t.Error("Refresh token was not blacklisted") + } + } + if accessToken, ok := session.Values["access_token"].(string); ok && accessToken != "" { + if !tOidc.tokenBlacklist.IsBlacklisted(accessToken) { + t.Error("Access token was not blacklisted") + } + } + }) + } +} + +// TestRevokeTokenWithProvider tests the token revocation with provider +func TestRevokeTokenWithProvider(t *testing.T) { + ts := &TestSuite{t: t} + ts.Setup() + + tests := []struct { + name string + token string + tokenType string + statusCode int + expectError bool + }{ + { + name: "Successful token revocation", + token: "valid-token", + tokenType: "refresh_token", + statusCode: http.StatusOK, + expectError: false, + }, + { + name: "Failed token revocation", + token: "invalid-token", + tokenType: "refresh_token", + statusCode: http.StatusBadRequest, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Create test server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify request method and content type + if r.Method != "POST" { + t.Errorf("Expected POST request, got %s", r.Method) + } + if ct := r.Header.Get("Content-Type"); ct != "application/x-www-form-urlencoded" { + t.Errorf("Expected Content-Type application/x-www-form-urlencoded, got %s", ct) + } + + // Verify form values + if err := r.ParseForm(); err != nil { + t.Fatalf("Failed to parse form: %v", err) + } + if got := r.Form.Get("token"); got != tc.token { + t.Errorf("Expected token %s, got %s", tc.token, got) + } + if got := r.Form.Get("token_type_hint"); got != tc.tokenType { + t.Errorf("Expected token_type_hint %s, got %s", tc.tokenType, got) + } + if got := r.Form.Get("client_id"); got != ts.tOidc.clientID { + t.Errorf("Expected client_id %s, got %s", ts.tOidc.clientID, got) + } + if got := r.Form.Get("client_secret"); got != ts.tOidc.clientSecret { + t.Errorf("Expected client_secret %s, got %s", ts.tOidc.clientSecret, got) + } + + w.WriteHeader(tc.statusCode) + })) + defer server.Close() + + // Set revocation URL to test server + ts.tOidc.revocationURL = server.URL + + // Test token revocation + err := ts.tOidc.RevokeTokenWithProvider(tc.token, tc.tokenType) + if tc.expectError && err == nil { + t.Error("Expected error but got nil") + } + if !tc.expectError && err != nil { + t.Errorf("Unexpected error: %v", err) + } + }) + } +} + +// TestRevokeToken tests the token revocation functionality +func TestRevokeToken(t *testing.T) { + ts := &TestSuite{t: t} + ts.Setup() + + token := "test.token.with.claims" + claims := map[string]interface{}{ + "exp": float64(time.Now().Add(time.Hour).Unix()), + } + + // Test token revocation + t.Run("Token revocation", func(t *testing.T) { + // Create a new instance for this specific test + tOidc := &TraefikOidc{ + tokenBlacklist: NewTokenBlacklist(), + tokenCache: NewTokenCache(), + } + + // Cache the token + tOidc.tokenCache.Set(token, claims, time.Hour) + + // Revoke the token + tOidc.RevokeToken(token) + + // Verify token was removed from cache + if _, exists := tOidc.tokenCache.Get(token); exists { + t.Error("Token was not removed from cache") + } + + // Verify token was added to blacklist + if !tOidc.tokenBlacklist.IsBlacklisted(token) { + t.Error("Token was not added to blacklist") + } + }) +} + +// Add this new test function +func TestBuildLogoutURL(t *testing.T) { + tests := []struct { + name string + endSessionURL string + idToken string + postLogoutRedirect string + expectedURL string + expectError bool + }{ + { + name: "Valid URL", + endSessionURL: "https://provider/end-session", + idToken: "test.id.token", + postLogoutRedirect: "http://example.com/", + expectedURL: "https://provider/end-session?id_token_hint=test.id.token&post_logout_redirect_uri=http%3A%2F%2Fexample.com%2F", + expectError: false, + }, + { + name: "Invalid URL", + endSessionURL: "://invalid-url", + idToken: "test.id.token", + postLogoutRedirect: "http://example.com/", + expectError: true, + }, + { + name: "URL with existing query parameters", + endSessionURL: "https://provider/end-session?existing=param", + idToken: "test.id.token", + postLogoutRedirect: "http://example.com/", + expectedURL: "https://provider/end-session?existing=param&id_token_hint=test.id.token&post_logout_redirect_uri=http%3A%2F%2Fexample.com%2F", + expectError: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + url, err := BuildLogoutURL(tc.endSessionURL, tc.idToken, tc.postLogoutRedirect) + + if tc.expectError { + if err == nil { + t.Error("Expected error but got nil") + } + } else { + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if url != tc.expectedURL { + t.Errorf("Expected URL %q, got %q", tc.expectedURL, url) + } + } + }) + } +} + +// Add this new test function +func TestHandleExpiredToken(t *testing.T) { + ts := &TestSuite{t: t} + ts.Setup() + + tests := []struct { + name string + setupSession func(*sessions.Session) + expectedPath string + }{ + { + name: "Basic expired token", + setupSession: func(session *sessions.Session) { + session.Values["authenticated"] = true + session.Values["id_token"] = "expired.token" + session.Values["email"] = "test@example.com" + }, + expectedPath: "/original/path", + }, + { + name: "Session with additional values", + setupSession: func(session *sessions.Session) { + session.Values["authenticated"] = true + session.Values["id_token"] = "expired.token" + session.Values["custom_value"] = "should-be-cleared" + }, + expectedPath: "/another/path", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Create a new TraefikOidc instance for each test + tOidc := &TraefikOidc{ + store: sessions.NewCookieStore([]byte("test-secret-key")), + logger: NewLogger("info"), + redirectURL: "http://example.com/callback", + tokenVerifier: ts.tOidc.tokenVerifier, + jwtVerifier: ts.tOidc.jwtVerifier, + initComplete: make(chan struct{}), + // Add this initialization of initiateAuthenticationFunc + initiateAuthenticationFunc: func(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string) { + // Mock implementation for test + http.Redirect(rw, req, "/login", http.StatusFound) + }, + } + close(tOidc.initComplete) + + // Create request + req := httptest.NewRequest("GET", tc.expectedPath, nil) + rr := httptest.NewRecorder() + + // Get session + session, _ := tOidc.store.New(req, cookieName) + tc.setupSession(session) + + // Handle expired token + tOidc.handleExpiredToken(rr, req, session) + + // Verify session is cleaned + if len(session.Values) != 3 { // Should only have csrf, incoming_path, and nonce + t.Errorf("Expected 3 session values, got %d", len(session.Values)) + } + + // Verify required values are set + if _, ok := session.Values["csrf"].(string); !ok { + t.Error("CSRF token not set") + } + if path, ok := session.Values["incoming_path"].(string); !ok || path != tc.expectedPath { + t.Errorf("Expected path %s, got %s", tc.expectedPath, path) + } + if _, ok := session.Values["nonce"].(string); !ok { + t.Error("Nonce not set") + } + + // Verify session options + if session.Options.MaxAge != defaultSessionOptions.MaxAge { + t.Error("Session MaxAge not set correctly") + } + + // Verify redirect status + if rr.Code != http.StatusFound { + t.Errorf("Expected status %d, got %d", http.StatusFound, rr.Code) + } + }) + } +} + +// Add this new test function +func TestExtractGroupsAndRoles(t *testing.T) { + ts := &TestSuite{t: t} + ts.Setup() + + tests := []struct { + name string + claims map[string]interface{} + expectGroups []string + expectRoles []string + expectError bool + }{ + { + name: "Valid groups and roles", + claims: map[string]interface{}{ + "groups": []interface{}{"group1", "group2"}, + "roles": []interface{}{"role1", "role2"}, + }, + expectGroups: []string{"group1", "group2"}, + expectRoles: []string{"role1", "role2"}, + expectError: false, + }, + { + name: "Empty groups and roles", + claims: map[string]interface{}{ + "groups": []interface{}{}, + "roles": []interface{}{}, + }, + expectGroups: []string{}, + expectRoles: []string{}, + expectError: false, + }, + { + name: "Invalid groups format", + claims: map[string]interface{}{ + "groups": "not-an-array", + "roles": []interface{}{"role1"}, + }, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Create a test token with the claims + token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", tc.claims) + if err != nil { + t.Fatalf("Failed to create test token: %v", err) + } + + groups, roles, err := ts.tOidc.extractGroupsAndRoles(token) + + if tc.expectError { + if err == nil { + t.Error("Expected error but got nil") + } + } else { + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + // Compare groups + if !stringSliceEqual(groups, tc.expectGroups) { + t.Errorf("Expected groups %v, got %v", tc.expectGroups, groups) + } + + // Compare roles + if !stringSliceEqual(roles, tc.expectRoles) { + t.Errorf("Expected roles %v, got %v", tc.expectRoles, roles) + } + } + }) + } +} + +// Helper function to compare string slices +func stringSliceEqual(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} diff --git a/settings.go b/settings.go index 4da83e2..053170b 100644 --- a/settings.go +++ b/settings.go @@ -30,6 +30,7 @@ type Config struct { ExcludedURLs []string `json:"excludedURLs"` AllowedUserDomains []string `json:"allowedUserDomains"` AllowedRolesAndGroups []string `json:"allowedRolesAndGroups"` + OIDCEndSessionURL string `json:"oidcEndSessionURL"` HTTPClient *http.Client }