Files
traefikoidc/url_helpers_ultra_test.go
T
paiking1 cf6ed1da55 feat: feat: add extraAuthParams (extra authorization request parameters) (#139)
Adds optional extraAuthParams map[string]string config.

Extra params are appended to the authorization request but can never
override plugin-managed params (client_id, state, nonce, etc.).
2026-05-27 21:41:09 +01:00

608 lines
19 KiB
Go

package traefikoidc
import (
"crypto/tls"
"net/http/httptest"
"net/url"
"testing"
"github.com/lukaszraczylo/traefikoidc/internal/utils"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// Test TLS connection state for testing HTTPS detection
var testTLSState = tls.ConnectionState{
Version: tls.VersionTLS13,
HandshakeComplete: true,
ServerName: "example.com",
}
// createMinimalMiddleware creates a minimal TraefikOidc instance for testing URL helpers
func createMinimalMiddleware() *TraefikOidc {
logger := newNoOpLogger()
return &TraefikOidc{
logger: logger,
issuerURL: "https://provider.example.com",
clientID: "test-client",
clientSecret: "test-secret",
authURL: "https://provider.example.com/authorize",
tokenURL: "https://provider.example.com/token",
excludedURLs: make(map[string]struct{}),
scopes: []string{"openid", "profile", "email"},
enablePKCE: false,
}
}
// TestDetermineScheme tests scheme determination edge cases
func TestDetermineScheme(t *testing.T) {
t.Run("forceHTTPS=false: backward compatibility", func(t *testing.T) {
middleware := createMinimalMiddleware()
middleware.forceHTTPS = false
t.Run("defaults to http when no headers or TLS", func(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com/auth", nil)
scheme := utils.DetermineScheme(req, middleware.forceHTTPS)
assert.Equal(t, "http", scheme)
})
t.Run("uses X-Forwarded-Proto when present", func(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com/auth", nil)
req.Header.Set("X-Forwarded-Proto", "https")
scheme := utils.DetermineScheme(req, middleware.forceHTTPS)
assert.Equal(t, "https", scheme)
})
t.Run("X-Forwarded-Proto takes precedence over TLS", func(t *testing.T) {
req := httptest.NewRequest("GET", "https://example.com/auth", nil)
req.TLS = &testTLSState
req.Header.Set("X-Forwarded-Proto", "http")
scheme := utils.DetermineScheme(req, middleware.forceHTTPS)
assert.Equal(t, "http", scheme)
})
t.Run("uses TLS when present and no X-Forwarded-Proto", func(t *testing.T) {
req := httptest.NewRequest("GET", "https://example.com/auth", nil)
req.TLS = &testTLSState
scheme := utils.DetermineScheme(req, middleware.forceHTTPS)
assert.Equal(t, "https", scheme)
})
})
t.Run("forceHTTPS=true: overrides all detection", func(t *testing.T) {
middleware := createMinimalMiddleware()
middleware.forceHTTPS = true
t.Run("returns https with no headers or TLS", func(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com/auth", nil)
scheme := utils.DetermineScheme(req, middleware.forceHTTPS)
assert.Equal(t, "https", scheme, "forceHTTPS should override default http")
})
t.Run("returns https even with X-Forwarded-Proto: http", func(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com/auth", nil)
req.Header.Set("X-Forwarded-Proto", "http")
scheme := utils.DetermineScheme(req, middleware.forceHTTPS)
assert.Equal(t, "https", scheme, "forceHTTPS should override X-Forwarded-Proto")
})
t.Run("returns https with X-Forwarded-Proto: https", func(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com/auth", nil)
req.Header.Set("X-Forwarded-Proto", "https")
scheme := utils.DetermineScheme(req, middleware.forceHTTPS)
assert.Equal(t, "https", scheme)
})
t.Run("returns https with TLS connection", func(t *testing.T) {
req := httptest.NewRequest("GET", "https://example.com/auth", nil)
req.TLS = &testTLSState
scheme := utils.DetermineScheme(req, middleware.forceHTTPS)
assert.Equal(t, "https", scheme)
})
t.Run("returns https even when all indicators suggest http", func(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com/auth", nil)
req.Header.Set("X-Forwarded-Proto", "http")
req.TLS = nil
scheme := utils.DetermineScheme(req, middleware.forceHTTPS)
assert.Equal(t, "https", scheme, "forceHTTPS should be absolute override")
})
})
t.Run("AWS ALB scenario: TLS termination at load balancer", func(t *testing.T) {
middleware := createMinimalMiddleware()
middleware.forceHTTPS = true
t.Run("simulates ALB overwriting X-Forwarded-Proto to http", func(t *testing.T) {
// This simulates the issue from GitHub #82:
// - Client connects via HTTPS to ALB
// - ALB terminates TLS and forwards HTTP to Traefik
// - Traefik overwrites X-Forwarded-Proto based on its view (HTTP)
// - Plugin receives X-Forwarded-Proto: http (incorrect)
req := httptest.NewRequest("GET", "http://example.com/auth", nil)
req.Header.Set("X-Forwarded-Proto", "http") // Overwritten by Traefik
req.TLS = nil // No TLS at plugin level
scheme := utils.DetermineScheme(req, middleware.forceHTTPS)
assert.Equal(t, "https", scheme, "forceHTTPS should ensure HTTPS redirect_uri despite incorrect header")
})
t.Run("simulates missing X-Forwarded-Proto header", func(t *testing.T) {
// Some configurations may not set the header at all
req := httptest.NewRequest("GET", "http://example.com/auth", nil)
req.TLS = nil
scheme := utils.DetermineScheme(req, middleware.forceHTTPS)
assert.Equal(t, "https", scheme, "forceHTTPS should ensure HTTPS even without headers")
})
})
}
// TestBuildURLWithParamsErrorPaths tests error handling in buildURLWithParams
func TestBuildURLWithParamsErrorPaths(t *testing.T) {
middleware := createMinimalMiddleware()
t.Run("invalid issuer URL returns empty string", func(t *testing.T) {
middleware.issuerURL = "://invalid"
params := url.Values{}
params.Set("test", "value")
result := middleware.buildURLWithParams("/path", params)
assert.Empty(t, result)
})
t.Run("invalid relative URL returns empty string", func(t *testing.T) {
middleware.issuerURL = "https://provider.example.com"
params := url.Values{}
result := middleware.buildURLWithParams("://invalid-relative", params)
assert.Empty(t, result)
})
t.Run("invalid absolute URL returns empty string", func(t *testing.T) {
params := url.Values{}
result := middleware.buildURLWithParams("http://[invalid-url", params)
assert.Empty(t, result)
})
t.Run("dangerous host in absolute URL returns empty string", func(t *testing.T) {
params := url.Values{}
result := middleware.buildURLWithParams("https://localhost/callback", params)
assert.Empty(t, result)
})
t.Run("successful relative URL resolution", func(t *testing.T) {
middleware.issuerURL = "https://provider.example.com"
params := url.Values{}
params.Set("key", "value")
result := middleware.buildURLWithParams("/oauth/authorize", params)
assert.NotEmpty(t, result)
assert.Contains(t, result, "https://provider.example.com/oauth/authorize")
assert.Contains(t, result, "key=value")
})
t.Run("successful absolute URL", func(t *testing.T) {
params := url.Values{}
params.Set("client_id", "test")
result := middleware.buildURLWithParams("https://api.example.com/endpoint", params)
assert.NotEmpty(t, result)
assert.Contains(t, result, "https://api.example.com/endpoint")
assert.Contains(t, result, "client_id=test")
})
}
// TestValidateParsedURLCases tests URL validation edge cases
func TestValidateParsedURLCases(t *testing.T) {
middleware := createMinimalMiddleware()
t.Run("disallowed schemes rejected", func(t *testing.T) {
invalidSchemes := []string{
"ftp://example.com",
"file:///etc/passwd",
"javascript:alert(1)",
"data:text/html,test",
}
for _, urlStr := range invalidSchemes {
u, _ := url.Parse(urlStr)
err := middleware.validateParsedURL(u)
assert.Error(t, err, "should reject scheme: %s", urlStr)
assert.Contains(t, err.Error(), "disallowed URL scheme")
}
})
t.Run("http scheme allowed with warning", func(t *testing.T) {
u, _ := url.Parse("http://example.com/path")
err := middleware.validateParsedURL(u)
assert.NoError(t, err)
})
t.Run("missing host rejected", func(t *testing.T) {
u := &url.URL{
Scheme: "https",
Host: "",
Path: "/path",
}
err := middleware.validateParsedURL(u)
assert.Error(t, err)
assert.Contains(t, err.Error(), "missing host")
})
t.Run("path traversal rejected", func(t *testing.T) {
u, _ := url.Parse("https://example.com/../../etc/passwd")
err := middleware.validateParsedURL(u)
assert.Error(t, err)
assert.Contains(t, err.Error(), "path traversal")
})
t.Run("valid URLs accepted", func(t *testing.T) {
validURLs := []string{
"https://example.com",
"https://example.com/path",
"https://sub.example.com:8080/path?query=value",
}
for _, urlStr := range validURLs {
u, _ := url.Parse(urlStr)
err := middleware.validateParsedURL(u)
assert.NoError(t, err, "should accept: %s", urlStr)
}
})
}
// TestValidateHostComprehensive tests comprehensive host validation
func TestValidateHostComprehensive(t *testing.T) {
middleware := createMinimalMiddleware()
t.Run("loopback IPs rejected", func(t *testing.T) {
loopbacks := []string{
"127.0.0.1",
"127.255.255.255",
"::1",
}
for _, ip := range loopbacks {
err := middleware.validateHost(ip)
assert.Error(t, err, "should reject loopback: %s", ip)
}
})
t.Run("private IPs rejected", func(t *testing.T) {
privateIPs := []string{
"10.0.0.1",
"172.16.0.1",
"192.168.1.1",
"fd00::1",
}
for _, ip := range privateIPs {
err := middleware.validateHost(ip)
assert.Error(t, err, "should reject private IP: %s", ip)
}
})
t.Run("link-local IPs rejected", func(t *testing.T) {
linkLocal := []string{
"169.254.1.1",
"fe80::1",
}
for _, ip := range linkLocal {
err := middleware.validateHost(ip)
assert.Error(t, err, "should reject link-local: %s", ip)
}
})
t.Run("unspecified and multicast rejected", func(t *testing.T) {
special := []string{
"0.0.0.0",
"::",
"224.0.0.1",
"ff02::1",
}
for _, ip := range special {
err := middleware.validateHost(ip)
assert.Error(t, err, "should reject special IP: %s", ip)
}
})
t.Run("dangerous hostnames rejected", func(t *testing.T) {
dangerous := []string{
"localhost",
"LOCALHOST",
"169.254.169.254",
"metadata.google.internal",
}
for _, host := range dangerous {
err := middleware.validateHost(host)
assert.Error(t, err, "should reject: %s", host)
}
})
t.Run("invalid host format rejected", func(t *testing.T) {
invalid := []string{
"[::1:invalid",
}
for _, host := range invalid {
err := middleware.validateHost(host)
assert.Error(t, err, "should reject invalid format: %s", host)
}
})
t.Run("hosts with ports", func(t *testing.T) {
err := middleware.validateHost("localhost:8080")
assert.Error(t, err)
err = middleware.validateHost("192.168.1.1:443")
assert.Error(t, err)
err = middleware.validateHost("example.com:443")
assert.NoError(t, err)
})
t.Run("valid public IPs accepted", func(t *testing.T) {
publicIPs := []string{
"8.8.8.8",
"1.1.1.1",
"93.184.216.34",
}
for _, ip := range publicIPs {
err := middleware.validateHost(ip)
assert.NoError(t, err, "should accept public IP: %s", ip)
}
})
t.Run("valid hostnames accepted", func(t *testing.T) {
validHosts := []string{
"example.com",
"sub.example.com",
"api.service.example.com:443",
}
for _, host := range validHosts {
err := middleware.validateHost(host)
assert.NoError(t, err, "should accept: %s", host)
}
})
}
// TestValidateURLEdgeCasesComprehensive tests the validateURL wrapper
func TestValidateURLEdgeCasesComprehensive(t *testing.T) {
middleware := createMinimalMiddleware()
t.Run("empty URL rejected", func(t *testing.T) {
err := middleware.validateURL("")
assert.Error(t, err)
assert.Contains(t, err.Error(), "empty URL")
})
t.Run("invalid URL format rejected", func(t *testing.T) {
err := middleware.validateURL("ht tp://invalid url")
assert.Error(t, err)
assert.Contains(t, err.Error(), "invalid URL format")
})
t.Run("valid URLs accepted", func(t *testing.T) {
validURLs := []string{
"https://example.com/path",
"https://example.com/path?key=value",
}
for _, urlStr := range validURLs {
err := middleware.validateURL(urlStr)
assert.NoError(t, err, "should accept: %s", urlStr)
}
})
t.Run("URL with dangerous host rejected", func(t *testing.T) {
err := middleware.validateURL("https://localhost/path")
assert.Error(t, err)
require.Contains(t, err.Error(), "invalid host")
})
}
// TestBuildAuthURLAudienceParameter tests audience parameter handling
func TestBuildAuthURLAudienceParameter(t *testing.T) {
t.Run("audience added when different from client_id", func(t *testing.T) {
middleware := createMinimalMiddleware()
middleware.audience = "https://api.example.com"
authURL := middleware.buildAuthURL(
"https://app.com/callback",
"state123",
"nonce456",
"",
)
assert.Contains(t, authURL, "audience=")
})
t.Run("audience not added when empty", func(t *testing.T) {
middleware := createMinimalMiddleware()
middleware.audience = ""
authURL := middleware.buildAuthURL(
"https://app.com/callback",
"state123",
"nonce456",
"",
)
assert.NotContains(t, authURL, "audience=")
})
t.Run("audience not added when equal to client_id", func(t *testing.T) {
middleware := createMinimalMiddleware()
middleware.audience = middleware.clientID
authURL := middleware.buildAuthURL(
"https://app.com/callback",
"state123",
"nonce456",
"",
)
assert.NotContains(t, authURL, "audience=")
})
}
// TestBuildAuthURLPKCEParameters tests PKCE parameter handling
func TestBuildAuthURLPKCEParameters(t *testing.T) {
t.Run("PKCE parameters added when enabled with challenge", func(t *testing.T) {
middleware := createMinimalMiddleware()
middleware.enablePKCE = true
authURL := middleware.buildAuthURL(
"https://app.com/callback",
"state123",
"nonce456",
"challenge789",
)
assert.Contains(t, authURL, "code_challenge=challenge789")
assert.Contains(t, authURL, "code_challenge_method=S256")
})
t.Run("PKCE parameters not added when challenge empty", func(t *testing.T) {
middleware := createMinimalMiddleware()
middleware.enablePKCE = true
authURL := middleware.buildAuthURL(
"https://app.com/callback",
"state123",
"nonce456",
"", // Empty challenge
)
assert.NotContains(t, authURL, "code_challenge=")
})
t.Run("PKCE parameters not added when disabled", func(t *testing.T) {
middleware := createMinimalMiddleware()
middleware.enablePKCE = false
authURL := middleware.buildAuthURL(
"https://app.com/callback",
"state123",
"nonce456",
"challenge789",
)
assert.NotContains(t, authURL, "code_challenge=")
})
}
// TestForceHTTPSIntegration tests the complete flow of building redirect URIs with forceHTTPS
func TestForceHTTPSIntegration(t *testing.T) {
t.Run("redirect_uri uses https when forceHTTPS=true", func(t *testing.T) {
middleware := createMinimalMiddleware()
middleware.forceHTTPS = true
// Simulate AWS ALB scenario: HTTP request with incorrect X-Forwarded-Proto
req := httptest.NewRequest("GET", "http://service.example.com/protected", nil)
req.Header.Set("X-Forwarded-Proto", "http") // Traefik overwrote it
req.Host = "service.example.com"
req.TLS = nil
// Build the full redirect URL as middleware does
scheme := utils.DetermineScheme(req, middleware.forceHTTPS)
host := utils.DetermineHost(req)
redirectURL := buildFullURL(scheme, host, "/oauth2/callback")
assert.Equal(t, "https", scheme, "scheme should be https due to forceHTTPS")
assert.Equal(t, "https://service.example.com/oauth2/callback", redirectURL,
"redirect_uri should use https scheme")
})
t.Run("buildAuthURL contains https redirect_uri with forceHTTPS", func(t *testing.T) {
middleware := createMinimalMiddleware()
middleware.forceHTTPS = true
// Simulate building auth URL with HTTP redirect_uri
req := httptest.NewRequest("GET", "http://service.example.com/protected", nil)
req.Header.Set("X-Forwarded-Proto", "http")
req.Host = "service.example.com"
req.TLS = nil
scheme := utils.DetermineScheme(req, middleware.forceHTTPS)
host := utils.DetermineHost(req)
redirectURL := buildFullURL(scheme, host, "/oauth2/callback")
authURL := middleware.buildAuthURL(redirectURL, "state123", "nonce456", "")
assert.Contains(t, authURL, "redirect_uri=https%3A%2F%2Fservice.example.com%2Foauth2%2Fcallback",
"auth URL should contain HTTPS redirect_uri")
assert.NotContains(t, authURL, "redirect_uri=http%3A",
"auth URL should not contain HTTP redirect_uri")
})
t.Run("without forceHTTPS respects X-Forwarded-Proto", func(t *testing.T) {
middleware := createMinimalMiddleware()
middleware.forceHTTPS = false
req := httptest.NewRequest("GET", "http://service.example.com/protected", nil)
req.Header.Set("X-Forwarded-Proto", "https")
req.Host = "service.example.com"
scheme := utils.DetermineScheme(req, middleware.forceHTTPS)
host := utils.DetermineHost(req)
redirectURL := buildFullURL(scheme, host, "/oauth2/callback")
assert.Equal(t, "https://service.example.com/oauth2/callback", redirectURL,
"should use https from X-Forwarded-Proto when forceHTTPS is false")
})
}
// TestBuildAuthURLExtraAuthParams verifies operator-configured extra
// authorization parameters are appended to the authorization URL, and that
// they can never override parameters the plugin itself manages.
func TestBuildAuthURLExtraAuthParams(t *testing.T) {
t.Run("extra params are added (e.g. screen_hint=signup)", func(t *testing.T) {
middleware := createMinimalMiddleware()
middleware.extraAuthParams = map[string]string{
"screen_hint": "signup",
"ui_locales": "en",
}
authURL := middleware.buildAuthURL(
"https://app.com/callback", "state123", "nonce456", "",
)
assert.Contains(t, authURL, "screen_hint=signup")
assert.Contains(t, authURL, "ui_locales=en")
})
t.Run("nil/empty extraAuthParams is a no-op", func(t *testing.T) {
middleware := createMinimalMiddleware()
// extraAuthParams left nil
authURL := middleware.buildAuthURL(
"https://app.com/callback", "state123", "nonce456", "",
)
assert.Contains(t, authURL, "client_id=test-client")
assert.NotContains(t, authURL, "screen_hint")
})
t.Run("extra params CANNOT override plugin-managed params", func(t *testing.T) {
middleware := createMinimalMiddleware()
middleware.extraAuthParams = map[string]string{
"client_id": "ATTACKER",
"state": "ATTACKER",
"redirect_uri": "https://evil.example.com",
"response_type": "token",
}
authURL := middleware.buildAuthURL(
"https://app.com/callback", "state123", "nonce456", "",
)
// Plugin-managed values must win; injected values must be absent.
assert.Contains(t, authURL, "client_id=test-client")
assert.NotContains(t, authURL, "ATTACKER")
assert.NotContains(t, authURL, "evil.example.com")
assert.Contains(t, authURL, "response_type=code")
})
}