Files
traefikoidc/internal/middleware/request_handler_test.go
T
lukaszraczylo c3f23cb99b Release 0.7.5 (#70)
* Resolve issue with opaque tokens not being parsed correctly

* Increase test coverage

* Further improvements to test coverage and code quality

* Add new providers.

* fixup! Add new providers.

* Cleanup.

* fixup! Cleanup.

* fixup! fixup! Cleanup.

* fixup! fixup! fixup! Cleanup.

* fixup! fixup! fixup! fixup! Cleanup.

* Memory management optimisation

24 bytes per Put < 256-4096 bytes per buffer allocation avoided (10-170x difference)

* Pooling cleanup.
2025-10-01 12:13:10 +01:00

656 lines
16 KiB
Go

package middleware
import (
"context"
"crypto/tls"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
)
// MockLogger implements the Logger interface for testing
type MockLogger struct {
DebugCalls []string
DebugfCalls []string
ErrorCalls []string
ErrorfCalls []string
InfoCalls []string
InfofCalls []string
}
func (m *MockLogger) Debug(msg string) {
m.DebugCalls = append(m.DebugCalls, msg)
}
func (m *MockLogger) Debugf(format string, args ...interface{}) {
m.DebugfCalls = append(m.DebugfCalls, format)
}
func (m *MockLogger) Error(msg string) {
m.ErrorCalls = append(m.ErrorCalls, msg)
}
func (m *MockLogger) Errorf(format string, args ...interface{}) {
m.ErrorfCalls = append(m.ErrorfCalls, format)
}
func (m *MockLogger) Info(msg string) {
m.InfoCalls = append(m.InfoCalls, msg)
}
func (m *MockLogger) Infof(format string, args ...interface{}) {
m.InfofCalls = append(m.InfofCalls, format)
}
// TestNewRequestProcessor tests the constructor
func TestNewRequestProcessor(t *testing.T) {
logger := &MockLogger{}
processor := NewRequestProcessor(logger)
if processor == nil {
t.Error("Expected NewRequestProcessor to return non-nil processor")
return
}
if processor.logger != logger {
t.Error("Expected processor to use provided logger")
}
}
// TestBuildRequestContext tests request context building
func TestBuildRequestContext(t *testing.T) {
logger := &MockLogger{}
processor := NewRequestProcessor(logger)
tests := []struct {
name string
setupRequest func() (*http.Request, http.ResponseWriter)
redirectPath string
expectedURL string
expectedHost string
}{
{
name: "Basic HTTP request",
setupRequest: func() (*http.Request, http.ResponseWriter) {
req := httptest.NewRequest("GET", "http://example.com/test", nil)
rw := httptest.NewRecorder()
return req, rw
},
redirectPath: "/callback",
expectedURL: "http://example.com/callback",
expectedHost: "example.com",
},
{
name: "HTTPS request with TLS",
setupRequest: func() (*http.Request, http.ResponseWriter) {
req := httptest.NewRequest("GET", "https://secure.com/test", nil)
req.TLS = &tls.ConnectionState{} // Simulate TLS
rw := httptest.NewRecorder()
return req, rw
},
redirectPath: "/auth",
expectedURL: "https://secure.com/auth",
expectedHost: "secure.com",
},
{
name: "Request with X-Forwarded-Proto header",
setupRequest: func() (*http.Request, http.ResponseWriter) {
req := httptest.NewRequest("GET", "http://internal.com/test", nil)
req.Header.Set("X-Forwarded-Proto", "https")
rw := httptest.NewRecorder()
return req, rw
},
redirectPath: "/callback",
expectedURL: "https://internal.com/callback",
expectedHost: "internal.com",
},
{
name: "Request with X-Forwarded-Host header",
setupRequest: func() (*http.Request, http.ResponseWriter) {
req := httptest.NewRequest("GET", "http://internal.com/test", nil)
req.Header.Set("X-Forwarded-Host", "public.com")
rw := httptest.NewRecorder()
return req, rw
},
redirectPath: "/callback",
expectedURL: "http://public.com/callback",
expectedHost: "public.com",
},
{
name: "Request with both forwarded headers",
setupRequest: func() (*http.Request, http.ResponseWriter) {
req := httptest.NewRequest("GET", "http://internal.com/test", nil)
req.Header.Set("X-Forwarded-Proto", "https")
req.Header.Set("X-Forwarded-Host", "public.com")
rw := httptest.NewRecorder()
return req, rw
},
redirectPath: "/auth",
expectedURL: "https://public.com/auth",
expectedHost: "public.com",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req, rw := tt.setupRequest()
ctx := processor.BuildRequestContext(rw, req, tt.redirectPath)
if ctx == nil {
t.Error("Expected BuildRequestContext to return non-nil context")
return
}
if ctx.Writer != rw {
t.Error("Expected context writer to match provided writer")
}
if ctx.Request != req {
t.Error("Expected context request to match provided request")
}
if ctx.RedirectURL != tt.expectedURL {
t.Errorf("Expected redirect URL '%s', got '%s'", tt.expectedURL, ctx.RedirectURL)
}
if ctx.Host != tt.expectedHost {
t.Errorf("Expected host '%s', got '%s'", tt.expectedHost, ctx.Host)
}
})
}
}
// TestIsHealthCheckRequest tests health check detection
func TestIsHealthCheckRequest(t *testing.T) {
logger := &MockLogger{}
processor := NewRequestProcessor(logger)
tests := []struct {
name string
path string
expected bool
}{
{
name: "Health check path",
path: "/health",
expected: true,
},
{
name: "Health check subpath",
path: "/health/status",
expected: true,
},
{
name: "Health check with query params",
path: "/health?check=db",
expected: true,
},
{
name: "Not a health check",
path: "/api/users",
expected: false,
},
{
name: "Health-related path (matches prefix)",
path: "/healthiness",
expected: true, // HasPrefix behavior - this actually matches
},
{
name: "Root path",
path: "/",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com"+tt.path, nil)
result := processor.IsHealthCheckRequest(req)
if result != tt.expected {
t.Errorf("Expected IsHealthCheckRequest to return %v for path '%s', got %v", tt.expected, tt.path, result)
}
})
}
}
// TestIsEventStreamRequest tests event stream detection
func TestIsEventStreamRequest(t *testing.T) {
logger := &MockLogger{}
processor := NewRequestProcessor(logger)
tests := []struct {
name string
acceptHeader string
expected bool
}{
{
name: "Event stream accept header",
acceptHeader: "text/event-stream",
expected: true,
},
{
name: "Event stream with other types",
acceptHeader: "text/html, text/event-stream, application/json",
expected: true,
},
{
name: "JSON accept header",
acceptHeader: "application/json",
expected: false,
},
{
name: "HTML accept header",
acceptHeader: "text/html,application/xhtml+xml",
expected: false,
},
{
name: "Empty accept header",
acceptHeader: "",
expected: false,
},
{
name: "Similar but not event stream",
acceptHeader: "text/event-source",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com/test", nil)
if tt.acceptHeader != "" {
req.Header.Set("Accept", tt.acceptHeader)
}
result := processor.IsEventStreamRequest(req)
if result != tt.expected {
t.Errorf("Expected IsEventStreamRequest to return %v for accept header '%s', got %v", tt.expected, tt.acceptHeader, result)
}
})
}
}
// TestIsAjaxRequest tests AJAX request detection
func TestIsAjaxRequest(t *testing.T) {
logger := &MockLogger{}
processor := NewRequestProcessor(logger)
tests := []struct {
name string
setupHeader func(*http.Request)
expected bool
}{
{
name: "XMLHttpRequest header",
setupHeader: func(req *http.Request) {
req.Header.Set("X-Requested-With", "XMLHttpRequest")
},
expected: true,
},
{
name: "JSON content type",
setupHeader: func(req *http.Request) {
req.Header.Set("Content-Type", "application/json")
},
expected: true,
},
{
name: "JSON content type with charset",
setupHeader: func(req *http.Request) {
req.Header.Set("Content-Type", "application/json; charset=utf-8")
},
expected: true,
},
{
name: "JSON accept header",
setupHeader: func(req *http.Request) {
req.Header.Set("Accept", "application/json")
},
expected: true,
},
{
name: "JSON accept with other types",
setupHeader: func(req *http.Request) {
req.Header.Set("Accept", "text/html, application/json, application/xml")
},
expected: true,
},
{
name: "Multiple AJAX indicators",
setupHeader: func(req *http.Request) {
req.Header.Set("X-Requested-With", "XMLHttpRequest")
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
},
expected: true,
},
{
name: "Regular HTML request",
setupHeader: func(req *http.Request) {
req.Header.Set("Accept", "text/html,application/xhtml+xml")
},
expected: false,
},
{
name: "Form submission",
setupHeader: func(req *http.Request) {
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
},
expected: false,
},
{
name: "No special headers",
setupHeader: func(req *http.Request) {},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("POST", "http://example.com/api", nil)
tt.setupHeader(req)
result := processor.IsAjaxRequest(req)
if result != tt.expected {
t.Errorf("Expected IsAjaxRequest to return %v, got %v", tt.expected, result)
}
})
}
}
// TestWaitForInitialization tests initialization waiting
func TestWaitForInitialization(t *testing.T) {
logger := &MockLogger{}
processor := NewRequestProcessor(logger)
t.Run("Initialization completes successfully", func(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com/test", nil)
initComplete := make(chan struct{})
go func() {
time.Sleep(10 * time.Millisecond)
close(initComplete)
}()
err := processor.WaitForInitialization(req, initComplete)
if err != nil {
t.Errorf("Expected no error when initialization completes, got: %v", err)
}
})
t.Run("Request context cancelled", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
req := httptest.NewRequest("GET", "http://example.com/test", nil)
req = req.WithContext(ctx)
initComplete := make(chan struct{})
go func() {
time.Sleep(10 * time.Millisecond)
cancel()
}()
err := processor.WaitForInitialization(req, initComplete)
if err == nil {
t.Error("Expected error when request context is cancelled")
}
if !strings.Contains(err.Error(), "request cancelled") {
t.Errorf("Expected 'request cancelled' error, got: %v", err)
}
if len(logger.DebugCalls) == 0 {
t.Error("Expected debug log when request is cancelled")
}
})
t.Run("Initialization timeout", func(t *testing.T) {
if testing.Short() {
t.Skip("Skipping timeout test in short mode")
}
req := httptest.NewRequest("GET", "http://example.com/test", nil)
initComplete := make(chan struct{}) // Never closes
// Note: This test takes 30 seconds due to hardcoded timeout in implementation
start := time.Now()
err := processor.WaitForInitialization(req, initComplete)
duration := time.Since(start)
if err == nil {
t.Error("Expected timeout error")
}
if !strings.Contains(err.Error(), "timeout") {
t.Errorf("Expected timeout error, got: %v", err)
}
// The timeout should be around 30 seconds, allow some variance
if duration < 29*time.Second || duration > 31*time.Second {
t.Errorf("Expected timeout after ~30 seconds, but got %v", duration)
}
if len(logger.ErrorCalls) == 0 {
t.Error("Expected error log when timeout occurs")
}
})
}
// TestDetermineScheme tests scheme determination
func TestDetermineScheme(t *testing.T) {
logger := &MockLogger{}
processor := NewRequestProcessor(logger)
tests := []struct {
name string
setup func(*http.Request)
expected string
}{
{
name: "X-Forwarded-Proto HTTPS",
setup: func(req *http.Request) {
req.Header.Set("X-Forwarded-Proto", "https")
},
expected: "https",
},
{
name: "X-Forwarded-Proto HTTP",
setup: func(req *http.Request) {
req.Header.Set("X-Forwarded-Proto", "http")
},
expected: "http",
},
{
name: "TLS connection without header",
setup: func(req *http.Request) {
req.TLS = &tls.ConnectionState{}
},
expected: "https",
},
{
name: "No TLS, no header",
setup: func(req *http.Request) {
// No special setup
},
expected: "http",
},
{
name: "X-Forwarded-Proto takes precedence over TLS",
setup: func(req *http.Request) {
req.Header.Set("X-Forwarded-Proto", "http")
req.TLS = &tls.ConnectionState{}
},
expected: "http",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com/test", nil)
tt.setup(req)
result := processor.determineScheme(req)
if result != tt.expected {
t.Errorf("Expected scheme '%s', got '%s'", tt.expected, result)
}
})
}
}
// TestDetermineHost tests host determination
func TestDetermineHost(t *testing.T) {
logger := &MockLogger{}
processor := NewRequestProcessor(logger)
tests := []struct {
name string
setup func(*http.Request)
expected string
}{
{
name: "X-Forwarded-Host header present",
setup: func(req *http.Request) {
req.Header.Set("X-Forwarded-Host", "public.example.com")
},
expected: "public.example.com",
},
{
name: "No X-Forwarded-Host, use req.Host",
setup: func(req *http.Request) {
// No special setup, will use req.Host
},
expected: "example.com",
},
{
name: "Empty X-Forwarded-Host, fallback to req.Host",
setup: func(req *http.Request) {
req.Header.Set("X-Forwarded-Host", "")
},
expected: "example.com",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com/test", nil)
tt.setup(req)
result := processor.determineHost(req)
if result != tt.expected {
t.Errorf("Expected host '%s', got '%s'", tt.expected, result)
}
})
}
}
// TestBuildFullURL tests URL building
func TestBuildFullURL(t *testing.T) {
tests := []struct {
name string
scheme string
host string
path string
expected string
}{
{
name: "Basic URL construction",
scheme: "https",
host: "example.com",
path: "/callback",
expected: "https://example.com/callback",
},
{
name: "Path without leading slash",
scheme: "http",
host: "test.com",
path: "auth",
expected: "http://test.com/auth",
},
{
name: "Absolute HTTP URL in path",
scheme: "https",
host: "example.com",
path: "http://other.com/callback",
expected: "http://other.com/callback",
},
{
name: "Absolute HTTPS URL in path",
scheme: "http",
host: "example.com",
path: "https://secure.com/auth",
expected: "https://secure.com/auth",
},
{
name: "Root path",
scheme: "https",
host: "example.com:8080",
path: "/",
expected: "https://example.com:8080/",
},
{
name: "Empty path",
scheme: "https",
host: "example.com",
path: "",
expected: "https://example.com/",
},
{
name: "Path with query parameters",
scheme: "https",
host: "example.com",
path: "/callback?state=abc123",
expected: "https://example.com/callback?state=abc123",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := buildFullURL(tt.scheme, tt.host, tt.path)
if result != tt.expected {
t.Errorf("Expected URL '%s', got '%s'", tt.expected, result)
}
})
}
}
// TestRequestContext tests the RequestContext struct
func TestRequestContext(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com/test", nil)
rw := httptest.NewRecorder()
ctx := &RequestContext{
Writer: rw,
Request: req,
RedirectURL: "https://example.com/callback",
Scheme: "https",
Host: "example.com",
}
if ctx.Writer != rw {
t.Error("Expected Writer to be set correctly")
}
if ctx.Request != req {
t.Error("Expected Request to be set correctly")
}
if ctx.RedirectURL != "https://example.com/callback" {
t.Error("Expected RedirectURL to be set correctly")
}
if ctx.Scheme != "https" {
t.Error("Expected Scheme to be set correctly")
}
if ctx.Host != "example.com" {
t.Error("Expected Host to be set correctly")
}
}