diff --git a/google_session_test.go b/google_session_test.go index 1d5ef3b..bc442f5 100644 --- a/google_session_test.go +++ b/google_session_test.go @@ -161,9 +161,14 @@ func TestGoogleOIDCRefreshTokenHandling(t *testing.T) { session.GetRefreshToken()) } - // Check that the access token was updated - if session.GetAccessToken() != "new-id-token-from-google" { - t.Errorf("Access token not updated: got %s, expected 'new-id-token-from-google'", + // Check that the tokens were updated correctly + if session.GetIDToken() != "new-id-token-from-google" { + t.Errorf("ID token not updated: got %s, expected 'new-id-token-from-google'", + session.GetIDToken()) + } + + if session.GetAccessToken() != "new-access-token-from-google" { + t.Errorf("Access token not updated: got %s, expected 'new-access-token-from-google'", session.GetAccessToken()) } }) @@ -445,7 +450,7 @@ func TestGoogleOIDCRefreshTokenHandling(t *testing.T) { // Return a successful token response with a proper JWT return &TokenResponse{ IDToken: initialIDToken, - AccessToken: "google_access_token", + AccessToken: initialIDToken, // Use a valid JWT as the access token too RefreshToken: "google_refresh_token", ExpiresIn: 3600, }, nil @@ -459,8 +464,8 @@ func TestGoogleOIDCRefreshTokenHandling(t *testing.T) { // Return a successful refresh response with a proper JWT return &TokenResponse{ IDToken: refreshedIDToken, - AccessToken: "new_google_access_token", - RefreshToken: "", // Google doesn't always return a new refresh token + AccessToken: refreshedIDToken, // Use a valid JWT as the access token + RefreshToken: "", // Google doesn't always return a new refresh token ExpiresIn: 3600, }, nil }, diff --git a/main.go b/main.go index d0773bd..ded16fb 100644 --- a/main.go +++ b/main.go @@ -769,7 +769,7 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http return } - groups, roles, err := t.extractGroupsAndRoles(session.GetAccessToken()) + groups, roles, err := t.extractGroupsAndRoles(session.GetAccessToken()) // Using the actual access token if err != nil { t.logger.Errorf("Failed to extract groups and roles: %v", err) // Continue without group/role headers if extraction fails @@ -805,7 +805,7 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http // Set OIDC-specific headers req.Header.Set("X-Auth-Request-Redirect", req.URL.RequestURI()) req.Header.Set("X-Auth-Request-User", email) - if idToken := session.GetAccessToken(); idToken != "" { + if idToken := session.GetIDToken(); idToken != "" { req.Header.Set("X-Auth-Request-Token", idToken) } @@ -826,8 +826,8 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http RefreshToken string Claims map[string]interface{} }{ - AccessToken: accessToken, - IdToken: accessToken, // Using access token as ID token + AccessToken: session.GetAccessToken(), + IdToken: session.GetIDToken(), RefreshToken: refreshToken, Claims: claims, } @@ -887,6 +887,7 @@ func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Reque t.logger.Debug("Handling expired token: Clearing session and initiating re-authentication.") // Clear authentication data but preserve CSRF state if possible (though Clear might remove it) session.SetAuthenticated(false) + session.SetIDToken("") session.SetAccessToken("") session.SetRefreshToken("") session.SetEmail("") @@ -983,7 +984,7 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, return } - // Verify tokens and claims + // Verify ID token and claims if err := t.VerifyToken(tokenResponse.IDToken); err != nil { t.logger.Errorf("Failed to verify id_token during callback: %v", err) t.sendErrorResponse(rw, req, "Authentication failed: Could not verify ID token", http.StatusInternalServerError) @@ -1039,7 +1040,8 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, return } session.SetEmail(email) - session.SetAccessToken(tokenResponse.IDToken) + session.SetIDToken(tokenResponse.IDToken) + session.SetAccessToken(tokenResponse.AccessToken) session.SetRefreshToken(tokenResponse.RefreshToken) // Clear CSRF, Nonce, CodeVerifier after use @@ -1569,13 +1571,13 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se return false } - // Verify the new access token (ID token) + // Verify the new ID token if err := t.verifyToken(newToken.IDToken); err != nil { - truncatedNewToken := newToken.IDToken + truncatedToken := newToken.IDToken if len(newToken.IDToken) > 10 { - truncatedNewToken = newToken.IDToken[:10] + truncatedToken = newToken.IDToken[:10] } - t.logger.Errorf("refreshToken failed: Failed to verify newly obtained ID token starting with %s...: %v", truncatedNewToken, err) + t.logger.Errorf("refreshToken failed: Failed to verify newly obtained ID token starting with %s...: %v", truncatedToken, err) return false } @@ -1614,8 +1616,9 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se t.logger.Debugf("New token expires at: %v (in %v)", expiryTime, time.Until(expiryTime)) } - // Set the new access token - session.SetAccessToken(newToken.IDToken) + // Set the new tokens + session.SetIDToken(newToken.IDToken) + session.SetAccessToken(newToken.AccessToken) // Handle the refresh token if newToken.RefreshToken != "" { diff --git a/session.go b/session.go index a7781bd..de41b4d 100644 --- a/session.go +++ b/session.go @@ -757,3 +757,39 @@ func (sd *SessionData) GetIncomingPath() string { func (sd *SessionData) SetIncomingPath(path string) { sd.mainSession.Values["incoming_path"] = path } + +// GetIDToken retrieves the ID token stored in the session. +// It handles reassembling the token from multiple cookie chunks if necessary +// and decompresses it if it was stored compressed. +// +// Returns: +// - The complete, decompressed ID token string, or an empty string if not found. +func (sd *SessionData) GetIDToken() string { + token, _ := sd.mainSession.Values["id_token"].(string) + if token != "" { + compressed, _ := sd.mainSession.Values["id_token_compressed"].(bool) + if compressed { + return decompressToken(token) + } + return token + } + return "" +} + +// SetIDToken stores the provided ID token in the session. +// +// Parameters: +// - token: The ID token string to store. +func (sd *SessionData) SetIDToken(token string) { + if token == "" { + sd.mainSession.Values["id_token"] = "" + sd.mainSession.Values["id_token_compressed"] = false + return + } + + // Compress token + compressed := compressToken(token) + + sd.mainSession.Values["id_token"] = compressed + sd.mainSession.Values["id_token_compressed"] = true +} diff --git a/templated_header_execution_test.go b/templated_header_execution_test.go index a40b522..9a0bb79 100644 --- a/templated_header_execution_test.go +++ b/templated_header_execution_test.go @@ -192,14 +192,14 @@ func TestTemplateExecutionContext(t *testing.T) { expectedValue string }{ { - name: "Access and ID token identity", + name: "Access and ID token distinction", templateText: "Access: {{.AccessToken}} ID: {{.IdToken}}", data: templateData{ - AccessToken: "access-token", - IdToken: "access-token", // Same as AccessToken in processAuthorizedRequest + AccessToken: "access-token-value", + IdToken: "id-token-value", // Now these should be distinct values Claims: map[string]interface{}{}, }, - expectedValue: "Access: access-token ID: access-token", + expectedValue: "Access: access-token-value ID: id-token-value", }, { name: "Combining tokens and claims", diff --git a/templated_header_integration_test.go b/templated_header_integration_test.go index 67357ba..1ed1648 100644 --- a/templated_header_integration_test.go +++ b/templated_header_integration_test.go @@ -66,6 +66,28 @@ func TestTemplatedHeadersIntegration(t *testing.T) { "Authorization": "", }, }, + { + name: "ID Token Header", + headers: []TemplatedHeader{ + {Name: "X-ID-Token", Value: "{{.IdToken}}"}, + }, + expectedHeaders: map[string]string{ + // We'll update this dynamically after generating the token + "X-ID-Token": "", + }, + }, + { + name: "Both Token Types", + headers: []TemplatedHeader{ + {Name: "X-Access-Token", Value: "{{.AccessToken}}"}, + {Name: "X-ID-Token", Value: "{{.IdToken}}"}, + }, + expectedHeaders: map[string]string{ + // We'll update these dynamically after generating the tokens + "X-Access-Token": "", + "X-ID-Token": "", + }, + }, { name: "Missing Claim", headers: []TemplatedHeader{ @@ -210,9 +232,78 @@ func TestTemplatedHeadersIntegration(t *testing.T) { // Setup the session with authentication data session.SetAuthenticated(true) session.SetEmail("user@example.com") + // For most tests, set the same token for both ID and Access token for backward compatibility + session.SetIDToken(token) session.SetAccessToken(token) session.SetRefreshToken("test-refresh-token") + // For tests specifically testing token distinction, set different tokens + if tc.name == "ID Token Header" || tc.name == "Both Token Types" { + // Both tokens need to use the same key ID to avoid verification issues + idToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{ + "iss": "https://test-issuer.com", + "aud": "test-client-id", + "exp": float64(3000000000), + "iat": float64(1000000000), + "nbf": float64(1000000000), + "sub": "test-subject", + "nonce": "test-nonce", + "jti": generateRandomString(16), + "type": "id_token", + "email": "user@example.com", // Use the standard test email directly + }) + if err != nil { + t.Fatalf("Failed to create test ID JWT: %v", err) + } + + accessToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{ + "iss": "https://test-issuer.com", + "aud": "test-client-id", + "exp": float64(3000000000), + "iat": float64(1000000000), + "nbf": float64(1000000000), + "sub": "test-subject", + "jti": generateRandomString(16), + "type": "access_token", + "scope": "openid email profile", + "email": "user@example.com", // Include email in access token too + }) + if err != nil { + t.Fatalf("Failed to create test access JWT: %v", err) + } + + // Create a proper token exchanger that won't cause nil pointer issues + tOidc.tokenExchanger = &MockTokenExchanger{ + RefreshTokenFunc: func(refreshToken string) (*TokenResponse, error) { + return &TokenResponse{ + IDToken: idToken, + AccessToken: accessToken, + RefreshToken: refreshToken, + ExpiresIn: 3600, + }, nil + }, + } + + // Also add a mock token verifier to skip verification + tOidc.tokenVerifier = &MockTokenVerifier{ + VerifyFunc: func(token string) error { + return nil + }, + } + + // Set both tokens in the session + session.SetIDToken(idToken) + session.SetAccessToken(accessToken) + + // Update expectedHeaders for the token tests + if tc.name == "ID Token Header" { + tc.expectedHeaders["X-ID-Token"] = idToken + } else if tc.name == "Both Token Types" { + tc.expectedHeaders["X-ID-Token"] = idToken + tc.expectedHeaders["X-Access-Token"] = accessToken + } + } + if err := session.Save(req, rr); err != nil { t.Fatalf("Failed to save session: %v", err) } @@ -361,9 +452,23 @@ func TestEdgeCaseTemplatedHeaders(t *testing.T) { // Setup the session with authentication data session.SetAuthenticated(true) session.SetEmail("user@example.com") - session.SetAccessToken(token) + session.SetIDToken(token) // Use the new method + session.SetAccessToken(token) // Also set access token to match session.SetRefreshToken("test-refresh-token") + // Make sure these properties are set so refreshToken won't panic + tOidc.extractClaimsFunc = extractClaims // Ensure claims extraction works + tOidc.tokenExchanger = &MockTokenExchanger{ // Add a mock token exchanger + RefreshTokenFunc: func(refreshToken string) (*TokenResponse, error) { + return &TokenResponse{ + IDToken: token, + AccessToken: token, + RefreshToken: refreshToken, + ExpiresIn: 3600, + }, nil + }, + } + if err := session.Save(req, rr); err != nil { t.Fatalf("Failed to save session: %v", err) } diff --git a/token_handling_test.go b/token_handling_test.go new file mode 100644 index 0000000..84fffd1 --- /dev/null +++ b/token_handling_test.go @@ -0,0 +1,309 @@ +package traefikoidc + +import ( + "bytes" + "net/http" + "net/http/httptest" + "testing" + "text/template" + "time" + + "golang.org/x/time/rate" +) + +// TestTokenTypeDistinction tests that AccessToken and IdToken are correctly distinguished in templates +func TestTokenTypeDistinction(t *testing.T) { + // Define test data where AccessToken and IdToken are deliberately different + type templateData struct { + AccessToken string + IdToken string + RefreshToken string + Claims map[string]interface{} + } + + testData := templateData{ + AccessToken: "test-access-token-abc123", + IdToken: "test-id-token-xyz789", + RefreshToken: "test-refresh-token", + Claims: map[string]interface{}{ + "sub": "test-subject", + "email": "user@example.com", + }, + } + + // Test cases + tests := []struct { + name string + templateText string + expectedValue string + }{ + { + name: "Access Token Only", + templateText: "Bearer {{.AccessToken}}", + expectedValue: "Bearer test-access-token-abc123", + }, + { + name: "ID Token Only", + templateText: "ID: {{.IdToken}}", + expectedValue: "ID: test-id-token-xyz789", + }, + { + name: "Both Tokens", + templateText: "Access: {{.AccessToken}} ID: {{.IdToken}}", + expectedValue: "Access: test-access-token-abc123 ID: test-id-token-xyz789", + }, + { + name: "Both Tokens in Authorization Format", + templateText: "Bearer {{.AccessToken}} and Bearer {{.IdToken}}", + expectedValue: "Bearer test-access-token-abc123 and Bearer test-id-token-xyz789", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + tmpl, err := template.New("test").Parse(tc.templateText) + if err != nil { + t.Fatalf("Failed to parse template: %v", err) + } + + var buf bytes.Buffer + err = tmpl.Execute(&buf, testData) + if err != nil { + t.Fatalf("Failed to execute template: %v", err) + } + + result := buf.String() + if result != tc.expectedValue { + t.Errorf("Expected template output %q, got %q", tc.expectedValue, result) + } + }) + } +} + +// TestTokenTypeIntegration tests the integration of ID and access tokens with the middleware +func TestTokenTypeIntegration(t *testing.T) { + // Create a TestSuite to use its helper methods and fields + ts := &TestSuite{t: t} + ts.Setup() + + // Create different tokens for ID and access tokens + idToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{ + "iss": "https://test-issuer.com", + "aud": "test-client-id", + "exp": float64(3000000000), + "iat": float64(1000000000), + "nbf": float64(1000000000), + "sub": "test-subject", + "nonce": "test-nonce", + "jti": generateRandomString(16), + "token_type": "id_token", + "email": "user@example.com", + }) + if err != nil { + t.Fatalf("Failed to create test ID JWT: %v", err) + } + + accessToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{ + "iss": "https://test-issuer.com", + "aud": "test-client-id", + "exp": float64(3000000000), + "iat": float64(1000000000), + "nbf": float64(1000000000), + "sub": "test-subject", + "jti": generateRandomString(16), + "token_type": "access_token", + "scope": "openid profile email", + "email": "user@example.com", // Add email to access token so it's available in claims + }) + if err != nil { + t.Fatalf("Failed to create test access JWT: %v", err) + } + + // Define test headers that use both token types + headers := []TemplatedHeader{ + {Name: "X-ID-Token", Value: "{{.IdToken}}"}, + {Name: "X-Access-Token", Value: "{{.AccessToken}}"}, + {Name: "Authorization", Value: "Bearer {{.AccessToken}}"}, + {Name: "X-Email-From-Claims", Value: "{{.Claims.email}}"}, + } + + // Store intercepted headers for verification + interceptedHeaders := make(map[string]string) + + // Create a test next handler that captures the headers + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Capture headers for verification + for _, header := range headers { + if value := r.Header.Get(header.Name); value != "" { + interceptedHeaders[header.Name] = value + } + } + w.WriteHeader(http.StatusOK) + }) + + // Create the TraefikOidc instance + tOidc := &TraefikOidc{ + next: nextHandler, + name: "test", + redirURLPath: "/callback", + logoutURLPath: "/callback/logout", + issuerURL: "https://test-issuer.com", + clientID: "test-client-id", + clientSecret: "test-client-secret", + jwkCache: ts.mockJWKCache, + jwksURL: "https://test-jwks-url.com", + tokenBlacklist: NewCache(), + tokenCache: NewTokenCache(), + limiter: rate.NewLimiter(rate.Every(time.Second), 10), + logger: NewLogger("debug"), + allowedUserDomains: map[string]struct{}{"example.com": {}}, + excludedURLs: map[string]struct{}{"/favicon": {}}, + httpClient: &http.Client{}, + initComplete: make(chan struct{}), + sessionManager: ts.sessionManager, + extractClaimsFunc: extractClaims, + headerTemplates: make(map[string]*template.Template), + } + + // Initialize and parse header templates + for _, header := range headers { + tmpl, err := template.New(header.Name).Parse(header.Value) + if err != nil { + t.Fatalf("Failed to parse header template for %s: %v", header.Name, err) + } + tOidc.headerTemplates[header.Name] = tmpl + } + + // Close the initComplete channel to bypass the waiting + close(tOidc.initComplete) + + // Create a test request + req := httptest.NewRequest("GET", "/protected", nil) + req.Header.Set("X-Forwarded-Proto", "https") + req.Header.Set("X-Forwarded-Host", "example.com") + rr := httptest.NewRecorder() + + // Create a session + session, err := tOidc.sessionManager.GetSession(req) + if err != nil { + t.Fatalf("Failed to get session: %v", err) + } + + // Setup the session with authentication data + session.SetAuthenticated(true) + session.SetEmail("user@example.com") + session.SetIDToken(idToken) // Set the ID token + session.SetAccessToken(accessToken) // Set the access token + session.SetRefreshToken("test-refresh-token") + + if err := session.Save(req, rr); err != nil { + t.Fatalf("Failed to save session: %v", err) + } + + // Add session cookies to the request + for _, cookie := range rr.Result().Cookies() { + req.AddCookie(cookie) + } + + // Reset the response recorder for the main test + rr = httptest.NewRecorder() + + // Process the request + tOidc.ServeHTTP(rr, req) + + // Check status code + if rr.Code != http.StatusOK { + t.Errorf("Expected status code %d, got %d", http.StatusOK, rr.Code) + } + + // Verify headers were set correctly + expectedHeaders := map[string]string{ + "X-ID-Token": idToken, + "X-Access-Token": accessToken, + "Authorization": "Bearer " + accessToken, + "X-Email-From-Claims": "user@example.com", + } + + for name, expectedValue := range expectedHeaders { + if value, exists := interceptedHeaders[name]; !exists { + t.Errorf("Expected header %s was not set", name) + } else if value != expectedValue { + t.Errorf("Header %s expected value %q, got %q", name, expectedValue, value) + } + } +} + +// TestSessionIDTokenAccessToken tests that the SessionData correctly stores and retrieves +// both ID tokens and access tokens separately +func TestSessionIDTokenAccessToken(t *testing.T) { + // Create a logger for the session manager + logger := NewLogger("debug") + + // Create a session manager + sessionManager, err := NewSessionManager("test-session-encryption-key-at-least-32-bytes", false, logger) + if err != nil { + t.Fatalf("Failed to create session manager: %v", err) + } + + // Create a test request + req := httptest.NewRequest("GET", "/test", nil) + rr := httptest.NewRecorder() + + // Get a session + session, err := sessionManager.GetSession(req) + if err != nil { + t.Fatalf("Failed to get session: %v", err) + } + + // Set test tokens + idToken := "test-id-token-123" + accessToken := "test-access-token-456" + refreshToken := "test-refresh-token-789" + + // Store tokens in session + session.SetIDToken(idToken) + session.SetAccessToken(accessToken) + session.SetRefreshToken(refreshToken) + + // Save the session + if err := session.Save(req, rr); err != nil { + t.Fatalf("Failed to save session: %v", err) + } + + // Get cookies from response + cookies := rr.Result().Cookies() + + // Create a new request with those cookies + req2 := httptest.NewRequest("GET", "/test", nil) + for _, cookie := range cookies { + req2.AddCookie(cookie) + } + + // Get the session again + session2, err := sessionManager.GetSession(req2) + if err != nil { + t.Fatalf("Failed to get session from request with cookies: %v", err) + } + + // Verify that the tokens were correctly stored and retrieved + retrievedIDToken := session2.GetIDToken() + retrievedAccessToken := session2.GetAccessToken() + retrievedRefreshToken := session2.GetRefreshToken() + + if retrievedIDToken != idToken { + t.Errorf("ID token mismatch: expected %q, got %q", idToken, retrievedIDToken) + } + + if retrievedAccessToken != accessToken { + t.Errorf("Access token mismatch: expected %q, got %q", accessToken, retrievedAccessToken) + } + + if retrievedRefreshToken != refreshToken { + t.Errorf("Refresh token mismatch: expected %q, got %q", refreshToken, retrievedRefreshToken) + } + + // Verify that the tokens are distinct + if retrievedIDToken == retrievedAccessToken { + t.Errorf("ID token and Access token should be different, but both are %q", retrievedIDToken) + } +}