From 09daa1025cf3a1f9789ca9100d464d52d1b8370c Mon Sep 17 00:00:00 2001 From: Lukasz Raczylo Date: Thu, 6 Feb 2025 23:31:13 +0000 Subject: [PATCH] Follow multiple redirects during the OIDC flow. --- helpers.go | 18 ++++++++- main.go | 7 ++++ main_test.go | 105 +++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 129 insertions(+), 1 deletion(-) diff --git a/helpers.go b/helpers.go index 4475d3e..312289d 100644 --- a/helpers.go +++ b/helpers.go @@ -8,6 +8,7 @@ import ( "fmt" "io" "net/http" + "net/http/cookiejar" "net/url" "strings" "sync" @@ -68,13 +69,28 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken data.Set("refresh_token", codeOrToken) } + // Create a cookie jar for this request to handle redirects with cookies + jar, _ := cookiejar.New(nil) + client := &http.Client{ + Transport: t.httpClient.Transport, + Timeout: t.httpClient.Timeout, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + // Always follow redirects for OIDC endpoints + if len(via) >= 50 { + return fmt.Errorf("stopped after 50 redirects") + } + return nil + }, + Jar: jar, + } + req, err := http.NewRequestWithContext(ctx, "POST", t.tokenURL, strings.NewReader(data.Encode())) if err != nil { return nil, fmt.Errorf("failed to create token request: %w", err) } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - resp, err := t.httpClient.Do(req) + resp, err := client.Do(req) if err != nil { return nil, fmt.Errorf("failed to exchange tokens: %w", err) } diff --git a/main.go b/main.go index 8cac30d..8dbd3ac 100644 --- a/main.go +++ b/main.go @@ -225,6 +225,13 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h httpClient = &http.Client{ Timeout: time.Second * 15, // Reduced timeout Transport: transport, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + // Always follow redirects for OIDC endpoints + if len(via) >= 50 { + return fmt.Errorf("stopped after 50 redirects") + } + return nil + }, } } diff --git a/main_test.go b/main_test.go index f4147ce..f71369f 100644 --- a/main_test.go +++ b/main_test.go @@ -1648,6 +1648,111 @@ func TestServeHTTPRolesAndGroups(t *testing.T) { // Helper function to compare string slices +// TestExchangeTokensWithRedirects tests the token exchange process with redirects +func TestExchangeTokensWithRedirects(t *testing.T) { + ts := &TestSuite{t: t} + ts.Setup() + + tests := []struct { + name string + setupServer func() *httptest.Server + expectError bool + errorContains string + }{ + { + name: "Successful token exchange with redirects", + setupServer: func() *httptest.Server { + redirectCount := 0 + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if redirectCount < 3 { + // Set a cookie before redirecting + http.SetCookie(w, &http.Cookie{ + Name: fmt.Sprintf("redirect-cookie-%d", redirectCount), + Value: "test-value", + }) + redirectCount++ + w.Header().Set("Location", r.URL.String()) + w.WriteHeader(http.StatusFound) + return + } + + // Verify all cookies from previous redirects are present + cookies := r.Cookies() + if len(cookies) != 3 { + t.Errorf("Expected 3 cookies, got %d", len(cookies)) + } + for i := 0; i < 3; i++ { + found := false + expectedName := fmt.Sprintf("redirect-cookie-%d", i) + for _, cookie := range cookies { + if cookie.Name == expectedName { + found = true + break + } + } + if !found { + t.Errorf("Cookie %s not found", expectedName) + } + } + + // Return successful token response + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(TokenResponse{ + IDToken: "test.id.token", + AccessToken: "test-access-token", + TokenType: "Bearer", + ExpiresIn: 3600, + RefreshToken: "test-refresh-token", + }) + })) + }, + expectError: false, + }, + { + name: "Too many redirects", + setupServer: func() *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Location", r.URL.String()) + w.WriteHeader(http.StatusFound) + })) + }, + expectError: true, + errorContains: "stopped after 50 redirects", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + server := tc.setupServer() + defer server.Close() + + // Configure the test instance + tOidc := ts.tOidc + tOidc.tokenURL = server.URL + + // Test token exchange + response, err := tOidc.exchangeTokens(context.Background(), "authorization_code", "test-code", "http://callback") + + if tc.expectError { + if err == nil { + t.Error("Expected error but got nil") + } else if !strings.Contains(err.Error(), tc.errorContains) { + t.Errorf("Expected error containing %q, got %q", tc.errorContains, err.Error()) + } + } else { + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if response == nil { + t.Error("Expected token response but got nil") + } else if response.IDToken != "test.id.token" { + t.Errorf("Expected ID token %q, got %q", "test.id.token", response.IDToken) + } + } + }) + } +} + func stringSliceEqual(a, b []string) bool { if len(a) != len(b) { return false