From 87677564318ec93e9ec35ec360ae6699f940f6aa Mon Sep 17 00:00:00 2001 From: Lukasz Raczylo Date: Fri, 26 Jul 2024 00:08:18 +0100 Subject: [PATCH] Add tests --- helpers.go | 8 +- jwt.go | 50 +---- main.go | 102 +++++++-- main_test.go | 586 +++++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 675 insertions(+), 71 deletions(-) create mode 100644 main_test.go diff --git a/helpers.go b/helpers.go index a1f72bd..24836ce 100644 --- a/helpers.go +++ b/helpers.go @@ -78,6 +78,9 @@ func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) { handleError(rw, "Failed to save session", http.StatusInternalServerError, t.logger) return } + + rw.WriteHeader(http.StatusForbidden) + rw.Write([]byte("Logged out")) } func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request) (bool, string) { @@ -97,15 +100,16 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request) code := req.URL.Query().Get("code") redirectURL := buildFullURL(t.scheme, req.Host, t.redirURLPath) + oauth2Token, err := t.exchangeCodeForToken(ctx, code, redirectURL) if err != nil { - handleError(rw, "Failed to exchange token", http.StatusInternalServerError, t.logger) + handleError(rw, "Failed to exchange token", http.StatusUnauthorized, t.logger) return false, "" } rawIDToken, ok := oauth2Token["id_token"].(string) if !ok { - handleError(rw, "No id_token field in oauth2 token", http.StatusInternalServerError, t.logger) + handleError(rw, "No id_token field in oauth2 token", http.StatusUnauthorized, t.logger) return false, "" } diff --git a/jwt.go b/jwt.go index 153bb4b..de84ba1 100644 --- a/jwt.go +++ b/jwt.go @@ -132,57 +132,11 @@ func decodeSegment(seg string) (map[string]interface{}, error) { } func (t *TraefikOidc) verifyAndCacheToken(token string) error { - t.logger.Debugf("Verifying token") - if !t.limiter.Allow() { - return fmt.Errorf("rate limit exceeded") - } - - if t.tokenBlacklist.IsBlacklisted(token) { - return fmt.Errorf("token is blacklisted") - } - - if _, exists := t.tokenCache.Get(token); exists { - t.logger.Debugf("Token is valid and cached") - return nil // Token is valid and cached - } - - jwt, err := parseJWT(token) - if err != nil { - return fmt.Errorf("failed to parse JWT: %w", err) - } - - if err := t.verifyJWTSignatureAndClaims(jwt, token); err != nil { - return err - } - - expirationTime := time.Unix(int64(jwt.Claims["exp"].(float64)), 0) - t.tokenCache.Set(token, expirationTime) - - return nil + return t.tokenVerifier.VerifyToken(token) } func (t *TraefikOidc) verifyJWTSignatureAndClaims(jwt *JWT, token string) error { - t.logger.Debugf("Verifying JWT signature and claims") - jwks, err := t.jwkCache.GetJWKS(t.jwksURL, t.httpClient) - if err != nil { - return fmt.Errorf("failed to get JWKS: %w", err) - } - - kid, ok := jwt.Header["kid"].(string) - if !ok { - return fmt.Errorf("missing key ID in token header") - } - - publicKeyPEM, err := getPublicKeyPEM(jwks, kid) - if err != nil { - return err - } - - if err := verifySignature(token, publicKeyPEM); err != nil { - return fmt.Errorf("signature verification failed: %w", err) - } - - return jwt.Verify(t.issuerURL, t.clientID) + return t.jwtVerifier.VerifyJWTSignatureAndClaims(jwt, token) } func getPublicKeyPEM(jwks *JWKSet, kid string) ([]byte, error) { diff --git a/main.go b/main.go index e8ca97b..e2c96fc 100644 --- a/main.go +++ b/main.go @@ -14,6 +14,14 @@ import ( "golang.org/x/time/rate" ) +type TokenVerifier interface { + VerifyToken(token string) error +} + +type JWTVerifier interface { + VerifyJWTSignatureAndClaims(jwt *JWT, token string) error +} + type TraefikOidc struct { next http.Handler name string @@ -36,6 +44,8 @@ type TraefikOidc struct { httpClient *http.Client logger *Logger redirectURL string + tokenVerifier TokenVerifier + jwtVerifier JWTVerifier } type ProviderMetadata struct { @@ -45,6 +55,59 @@ type ProviderMetadata struct { JWKSURL string `json:"jwks_uri"` } +func (t *TraefikOidc) VerifyToken(token string) error { + t.logger.Debugf("Verifying token") + if !t.limiter.Allow() { + return fmt.Errorf("rate limit exceeded") + } + + if t.tokenBlacklist.IsBlacklisted(token) { + return fmt.Errorf("token is blacklisted") + } + + if _, exists := t.tokenCache.Get(token); exists { + t.logger.Debugf("Token is valid and cached") + return nil // Token is valid and cached + } + + jwt, err := parseJWT(token) + if err != nil { + return fmt.Errorf("failed to parse JWT: %w", err) + } + + if err := t.VerifyJWTSignatureAndClaims(jwt, token); err != nil { + return err + } + + expirationTime := time.Unix(int64(jwt.Claims["exp"].(float64)), 0) + t.tokenCache.Set(token, expirationTime) + + return nil +} + +func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error { + jwks, err := t.jwkCache.GetJWKS(t.jwksURL, t.httpClient) + if err != nil { + return fmt.Errorf("failed to get JWKS: %w", err) + } + + kid, ok := jwt.Header["kid"].(string) + if !ok { + return fmt.Errorf("missing key ID in token header") + } + + publicKeyPEM, err := getPublicKeyPEM(jwks, kid) + if err != nil { + return err + } + + if err := verifySignature(token, publicKeyPEM); err != nil { + return fmt.Errorf("signature verification failed: %w", err) + } + + return jwt.Verify(t.issuerURL, t.clientID) +} + func New(ctx context.Context, next http.Handler, config *Config, name string) (http.Handler, error) { store := sessions.NewCookieStore([]byte(config.SessionEncryptionKey)) store.Options = &sessions.Options{ @@ -82,6 +145,9 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h logger: NewLogger(config.LogLevel), redirectURL: "", } + + t.tokenVerifier = t + t.jwtVerifier = t t.startTokenCleanup() return t, nil } @@ -92,6 +158,9 @@ func discoverProviderMetadata(providerURL string, httpClient http.Client) (*Prov if err != nil { return nil, fmt.Errorf("failed to fetch provider metadata: %w", err) } + if resp == nil { + return nil, fmt.Errorf("received nil response from provider") + } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { @@ -150,32 +219,23 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { } func (t *TraefikOidc) determineScheme(req *http.Request) string { - scheme := req.URL.Scheme - if scheme == "" { - scheme = req.Header.Get("X-Forwarded-Proto") - } - if scheme == "" { - if req.TLS != nil { - scheme = "https" - } else { - scheme = "http" - } - } if t.forceHTTPS { - scheme = "https" + return "https" } - return scheme + if scheme := req.Header.Get("X-Forwarded-Proto"); scheme != "" { + return scheme + } + if req.TLS != nil { + return "https" + } + return "http" } func (t *TraefikOidc) determineHost(req *http.Request) string { - host := req.URL.Host - if host == "" { - host = req.Header.Get("X-Forwarded-Host") + if host := req.Header.Get("X-Forwarded-Host"); host != "" { + return host } - if host == "" { - host = req.Host - } - return host + return req.Host } func (t *TraefikOidc) isUserAuthenticated(session *sessions.Session) bool { @@ -232,7 +292,7 @@ func (t *TraefikOidc) initiateAuthentication(rw http.ResponseWriter, req *http.R } func (t *TraefikOidc) verifyToken(token string) error { - return t.verifyAndCacheToken(token) + return t.tokenVerifier.VerifyToken(token) } func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce string) string { diff --git a/main_test.go b/main_test.go new file mode 100644 index 0000000..ab06b16 --- /dev/null +++ b/main_test.go @@ -0,0 +1,586 @@ +// main_test.go + +package traefikoidc + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "math/big" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/gorilla/sessions" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" + "golang.org/x/time/rate" +) + +type MockHTTPClient struct { + mock.Mock +} + +func (m *MockHTTPClient) RoundTrip(req *http.Request) (*http.Response, error) { + args := m.Called(req) + return args.Get(0).(*http.Response), args.Error(1) +} + +type MockSessionStore struct { + mock.Mock +} + +func (m *MockSessionStore) Get(r *http.Request, name string) (*sessions.Session, error) { + args := m.Called(r, name) + if session, ok := args.Get(0).(*sessions.Session); ok { + return session, args.Error(1) + } + return nil, args.Error(1) +} + +func (m *MockSessionStore) New(r *http.Request, name string) (*sessions.Session, error) { + args := m.Called(r, name) + return args.Get(0).(*sessions.Session), args.Error(1) +} + +func (m *MockSessionStore) Save(r *http.Request, w http.ResponseWriter, s *sessions.Session) error { + args := m.Called(r, w, s) + return args.Error(0) +} + +type MockTokenVerifier struct { + mock.Mock +} + +func (m *MockTokenVerifier) VerifyToken(token string) error { + args := m.Called(token) + return args.Error(0) +} + +type MockJWTVerifier struct { + mock.Mock +} + +func (m *MockJWTVerifier) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error { + args := m.Called(jwt, token) + return args.Error(0) +} + +type TraefikOidcTestSuite struct { + suite.Suite + oidc *TraefikOidc + mockHTTPClient *MockHTTPClient + mockStore *MockSessionStore + mockTokenVerifier *MockTokenVerifier + mockJWTVerifier *MockJWTVerifier +} + +func (suite *TraefikOidcTestSuite) SetupTest() { + suite.mockHTTPClient = new(MockHTTPClient) + suite.mockStore = new(MockSessionStore) + suite.mockTokenVerifier = new(MockTokenVerifier) + suite.mockJWTVerifier = new(MockJWTVerifier) + + config := &Config{ + ProviderURL: "https://example.com", + ClientID: "test-client-id", + ClientSecret: "test-client-secret", + CallbackURL: "/callback", + LogoutURL: "/logout", + SessionEncryptionKey: "test-encryption-key", + Scopes: []string{"openid", "email", "profile"}, + } + + suite.oidc = &TraefikOidc{ + clientID: config.ClientID, + clientSecret: config.ClientSecret, + redirURLPath: config.CallbackURL, + logoutURLPath: config.LogoutURL, + store: suite.mockStore, + httpClient: &http.Client{Transport: suite.mockHTTPClient}, + jwkCache: &JWKCache{}, + tokenBlacklist: NewTokenBlacklist(), + tokenCache: NewTokenCache(), + logger: NewLogger("debug"), + limiter: rate.NewLimiter(rate.Every(time.Second), 100), + authURL: "https://example.com/auth", + tokenURL: "https://example.com/token", + jwksURL: "https://example.com/.well-known/jwks.json", + tokenVerifier: suite.mockTokenVerifier, + jwtVerifier: suite.mockJWTVerifier, + } +} + +func (suite *TraefikOidcTestSuite) TestServeHTTP_AuthenticatedUser() { + req := httptest.NewRequest("GET", "http://example.com", nil) + rw := httptest.NewRecorder() + + session := sessions.NewSession(suite.mockStore, cookieName) + session.Values["authenticated"] = true + + claims := map[string]interface{}{ + "exp": float64(time.Now().Add(time.Hour).Unix()), + } + claimsJSON, _ := json.Marshal(claims) + encodedClaims := base64.RawURLEncoding.EncodeToString(claimsJSON) + mockToken := fmt.Sprintf("header.%s.signature", encodedClaims) + session.Values["id_token"] = mockToken + + suite.mockStore.On("Get", req, cookieName).Return(session, nil) + suite.mockStore.On("Save", mock.Anything, mock.Anything, mock.Anything).Return(nil) + + suite.mockTokenVerifier.On("VerifyToken", mockToken).Return(nil) + + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) + }) + + suite.oidc.next = nextHandler + + suite.oidc.ServeHTTP(rw, req) + + suite.Equal(http.StatusOK, rw.Code) + suite.Equal("OK", rw.Body.String()) +} + +func (suite *TraefikOidcTestSuite) TestServeHTTP_CallbackPath() { + req := httptest.NewRequest("GET", "http://example.com"+suite.oidc.redirURLPath+"?code=test_code&state=test_state", nil) + rw := httptest.NewRecorder() + + session := sessions.NewSession(suite.mockStore, cookieName) + session.Values["csrf"] = "test_state" + session.Values["incoming_path"] = "/original_path" + + suite.mockStore.On("Get", req, cookieName).Return(session, nil) + suite.mockStore.On("Save", mock.Anything, mock.Anything, mock.Anything).Return(nil) + + claims := map[string]interface{}{ + "exp": float64(time.Now().Add(time.Hour).Unix()), + "email": "test@example.com", + } + claimsJSON, _ := json.Marshal(claims) + encodedClaims := base64.RawURLEncoding.EncodeToString(claimsJSON) + mockToken := fmt.Sprintf("header.%s.signature", encodedClaims) + + tokenResponse := map[string]interface{}{ + "id_token": mockToken, + } + tokenResponseJSON, _ := json.Marshal(tokenResponse) + + suite.mockHTTPClient.On("RoundTrip", mock.MatchedBy(func(req *http.Request) bool { + return strings.Contains(req.URL.String(), "token") + })).Return(&http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(string(tokenResponseJSON))), + }, nil) + + suite.mockTokenVerifier.On("VerifyToken", mockToken).Return(nil) + + suite.oidc.ServeHTTP(rw, req) + + suite.Equal(http.StatusFound, rw.Code) + suite.Equal("/original_path", rw.Header().Get("Location")) +} + +func (suite *TraefikOidcTestSuite) TestVerifyToken() { + token := "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Rfa2lkIn0.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyLCJleHAiOjE1MTYyMzkxMjJ9.ZmFrZV9zaWduYXR1cmU" + + suite.mockTokenVerifier.On("VerifyToken", token).Return(nil) + + err := suite.oidc.verifyToken(token) + suite.Require().NoError(err) +} + +func (suite *TraefikOidcTestSuite) TestBuildAuthURL() { + authURL := suite.oidc.buildAuthURL("http://example.com/callback", "test_state", "test_nonce") + suite.Contains(authURL, suite.oidc.authURL) + suite.Contains(authURL, "client_id="+suite.oidc.clientID) + suite.Contains(authURL, "redirect_uri=http%3A%2F%2Fexample.com%2Fcallback") + suite.Contains(authURL, "state=test_state") + suite.Contains(authURL, "nonce=test_nonce") +} + +func (suite *TraefikOidcTestSuite) TestJWKToPEM() { + jwk := &JWK{ + N: base64.RawURLEncoding.EncodeToString(big.NewInt(12345).Bytes()), + E: base64.RawURLEncoding.EncodeToString(big.NewInt(65537).Bytes()), + } + pem, err := jwkToPEM(jwk) + suite.Require().NoError(err) + suite.NotEmpty(pem) +} + +func (suite *TraefikOidcTestSuite) TestTokenBlacklist() { + tb := NewTokenBlacklist() + token := "test_token" + expiration := time.Now().Add(time.Hour) + + tb.Add(token, expiration) + suite.True(tb.IsBlacklisted(token)) + + tb.Cleanup() + suite.True(tb.IsBlacklisted(token)) + + tb.Add("expired_token", time.Now().Add(-time.Hour)) + tb.Cleanup() + suite.False(tb.IsBlacklisted("expired_token")) +} + +func (suite *TraefikOidcTestSuite) TestTokenCache() { + tc := NewTokenCache() + token := "test_token" + expiration := time.Now().Add(time.Hour) + + tc.Set(token, expiration) + info, exists := tc.Get(token) + suite.True(exists) + suite.Equal(token, info.Token) + suite.Equal(expiration, info.ExpiresAt) + + tc.Delete(token) + _, exists = tc.Get(token) + suite.False(exists) + + tc.Set("expired_token", time.Now().Add(-time.Hour)) + tc.Cleanup() + _, exists = tc.Get("expired_token") + suite.False(exists) +} + +func TestTraefikOidcSuite(t *testing.T) { + suite.Run(t, new(TraefikOidcTestSuite)) +} + +func (suite *TraefikOidcTestSuite) TestGenerateNonce() { + nonce, err := generateNonce() + suite.NoError(err) + suite.Len(nonce, 44) // Base64 encoded 32 bytes +} + +func (suite *TraefikOidcTestSuite) TestBuildFullURL() { + url := buildFullURL("https", "example.com", "/path") + suite.Equal("https://example.com/path", url) + + url = buildFullURL("", "example.com", "/path") + suite.Equal("http://example.com/path", url) +} + +func (suite *TraefikOidcTestSuite) TestExchangeCodeForToken() { + ctx := context.Background() + code := "test_code" + redirectURL := "http://example.com/callback" + + expectedToken := map[string]interface{}{ + "access_token": "test_access_token", + "id_token": "test_id_token", + } + tokenJSON, _ := json.Marshal(expectedToken) + + suite.mockHTTPClient.On("RoundTrip", mock.Anything).Return(&http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewReader(tokenJSON)), + }, nil).Once() + + token, err := suite.oidc.exchangeCodeForToken(ctx, code, redirectURL) + suite.NoError(err) + suite.Equal(expectedToken, token) +} + +func (suite *TraefikOidcTestSuite) TestHandleLogout() { + req := httptest.NewRequest("GET", "http://example.com/logout", nil) + rw := httptest.NewRecorder() + + session := sessions.NewSession(suite.mockStore, cookieName) + session.Values["id_token"] = "test_token" + + suite.mockStore.On("Get", req, cookieName).Return(session, nil) + suite.mockStore.On("Save", mock.Anything, mock.Anything, mock.Anything).Return(nil) + + suite.oidc.handleLogout(rw, req) + + suite.Equal(http.StatusForbidden, rw.Code) + suite.Equal("Logged out", rw.Body.String()) +} + +func (suite *TraefikOidcTestSuite) TestExtractClaims() { + tokenString := "header.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.signature" + claims, err := extractClaims(tokenString) + suite.NoError(err) + suite.Equal("1234567890", claims["sub"]) + suite.Equal("John Doe", claims["name"]) + suite.Equal(float64(1516239022), claims["iat"]) +} + +func (suite *TraefikOidcTestSuite) TestDiscoverProviderMetadata() { + providerURL := "https://example.com" + expectedMetadata := &ProviderMetadata{ + Issuer: "https://example.com", + AuthURL: "https://example.com/auth", + TokenURL: "https://example.com/token", + JWKSURL: "https://example.com/.well-known/jwks.json", + } + metadataJSON, _ := json.Marshal(expectedMetadata) + + httpClient := &http.Client{ + Transport: suite.mockHTTPClient, + } + + suite.mockHTTPClient.On("RoundTrip", mock.Anything).Return(&http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewReader(metadataJSON)), + }, nil) + + metadata, err := discoverProviderMetadata(providerURL, *httpClient) + suite.NoError(err) + suite.Equal(expectedMetadata, metadata) +} + +func (suite *TraefikOidcTestSuite) TestDetermineScheme() { + req := httptest.NewRequest("GET", "http://example.com", nil) + scheme := suite.oidc.determineScheme(req) + suite.Equal("http", scheme) + + req.Header.Set("X-Forwarded-Proto", "https") + scheme = suite.oidc.determineScheme(req) + suite.Equal("https", scheme) + + suite.oidc.forceHTTPS = true + scheme = suite.oidc.determineScheme(req) + suite.Equal("https", scheme) +} + +func (suite *TraefikOidcTestSuite) TestDetermineHost() { + req := httptest.NewRequest("GET", "http://example.com", nil) + host := suite.oidc.determineHost(req) + suite.Equal("example.com", host) + + req.Header.Set("X-Forwarded-Host", "forwarded.example.com") + host = suite.oidc.determineHost(req) + suite.Equal("forwarded.example.com", host) +} + +func (suite *TraefikOidcTestSuite) TestIsUserAuthenticated() { + session := sessions.NewSession(suite.mockStore, cookieName) + session.Values["authenticated"] = true + session.Values["id_token"] = "valid.eyJleHAiOjk5OTk5OTk5OTl9.signature" + + suite.mockTokenVerifier.On("VerifyToken", "valid.eyJleHAiOjk5OTk5OTk5OTl9.signature").Return(nil) + + authenticated := suite.oidc.isUserAuthenticated(session) + suite.True(authenticated) +} + +func (suite *TraefikOidcTestSuite) TestInitiateAuthentication() { + req := httptest.NewRequest("GET", "http://example.com", nil) + rw := httptest.NewRecorder() + session := sessions.NewSession(suite.mockStore, cookieName) + + suite.mockStore.On("Save", mock.Anything, mock.Anything, mock.Anything).Return(nil) + + suite.oidc.initiateAuthentication(rw, req, session, "http://example.com/callback") + + suite.Equal(http.StatusFound, rw.Code) + location := rw.Header().Get("Location") + suite.Contains(location, suite.oidc.authURL) + suite.Contains(location, "redirect_uri=http%3A%2F%2Fexample.com%2Fcallback") +} + +func (suite *TraefikOidcTestSuite) TestRevokeToken() { + token := "valid.eyJleHAiOjk5OTk5OTk5OTl9.signature" + suite.oidc.RevokeToken(token) + + _, exists := suite.oidc.tokenCache.Get(token) + suite.False(exists) + suite.True(suite.oidc.tokenBlacklist.IsBlacklisted(token)) +} + +func (suite *TraefikOidcTestSuite) TestServeHTTP_InvalidSession() { + req := httptest.NewRequest("GET", "http://example.com", nil) + rw := httptest.NewRecorder() + + suite.mockStore.On("Get", req, cookieName).Return((*sessions.Session)(nil), fmt.Errorf("invalid session")) + + suite.oidc.ServeHTTP(rw, req) + + suite.Equal(http.StatusInternalServerError, rw.Code) + suite.Contains(rw.Body.String(), "Session error") +} + +func (suite *TraefikOidcTestSuite) TestServeHTTP_ExpiredToken() { + req := httptest.NewRequest("GET", "http://example.com", nil) + rw := httptest.NewRecorder() + + session := sessions.NewSession(suite.mockStore, cookieName) + session.Values["authenticated"] = true + session.Values["id_token"] = "expired.eyJleHAiOjF9.signature" // expired token + + suite.mockStore.On("Get", req, cookieName).Return(session, nil) + suite.mockStore.On("Save", mock.Anything, mock.Anything, mock.Anything).Return(nil) + + suite.oidc.ServeHTTP(rw, req) + + suite.Equal(http.StatusFound, rw.Code) // Should redirect to authentication +} + +func (suite *TraefikOidcTestSuite) TestHandleCallback_InvalidState() { + req := httptest.NewRequest("GET", "http://example.com"+suite.oidc.redirURLPath+"?code=test_code&state=invalid_state", nil) + rw := httptest.NewRecorder() + + session := sessions.NewSession(suite.mockStore, cookieName) + session.Values["csrf"] = "valid_state" + + suite.mockStore.On("Get", req, cookieName).Return(session, nil) + + suite.oidc.ServeHTTP(rw, req) + + suite.Equal(http.StatusBadRequest, rw.Code) + suite.Contains(rw.Body.String(), "Invalid state parameter") +} + +func (suite *TraefikOidcTestSuite) TestHandleCallback_TokenExchangeError() { + req := httptest.NewRequest("GET", "http://example.com"+suite.oidc.redirURLPath+"?code=invalid_code&state=test_state", nil) + rw := httptest.NewRecorder() + + session := sessions.NewSession(suite.mockStore, cookieName) + session.Values["csrf"] = "test_state" + + suite.mockStore.On("Get", req, cookieName).Return(session, nil) + suite.mockStore.On("Save", mock.Anything, mock.Anything, mock.Anything).Return(nil) + + suite.mockHTTPClient.On("RoundTrip", mock.Anything).Return(&http.Response{ + StatusCode: http.StatusBadRequest, + Body: io.NopCloser(strings.NewReader(`{"error": "invalid_grant"}`)), + }, nil) + + suite.oidc.ServeHTTP(rw, req) + + suite.Equal(http.StatusUnauthorized, rw.Code) + suite.Contains(rw.Body.String(), "Authentication failed") +} + +func (suite *TraefikOidcTestSuite) TestVerifyToken_RateLimitExceeded() { + suite.oidc.limiter = rate.NewLimiter(rate.Every(time.Hour), 1) // Set a very low limit + + // Use up the only allowed request + suite.oidc.limiter.Allow() + + err := suite.oidc.VerifyToken("some_token") + suite.Error(err) + suite.Contains(err.Error(), "rate limit exceeded") +} + +func (suite *TraefikOidcTestSuite) TestVerifyToken_BlacklistedToken() { + token := "blacklisted_token" + suite.oidc.tokenBlacklist.Add(token, time.Now().Add(time.Hour)) + + err := suite.oidc.VerifyToken(token) + suite.Error(err) + suite.Contains(err.Error(), "token is blacklisted") +} + +func (suite *TraefikOidcTestSuite) TestExtractClaims_InvalidToken() { + invalidToken := "invalid.token.format" + claims, err := extractClaims(invalidToken) + suite.Error(err) + suite.Nil(claims) +} + +func (suite *TraefikOidcTestSuite) TestDiscoverProviderMetadata_HTTPError() { + providerURL := "https://example.com" + httpClient := &http.Client{ + Transport: suite.mockHTTPClient, + } + + suite.mockHTTPClient.On("RoundTrip", mock.Anything).Return(&http.Response{ + StatusCode: http.StatusInternalServerError, + Body: io.NopCloser(strings.NewReader("Internal Server Error")), + }, nil) + + metadata, err := discoverProviderMetadata(providerURL, *httpClient) + suite.Error(err) + suite.Nil(metadata) + suite.Contains(err.Error(), "failed to fetch provider metadata: status code 500") +} + +func (suite *TraefikOidcTestSuite) TestRevokeToken_InvalidToken() { + invalidToken := "invalid.token" + suite.oidc.RevokeToken(invalidToken) + + // Check that the invalid token is not added to the blacklist + suite.False(suite.oidc.tokenBlacklist.IsBlacklisted(invalidToken)) +} + +func TestTraefikOidc_ServeHTTP(t *testing.T) { + type fields struct { + next http.Handler + name string + store sessions.Store + redirURLPath string + logoutURLPath string + issuerURL string + jwkCache *JWKCache + tokenBlacklist *TokenBlacklist + jwksURL string + clientID string + clientSecret string + authURL string + tokenURL string + scopes []string + limiter *rate.Limiter + forceHTTPS bool + scheme string + tokenCache *TokenCache + httpClient *http.Client + logger *Logger + redirectURL string + tokenVerifier TokenVerifier + jwtVerifier JWTVerifier + } + type args struct { + rw http.ResponseWriter + req *http.Request + } + tests := []struct { + name string + fields fields + args args + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tr := &TraefikOidc{ + next: tt.fields.next, + name: tt.fields.name, + store: tt.fields.store, + redirURLPath: tt.fields.redirURLPath, + logoutURLPath: tt.fields.logoutURLPath, + issuerURL: tt.fields.issuerURL, + jwkCache: tt.fields.jwkCache, + tokenBlacklist: tt.fields.tokenBlacklist, + jwksURL: tt.fields.jwksURL, + clientID: tt.fields.clientID, + clientSecret: tt.fields.clientSecret, + authURL: tt.fields.authURL, + tokenURL: tt.fields.tokenURL, + scopes: tt.fields.scopes, + limiter: tt.fields.limiter, + forceHTTPS: tt.fields.forceHTTPS, + scheme: tt.fields.scheme, + tokenCache: tt.fields.tokenCache, + httpClient: tt.fields.httpClient, + logger: tt.fields.logger, + redirectURL: tt.fields.redirectURL, + tokenVerifier: tt.fields.tokenVerifier, + jwtVerifier: tt.fields.jwtVerifier, + } + tr.ServeHTTP(tt.args.rw, tt.args.req) + }) + } +}