From 13572003a5e31b41f921ae57ba724ba105930881 Mon Sep 17 00:00:00 2001 From: Lukasz Raczylo Date: Fri, 26 Jul 2024 12:57:03 +0100 Subject: [PATCH] Even more tests, crossed 50% coverage --- helpers.go | 3 +-- main.go | 8 +++--- main_test.go | 70 +++++++++++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 75 insertions(+), 6 deletions(-) diff --git a/helpers.go b/helpers.go index 24836ce..24126bb 100644 --- a/helpers.go +++ b/helpers.go @@ -79,8 +79,7 @@ func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) { return } - rw.WriteHeader(http.StatusForbidden) - rw.Write([]byte("Logged out")) + http.Error(rw, "Logged out", http.StatusForbidden) } func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request) (bool, string) { diff --git a/main.go b/main.go index e2c96fc..3edd08f 100644 --- a/main.go +++ b/main.go @@ -181,8 +181,7 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { if req.URL.Path == t.logoutURLPath { t.handleLogout(rw, req) - http.Error(rw, "Logged out", http.StatusForbidden) - return + return // Remove the http.Error call here } if t.redirectURL == "" { @@ -300,11 +299,14 @@ func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce string) string { "client_id": {t.clientID}, "response_type": {"code"}, "redirect_uri": {redirectURL}, - "scope": {strings.Join(t.scopes, " ")}, "state": {state}, "nonce": {nonce}, } + if len(t.scopes) > 0 { + params.Set("scope", strings.Join(t.scopes, " ")) + } + return fmt.Sprintf("%s?%s", t.authURL, params.Encode()) } diff --git a/main_test.go b/main_test.go index ab06b16..3a52dfb 100644 --- a/main_test.go +++ b/main_test.go @@ -28,6 +28,9 @@ type MockHTTPClient struct { func (m *MockHTTPClient) RoundTrip(req *http.Request) (*http.Response, error) { args := m.Called(req) + if args.Get(0) == nil { + return nil, args.Error(1) + } return args.Get(0).(*http.Response), args.Error(1) } @@ -305,7 +308,7 @@ func (suite *TraefikOidcTestSuite) TestHandleLogout() { suite.oidc.handleLogout(rw, req) suite.Equal(http.StatusForbidden, rw.Code) - suite.Equal("Logged out", rw.Body.String()) + suite.Equal("Logged out\n", rw.Body.String()) } func (suite *TraefikOidcTestSuite) TestExtractClaims() { @@ -584,3 +587,68 @@ func TestTraefikOidc_ServeHTTP(t *testing.T) { }) } } + +func (suite *TraefikOidcTestSuite) TestBuildAuthURL_CustomScopes() { + suite.oidc.scopes = []string{"openid", "email", "custom_scope"} + authURL := suite.oidc.buildAuthURL("http://example.com/callback", "test_state", "test_nonce") + suite.Contains(authURL, "scope=openid+email+custom_scope") +} + +func (suite *TraefikOidcTestSuite) TestBuildAuthURL_EmptyScopes() { + suite.oidc.scopes = []string{} + authURL := suite.oidc.buildAuthURL("http://example.com/callback", "test_state", "test_nonce") + suite.NotContains(authURL, "scope=") +} + +func (suite *TraefikOidcTestSuite) TestDetermineScheme_ForceHTTPS() { + suite.oidc.forceHTTPS = true + req := httptest.NewRequest("GET", "http://example.com", nil) + scheme := suite.oidc.determineScheme(req) + suite.Equal("https", scheme) +} + +func (suite *TraefikOidcTestSuite) TestHandleLogout_CustomLogoutURL() { + suite.oidc.logoutURLPath = "/custom-logout" + req := httptest.NewRequest("GET", "http://example.com/custom-logout", nil) + rw := httptest.NewRecorder() + + session := sessions.NewSession(suite.mockStore, cookieName) + session.Values["id_token"] = "test_token" + + suite.mockStore.On("Get", req, cookieName).Return(session, nil) + suite.mockStore.On("Save", mock.Anything, mock.Anything, mock.Anything).Return(nil) + + suite.oidc.ServeHTTP(rw, req) + + suite.Equal(http.StatusForbidden, rw.Code) + suite.Equal("Logged out\n", rw.Body.String()) +} + +func (suite *TraefikOidcTestSuite) TestVerifyToken_RateLimitReached() { + suite.oidc.limiter = rate.NewLimiter(rate.Every(time.Hour), 1) // Set a very low limit + suite.oidc.limiter.Allow() // Use up the only allowed request + + err := suite.oidc.VerifyToken("some_token") + suite.Error(err) + suite.Contains(err.Error(), "rate limit exceeded") +} + +func (suite *TraefikOidcTestSuite) TestVerifyToken_InvalidJWTFormat() { + invalidToken := "invalid.jwt.format" + err := suite.oidc.VerifyToken(invalidToken) + suite.Error(err) + suite.Contains(err.Error(), "failed to parse JWT") +} + +func (suite *TraefikOidcTestSuite) TestDiscoverProviderMetadata_InvalidURL() { + invalidURL := "invalid-url" + httpClient := &http.Client{ + Transport: suite.mockHTTPClient, + } + + suite.mockHTTPClient.On("RoundTrip", mock.Anything).Return(nil, fmt.Errorf("invalid URL")) + + _, err := discoverProviderMetadata(invalidURL, *httpClient) + suite.Error(err) + suite.Contains(err.Error(), "failed to fetch provider metadata") +}