diff --git a/.traefik.yml b/.traefik.yml index b80c2a5..8b69a2b 100644 --- a/.traefik.yml +++ b/.traefik.yml @@ -58,6 +58,16 @@ testData: - /public - /health - /metrics + + headers: # Custom headers to set with templated values from claims and tokens + - name: "X-User-Email" + value: "{{.Claims.email}}" + - name: "X-User-ID" + value: "{{.Claims.sub}}" + - name: "Authorization" + value: "Bearer {{.AccessToken}}" + - name: "X-User-Roles" + value: "{{range $i, $e := .Claims.roles}}{{if $i}},{{end}}{{$e}}{{end}}" # Advanced parameters (usually discovered automatically from provider metadata) revocationURL: https://accounts.google.com/revoke # Endpoint for revoking tokens @@ -243,3 +253,31 @@ configuration: Default: false required: false + + headers: + type: array + description: | + Custom HTTP headers to set with templated values derived from OIDC claims and tokens. + Each header has a name and a value template that can access: + - {{.Claims.field}} - Access ID token claims (e.g., email, sub, name) + - {{.AccessToken}} - The raw access token string + - {{.IdToken}} - The raw ID token string + - {{.RefreshToken}} - The raw refresh token string + + Templates support Go template syntax including conditionals and iteration. + Variable names are case-sensitive - use .Claims not .claims. + + Examples: + - name: "X-User-Email", value: "{{.Claims.email}}" + - name: "Authorization", value: "Bearer {{.AccessToken}}" + - name: "X-User-Roles", value: "{{range $i, $e := .Claims.roles}}{{if $i}},{{end}}{{$e}}{{end}}" + required: false + items: + type: object + properties: + name: + type: string + description: The HTTP header name to set + value: + type: string + description: Template string for the header value diff --git a/README.md b/README.md index e86975d..3587660 100644 --- a/README.md +++ b/README.md @@ -78,6 +78,7 @@ The middleware supports the following configuration options: | `oidcEndSessionURL` | The provider's end session endpoint | auto-discovered | `https://accounts.google.com/logout` | | `enablePKCE` | Enables PKCE (Proof Key for Code Exchange) for authorization code flow | `false` | `true`, `false` | | `refreshGracePeriodSeconds` | Seconds before token expiry to attempt proactive refresh | `60` | `120` | +| `headers` | Custom HTTP headers with templates that can access OIDC claims and tokens | none | See "Templated Headers" section | ## Usage Examples @@ -235,6 +236,41 @@ spec: - profile ``` +### With Templated Headers + +```yaml +apiVersion: traefik.io/v1alpha1 +kind: Middleware +metadata: + name: oidc-with-headers + namespace: traefik +spec: + plugin: + traefikoidc: + providerURL: https://accounts.google.com + clientID: 1234567890.apps.googleusercontent.com + clientSecret: your-client-secret + sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long + callbackURL: /oauth2/callback + logoutURL: /oauth2/logout + scopes: + - openid + - email + - profile + - roles + headers: + - name: "X-User-Email" + value: "{{.Claims.email}}" + - name: "X-User-ID" + value: "{{.Claims.sub}}" + - name: "Authorization" + value: "Bearer {{.AccessToken}}" + - name: "X-User-Roles" + value: "{{range $i, $e := .Claims.roles}}{{if $i}},{{end}}{{$e}}{{end}}" + - name: "X-Is-Admin" + value: "{{if eq .Claims.role \"admin\"}}true{{else}}false{{end}}" +``` + ### With PKCE Enabled ```yaml @@ -424,6 +460,15 @@ http: - /public - /health - /metrics + headers: + - name: "X-User-Email" + value: "{{.Claims.email}}" + - name: "X-User-ID" + value: "{{.Claims.sub}}" + - name: "Authorization" + value: "Bearer {{.AccessToken}}" + - name: "X-User-Roles" + value: "{{range $i, $e := .Claims.roles}}{{if $i}},{{end}}{{$e}}{{end}}" ``` ## Advanced Configuration @@ -463,8 +508,52 @@ This middleware aims to provide long-lived user sessions, typically up to 24 hou ### Token Caching and Blacklisting The middleware automatically caches validated tokens to improve performance and maintains a blacklist of revoked tokens. +### Templated Headers + +The middleware supports setting custom HTTP headers with values templated from OIDC claims and tokens. This allows you to pass authentication information to downstream services in a flexible, customized format. + +Templates can access the following variables: +- `{{.Claims.field}}` - Access individual claims from the ID token (e.g., `{{.Claims.email}}`, `{{.Claims.sub}}`) +- `{{.AccessToken}}` - The raw access token string +- `{{.IdToken}}` - The raw ID token string (same as AccessToken in most configurations) +- `{{.RefreshToken}}` - The raw refresh token string + +**Example configuration:** +```yaml +headers: + - name: "X-User-Email" + value: "{{.Claims.email}}" + - name: "X-User-ID" + value: "{{.Claims.sub}}" + - name: "Authorization" + value: "Bearer {{.AccessToken}}" + - name: "X-User-Name" + value: "{{.Claims.given_name}} {{.Claims.family_name}}" +``` + +**Advanced template examples:** + +Conditional logic: +```yaml +headers: + - name: "X-Is-Admin" + value: "{{if eq .Claims.role \"admin\"}}true{{else}}false{{end}}" +``` + +Array handling: +```yaml +headers: + - name: "X-User-Roles" + value: "{{range $i, $e := .Claims.roles}}{{if $i}},{{end}}{{$e}}{{end}}" +``` + +**Notes:** +- Variable names are case-sensitive (use `.Claims`, not `.claims`) +- Missing claims will result in `` in the header value +- The middleware validates templates during startup and logs errors for invalid templates + +### Default Headers Set for Downstream Services -### Headers Set for Downstream Services When a user is authenticated, the middleware sets the following headers for downstream services: diff --git a/main.go b/main.go index ca39751..4bc8a6d 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,7 @@ package traefikoidc import ( + "bytes" "context" "encoding/json" "fmt" @@ -11,6 +12,7 @@ import ( "net/url" "runtime" "strings" + "text/template" "time" "github.com/google/uuid" @@ -115,8 +117,9 @@ type TraefikOidc struct { endSessionURL string postLogoutRedirectURI string sessionManager *SessionManager - tokenExchanger TokenExchanger // Added field for mocking - refreshGracePeriod time.Duration // Configurable grace period for proactive refresh + tokenExchanger TokenExchanger // Added field for mocking + refreshGracePeriod time.Duration // Configurable grace period for proactive refresh + headerTemplates map[string]*template.Template // Parsed templates for custom headers } // ProviderMetadata holds OIDC provider metadata @@ -421,6 +424,19 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h t.jwtVerifier = t t.startTokenCleanup() t.tokenExchanger = t // Initialize the interface field to self + + // Initialize and parse header templates + t.headerTemplates = make(map[string]*template.Template) + for _, header := range config.Headers { + tmpl, err := template.New(header.Name).Parse(header.Value) + if err != nil { + logger.Errorf("Failed to parse header template for %s: %v", header.Name, err) + continue + } + t.headerTemplates[header.Name] = tmpl + logger.Debugf("Parsed template for header %s: %s", header.Name, header.Value) + } + go t.initializeMetadata(config.ProviderURL) return t, nil @@ -793,6 +809,43 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http req.Header.Set("X-Auth-Request-Token", idToken) } + // Execute and set templated headers if configured + if len(t.headerTemplates) > 0 { + accessToken := session.GetAccessToken() + refreshToken := session.GetRefreshToken() + claims, err := t.extractClaimsFunc(accessToken) + if err != nil { + t.logger.Errorf("Failed to extract claims for template headers: %v", err) + } else { + // Create template data context with available tokens and claims + // Fields must be exported (uppercase) to be accessible in templates + templateData := struct { + // These fields need to be exported (uppercase) for template access + AccessToken string + IdToken string + RefreshToken string + Claims map[string]interface{} + }{ + AccessToken: accessToken, + IdToken: accessToken, // Using access token as ID token + RefreshToken: refreshToken, + Claims: claims, + } + + // Execute each template and set the resulting header + for headerName, tmpl := range t.headerTemplates { + var buf bytes.Buffer + if err := tmpl.Execute(&buf, templateData); err != nil { + t.logger.Errorf("Failed to execute template for header %s: %v", headerName, err) + continue + } + headerValue := buf.String() + req.Header.Set(headerName, headerValue) + t.logger.Debugf("Set templated header %s = %s", headerName, headerValue) + } + } + } + // Set security headers rw.Header().Set("X-Frame-Options", "DENY") rw.Header().Set("X-Content-Type-Options", "nosniff") diff --git a/settings.go b/settings.go index c737cc5..ae51a96 100644 --- a/settings.go +++ b/settings.go @@ -10,6 +10,18 @@ import ( "strings" ) +// TemplatedHeader represents a custom HTTP header with a templated value. +// The value can contain template expressions that will be evaluated for each +// authenticated request, such as {{.claims.email}} or {{.accessToken}}. +type TemplatedHeader struct { + // Name is the HTTP header name to set (e.g., "X-Forwarded-Email") + Name string `json:"name"` + + // Value is the template string for the header value + // Example: "{{.claims.email}}", "Bearer {{.accessToken}}" + Value string `json:"value"` +} + // Config holds the configuration for the OIDC middleware. // It provides all necessary settings to configure OpenID Connect authentication // with various providers like Auth0, Logto, or any standard OIDC provider. @@ -89,6 +101,17 @@ type Config struct { // the plugin should attempt to refresh it proactively (optional) // Default: 60 RefreshGracePeriodSeconds int `json:"refreshGracePeriodSeconds"` + // Headers defines custom HTTP headers to set with templated values (optional) + // Values can reference tokens and claims using Go templates with the following variables: + // - {{.AccessToken}} - The access token (ID token) + // - {{.IdToken}} - Same as AccessToken (for consistency) + // - {{.RefreshToken}} - The refresh token + // - {{.Claims.email}} - Access token claims (use proper case for claim names) + // Examples: + // + // [{Name: "X-Forwarded-Email", Value: "{{.Claims.email}}"}] + // [{Name: "Authorization", Value: "Bearer {{.AccessToken}}"}] + Headers []TemplatedHeader `json:"headers"` } const ( @@ -221,6 +244,33 @@ func (c *Config) Validate() error { return fmt.Errorf("refreshGracePeriodSeconds cannot be negative") } + // Validate headers configuration + for _, header := range c.Headers { + if header.Name == "" { + return fmt.Errorf("header name cannot be empty") + } + if header.Value == "" { + return fmt.Errorf("header value template cannot be empty") + } + if !strings.Contains(header.Value, "{{") || !strings.Contains(header.Value, "}}") { + return fmt.Errorf("header value '%s' does not appear to be a valid template (missing {{ }})", header.Value) + } + + // Provide more helpful guidance for common template errors + if strings.Contains(header.Value, "{{.claims") { + return fmt.Errorf("header template '%s' appears to use lowercase 'claims' - use '{{.Claims...' instead (case sensitive)", header.Value) + } + if strings.Contains(header.Value, "{{.accessToken") { + return fmt.Errorf("header template '%s' appears to use lowercase 'accessToken' - use '{{.AccessToken...' instead (case sensitive)", header.Value) + } + if strings.Contains(header.Value, "{{.idToken") { + return fmt.Errorf("header template '%s' appears to use lowercase 'idToken' - use '{{.IdToken...' instead (case sensitive)", header.Value) + } + if strings.Contains(header.Value, "{{.refreshToken") { + return fmt.Errorf("header template '%s' appears to use lowercase 'refreshToken' - use '{{.RefreshToken...' instead (case sensitive)", header.Value) + } + } + return nil } diff --git a/templated_header_config_test.go b/templated_header_config_test.go new file mode 100644 index 0000000..cd76d19 --- /dev/null +++ b/templated_header_config_test.go @@ -0,0 +1,197 @@ +package traefikoidc + +import ( + "testing" + "text/template" +) + +func TestTemplatedHeaderValidation(t *testing.T) { + tests := []struct { + name string + header TemplatedHeader + expectedError string + }{ + { + name: "Empty Name", + header: TemplatedHeader{Name: "", Value: "{{.Claims.email}}"}, + expectedError: "header name cannot be empty", + }, + { + name: "Empty Value", + header: TemplatedHeader{Name: "X-Email", Value: ""}, + expectedError: "header value template cannot be empty", + }, + { + name: "Not a Template", + header: TemplatedHeader{Name: "X-Email", Value: "static-value"}, + expectedError: "header value 'static-value' does not appear to be a valid template (missing {{ }})", + }, + { + name: "Lowercase claims", + header: TemplatedHeader{Name: "X-Email", Value: "{{.claims.email}}"}, + expectedError: "header template '{{.claims.email}}' appears to use lowercase 'claims' - use '{{.Claims...' instead (case sensitive)", + }, + { + name: "Lowercase accessToken", + header: TemplatedHeader{Name: "X-Token", Value: "Bearer {{.accessToken}}"}, + expectedError: "header template 'Bearer {{.accessToken}}' appears to use lowercase 'accessToken' - use '{{.AccessToken...' instead (case sensitive)", + }, + { + name: "Lowercase idToken", + header: TemplatedHeader{Name: "X-Token", Value: "Bearer {{.idToken}}"}, + expectedError: "header template 'Bearer {{.idToken}}' appears to use lowercase 'idToken' - use '{{.IdToken...' instead (case sensitive)", + }, + { + name: "Lowercase refreshToken", + header: TemplatedHeader{Name: "X-Refresh", Value: "Bearer {{.refreshToken}}"}, + expectedError: "header template 'Bearer {{.refreshToken}}' appears to use lowercase 'refreshToken' - use '{{.RefreshToken...' instead (case sensitive)", + }, + { + name: "Valid Template", + header: TemplatedHeader{Name: "X-Email", Value: "{{.Claims.email}}"}, + expectedError: "", + }, + { + name: "Valid Bearer Token Template", + header: TemplatedHeader{Name: "Authorization", Value: "Bearer {{.AccessToken}}"}, + expectedError: "", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + config := &Config{ + ProviderURL: "https://provider.com", + CallbackURL: "/callback", + ClientID: "client-id", + ClientSecret: "client-secret", + SessionEncryptionKey: "this-is-a-long-enough-encryption-key", + RateLimit: 10, // Adding minimum required rate limit + Headers: []TemplatedHeader{tc.header}, + } + + err := config.Validate() + if tc.expectedError == "" { + if err != nil { + t.Errorf("Expected no error, got: %v", err) + } + } else { + if err == nil { + t.Errorf("Expected error: %s, got nil", tc.expectedError) + } else if err.Error() != tc.expectedError { + t.Errorf("Expected error: %s, got: %s", tc.expectedError, err.Error()) + } + } + }) + } +} + +func TestTemplateParsingInNew(t *testing.T) { + // Test successful parsing of templates during middleware creation + tests := []struct { + name string + headers []TemplatedHeader + expectedTemplates int + expectError bool + }{ + { + name: "Single Valid Template", + headers: []TemplatedHeader{ + {Name: "X-Email", Value: "{{.Claims.email}}"}, + }, + expectedTemplates: 1, + expectError: false, + }, + { + name: "Multiple Valid Templates", + headers: []TemplatedHeader{ + {Name: "X-Email", Value: "{{.Claims.email}}"}, + {Name: "X-User-ID", Value: "{{.Claims.sub}}"}, + {Name: "Authorization", Value: "Bearer {{.AccessToken}}"}, + }, + expectedTemplates: 3, + expectError: false, + }, + { + name: "Invalid Template", + headers: []TemplatedHeader{ + {Name: "X-Email", Value: "{{.Claims.email"}, // Missing closing braces + }, + expectedTemplates: 0, + expectError: true, + }, + { + name: "Mix of Valid and Invalid Templates", + headers: []TemplatedHeader{ + {Name: "X-Email", Value: "{{.Claims.email}}"}, + {Name: "X-Invalid", Value: "{{if .Claims.admin}}Admin{{end"}, // Invalid template + }, + expectedTemplates: 1, // Only the valid template should be parsed + expectError: true, // We expect an error for the invalid template, but we'll handle it + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // For testing template parsing, we'll directly try to parse the templates instead of using New() + // This avoids the provider discovery that would fail in tests + headerTemplates := make(map[string]*template.Template) + + // Special handling for the mixed valid/invalid templates case + if tc.name == "Mix of Valid and Invalid Templates" { + // Process templates one at a time so we can still have valid templates + for _, header := range tc.headers { + tmpl, err := template.New(header.Name).Parse(header.Value) + if err != nil { + // We expect an error for the invalid template + if !tc.expectError { + t.Errorf("Unexpected error parsing template %s: %v", header.Name, err) + } + // Skip this template but continue processing others + continue + } + headerTemplates[header.Name] = tmpl + } + } else { + // Normal handling for other test cases + var parseErr error + for _, header := range tc.headers { + tmpl, err := template.New(header.Name).Parse(header.Value) + if err != nil { + parseErr = err + break + } + headerTemplates[header.Name] = tmpl + } + + if tc.expectError { + if parseErr == nil { + t.Error("Expected error parsing templates but got nil") + } + return + } + + if parseErr != nil { + t.Fatalf("Unexpected error: %v", parseErr) + } + } + + // Check the number of parsed templates + if len(headerTemplates) != tc.expectedTemplates { + t.Errorf("Expected %d parsed templates, got %d", tc.expectedTemplates, len(headerTemplates)) + } + + // Check each template was parsed + for _, header := range tc.headers { + // Skip the known invalid templates + if header.Value == "{{.Claims.email" || header.Value == "{{if .Claims.admin}}Admin{{end" { + continue + } + + if _, ok := headerTemplates[header.Name]; !ok { + t.Errorf("Template for header %s was not parsed", header.Name) + } + } + }) + } +} diff --git a/templated_header_execution_test.go b/templated_header_execution_test.go new file mode 100644 index 0000000..a40b522 --- /dev/null +++ b/templated_header_execution_test.go @@ -0,0 +1,237 @@ +package traefikoidc + +import ( + "bytes" + "testing" + "text/template" +) + +// TestTemplateExecution tests that templates are executed correctly with different types of claims +func TestTemplateExecution(t *testing.T) { + tests := []struct { + name string + templateText string + data map[string]interface{} + expectedValue string + expectError bool + }{ + { + name: "String Claim", + templateText: "{{.Claims.email}}", + data: map[string]interface{}{ + "Claims": map[string]interface{}{ + "email": "user@example.com", + }, + }, + expectedValue: "user@example.com", + expectError: false, + }, + { + name: "Number Claim", + templateText: "{{.Claims.age}}", + data: map[string]interface{}{ + "Claims": map[string]interface{}{ + "age": 30, + }, + }, + expectedValue: "30", + expectError: false, + }, + { + name: "Boolean Claim", + templateText: "{{.Claims.admin}}", + data: map[string]interface{}{ + "Claims": map[string]interface{}{ + "admin": true, + }, + }, + expectedValue: "true", + expectError: false, + }, + { + name: "Array Claim", + templateText: "{{index .Claims.roles 0}}", + data: map[string]interface{}{ + "Claims": map[string]interface{}{ + "roles": []string{"admin", "user"}, + }, + }, + expectedValue: "admin", + expectError: false, + }, + { + name: "Nested Object Claim", + templateText: "{{.Claims.user.name}}", + data: map[string]interface{}{ + "Claims": map[string]interface{}{ + "user": map[string]interface{}{ + "name": "John Doe", + }, + }, + }, + expectedValue: "John Doe", + expectError: false, + }, + { + name: "Access Token", + templateText: "Bearer {{.AccessToken}}", + data: map[string]interface{}{ + "AccessToken": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30.Et9HFtf9R3GEMA0IICOfFMVXY7kkTX1wr4qCyhIf58U", + }, + expectedValue: "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30.Et9HFtf9R3GEMA0IICOfFMVXY7kkTX1wr4qCyhIf58U", + expectError: false, + }, + { + name: "ID Token", + templateText: "{{.IdToken}}", + data: map[string]interface{}{ + "IdToken": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30.Et9HFtf9R3GEMA0IICOfFMVXY7kkTX1wr4qCyhIf58U", + }, + expectedValue: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30.Et9HFtf9R3GEMA0IICOfFMVXY7kkTX1wr4qCyhIf58U", + expectError: false, + }, + { + name: "Refresh Token", + templateText: "{{.RefreshToken}}", + data: map[string]interface{}{ + "RefreshToken": "refresh-token-value", + }, + expectedValue: "refresh-token-value", + expectError: false, + }, + { + name: "Conditional Template", + templateText: "{{if .Claims.admin}}Admin User{{else}}Regular User{{end}}", + data: map[string]interface{}{ + "Claims": map[string]interface{}{ + "admin": true, + }, + }, + expectedValue: "Admin User", + expectError: false, + }, + { + name: "Multiple Claims", + templateText: "{{.Claims.firstName}} {{.Claims.lastName}} <{{.Claims.email}}>", + data: map[string]interface{}{ + "Claims": map[string]interface{}{ + "firstName": "John", + "lastName": "Doe", + "email": "john.doe@example.com", + }, + }, + expectedValue: "John Doe ", + expectError: false, + }, + { + name: "Missing Claim", + templateText: "{{.Claims.missing}}", + data: map[string]interface{}{ + "Claims": map[string]interface{}{}, + }, + expectedValue: "", + expectError: false, // Go templates don't error on missing values + }, + { + name: "Invalid Template Syntax", + templateText: "{{.Claims.email", + data: map[string]interface{}{ + "Claims": map[string]interface{}{ + "email": "user@example.com", + }, + }, + expectedValue: "", + expectError: true, // Parsing should fail + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + tmpl, err := template.New("test").Parse(tc.templateText) + + if tc.expectError { + if err == nil { + t.Fatal("Expected template parsing error, but got nil") + } + return + } + + if err != nil { + t.Fatalf("Failed to parse template: %v", err) + } + + var buf bytes.Buffer + err = tmpl.Execute(&buf, tc.data) + if err != nil { + t.Fatalf("Failed to execute template: %v", err) + } + + result := buf.String() + if result != tc.expectedValue { + t.Errorf("Expected template output %q, got %q", tc.expectedValue, result) + } + }) + } +} + +// TestTemplateExecutionContext tests the specific template data context used in processAuthorizedRequest +func TestTemplateExecutionContext(t *testing.T) { + // Define a test struct that matches the one used in processAuthorizedRequest + type templateData struct { + AccessToken string + IdToken string + RefreshToken string + Claims map[string]interface{} + } + + // Test cases + tests := []struct { + name string + templateText string + data templateData + expectedValue string + }{ + { + name: "Access and ID token identity", + templateText: "Access: {{.AccessToken}} ID: {{.IdToken}}", + data: templateData{ + AccessToken: "access-token", + IdToken: "access-token", // Same as AccessToken in processAuthorizedRequest + Claims: map[string]interface{}{}, + }, + expectedValue: "Access: access-token ID: access-token", + }, + { + name: "Combining tokens and claims", + templateText: "User: {{.Claims.sub}} Token: {{.AccessToken}}", + data: templateData{ + AccessToken: "access-token", + IdToken: "access-token", + Claims: map[string]interface{}{ + "sub": "user123", + }, + }, + expectedValue: "User: user123 Token: access-token", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + tmpl, err := template.New("test").Parse(tc.templateText) + if err != nil { + t.Fatalf("Failed to parse template: %v", err) + } + + var buf bytes.Buffer + err = tmpl.Execute(&buf, tc.data) + if err != nil { + t.Fatalf("Failed to execute template: %v", err) + } + + result := buf.String() + if result != tc.expectedValue { + t.Errorf("Expected template output %q, got %q", tc.expectedValue, result) + } + }) + } +} diff --git a/templated_header_integration_test.go b/templated_header_integration_test.go new file mode 100644 index 0000000..67357ba --- /dev/null +++ b/templated_header_integration_test.go @@ -0,0 +1,424 @@ +package traefikoidc + +import ( + "net/http" + "net/http/httptest" + "testing" + "text/template" + "time" + + "golang.org/x/time/rate" +) + +// TestTemplatedHeadersIntegration tests that templated headers are correctly added to requests +// in the actual middleware flow +func TestTemplatedHeadersIntegration(t *testing.T) { + // Create a TestSuite to use its helper methods and fields + ts := &TestSuite{t: t} + ts.Setup() + + tests := []struct { + name string + headers []TemplatedHeader + sessionSetup func(*SessionData) + claims map[string]interface{} + expectedHeaders map[string]string + interceptedHeaders map[string]string + }{ + { + name: "Basic Email Header", + headers: []TemplatedHeader{ + {Name: "X-User-Email", Value: "{{.Claims.email}}"}, + }, + claims: map[string]interface{}{ + "email": "user@example.com", + }, + expectedHeaders: map[string]string{ + "X-User-Email": "user@example.com", + }, + }, + { + name: "Multiple Headers", + headers: []TemplatedHeader{ + {Name: "X-User-Email", Value: "{{.Claims.email}}"}, + {Name: "X-User-ID", Value: "{{.Claims.sub}}"}, + {Name: "X-User-Name", Value: "{{.Claims.given_name}} {{.Claims.family_name}}"}, + }, + claims: map[string]interface{}{ + "email": "user@example.com", + "sub": "user123", + "given_name": "John", + "family_name": "Doe", + }, + expectedHeaders: map[string]string{ + "X-User-Email": "user@example.com", + "X-User-ID": "user123", + "X-User-Name": "John Doe", + }, + }, + { + name: "Authorization Header with Bearer Token", + headers: []TemplatedHeader{ + {Name: "Authorization", Value: "Bearer {{.AccessToken}}"}, + }, + expectedHeaders: map[string]string{ + // We'll update this dynamically after generating the token + "Authorization": "", + }, + }, + { + name: "Missing Claim", + headers: []TemplatedHeader{ + {Name: "X-User-Role", Value: "{{.Claims.role}}"}, + }, + claims: map[string]interface{}{ + "email": "user@example.com", + // role claim is missing + }, + expectedHeaders: map[string]string{ + "X-User-Role": "", // Go templates provide for missing fields + }, + }, + { + name: "Conditional Header", + headers: []TemplatedHeader{ + {Name: "X-User-Admin", Value: "{{if .Claims.is_admin}}true{{else}}false{{end}}"}, + }, + claims: map[string]interface{}{ + "email": "admin@example.com", + "is_admin": true, + }, + expectedHeaders: map[string]string{ + "X-User-Admin": "true", + }, + }, + { + name: "Combined Token and Claim", + headers: []TemplatedHeader{ + {Name: "X-Auth-Info", Value: "User={{.Claims.email}}, Token={{.AccessToken}}"}, + }, + claims: map[string]interface{}{ + "email": "user@example.com", + }, + expectedHeaders: map[string]string{ + // We'll update this dynamically after generating the token + "X-Auth-Info": "", + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Create token with the test claims + token := ts.token + if len(tc.claims) > 0 { + var err error + claims := map[string]interface{}{ + "iss": "https://test-issuer.com", + "aud": "test-client-id", + "exp": float64(3000000000), // Far future timestamp + "iat": float64(1000000000), + "nbf": float64(1000000000), + "sub": "test-subject", + "nonce": "test-nonce", + "jti": generateRandomString(16), + } + + // Add the test-specific claims + for k, v := range tc.claims { + claims[k] = v + } + + token, err = createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims) + if err != nil { + t.Fatalf("Failed to create test JWT: %v", err) + } + } + + // Update expectedHeaders for the token-based tests after token generation + if tc.name == "Authorization Header with Bearer Token" { + tc.expectedHeaders["Authorization"] = "Bearer " + token + } + + if tc.name == "Combined Token and Claim" { + tc.expectedHeaders["X-Auth-Info"] = "User=user@example.com, Token=" + token + } + + // Store intercepted headers for verification + interceptedHeaders := make(map[string]string) + + // Create a test next handler that captures the headers + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Capture headers for verification + for name := range tc.expectedHeaders { + if value := r.Header.Get(name); value != "" { + interceptedHeaders[name] = value + } + } + w.WriteHeader(http.StatusOK) + }) + + // Instead of using New(), we'll directly create a TraefikOidc instance + // similar to how it's done in TestSuite.Setup() + tOidc := &TraefikOidc{ + next: nextHandler, + name: "test", + redirURLPath: "/callback", + logoutURLPath: "/callback/logout", + issuerURL: "https://test-issuer.com", + clientID: "test-client-id", + clientSecret: "test-client-secret", + jwkCache: ts.mockJWKCache, + jwksURL: "https://test-jwks-url.com", + tokenBlacklist: NewCache(), + tokenCache: NewTokenCache(), + limiter: rate.NewLimiter(rate.Every(time.Second), 10), + logger: NewLogger("debug"), + allowedUserDomains: map[string]struct{}{"example.com": {}}, + excludedURLs: map[string]struct{}{"/favicon": {}}, + httpClient: &http.Client{}, + initComplete: make(chan struct{}), + sessionManager: ts.sessionManager, + extractClaimsFunc: extractClaims, + headerTemplates: make(map[string]*template.Template), + } + + // Initialize and parse header templates + for _, header := range tc.headers { + tmpl, err := template.New(header.Name).Parse(header.Value) + if err != nil { + t.Fatalf("Failed to parse header template for %s: %v", header.Name, err) + } + tOidc.headerTemplates[header.Name] = tmpl + } + + // Close the initComplete channel to bypass the waiting + close(tOidc.initComplete) + + // Create a test request + req := httptest.NewRequest("GET", "/protected", nil) + req.Header.Set("X-Forwarded-Proto", "https") + req.Header.Set("X-Forwarded-Host", "example.com") + rr := httptest.NewRecorder() + + // Create a session + session, err := tOidc.sessionManager.GetSession(req) + if err != nil { + t.Fatalf("Failed to get session: %v", err) + } + + // Setup the session with authentication data + session.SetAuthenticated(true) + session.SetEmail("user@example.com") + session.SetAccessToken(token) + session.SetRefreshToken("test-refresh-token") + + if err := session.Save(req, rr); err != nil { + t.Fatalf("Failed to save session: %v", err) + } + + // Add session cookies to the request + for _, cookie := range rr.Result().Cookies() { + req.AddCookie(cookie) + } + + // Reset the response recorder for the main test + rr = httptest.NewRecorder() + + // Process the request + tOidc.ServeHTTP(rr, req) + + // Check status code + if rr.Code != http.StatusOK { + t.Errorf("Expected status code %d, got %d", http.StatusOK, rr.Code) + } + + // Verify headers were set correctly + for name, expectedValue := range tc.expectedHeaders { + if value, exists := interceptedHeaders[name]; !exists { + t.Errorf("Expected header %s was not set", name) + } else if value != expectedValue { + t.Errorf("Header %s expected value %q, got %q", name, expectedValue, value) + } + } + }) + } +} + +// TestEdgeCaseTemplatedHeaders tests edge cases for templated headers +func TestEdgeCaseTemplatedHeaders(t *testing.T) { + // Create a TestSuite to use its helper methods and fields + ts := &TestSuite{t: t} + ts.Setup() + + tests := []struct { + name string + headers []TemplatedHeader + claims map[string]interface{} + shouldExecuteCheck bool + }{ + { + name: "Very Large Template", + headers: []TemplatedHeader{ + { + Name: "X-Large-Header", + Value: createLargeTemplate(500), // Template with 500 variable references + }, + }, + claims: createLargeClaims(500), // Map with 500 claims + shouldExecuteCheck: true, + }, + { + name: "Array Claim Access", + headers: []TemplatedHeader{ + {Name: "X-Roles", Value: "{{range $i, $e := .Claims.roles}}{{if $i}},{{end}}{{$e}}{{end}}"}, + }, + claims: map[string]interface{}{ + "roles": []interface{}{"admin", "user", "manager"}, + }, + shouldExecuteCheck: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Create token with the test claims + claims := map[string]interface{}{ + "iss": "https://test-issuer.com", + "aud": "test-client-id", + "exp": float64(3000000000), // Far future timestamp + "iat": float64(1000000000), + "nbf": float64(1000000000), + "sub": "test-subject", + "nonce": "test-nonce", + "jti": generateRandomString(16), + } + + // Add the test-specific claims + for k, v := range tc.claims { + claims[k] = v + } + + token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims) + if err != nil { + t.Fatalf("Failed to create test JWT: %v", err) + } + + // Create a test next handler + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + // Instead of using New(), we'll directly create a TraefikOidc instance + // similar to how it's done in TestSuite.Setup() + tOidc := &TraefikOidc{ + next: nextHandler, + name: "test", + redirURLPath: "/callback", + logoutURLPath: "/callback/logout", + issuerURL: "https://test-issuer.com", + clientID: "test-client-id", + clientSecret: "test-client-secret", + jwkCache: ts.mockJWKCache, + jwksURL: "https://test-jwks-url.com", + tokenBlacklist: NewCache(), + tokenCache: NewTokenCache(), + limiter: rate.NewLimiter(rate.Every(time.Second), 10), + logger: NewLogger("debug"), + allowedUserDomains: map[string]struct{}{"example.com": {}}, + excludedURLs: map[string]struct{}{"/favicon": {}}, + httpClient: &http.Client{}, + initComplete: make(chan struct{}), + sessionManager: ts.sessionManager, + extractClaimsFunc: extractClaims, + headerTemplates: make(map[string]*template.Template), + } + + // Initialize and parse header templates + for _, header := range tc.headers { + tmpl, err := template.New(header.Name).Parse(header.Value) + if err != nil { + t.Fatalf("Failed to parse header template for %s: %v", header.Name, err) + } + tOidc.headerTemplates[header.Name] = tmpl + } + + // Close the initComplete channel to bypass the waiting + close(tOidc.initComplete) + + // Create a test request + req := httptest.NewRequest("GET", "/protected", nil) + req.Header.Set("X-Forwarded-Proto", "https") + req.Header.Set("X-Forwarded-Host", "example.com") + rr := httptest.NewRecorder() + + // Create a session + session, err := tOidc.sessionManager.GetSession(req) + if err != nil { + t.Fatalf("Failed to get session: %v", err) + } + + // Setup the session with authentication data + session.SetAuthenticated(true) + session.SetEmail("user@example.com") + session.SetAccessToken(token) + session.SetRefreshToken("test-refresh-token") + + if err := session.Save(req, rr); err != nil { + t.Fatalf("Failed to save session: %v", err) + } + + // Add session cookies to the request + for _, cookie := range rr.Result().Cookies() { + req.AddCookie(cookie) + } + + // Reset the response recorder for the main test + rr = httptest.NewRecorder() + + // Process the request + tOidc.ServeHTTP(rr, req) + + // Check status code + if rr.Code != http.StatusOK { + t.Errorf("Expected status code %d, got %d", http.StatusOK, rr.Code) + } + + // We are primarily checking that these edge cases don't cause panics or errors + // For the array test, we can verify the content + if tc.name == "Array Claim Access" { + // Check if the header was set + headerValue := req.Header.Get("X-Roles") + expectedValue := "admin,user,manager" + if headerValue != expectedValue { + t.Errorf("Expected X-Roles header to be %q, got %q", expectedValue, headerValue) + } + } + }) + } +} + +// Helper functions for edge case tests + +// createLargeTemplate creates a template with many variable references +func createLargeTemplate(size int) string { + template := "{{with .Claims}}" + for i := 0; i < size; i++ { + if i > 0 { + template += "," + } + template += "{{.field" + string(rune('a'+i%26)) + string(rune('0'+i%10)) + "}}" + } + template += "{{end}}" + return template +} + +// createLargeClaims creates a map with many claims for testing large templates +func createLargeClaims(size int) map[string]interface{} { + claims := make(map[string]interface{}) + for i := 0; i < size; i++ { + key := "field" + string(rune('a'+i%26)) + string(rune('0'+i%10)) + claims[key] = "value" + string(rune('a'+i%26)) + string(rune('0'+i%10)) + } + return claims +}