package traefikoidc import ( "crypto/tls" "net/http/httptest" "net/url" "testing" "github.com/lukaszraczylo/traefikoidc/internal/utils" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // Test TLS connection state for testing HTTPS detection var testTLSState = tls.ConnectionState{ Version: tls.VersionTLS13, HandshakeComplete: true, ServerName: "example.com", } // createMinimalMiddleware creates a minimal TraefikOidc instance for testing URL helpers func createMinimalMiddleware() *TraefikOidc { logger := newNoOpLogger() return &TraefikOidc{ logger: logger, issuerURL: "https://provider.example.com", clientID: "test-client", clientSecret: "test-secret", authURL: "https://provider.example.com/authorize", tokenURL: "https://provider.example.com/token", excludedURLs: make(map[string]struct{}), scopes: []string{"openid", "profile", "email"}, enablePKCE: false, } } // TestDetermineScheme tests scheme determination edge cases func TestDetermineScheme(t *testing.T) { t.Run("forceHTTPS=false: backward compatibility", func(t *testing.T) { middleware := createMinimalMiddleware() middleware.forceHTTPS = false t.Run("defaults to http when no headers or TLS", func(t *testing.T) { req := httptest.NewRequest("GET", "http://example.com/auth", nil) scheme := utils.DetermineScheme(req, middleware.forceHTTPS) assert.Equal(t, "http", scheme) }) t.Run("uses X-Forwarded-Proto when present", func(t *testing.T) { req := httptest.NewRequest("GET", "http://example.com/auth", nil) req.Header.Set("X-Forwarded-Proto", "https") scheme := utils.DetermineScheme(req, middleware.forceHTTPS) assert.Equal(t, "https", scheme) }) t.Run("X-Forwarded-Proto takes precedence over TLS", func(t *testing.T) { req := httptest.NewRequest("GET", "https://example.com/auth", nil) req.TLS = &testTLSState req.Header.Set("X-Forwarded-Proto", "http") scheme := utils.DetermineScheme(req, middleware.forceHTTPS) assert.Equal(t, "http", scheme) }) t.Run("uses TLS when present and no X-Forwarded-Proto", func(t *testing.T) { req := httptest.NewRequest("GET", "https://example.com/auth", nil) req.TLS = &testTLSState scheme := utils.DetermineScheme(req, middleware.forceHTTPS) assert.Equal(t, "https", scheme) }) }) t.Run("forceHTTPS=true: overrides all detection", func(t *testing.T) { middleware := createMinimalMiddleware() middleware.forceHTTPS = true t.Run("returns https with no headers or TLS", func(t *testing.T) { req := httptest.NewRequest("GET", "http://example.com/auth", nil) scheme := utils.DetermineScheme(req, middleware.forceHTTPS) assert.Equal(t, "https", scheme, "forceHTTPS should override default http") }) t.Run("returns https even with X-Forwarded-Proto: http", func(t *testing.T) { req := httptest.NewRequest("GET", "http://example.com/auth", nil) req.Header.Set("X-Forwarded-Proto", "http") scheme := utils.DetermineScheme(req, middleware.forceHTTPS) assert.Equal(t, "https", scheme, "forceHTTPS should override X-Forwarded-Proto") }) t.Run("returns https with X-Forwarded-Proto: https", func(t *testing.T) { req := httptest.NewRequest("GET", "http://example.com/auth", nil) req.Header.Set("X-Forwarded-Proto", "https") scheme := utils.DetermineScheme(req, middleware.forceHTTPS) assert.Equal(t, "https", scheme) }) t.Run("returns https with TLS connection", func(t *testing.T) { req := httptest.NewRequest("GET", "https://example.com/auth", nil) req.TLS = &testTLSState scheme := utils.DetermineScheme(req, middleware.forceHTTPS) assert.Equal(t, "https", scheme) }) t.Run("returns https even when all indicators suggest http", func(t *testing.T) { req := httptest.NewRequest("GET", "http://example.com/auth", nil) req.Header.Set("X-Forwarded-Proto", "http") req.TLS = nil scheme := utils.DetermineScheme(req, middleware.forceHTTPS) assert.Equal(t, "https", scheme, "forceHTTPS should be absolute override") }) }) t.Run("AWS ALB scenario: TLS termination at load balancer", func(t *testing.T) { middleware := createMinimalMiddleware() middleware.forceHTTPS = true t.Run("simulates ALB overwriting X-Forwarded-Proto to http", func(t *testing.T) { // This simulates the issue from GitHub #82: // - Client connects via HTTPS to ALB // - ALB terminates TLS and forwards HTTP to Traefik // - Traefik overwrites X-Forwarded-Proto based on its view (HTTP) // - Plugin receives X-Forwarded-Proto: http (incorrect) req := httptest.NewRequest("GET", "http://example.com/auth", nil) req.Header.Set("X-Forwarded-Proto", "http") // Overwritten by Traefik req.TLS = nil // No TLS at plugin level scheme := utils.DetermineScheme(req, middleware.forceHTTPS) assert.Equal(t, "https", scheme, "forceHTTPS should ensure HTTPS redirect_uri despite incorrect header") }) t.Run("simulates missing X-Forwarded-Proto header", func(t *testing.T) { // Some configurations may not set the header at all req := httptest.NewRequest("GET", "http://example.com/auth", nil) req.TLS = nil scheme := utils.DetermineScheme(req, middleware.forceHTTPS) assert.Equal(t, "https", scheme, "forceHTTPS should ensure HTTPS even without headers") }) }) } // TestBuildURLWithParamsErrorPaths tests error handling in buildURLWithParams func TestBuildURLWithParamsErrorPaths(t *testing.T) { middleware := createMinimalMiddleware() t.Run("invalid issuer URL returns empty string", func(t *testing.T) { middleware.issuerURL = "://invalid" params := url.Values{} params.Set("test", "value") result := middleware.buildURLWithParams("/path", params) assert.Empty(t, result) }) t.Run("invalid relative URL returns empty string", func(t *testing.T) { middleware.issuerURL = "https://provider.example.com" params := url.Values{} result := middleware.buildURLWithParams("://invalid-relative", params) assert.Empty(t, result) }) t.Run("invalid absolute URL returns empty string", func(t *testing.T) { params := url.Values{} result := middleware.buildURLWithParams("http://[invalid-url", params) assert.Empty(t, result) }) t.Run("dangerous host in absolute URL returns empty string", func(t *testing.T) { params := url.Values{} result := middleware.buildURLWithParams("https://localhost/callback", params) assert.Empty(t, result) }) t.Run("successful relative URL resolution", func(t *testing.T) { middleware.issuerURL = "https://provider.example.com" params := url.Values{} params.Set("key", "value") result := middleware.buildURLWithParams("/oauth/authorize", params) assert.NotEmpty(t, result) assert.Contains(t, result, "https://provider.example.com/oauth/authorize") assert.Contains(t, result, "key=value") }) t.Run("successful absolute URL", func(t *testing.T) { params := url.Values{} params.Set("client_id", "test") result := middleware.buildURLWithParams("https://api.example.com/endpoint", params) assert.NotEmpty(t, result) assert.Contains(t, result, "https://api.example.com/endpoint") assert.Contains(t, result, "client_id=test") }) } // TestValidateParsedURLCases tests URL validation edge cases func TestValidateParsedURLCases(t *testing.T) { middleware := createMinimalMiddleware() t.Run("disallowed schemes rejected", func(t *testing.T) { invalidSchemes := []string{ "ftp://example.com", "file:///etc/passwd", "javascript:alert(1)", "data:text/html,test", } for _, urlStr := range invalidSchemes { u, _ := url.Parse(urlStr) err := middleware.validateParsedURL(u) assert.Error(t, err, "should reject scheme: %s", urlStr) assert.Contains(t, err.Error(), "disallowed URL scheme") } }) t.Run("http scheme allowed with warning", func(t *testing.T) { u, _ := url.Parse("http://example.com/path") err := middleware.validateParsedURL(u) assert.NoError(t, err) }) t.Run("missing host rejected", func(t *testing.T) { u := &url.URL{ Scheme: "https", Host: "", Path: "/path", } err := middleware.validateParsedURL(u) assert.Error(t, err) assert.Contains(t, err.Error(), "missing host") }) t.Run("path traversal rejected", func(t *testing.T) { u, _ := url.Parse("https://example.com/../../etc/passwd") err := middleware.validateParsedURL(u) assert.Error(t, err) assert.Contains(t, err.Error(), "path traversal") }) t.Run("valid URLs accepted", func(t *testing.T) { validURLs := []string{ "https://example.com", "https://example.com/path", "https://sub.example.com:8080/path?query=value", } for _, urlStr := range validURLs { u, _ := url.Parse(urlStr) err := middleware.validateParsedURL(u) assert.NoError(t, err, "should accept: %s", urlStr) } }) } // TestValidateHostComprehensive tests comprehensive host validation func TestValidateHostComprehensive(t *testing.T) { middleware := createMinimalMiddleware() t.Run("loopback IPs rejected", func(t *testing.T) { loopbacks := []string{ "127.0.0.1", "127.255.255.255", "::1", } for _, ip := range loopbacks { err := middleware.validateHost(ip) assert.Error(t, err, "should reject loopback: %s", ip) } }) t.Run("private IPs rejected", func(t *testing.T) { privateIPs := []string{ "10.0.0.1", "172.16.0.1", "192.168.1.1", "fd00::1", } for _, ip := range privateIPs { err := middleware.validateHost(ip) assert.Error(t, err, "should reject private IP: %s", ip) } }) t.Run("link-local IPs rejected", func(t *testing.T) { linkLocal := []string{ "169.254.1.1", "fe80::1", } for _, ip := range linkLocal { err := middleware.validateHost(ip) assert.Error(t, err, "should reject link-local: %s", ip) } }) t.Run("unspecified and multicast rejected", func(t *testing.T) { special := []string{ "0.0.0.0", "::", "224.0.0.1", "ff02::1", } for _, ip := range special { err := middleware.validateHost(ip) assert.Error(t, err, "should reject special IP: %s", ip) } }) t.Run("dangerous hostnames rejected", func(t *testing.T) { dangerous := []string{ "localhost", "LOCALHOST", "169.254.169.254", "metadata.google.internal", } for _, host := range dangerous { err := middleware.validateHost(host) assert.Error(t, err, "should reject: %s", host) } }) t.Run("invalid host format rejected", func(t *testing.T) { invalid := []string{ "[::1:invalid", } for _, host := range invalid { err := middleware.validateHost(host) assert.Error(t, err, "should reject invalid format: %s", host) } }) t.Run("hosts with ports", func(t *testing.T) { err := middleware.validateHost("localhost:8080") assert.Error(t, err) err = middleware.validateHost("192.168.1.1:443") assert.Error(t, err) err = middleware.validateHost("example.com:443") assert.NoError(t, err) }) t.Run("valid public IPs accepted", func(t *testing.T) { publicIPs := []string{ "8.8.8.8", "1.1.1.1", "93.184.216.34", } for _, ip := range publicIPs { err := middleware.validateHost(ip) assert.NoError(t, err, "should accept public IP: %s", ip) } }) t.Run("valid hostnames accepted", func(t *testing.T) { validHosts := []string{ "example.com", "sub.example.com", "api.service.example.com:443", } for _, host := range validHosts { err := middleware.validateHost(host) assert.NoError(t, err, "should accept: %s", host) } }) } // TestValidateURLEdgeCasesComprehensive tests the validateURL wrapper func TestValidateURLEdgeCasesComprehensive(t *testing.T) { middleware := createMinimalMiddleware() t.Run("empty URL rejected", func(t *testing.T) { err := middleware.validateURL("") assert.Error(t, err) assert.Contains(t, err.Error(), "empty URL") }) t.Run("invalid URL format rejected", func(t *testing.T) { err := middleware.validateURL("ht tp://invalid url") assert.Error(t, err) assert.Contains(t, err.Error(), "invalid URL format") }) t.Run("valid URLs accepted", func(t *testing.T) { validURLs := []string{ "https://example.com/path", "https://example.com/path?key=value", } for _, urlStr := range validURLs { err := middleware.validateURL(urlStr) assert.NoError(t, err, "should accept: %s", urlStr) } }) t.Run("URL with dangerous host rejected", func(t *testing.T) { err := middleware.validateURL("https://localhost/path") assert.Error(t, err) require.Contains(t, err.Error(), "invalid host") }) } // TestBuildAuthURLAudienceParameter tests audience parameter handling func TestBuildAuthURLAudienceParameter(t *testing.T) { t.Run("audience added when different from client_id", func(t *testing.T) { middleware := createMinimalMiddleware() middleware.audience = "https://api.example.com" authURL := middleware.buildAuthURL( "https://app.com/callback", "state123", "nonce456", "", ) assert.Contains(t, authURL, "audience=") }) t.Run("audience not added when empty", func(t *testing.T) { middleware := createMinimalMiddleware() middleware.audience = "" authURL := middleware.buildAuthURL( "https://app.com/callback", "state123", "nonce456", "", ) assert.NotContains(t, authURL, "audience=") }) t.Run("audience not added when equal to client_id", func(t *testing.T) { middleware := createMinimalMiddleware() middleware.audience = middleware.clientID authURL := middleware.buildAuthURL( "https://app.com/callback", "state123", "nonce456", "", ) assert.NotContains(t, authURL, "audience=") }) } // TestBuildAuthURLPKCEParameters tests PKCE parameter handling func TestBuildAuthURLPKCEParameters(t *testing.T) { t.Run("PKCE parameters added when enabled with challenge", func(t *testing.T) { middleware := createMinimalMiddleware() middleware.enablePKCE = true authURL := middleware.buildAuthURL( "https://app.com/callback", "state123", "nonce456", "challenge789", ) assert.Contains(t, authURL, "code_challenge=challenge789") assert.Contains(t, authURL, "code_challenge_method=S256") }) t.Run("PKCE parameters not added when challenge empty", func(t *testing.T) { middleware := createMinimalMiddleware() middleware.enablePKCE = true authURL := middleware.buildAuthURL( "https://app.com/callback", "state123", "nonce456", "", // Empty challenge ) assert.NotContains(t, authURL, "code_challenge=") }) t.Run("PKCE parameters not added when disabled", func(t *testing.T) { middleware := createMinimalMiddleware() middleware.enablePKCE = false authURL := middleware.buildAuthURL( "https://app.com/callback", "state123", "nonce456", "challenge789", ) assert.NotContains(t, authURL, "code_challenge=") }) } // TestForceHTTPSIntegration tests the complete flow of building redirect URIs with forceHTTPS func TestForceHTTPSIntegration(t *testing.T) { t.Run("redirect_uri uses https when forceHTTPS=true", func(t *testing.T) { middleware := createMinimalMiddleware() middleware.forceHTTPS = true // Simulate AWS ALB scenario: HTTP request with incorrect X-Forwarded-Proto req := httptest.NewRequest("GET", "http://service.example.com/protected", nil) req.Header.Set("X-Forwarded-Proto", "http") // Traefik overwrote it req.Host = "service.example.com" req.TLS = nil // Build the full redirect URL as middleware does scheme := utils.DetermineScheme(req, middleware.forceHTTPS) host := utils.DetermineHost(req) redirectURL := buildFullURL(scheme, host, "/oauth2/callback") assert.Equal(t, "https", scheme, "scheme should be https due to forceHTTPS") assert.Equal(t, "https://service.example.com/oauth2/callback", redirectURL, "redirect_uri should use https scheme") }) t.Run("buildAuthURL contains https redirect_uri with forceHTTPS", func(t *testing.T) { middleware := createMinimalMiddleware() middleware.forceHTTPS = true // Simulate building auth URL with HTTP redirect_uri req := httptest.NewRequest("GET", "http://service.example.com/protected", nil) req.Header.Set("X-Forwarded-Proto", "http") req.Host = "service.example.com" req.TLS = nil scheme := utils.DetermineScheme(req, middleware.forceHTTPS) host := utils.DetermineHost(req) redirectURL := buildFullURL(scheme, host, "/oauth2/callback") authURL := middleware.buildAuthURL(redirectURL, "state123", "nonce456", "") assert.Contains(t, authURL, "redirect_uri=https%3A%2F%2Fservice.example.com%2Foauth2%2Fcallback", "auth URL should contain HTTPS redirect_uri") assert.NotContains(t, authURL, "redirect_uri=http%3A", "auth URL should not contain HTTP redirect_uri") }) t.Run("without forceHTTPS respects X-Forwarded-Proto", func(t *testing.T) { middleware := createMinimalMiddleware() middleware.forceHTTPS = false req := httptest.NewRequest("GET", "http://service.example.com/protected", nil) req.Header.Set("X-Forwarded-Proto", "https") req.Host = "service.example.com" scheme := utils.DetermineScheme(req, middleware.forceHTTPS) host := utils.DetermineHost(req) redirectURL := buildFullURL(scheme, host, "/oauth2/callback") assert.Equal(t, "https://service.example.com/oauth2/callback", redirectURL, "should use https from X-Forwarded-Proto when forceHTTPS is false") }) } // TestBuildAuthURLExtraAuthParams verifies operator-configured extra // authorization parameters are appended to the authorization URL, and that // they can never override parameters the plugin itself manages. func TestBuildAuthURLExtraAuthParams(t *testing.T) { t.Run("extra params are added (e.g. screen_hint=signup)", func(t *testing.T) { middleware := createMinimalMiddleware() middleware.extraAuthParams = map[string]string{ "screen_hint": "signup", "ui_locales": "en", } authURL := middleware.buildAuthURL( "https://app.com/callback", "state123", "nonce456", "", ) assert.Contains(t, authURL, "screen_hint=signup") assert.Contains(t, authURL, "ui_locales=en") }) t.Run("nil/empty extraAuthParams is a no-op", func(t *testing.T) { middleware := createMinimalMiddleware() // extraAuthParams left nil authURL := middleware.buildAuthURL( "https://app.com/callback", "state123", "nonce456", "", ) assert.Contains(t, authURL, "client_id=test-client") assert.NotContains(t, authURL, "screen_hint") }) t.Run("extra params CANNOT override plugin-managed params", func(t *testing.T) { middleware := createMinimalMiddleware() middleware.extraAuthParams = map[string]string{ "client_id": "ATTACKER", "state": "ATTACKER", "redirect_uri": "https://evil.example.com", "response_type": "token", } authURL := middleware.buildAuthURL( "https://app.com/callback", "state123", "nonce456", "", ) // Plugin-managed values must win; injected values must be absent. assert.Contains(t, authURL, "client_id=test-client") assert.NotContains(t, authURL, "ATTACKER") assert.NotContains(t, authURL, "evil.example.com") assert.Contains(t, authURL, "response_type=code") }) }