Compare commits

...

4 Commits

Author SHA1 Message Date
lukaszraczylo c474bbafd6 Cleanup [dec2025] (#101)
* Cleanup excessive comments.

* Remove leftovers hanging around from previous refactor

* Improve test coverage
2025-12-09 01:38:02 +00:00
lukaszraczylo 9126c74723 December 2025 Improvements - Azure AD, Internal Networks, Startup Race Condition (#100)
* Allow internal IPs for OIDC configuration via extra flag.

Addresses issue #97

* Allow for internal IPs in OIDC configuration.

Addresses issue #97.

* feat: Add allowPrivateIPAddresses config option for internal networks

Adds a new configuration option `allowPrivateIPAddresses` that allows
OIDC provider URLs to use private IP addresses (10.x.x.x, 172.16-31.x.x,
192.168.x.x). This is useful for internal deployments where Keycloak or
other OIDC providers run on private networks without DNS resolution.

Security considerations:
- Loopback addresses (127.0.0.1, localhost, ::1) remain blocked
- Link-local addresses (169.254.x.x) remain blocked
- Default is false (secure by default)

Fixes #97

* feat: Support non-email user identifiers for Azure AD

Add userIdentifierClaim configuration option to support Azure AD users
without email addresses. This allows using alternative JWT claims like
"sub", "oid", "upn", or "preferred_username" for user identification.

- Default behavior uses "email" claim (backward compatible)
- Falls back to "sub" claim if configured claim is missing
- allowedUsers matches against the configured claim value
- allowedUserDomains only applies when using email-based identification

Fixes #95

* Race condition on traefik pod startup

When the plugin initializes and calls GetMetadataWithRecovery():

1. Checks cache first (if metadata is cached, returns immediately)
2. Creates a retry executor with startup-optimized settings (10 attempts, 1s delays)
3. Attempts to fetch metadata from the OIDC provider
4. If the fetch fails with a retryable error (connection refused, EOF, TLS/certificate errors, Traefik default cert), it waits and retries
5. After 10 attempts or on a non-retryable error, returns the error

This allows the plugin to handle the race condition where:
- Traefik initializes the plugin before routes are established
- Traefik serves its default certificate before loading real ones
- The OIDC provider pod isn't fully ready yet

Fixes issue #90

* Race condition on traefik pod startup

When the plugin initializes and calls GetMetadataWithRecovery():

1. Checks cache first (if metadata is cached, returns immediately)
2. Creates a retry executor with startup-optimized settings (10 attempts, 1s delays)
3. Attempts to fetch metadata from the OIDC provider
4. If the fetch fails with a retryable error (connection refused, EOF, TLS/certificate errors, Traefik default cert), it waits and retries
5. After 10 attempts or on a non-retryable error, returns the error

This allows the plugin to handle the race condition where:
- Traefik initializes the plugin before routes are established
- Traefik serves its default certificate before loading real ones
- The OIDC provider pod isn't fully ready yet

Fixes issue #90

* Headers too big and 431 responses

Added new option `minimalHeaders` to reduce the size of forwarded headers from the auth middleware to backend services.

  - When minimalHeaders: false (default): All headers are forwarded as before
    - X-Forwarded-User (always set)
    - X-Auth-Request-Redirect
    - X-Auth-Request-User
    - X-Auth-Request-Token (the large ID token)
    - X-User-Groups, X-User-Roles (if configured)
  - When minimalHeaders: true: Reduces header overhead
    - X-Forwarded-User (always set)
    - X-User-Groups, X-User-Roles (still forwarded if configured)
    - Custom templated headers (still processed)
    - Skipped: X-Auth-Request-Token, X-Auth-Request-User, X-Auth-Request-Redirect

Fixes issues #64 and #86
2025-12-08 14:21:17 +00:00
lukaszraczylo a750c4f5b9 Size computation for allocation may overflow (#99)
* Size computation for allocation may overflow

Performing calculations involving the size of potentially large strings or slices can result in an overflow (for signed integer types) or a wraparound (for unsigned types). An overflow causes the result of the calculation to become negative, while a wraparound results in a small (positive) number.
2025-12-08 11:22:28 +00:00
lukaszraczylo 56051779ee Hotfix: goreleaser archive format. 2025-12-08 02:39:40 +00:00
263 changed files with 36881 additions and 52459 deletions
+1 -1
View File
@@ -18,6 +18,6 @@ jobs:
pr-checks:
uses: lukaszraczylo/shared-actions/.github/workflows/go-pr.yaml@main
with:
go-version: "1.24"
go-version: "1.24.11"
coverage-threshold: 70
secrets: inherit
+1 -1
View File
@@ -17,5 +17,5 @@ jobs:
release:
uses: lukaszraczylo/shared-actions/.github/workflows/go-release.yaml@main
with:
go-version: "1.24"
go-version: "1.24.11"
secrets: inherit
+2 -1
View File
@@ -1,2 +1,3 @@
docker/
.claude/
.claude/*.out
*.test
+1 -1
View File
@@ -7,7 +7,7 @@ builds:
# Create source archive for GitHub releases
archives:
- format: tar.gz
- formats: [tar.gz]
name_template: "{{ .ProjectName }}_v{{ .Version }}_source"
files:
- "*.go"
+118
View File
@@ -77,6 +77,7 @@ testData:
# Custom claim names for Auth0 and other providers with namespaced claims
roleClaimName: roles # JWT claim name for extracting user roles (default: "roles")
groupClaimName: groups # JWT claim name for extracting user groups (default: "groups")
userIdentifierClaim: email # JWT claim for user identification (default: "email", alternatives: "sub", "oid", "upn", "preferred_username")
# ⚠️ CRITICAL for TLS termination scenarios (AWS ALB, Cloud Load Balancers, etc.)
# When NOT specified in config: defaults to FALSE (Go zero value)
@@ -120,6 +121,8 @@ testData:
allowOpaqueTokens: false # Enable opaque (non-JWT) access token support via RFC 7662 introspection
requireTokenIntrospection: false # Force introspection for opaque tokens (requires introspection endpoint)
disableReplayDetection: false # Disable JTI replay detection for multi-replica deployments (default: false)
allowPrivateIPAddresses: false # Allow private IP addresses in provider URLs for internal networks (default: false)
minimalHeaders: false # Reduce forwarded headers to prevent 431 errors (default: false)
# Security Headers Configuration (enabled by default with 'default' profile)
securityHeaders:
@@ -266,6 +269,8 @@ testDataWithRedis:
# allowedRolesAndGroups: # Corresponds to 'Token Claim Name' in Keycloak mappers
# - admin
# - editor
# # For internal Keycloak deployments with private IPs (Docker/Kubernetes internal):
# # allowPrivateIPAddresses: true # Enable for private IP addresses like 192.168.x.x, 10.x.x.x
# # Ensure Keycloak client mappers add 'email', 'roles', 'groups' etc. to the ID Token.
# # See README.md "Provider Configuration Recommendations" for Keycloak.
@@ -287,6 +292,26 @@ testDataWithRedis:
# - "AppRoleName"
# # See README.md "Provider Configuration Recommendations" for Azure AD.
# --- Azure AD Users Without Email Example (Issue #95) ---
# testDataAzureADNoEmail:
# providerURL: https://login.microsoftonline.com/your-tenant-id/v2.0
# clientID: your-azure-ad-client-id
# clientSecret: your-azure-ad-client-secret
# callbackURL: /oauth2/callback
# sessionEncryptionKey: "a-very-secure-key-at-least-32-bytes-long-for-azure"
# # Use 'sub' claim instead of 'email' for user identification
# userIdentifierClaim: sub # or "oid", "upn", "preferred_username"
# overrideScopes: true # Remove email scope if not needed
# scopes:
# - openid
# - profile
# - groups # For group-based access control
# # When using non-email identifiers, allowedUsers matches against the claim value
# allowedUsers:
# - "abc12345-6789-0abc-def0-123456789abc" # Azure AD user object ID (sub or oid claim)
# # NOTE: allowedUserDomains is ignored when userIdentifierClaim is not "email"
# # See: https://github.com/lukaszraczylo/traefikoidc/issues/95
# --- Google Workspace / Google Cloud Identity Example ---
# testDataGoogle:
# providerURL: https://accounts.google.com # Standard Google OIDC endpoint
@@ -605,6 +630,38 @@ configuration:
items:
type: string
userIdentifierClaim:
type: string
description: |
Specifies the JWT claim to use as the user identifier for authentication and authorization.
This allows authentication for users without email addresses, such as Azure AD service
accounts or organizational accounts that don't have email attributes configured.
When set to a non-email claim (e.g., "sub", "oid", "upn"):
- AllowedUsers will match against this claim value instead of email
- AllowedUserDomains validation is skipped (domains only apply to email addresses)
- The session stores this identifier as the user's identity
- If the configured claim is missing, falls back to "sub" (required by OIDC spec)
Common values by provider:
- Default: "email" (standard email-based identification)
- Azure AD: "sub", "oid" (object ID), "upn" (User Principal Name), "preferred_username"
- Generic OIDC: "sub" (always present per OIDC specification)
- Keycloak: "sub", "preferred_username"
Example for Azure AD users without email:
```yaml
userIdentifierClaim: sub
allowedUsers:
- "abc123-user-object-id"
- "xyz789-another-user-id"
```
Default: "email"
See: https://github.com/lukaszraczylo/traefikoidc/issues/95
required: false
revocationURL:
type: string
description: |
@@ -903,6 +960,67 @@ configuration:
Default: false (replay detection enabled)
required: false
allowPrivateIPAddresses:
type: boolean
description: |
Allow private IP addresses in OIDC provider URLs for internal network deployments.
By default, the plugin blocks URLs containing private IP address ranges
(10.x.x.x, 172.16-31.x.x, 192.168.x.x) to prevent SSRF attacks and ensure
OIDC providers are publicly accessible.
Enable this option when:
- Your OIDC provider (e.g., Keycloak) runs on an internal network with private IPs
- You don't have DNS resolution available for internal services
- Your entire stack runs in a Docker network or Kubernetes cluster with private addressing
When enabled, the plugin will accept provider URLs like:
- https://192.168.1.100:8443/auth/realms/your-realm
- https://10.0.0.50:8080/realms/master
- https://172.16.0.10/auth
Security Warning:
Enabling this option reduces SSRF protection. Only use in trusted network
environments where the OIDC provider is known and controlled. Loopback
addresses (127.0.0.1, localhost, ::1) remain blocked even with this option enabled.
Default: false (private IPs are blocked for security)
See: https://github.com/lukaszraczylo/traefikoidc/issues/97
required: false
minimalHeaders:
type: boolean
description: |
Reduce forwarded headers to prevent "431 Request Header Fields Too Large" errors.
When enabled, the middleware only forwards the X-Forwarded-User header and skips
the larger authentication headers that can cause downstream services to reject
requests due to header size limits (typically 8KB).
Headers when disabled (default):
- X-Forwarded-User: User's email address (always set)
- X-Auth-Request-Redirect: Original request URI
- X-Auth-Request-User: User's email address
- X-Auth-Request-Token: Full ID token (can be very large with many claims)
- X-User-Groups: Comma-separated user groups (if configured)
- X-User-Roles: Comma-separated user roles (if configured)
Headers when enabled:
- X-Forwarded-User: User's email address (always set)
- X-User-Groups: Comma-separated user groups (if configured, still forwarded)
- X-User-Roles: Comma-separated user roles (if configured, still forwarded)
- Custom templated headers (still processed)
Use this option when:
- Downstream services return "431 Request Header Fields Too Large" errors
- Your ID tokens are large (many claims, long group lists)
- You don't need the full ID token forwarded to backend services
- You want to reduce request overhead
Default: false (all headers forwarded for backward compatibility)
See: https://github.com/lukaszraczylo/traefikoidc/issues/64
required: false
headers:
type: array
description: |
-286
View File
@@ -1,286 +0,0 @@
# CI/CD Setup Guide
## 📋 Overview
This repository now has a comprehensive CI/CD pipeline that runs **20+ parallel checks** on every pull request to ensure code quality, security, and reliability.
## 🎯 What Was Added
### GitHub Actions Workflow
- **`.github/workflows/pr-validation.yml`** - Main CI/CD pipeline (single file, all parallel)
### Configuration Files
- **`.golangci.yml`** - Linter configuration with 30+ enabled checks
- **`.github/dependabot.yml`** - Automated dependency updates
- **`.github/CODEOWNERS`** - Automatic PR reviewer assignment
- **`.github/PULL_REQUEST_TEMPLATE.md`** - Standardized PR descriptions
- **`.github/workflows/README.md`** - Detailed workflow documentation
- **`.github/workflows/.gitattributes`** - Consistent line endings
## ✅ What Gets Tested (All in Parallel)
### Code Quality (3 checks)
- **Format & Basic Checks** - gofmt, go vet, go mod
- **golangci-lint** - 30+ linters including style, complexity, bugs
- **Staticcheck** - Advanced static analysis
### Security (3 checks)
- **Gosec** - Security vulnerability scanning with SARIF reports
- **Govulncheck** - Go vulnerability database scanning
- **CodeQL** - GitHub's semantic code analysis
### Testing (9 test suites)
- **Race Detector** - Concurrent access bugs
- **Coverage** - 75% threshold with PR comments
- **Memory Leaks** - Goroutine and memory leak detection
- **Integration Tests** - Full integration suite
- **Regression Tests** - Prevent old bugs from returning
- **Security Edge Cases** - Security-specific scenarios
- **Session Tests** - Session management
- **Token Tests** - Token validation
- **CSRF Tests** - CSRF protection
### Provider Testing (9 providers in parallel)
- Google, Azure AD, Auth0, Okta, Keycloak, AWS Cognito, GitLab, GitHub, Generic
### Performance & Build (3 checks)
- **Benchmarks** - Performance regression detection
- **Multi-platform Build** - 4 combinations (linux/darwin × amd64/arm64)
- **Go Version Compatibility** - Go 1.23 & 1.24
## 🚀 Quick Start
### 1. Push to GitHub
```bash
git add .github .golangci.yml CI_SETUP.md
git commit -m "Add comprehensive CI/CD pipeline"
git push origin main
```
### 2. Create a Test PR
```bash
# Create a feature branch
git checkout -b feature/test-ci
echo "# Test" >> test.md
git add test.md
git commit -m "Test CI pipeline"
git push origin feature/test-ci
# Create PR on GitHub
# Watch all 20+ checks run in parallel! ⚡
```
### 3. Monitor Results
- Go to Actions tab: `https://github.com/{owner}/{repo}/actions`
- Click on latest workflow run
- See all parallel checks in action
- Review coverage comment on PR
## 📊 Key Features
### ⚡ Maximum Speed
- **Parallel execution** - All checks run simultaneously
- **Smart caching** - Go modules and build cache
- **Optimized order** - Quick checks first for fast feedback
- **Expected runtime**: 5-10 minutes for full suite
### 🔒 Security First
- **3 security scanners** - gosec, govulncheck, CodeQL
- **SARIF integration** - Results in GitHub Security tab
- **Dependency scanning** - Automated with Dependabot
- **Security edge case tests**
### 📈 Coverage Tracking
- **Automatic PR comments** with coverage stats
- **Per-package breakdown** included
- **75% threshold** enforced (configurable)
- **Codecov integration** ready (optional)
### 🎨 Developer Experience
- **Clear PR template** guides contributors
- **Auto code owners** assignment
- **Detailed error messages** for failures
- **Benchmark tracking** for performance
## 🛠️ Local Development
### Install Required Tools
```bash
# golangci-lint (comprehensive linting)
go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
# staticcheck (static analysis)
go install honnef.co/go/tools/cmd/staticcheck@latest
# gosec (security scanning)
go install github.com/securego/gosec/v2/cmd/gosec@latest
# govulncheck (vulnerability scanning)
go install golang.org/x/vuln/cmd/govulncheck@latest
```
### Run Checks Locally
```bash
# Quick validation (before committing)
gofmt -s -w . # Format code
go vet ./... # Basic checks
go mod tidy # Clean dependencies
# Linting
golangci-lint run # Full lint suite
staticcheck ./... # Static analysis
# Testing
go test -race -timeout=15m ./... # Tests with race detector
go test -coverprofile=coverage.out ./... # Coverage
go tool cover -func=coverage.out # View coverage
# Security
gosec ./... # Security scan
govulncheck ./... # Vulnerability check
# Benchmarks
go test -bench=. -benchmem ./... # Performance tests
```
### Pre-commit Checklist
```bash
# Run this before every commit
gofmt -s -w . && \
go mod tidy && \
golangci-lint run && \
go test -race -short ./... && \
echo "✅ Ready to commit!"
```
## 📝 Configuration
### Adjust Coverage Threshold
Edit `.github/workflows/pr-validation.yml`:
```yaml
THRESHOLD=75 # Change to desired percentage
```
### Modify Linter Rules
Edit `.golangci.yml`:
```yaml
linters:
enable:
- newlinter # Add new linters here
```
### Update Go Version
Edit `.github/workflows/pr-validation.yml`:
```yaml
go-version: '1.24' # Update version
```
## 🐛 Troubleshooting
### Coverage Below Threshold
```bash
# See uncovered lines in browser
go test -coverprofile=coverage.out ./...
go tool cover -html=coverage.out
```
### Race Condition Found
```bash
# Run specific test with race detector
go test -race -v -run=TestName ./...
```
### Linter Errors
```bash
# See detailed lint errors
golangci-lint run -v
# Auto-fix some issues
golangci-lint run --fix
```
### Provider Test Fails
```bash
# Test specific provider
go test -v -run='.*Azure.*' ./internal/providers/
```
## 📈 Metrics & Monitoring
### GitHub Actions Dashboard
- View all runs: `Actions` tab
- Filter by workflow, branch, status
- Download logs and artifacts
### Status Badge
Add to README.md:
```markdown
[![PR Validation](https://github.com/lukaszraczylo/traefikoidc/actions/workflows/pr-validation.yml/badge.svg)](https://github.com/lukaszraczylo/traefikoidc/actions/workflows/pr-validation.yml)
```
### Notifications
- Configure in: Settings → Notifications
- Email alerts for workflow failures
- Slack/Discord webhooks supported
## 🔄 Continuous Improvement
### Dependabot Updates
- Automatic weekly dependency checks (Mondays 9 AM)
- Security updates prioritized
- Groups patch updates together
### Code Owners
- Auto-assigns reviewers based on file paths
- Ensures expertise reviews changes
- Speeds up PR review process
## 📚 Additional Resources
- [Workflow Documentation](.github/workflows/README.md)
- [golangci-lint Rules](.golangci.yml)
- [PR Template](.github/PULL_REQUEST_TEMPLATE.md)
- [Dependabot Config](.github/dependabot.yml)
## 🎉 Benefits
### For Contributors
- Clear expectations via PR template
- Fast feedback (5-10 min)
- Comprehensive local tooling
- Detailed error messages
### For Maintainers
- Automated code review
- Security scanning
- Performance tracking
- Quality gates enforcement
### For Users
- Higher code quality
- Fewer bugs in production
- Better security
- Consistent performance
## 🚦 Success Criteria
All PRs must pass:
- ✅ All 20+ parallel checks
- ✅ 75% test coverage minimum
- ✅ Zero security vulnerabilities
- ✅ No race conditions
- ✅ No memory leaks
- ✅ All providers tested
- ✅ Builds on all platforms
## 💡 Tips
1. **Run checks locally** before pushing to save CI time
2. **Watch for PR comments** - coverage stats posted automatically
3. **Check Security tab** for gosec/CodeQL findings
4. **Review benchmark results** in artifacts
5. **Use draft PRs** for work-in-progress to skip some checks
---
**Ready to go!** 🚀 Push your changes and create a PR to see it in action.
+84 -2
View File
@@ -124,6 +124,7 @@ The middleware supports the following configuration options:
| `allowedRolesAndGroups` | Restricts access to users with specific roles or groups | none | `["admin", "developer"]` |
| `roleClaimName` | JWT claim name for extracting user roles (supports namespaced claims for Auth0) | `"roles"` | `"https://myapp.com/roles"`, `"user_roles"` |
| `groupClaimName` | JWT claim name for extracting user groups (supports namespaced claims for Auth0) | `"groups"` | `"https://myapp.com/groups"`, `"user_groups"` |
| `userIdentifierClaim` | JWT claim to use as user identifier (for users without email, e.g., Azure AD service accounts) | `"email"` | `"sub"`, `"oid"`, `"upn"`, `"preferred_username"` |
| `revocationURL` | The endpoint for revoking tokens | auto-discovered | `https://accounts.google.com/revoke` |
| `oidcEndSessionURL` | The provider's end session endpoint | auto-discovered | `https://accounts.google.com/logout` |
| `enablePKCE` | Enables PKCE (Proof Key for Code Exchange) for authorization code flow | `false` | `true`, `false` |
@@ -138,6 +139,8 @@ The middleware supports the following configuration options:
| `headers` | Custom HTTP headers with templates that can access OIDC claims and tokens | none | See "Templated Headers" section |
| `securityHeaders` | Configure security headers including CSP, HSTS, CORS, and custom headers | enabled with default profile | See "Security Headers Configuration" section |
| `disableReplayDetection` | Disable JTI-based replay attack detection for multi-replica deployments | `false` | `true` |
| `allowPrivateIPAddresses` | Allow private IP addresses in provider URLs (for internal networks with Keycloak, etc.) | `false` | `true` |
| `minimalHeaders` | Reduce forwarded headers to prevent "431 Request Header Fields Too Large" errors | `false` | `true` |
| `redis` | Redis cache configuration for distributed deployments | disabled | See "Redis Cache" section |
> **⚠️ IMPORTANT - TLS Termination at Load Balancer:**
@@ -1241,6 +1244,45 @@ spec:
- "AppRoleName" # Application role names
```
### Azure AD Configuration (Users Without Email)
For Azure AD users without email addresses (service accounts, organizational accounts without mail attributes):
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-azure-no-email
namespace: traefik
spec:
plugin:
traefikoidc:
providerURL: https://login.microsoftonline.com/your-tenant-id/v2.0
clientID: your-azure-ad-client-id
clientSecret: your-azure-ad-client-secret
sessionEncryptionKey: your-secure-encryption-key-min-32-chars
callbackURL: /oauth2/callback
logoutURL: /oauth2/logout
# Use 'sub' instead of 'email' for user identification
userIdentifierClaim: sub # Can also use: "oid", "upn", "preferred_username"
overrideScopes: true # Optional: Don't request email scope if not needed
scopes:
- openid
- profile
- groups
# When using non-email identifiers, allowedUsers matches against the claim value
allowedUsers:
- "abc12345-6789-0abc-def0-123456789abc" # Azure AD user object ID
- "def67890-1234-5678-90ab-cdef12345678"
# NOTE: allowedUserDomains is ignored when userIdentifierClaim is not "email"
```
> **Note**: When `userIdentifierClaim` is set to a non-email claim (like `sub`, `oid`, or `upn`), the `allowedUserDomains` configuration is ignored since domain-based validation only applies to email addresses. Use `allowedUsers` with the actual claim values instead.
### Auth0 Configuration
```yaml
@@ -1327,8 +1369,12 @@ spec:
- admin
- editor
# Ensure Keycloak client mappers add necessary claims to ID Token
# For internal Keycloak deployments with private IPs (e.g., Docker network):
# allowPrivateIPAddresses: true
```
> **Internal Network Deployment**: If your Keycloak runs on an internal network with private IP addresses (e.g., `192.168.x.x`, `10.x.x.x`, `172.16-31.x.x`) and you don't have DNS resolution available, set `allowPrivateIPAddresses: true` to allow the plugin to connect to your Keycloak instance. See [Issue #97](https://github.com/lukaszraczylo/traefikoidc/issues/97) for details.
### AWS Cognito Configuration
```yaml
@@ -1629,12 +1675,39 @@ headers:
When a user is authenticated, the middleware sets the following headers for downstream services:
- `X-Forwarded-User`: The user's email address
- `X-Forwarded-User`: The user's email address (always set)
- `X-User-Groups`: Comma-separated list of user groups (if available)
- `X-User-Roles`: Comma-separated list of user roles (if available)
- `X-Auth-Request-Redirect`: The original request URI
- `X-Auth-Request-User`: The user's email address
- `X-Auth-Request-Token`: The user's access token
- `X-Auth-Request-Token`: The user's ID token (can be large)
#### Minimal Headers Mode
If your downstream services return **"431 Request Header Fields Too Large"** errors, you can enable minimal headers mode to reduce header overhead:
```yaml
http:
middlewares:
my-auth:
plugin:
traefikoidc:
minimalHeaders: true
# ... other config
```
When `minimalHeaders: true` is set:
- **Only forwards**: `X-Forwarded-User`
- **Skips**: `X-Auth-Request-Token` (the full ID token - often the largest header), `X-Auth-Request-User`, `X-Auth-Request-Redirect`
- **Still forwards**: `X-User-Groups` and `X-User-Roles` (if configured)
- **Still processes**: Custom templated headers
This is particularly useful when:
- Your ID tokens are large (many claims, long group lists)
- Downstream services have limited header buffer sizes (default 8KB in many servers)
- You don't need the full token forwarded to backend services
See [GitHub Issue #64](https://github.com/lukaszraczylo/traefikoidc/issues/64) for details.
### Security Headers
@@ -1862,6 +1935,15 @@ logLevel: debug
- No refresh tokens (re-authentication required on expiry)
- Use only for GitHub API access, not user authentication
15. **Environment variable names containing "API" cause plugin failure** ([Issue #98](https://github.com/lukaszraczylo/traefikoidc/issues/98)):
- When using environment variable syntax like `${OIDC_ENCRYPTION_SECRET_API}` in Traefik configuration, the plugin fails with "invalid handler type: \<nil\>" error
- This is a **Traefik-side issue**, not a plugin bug. Traefik uses reserved environment variables starting with `TRAEFIK_API_*` for its internal API configuration, and the "API" substring in user-defined variable names may interfere with Traefik's environment variable processing
- **Workaround**: Avoid using "API" as a substring in environment variable names. Use alternatives like:
- `${OIDC_ENCRYPTION_SECRET_SVC}` instead of `${OIDC_ENCRYPTION_SECRET_API}`
- `${OIDC_ENCRYPTION_SECRET_SERVICE}`
- `${OIDC_ENCRYPTION_SECRET_BACKEND}`
- Any name that doesn't contain the literal substring "API"
### Provider Warnings and Recommendations
The middleware includes built-in warnings for provider-specific limitations. Check your logs for important notices about:
+1375
View File
File diff suppressed because it is too large Load Diff
-931
View File
@@ -1,931 +0,0 @@
package traefikoidc
import (
"crypto/rand"
"crypto/rsa"
"encoding/base64"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"golang.org/x/time/rate"
)
// TestConfigAudienceValidation tests the Config.Validate() method for the audience field
func TestConfigAudienceValidation(t *testing.T) {
tests := []struct {
name string
audience string
wantErr bool
errContains string
}{
{
name: "Empty audience is valid for backward compatibility",
audience: "",
wantErr: false,
},
{
name: "Valid HTTPS URL audience Auth0 format",
audience: "https://api.example.com",
wantErr: false,
},
{
name: "Valid identifier audience",
audience: "my-api",
wantErr: false,
},
{
name: "Valid Azure AD Application ID URI format",
audience: "api://12345-guid-67890",
wantErr: false,
},
{
name: "Valid Auth0 API identifier",
audience: "https://my-company.auth0.com/api/v2/",
wantErr: false,
},
{
name: "HTTP URL audience should fail",
audience: "http://api.example.com",
wantErr: true,
errContains: "must use HTTPS",
},
{
name: "Audience with wildcard should fail",
audience: "https://api.*.example.com",
wantErr: true,
errContains: "must not contain wildcards",
},
{
name: "Audience with single asterisk should fail",
audience: "*",
wantErr: true,
errContains: "must not contain wildcards",
},
{
name: "Audience over 256 characters should fail",
audience: strings.Repeat("a", 257),
wantErr: true,
errContains: "must not exceed 256 characters",
},
{
name: "Audience with newline should fail",
audience: "my-api\ninjection",
wantErr: true,
errContains: "contains invalid characters",
},
{
name: "Audience with carriage return should fail",
audience: "my-api\rinjection",
wantErr: true,
errContains: "contains invalid characters",
},
{
name: "Audience with tab should fail",
audience: "my-api\tinjection",
wantErr: true,
errContains: "contains invalid characters",
},
{
name: "Valid audience exactly 256 characters",
audience: strings.Repeat("a", 256),
wantErr: false,
},
{
name: "Valid simple identifier",
audience: "my-service-api",
wantErr: false,
},
{
name: "Valid URN format",
audience: "urn:myservice:api:v1",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
config := CreateConfig()
config.ProviderURL = "https://provider.example.com"
config.ClientID = "test-client-id"
config.ClientSecret = "test-client-secret"
config.CallbackURL = "/callback"
config.SessionEncryptionKey = strings.Repeat("a", MinSessionEncryptionKeyLength)
config.Audience = tt.audience
err := config.Validate()
if (err != nil) != tt.wantErr {
t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr)
return
}
if err != nil && tt.errContains != "" && !strings.Contains(err.Error(), tt.errContains) {
t.Errorf("Error message should contain %q, got: %v", tt.errContains, err)
}
})
}
}
// TestJWTAudienceVerification tests JWT verification with custom audience values
func TestJWTAudienceVerification(t *testing.T) {
// Create cleanup helper
tc := newTestCleanup(t)
// Generate RSA key for signing JWTs
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("Failed to generate RSA key: %v", err)
}
rsaPublicKey := &rsaPrivateKey.PublicKey
// Create JWK
jwk := JWK{
Kty: "RSA",
Kid: "test-key-id",
Alg: "RS256",
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}),
}
jwks := &JWKSet{
Keys: []JWK{jwk},
}
mockJWKCache := &MockJWKCache{
JWKS: jwks,
Err: nil,
}
logger := NewLogger("debug")
tokenBlacklist := tc.addCache(NewCache())
tokenCache := tc.addTokenCache(NewTokenCache())
tests := []struct {
name string
configAudience string
tokenAudience interface{}
wantErr bool
errContains string
skipReplayCheck bool
}{
{
name: "JWT with string aud matching configured audience",
configAudience: "https://api.example.com",
tokenAudience: "https://api.example.com",
wantErr: false,
skipReplayCheck: true,
},
{
name: "JWT with array aud containing configured audience",
configAudience: "https://api.example.com",
tokenAudience: []interface{}{"https://other.com", "https://api.example.com", "https://another.com"},
wantErr: false,
skipReplayCheck: true,
},
{
name: "JWT with string aud NOT matching configured audience",
configAudience: "https://api.example.com",
tokenAudience: "https://wrong-api.example.com",
wantErr: true,
errContains: "invalid audience",
skipReplayCheck: true,
},
{
name: "JWT with array aud NOT containing configured audience",
configAudience: "https://api.example.com",
tokenAudience: []interface{}{"https://other.com", "https://another.com"},
wantErr: true,
errContains: "invalid audience",
skipReplayCheck: true,
},
{
name: "JWT with clientID as aud when no custom audience configured",
configAudience: "",
tokenAudience: "test-client-id",
wantErr: false,
skipReplayCheck: true,
},
{
name: "JWT with empty string aud",
configAudience: "https://api.example.com",
tokenAudience: "",
wantErr: true,
errContains: "invalid audience",
skipReplayCheck: true,
},
{
name: "Azure AD Application ID URI format",
configAudience: "api://12345-app-id",
tokenAudience: "api://12345-app-id",
wantErr: false,
skipReplayCheck: true,
},
{
name: "Auth0 custom API audience",
configAudience: "https://mycompany.com/api",
tokenAudience: "https://mycompany.com/api",
wantErr: false,
skipReplayCheck: true,
},
{
name: "Token confusion attack - audience for different service",
configAudience: "https://service-a.example.com",
tokenAudience: "https://service-b.example.com",
wantErr: true,
errContains: "invalid audience",
skipReplayCheck: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create TraefikOidc instance
tOidc := &TraefikOidc{
issuerURL: "https://test-issuer.com",
clientID: "test-client-id",
clientSecret: "test-client-secret",
jwkCache: mockJWKCache,
jwksURL: "https://test-jwks-url.com",
tokenBlacklist: tokenBlacklist,
tokenCache: tokenCache,
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
logger: logger,
httpClient: &http.Client{},
}
// Set up the token verifier and JWT verifier
tOidc.jwtVerifier = tOidc
tOidc.tokenVerifier = tOidc
// Determine the expected audience for validation
expectedAudience := tt.configAudience
if expectedAudience == "" {
expectedAudience = tOidc.clientID
}
// Set the audience field on the tOidc instance
tOidc.audience = expectedAudience
// Create JWT with specified audience
jti := generateRandomString(16)
if tt.skipReplayCheck {
// Use a unique JTI for each test to avoid replay detection
jti = fmt.Sprintf("test-%s-%s", tt.name, jti)
}
jwt, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": tt.tokenAudience,
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
"sub": "test-subject",
"email": "user@example.com",
"jti": jti,
})
if err != nil {
t.Fatalf("Failed to create test JWT: %v", err)
}
// Verify the token
err = tOidc.VerifyToken(jwt)
if (err != nil) != tt.wantErr {
t.Errorf("VerifyToken() error = %v, wantErr %v", err, tt.wantErr)
return
}
if err != nil && tt.errContains != "" && !strings.Contains(err.Error(), tt.errContains) {
t.Errorf("Error message should contain %q, got: %v", tt.errContains, err)
}
})
}
}
// TestJWTAudienceBackwardCompatibility tests that existing behavior is preserved
// when the Audience field is not set
func TestJWTAudienceBackwardCompatibility(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
// Test with no custom audience configured - should use clientID
jwt, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id", // Should match clientID
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
"sub": "test-subject",
"email": "user@example.com",
"jti": generateRandomString(16),
})
if err != nil {
t.Fatalf("Failed to create test JWT: %v", err)
}
err = ts.tOidc.VerifyToken(jwt)
if err != nil {
t.Errorf("Backward compatibility broken: VerifyToken() error = %v, expected nil", err)
}
}
// TestAudienceIntegrationAuth0Scenario tests Auth0-specific use case
func TestAudienceIntegrationAuth0Scenario(t *testing.T) {
// Create cleanup helper
tc := newTestCleanup(t)
// Simulate Auth0 scenario: custom audience for API access
config := CreateConfig()
config.ProviderURL = "https://mycompany.auth0.com"
config.ClientID = "auth0-client-id"
config.ClientSecret = "auth0-client-secret"
config.CallbackURL = "/callback"
config.SessionEncryptionKey = strings.Repeat("a", MinSessionEncryptionKeyLength)
config.Audience = "https://api.mycompany.com" // Custom API audience
// Validate config
if err := config.Validate(); err != nil {
t.Fatalf("Auth0 config validation failed: %v", err)
}
// Generate test keys
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("Failed to generate RSA key: %v", err)
}
rsaPublicKey := &rsaPrivateKey.PublicKey
jwk := JWK{
Kty: "RSA",
Kid: "auth0-key-id",
Alg: "RS256",
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}),
}
jwks := &JWKSet{
Keys: []JWK{jwk},
}
mockJWKCache := &MockJWKCache{
JWKS: jwks,
Err: nil,
}
logger := NewLogger("debug")
tokenBlacklist := tc.addCache(NewCache())
tokenCache := tc.addTokenCache(NewTokenCache())
tOidc := &TraefikOidc{
issuerURL: config.ProviderURL,
clientID: config.ClientID,
clientSecret: config.ClientSecret,
audience: config.Audience, // Set audience from config
jwkCache: mockJWKCache,
jwksURL: "https://mycompany.auth0.com/.well-known/jwks.json",
tokenBlacklist: tokenBlacklist,
tokenCache: tokenCache,
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
logger: logger,
httpClient: &http.Client{},
}
// Default audience to clientID if not specified
if tOidc.audience == "" {
tOidc.audience = tOidc.clientID
}
tOidc.jwtVerifier = tOidc
tOidc.tokenVerifier = tOidc
t.Run("Valid Auth0 API access token with custom audience", func(t *testing.T) {
jwt, err := createTestJWT(rsaPrivateKey, "RS256", "auth0-key-id", map[string]interface{}{
"iss": config.ProviderURL,
"aud": config.Audience, // Matches configured audience
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
"sub": "auth0|123456",
"email": "user@example.com",
"jti": generateRandomString(16),
})
if err != nil {
t.Fatalf("Failed to create Auth0 JWT: %v", err)
}
err = tOidc.VerifyToken(jwt)
if err != nil {
t.Errorf("Auth0 token verification failed: %v", err)
}
})
t.Run("Auth0 ACCESS token with clientID instead of API audience should fail", func(t *testing.T) {
jwt, err := createTestJWT(rsaPrivateKey, "RS256", "auth0-key-id", map[string]interface{}{
"iss": config.ProviderURL,
"aud": config.ClientID, // Using clientID instead of API audience
"scope": "openid profile email", // Mark as access token
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
"sub": "auth0|123456",
"email": "user@example.com",
"jti": generateRandomString(16),
})
if err != nil {
t.Fatalf("Failed to create Auth0 JWT: %v", err)
}
err = tOidc.VerifyToken(jwt)
if err == nil {
t.Error("Auth0 access token with wrong audience should have been rejected")
} else if !strings.Contains(err.Error(), "invalid audience") {
t.Errorf("Expected 'invalid audience' error, got: %v", err)
}
})
}
// TestAudienceIntegrationAzureADScenario tests Azure AD-specific use case
func TestAudienceIntegrationAzureADScenario(t *testing.T) {
// Create cleanup helper
tc := newTestCleanup(t)
// Simulate Azure AD scenario: Application ID URI format
config := CreateConfig()
config.ProviderURL = "https://login.microsoftonline.com/tenant-id/v2.0"
config.ClientID = "azure-client-id"
config.ClientSecret = "azure-client-secret"
config.CallbackURL = "/callback"
config.SessionEncryptionKey = strings.Repeat("a", MinSessionEncryptionKeyLength)
config.Audience = "api://12345-abcd-6789-efgh" // Azure AD Application ID URI
// Validate config
if err := config.Validate(); err != nil {
t.Fatalf("Azure AD config validation failed: %v", err)
}
// Generate test keys
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("Failed to generate RSA key: %v", err)
}
rsaPublicKey := &rsaPrivateKey.PublicKey
jwk := JWK{
Kty: "RSA",
Kid: "azure-key-id",
Alg: "RS256",
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}),
}
jwks := &JWKSet{
Keys: []JWK{jwk},
}
mockJWKCache := &MockJWKCache{
JWKS: jwks,
Err: nil,
}
logger := NewLogger("debug")
tokenBlacklist := tc.addCache(NewCache())
tokenCache := tc.addTokenCache(NewTokenCache())
tOidc := &TraefikOidc{
issuerURL: config.ProviderURL,
clientID: config.ClientID,
clientSecret: config.ClientSecret,
audience: config.Audience, // Set audience from config
jwkCache: mockJWKCache,
jwksURL: config.ProviderURL + "/.well-known/jwks.json",
tokenBlacklist: tokenBlacklist,
tokenCache: tokenCache,
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
logger: logger,
httpClient: &http.Client{},
}
// Default audience to clientID if not specified
if tOidc.audience == "" {
tOidc.audience = tOidc.clientID
}
tOidc.jwtVerifier = tOidc
tOidc.tokenVerifier = tOidc
t.Run("Valid Azure AD token with Application ID URI audience", func(t *testing.T) {
jwt, err := createTestJWT(rsaPrivateKey, "RS256", "azure-key-id", map[string]interface{}{
"iss": config.ProviderURL,
"aud": config.Audience, // Matches Application ID URI
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
"sub": "azure-user-id",
"email": "user@example.com",
"oid": "object-id-12345",
"tid": "tenant-id",
"jti": generateRandomString(16),
})
if err != nil {
t.Fatalf("Failed to create Azure AD JWT: %v", err)
}
err = tOidc.VerifyToken(jwt)
if err != nil {
t.Errorf("Azure AD token verification failed: %v", err)
}
})
t.Run("Azure AD token with multiple audiences including correct one", func(t *testing.T) {
jwt, err := createTestJWT(rsaPrivateKey, "RS256", "azure-key-id", map[string]interface{}{
"iss": config.ProviderURL,
"aud": []interface{}{config.ClientID, config.Audience, "https://graph.microsoft.com"},
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
"sub": "azure-user-id",
"email": "user@example.com",
"oid": "object-id-12345",
"tid": "tenant-id",
"jti": generateRandomString(16),
})
if err != nil {
t.Fatalf("Failed to create Azure AD JWT: %v", err)
}
err = tOidc.VerifyToken(jwt)
if err != nil {
t.Errorf("Azure AD token with multiple audiences verification failed: %v", err)
}
})
}
// TestAudienceSecurityTokenConfusionAttack tests security against token confusion attacks
func TestAudienceSecurityTokenConfusionAttack(t *testing.T) {
// Create cleanup helper
tc := newTestCleanup(t)
// Generate test keys
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("Failed to generate RSA key: %v", err)
}
rsaPublicKey := &rsaPrivateKey.PublicKey
jwk := JWK{
Kty: "RSA",
Kid: "test-key-id",
Alg: "RS256",
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}),
}
jwks := &JWKSet{
Keys: []JWK{jwk},
}
mockJWKCache := &MockJWKCache{
JWKS: jwks,
Err: nil,
}
logger := NewLogger("debug")
tokenBlacklist := tc.addCache(NewCache())
tokenCache := tc.addTokenCache(NewTokenCache())
// Service A configuration
serviceA := &TraefikOidc{
issuerURL: "https://auth.example.com",
clientID: "service-a-client-id",
clientSecret: "service-a-secret",
audience: "service-a-client-id", // Service A uses its clientID as audience
jwkCache: mockJWKCache,
jwksURL: "https://auth.example.com/.well-known/jwks.json",
tokenBlacklist: tokenBlacklist,
tokenCache: tokenCache,
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
logger: logger,
httpClient: &http.Client{},
}
serviceA.jwtVerifier = serviceA
serviceA.tokenVerifier = serviceA
t.Run("Token confusion - Try to use service B token on service A", func(t *testing.T) {
// Create a token intended for service B
serviceBToken, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://auth.example.com",
"aud": "https://service-b.example.com", // For service B
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
"sub": "attacker@example.com",
"email": "attacker@example.com",
"jti": generateRandomString(16),
})
if err != nil {
t.Fatalf("Failed to create service B token: %v", err)
}
// Try to verify the service B token on service A
err = serviceA.VerifyToken(serviceBToken)
switch {
case err == nil:
t.Error("SECURITY VULNERABILITY: Token confusion attack succeeded - service B token was accepted by service A")
case !strings.Contains(err.Error(), "invalid audience"):
t.Errorf("Expected 'invalid audience' error for token confusion, got: %v", err)
default:
t.Logf("Token confusion attack correctly prevented: %v", err)
}
})
}
// TestAudienceSecurityWildcardInjection tests that wildcards are rejected
func TestAudienceSecurityWildcardInjection(t *testing.T) {
tests := []struct {
name string
audience string
}{
{
name: "Single asterisk",
audience: "*",
},
{
name: "Wildcard in URL",
audience: "https://*.example.com",
},
{
name: "Wildcard in path",
audience: "https://api.example.com/*",
},
{
name: "Multiple wildcards",
audience: "https://*.*.example.com",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
config := CreateConfig()
config.ProviderURL = "https://provider.example.com"
config.ClientID = "test-client-id"
config.ClientSecret = "test-client-secret"
config.CallbackURL = "/callback"
config.SessionEncryptionKey = strings.Repeat("a", MinSessionEncryptionKeyLength)
config.Audience = tt.audience
err := config.Validate()
if err == nil {
t.Errorf("SECURITY VULNERABILITY: Wildcard audience %q was not rejected", tt.audience)
} else if !strings.Contains(err.Error(), "must not contain wildcards") {
t.Errorf("Expected 'must not contain wildcards' error, got: %v", err)
}
})
}
}
// TestAudienceSecurityInjectionAttempts tests various injection attempts
func TestAudienceSecurityInjectionAttempts(t *testing.T) {
tests := []struct {
name string
audience string
errContains string
}{
{
name: "Newline injection",
audience: "api.example.com\nmalicious.com",
errContains: "contains invalid characters",
},
{
name: "Carriage return injection",
audience: "api.example.com\rmalicious.com",
errContains: "contains invalid characters",
},
{
name: "Tab injection",
audience: "api.example.com\tmalicious.com",
errContains: "contains invalid characters",
},
{
name: "Null byte injection",
audience: "api.example.com\x00malicious.com",
errContains: "contains invalid characters",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
config := CreateConfig()
config.ProviderURL = "https://provider.example.com"
config.ClientID = "test-client-id"
config.ClientSecret = "test-client-secret"
config.CallbackURL = "/callback"
config.SessionEncryptionKey = strings.Repeat("a", MinSessionEncryptionKeyLength)
config.Audience = tt.audience
err := config.Validate()
if err == nil {
t.Errorf("SECURITY VULNERABILITY: Injection attempt with %q was not rejected", tt.name)
} else if !strings.Contains(err.Error(), tt.errContains) {
t.Errorf("Expected error containing %q, got: %v", tt.errContains, err)
}
})
}
}
// TestAudienceWithReplayProtection tests that replay protection works correctly with custom audiences
func TestAudienceWithReplayProtection(t *testing.T) {
// Create cleanup helper
tc := newTestCleanup(t)
// Generate test keys
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("Failed to generate RSA key: %v", err)
}
rsaPublicKey := &rsaPrivateKey.PublicKey
jwk := JWK{
Kty: "RSA",
Kid: "test-key-id",
Alg: "RS256",
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}),
}
jwks := &JWKSet{
Keys: []JWK{jwk},
}
mockJWKCache := &MockJWKCache{
JWKS: jwks,
Err: nil,
}
logger := NewLogger("debug")
tokenBlacklist := tc.addCache(NewCache())
tokenCache := tc.addTokenCache(NewTokenCache())
tOidc := &TraefikOidc{
issuerURL: "https://auth.example.com",
clientID: "test-client-id",
clientSecret: "test-client-secret",
jwkCache: mockJWKCache,
jwksURL: "https://auth.example.com/.well-known/jwks.json",
tokenBlacklist: tokenBlacklist,
tokenCache: tokenCache,
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
logger: logger,
httpClient: &http.Client{},
}
tOidc.jwtVerifier = tOidc
tOidc.tokenVerifier = tOidc
// Create a token with custom audience and fixed JTI
fixedJTI := "replay-test-jti-" + generateRandomString(8)
customAudience := "https://api.example.com"
// Set the audience field to match what we expect
tOidc.audience = customAudience
jwt, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://auth.example.com",
"aud": customAudience,
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
"sub": "test-user",
"email": "user@example.com",
"jti": fixedJTI,
})
if err != nil {
t.Fatalf("Failed to create JWT: %v", err)
}
// First verification should succeed
err = tOidc.VerifyToken(jwt)
if err != nil {
t.Fatalf("First verification failed: %v", err)
}
// Verify that the JTI was blacklisted
if blacklisted, exists := tOidc.tokenBlacklist.Get(fixedJTI); !exists || blacklisted == nil {
t.Logf("Note: JTI was not added to blacklist (may be due to test token prefix)")
} else {
t.Logf("Replay protection verified: JTI %s is correctly blacklisted", fixedJTI)
}
}
// TestAudienceEndToEndScenario tests a complete end-to-end scenario with middleware
func TestAudienceEndToEndScenario(t *testing.T) {
// Create cleanup helper
tc := newTestCleanup(t)
// Create a test next handler
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("Authenticated with custom audience"))
})
// Generate test keys
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("Failed to generate RSA key: %v", err)
}
rsaPublicKey := &rsaPrivateKey.PublicKey
jwk := JWK{
Kty: "RSA",
Kid: "test-key-id",
Alg: "RS256",
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}),
}
jwks := &JWKSet{
Keys: []JWK{jwk},
}
mockJWKCache := &MockJWKCache{
JWKS: jwks,
Err: nil,
}
logger := NewLogger("debug")
sm, err := NewSessionManager(strings.Repeat("a", MinSessionEncryptionKeyLength), false, "", "", 0, logger)
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
tokenBlacklist := tc.addCache(NewCache())
tokenCache := tc.addTokenCache(NewTokenCache())
customAudience := "https://api.company.com"
tOidc := &TraefikOidc{
next: nextHandler,
name: "test",
redirURLPath: "/callback",
logoutURLPath: "/callback/logout",
issuerURL: "https://auth.company.com",
clientID: "test-client-id",
clientSecret: "test-client-secret",
audience: customAudience, // Set custom audience
jwkCache: mockJWKCache,
jwksURL: "https://auth.company.com/.well-known/jwks.json",
tokenBlacklist: tokenBlacklist,
tokenCache: tokenCache,
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
logger: logger,
allowedUserDomains: map[string]struct{}{"company.com": {}},
excludedURLs: map[string]struct{}{},
httpClient: &http.Client{},
initComplete: make(chan struct{}),
sessionManager: sm,
extractClaimsFunc: extractClaims,
}
tOidc.jwtVerifier = tOidc
tOidc.tokenVerifier = tOidc
close(tOidc.initComplete)
t.Run("End-to-end with correct custom audience", func(t *testing.T) {
// Create a valid token with the custom audience
validJWT, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://auth.company.com",
"aud": customAudience,
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
"sub": "user-123",
"email": "user@company.com",
"jti": generateRandomString(16),
})
if err != nil {
t.Fatalf("Failed to create valid JWT: %v", err)
}
// Create a request with authenticated session
req := httptest.NewRequest("GET", "/protected", nil)
req.Header.Set("X-Forwarded-Proto", "https")
req.Header.Set("X-Forwarded-Host", "company.com")
// Create session with token
resp := httptest.NewRecorder()
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
if err := session.SetAuthenticated(true); err != nil {
t.Fatalf("Failed to set authenticated: %v", err)
}
session.SetEmail("user@company.com")
session.SetIDToken(validJWT)
session.SetAccessToken(validJWT)
if err := session.Save(req, resp); err != nil {
t.Fatalf("Failed to save session: %v", err)
}
// Get cookies and add them to a new request
cookies := resp.Result().Cookies()
req = httptest.NewRequest("GET", "/protected", nil)
req.Header.Set("X-Forwarded-Proto", "https")
req.Header.Set("X-Forwarded-Host", "company.com")
for _, cookie := range cookies {
req.AddCookie(cookie)
}
resp = httptest.NewRecorder()
tOidc.ServeHTTP(resp, req)
if resp.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d. Body: %s", resp.Code, resp.Body.String())
}
})
}
-409
View File
@@ -1,409 +0,0 @@
// Package auth provides authentication-related functionality for the OIDC middleware.
package auth
import (
"fmt"
"net"
"net/http"
"net/url"
"strings"
"github.com/google/uuid"
)
// ScopeFilter interface for filtering OAuth scopes based on provider capabilities
type ScopeFilter interface {
FilterSupportedScopes(requestedScopes, supportedScopes []string, providerURL string) []string
}
// Handler provides core authentication functionality for OIDC flows
type Handler struct {
logger Logger
enablePKCE bool
isGoogleProv func() bool
isAzureProv func() bool
clientID string
authURL string
issuerURL string
scopes []string
overrideScopes bool
scopeFilter ScopeFilter // NEW
scopesSupported []string // NEW - from provider metadata
}
// Logger interface for dependency injection
type Logger interface {
Debugf(format string, args ...interface{})
Errorf(format string, args ...interface{})
}
// NewAuthHandler creates a new Handler instance
func NewAuthHandler(logger Logger, enablePKCE bool, isGoogleProv, isAzureProv func() bool,
clientID, authURL, issuerURL string, scopes []string, overrideScopes bool,
scopeFilter ScopeFilter, scopesSupported []string) *Handler {
return &Handler{
logger: logger,
enablePKCE: enablePKCE,
isGoogleProv: isGoogleProv,
isAzureProv: isAzureProv,
clientID: clientID,
authURL: authURL,
issuerURL: issuerURL,
scopes: scopes,
overrideScopes: overrideScopes,
scopeFilter: scopeFilter, // NEW
scopesSupported: scopesSupported, // NEW
}
}
// InitiateAuthentication initiates the OIDC authentication flow.
// It generates CSRF tokens, nonce, PKCE parameters (if enabled), clears the session,
// stores authentication state, and redirects the user to the OIDC provider.
func (h *Handler) InitiateAuthentication(rw http.ResponseWriter, req *http.Request,
session SessionData, redirectURL string,
generateNonce, generateCodeVerifier, deriveCodeChallenge func() (string, error)) {
h.logger.Debugf("Initiating new OIDC authentication flow for request: %s", req.URL.RequestURI())
const maxRedirects = 5
redirectCount := session.GetRedirectCount()
if redirectCount >= maxRedirects {
h.logger.Errorf("Maximum redirect limit (%d) exceeded, possible redirect loop detected", maxRedirects)
session.ResetRedirectCount()
http.Error(rw, "Authentication failed: Too many redirects", http.StatusLoopDetected)
return
}
session.IncrementRedirectCount()
csrfToken := uuid.NewString()
nonce, err := generateNonce()
if err != nil {
h.logger.Errorf("Failed to generate nonce: %v", err)
http.Error(rw, "Failed to generate nonce", http.StatusInternalServerError)
return
}
// Generate PKCE code verifier and challenge if PKCE is enabled
var codeVerifier, codeChallenge string
if h.enablePKCE {
codeVerifier, err = generateCodeVerifier()
if err != nil {
h.logger.Errorf("Failed to generate code verifier: %v", err)
http.Error(rw, "Failed to generate code verifier", http.StatusInternalServerError)
return
}
codeChallenge, err = deriveCodeChallenge()
if err != nil {
h.logger.Errorf("Failed to generate code challenge: %v", err)
http.Error(rw, "Failed to generate code challenge", http.StatusInternalServerError)
return
}
h.logger.Debugf("PKCE enabled, generated code challenge")
}
session.SetAuthenticated(false)
session.SetEmail("")
session.SetAccessToken("")
session.SetRefreshToken("")
session.SetIDToken("")
session.SetNonce("")
session.SetCodeVerifier("")
session.SetCSRF(csrfToken)
session.SetNonce(nonce)
if h.enablePKCE {
session.SetCodeVerifier(codeVerifier)
}
session.SetIncomingPath(req.URL.RequestURI())
h.logger.Debugf("Storing incoming path: %s", req.URL.RequestURI())
session.MarkDirty()
if err := session.Save(req, rw); err != nil {
h.logger.Errorf("Failed to save session before redirecting to provider: %v", err)
http.Error(rw, "Failed to save session", http.StatusInternalServerError)
return
}
h.logger.Debugf("Session saved before redirect. CSRF: %s, Nonce: %s",
csrfToken, nonce)
authURL := h.BuildAuthURL(redirectURL, csrfToken, nonce, codeChallenge)
h.logger.Debugf("Redirecting user to OIDC provider: %s", authURL)
http.Redirect(rw, req, authURL, http.StatusFound)
}
// BuildAuthURL constructs the OIDC provider authorization URL.
// It builds the URL with all necessary parameters including client_id, scopes,
// PKCE parameters, and provider-specific parameters for Google and Azure.
func (h *Handler) BuildAuthURL(redirectURL, state, nonce, codeChallenge string) string {
params := url.Values{}
params.Set("client_id", h.clientID)
params.Set("response_type", "code")
params.Set("redirect_uri", redirectURL)
params.Set("state", state)
params.Set("nonce", nonce)
if h.enablePKCE && codeChallenge != "" {
params.Set("code_challenge", codeChallenge)
params.Set("code_challenge_method", "S256")
}
scopes := make([]string, len(h.scopes))
copy(scopes, h.scopes)
// Apply discovery-based scope filtering if available
if h.scopeFilter != nil && len(h.scopesSupported) > 0 {
scopes = h.scopeFilter.FilterSupportedScopes(scopes, h.scopesSupported, h.issuerURL)
h.logger.Debugf("AuthHandler.BuildAuthURL: After discovery filtering: %v", scopes)
}
// Apply provider-specific modifications
scopes, params = h.applyProviderSpecificConfig(scopes, params)
// Final filtering pass to remove anything the provider doesn't support
if h.scopeFilter != nil && len(h.scopesSupported) > 0 {
scopes = h.scopeFilter.FilterSupportedScopes(scopes, h.scopesSupported, h.issuerURL)
h.logger.Debugf("AuthHandler.BuildAuthURL: After final filtering: %v", scopes)
}
if len(scopes) > 0 {
finalScopeString := strings.Join(scopes, " ")
params.Set("scope", finalScopeString)
h.logger.Debugf("AuthHandler.BuildAuthURL: Final scope string being sent to OIDC provider: %s", finalScopeString)
}
return h.buildURLWithParams(h.authURL, params)
}
// applyProviderSpecificConfig applies provider-specific scope and parameter modifications
func (h *Handler) applyProviderSpecificConfig(scopes []string, params url.Values) ([]string, url.Values) {
switch {
case h.isGoogleProv():
return h.applyGoogleConfig(scopes, params)
case h.isAzureProv():
return h.applyAzureConfig(scopes, params)
default:
return h.applyStandardProviderConfig(scopes, params)
}
}
// applyGoogleConfig applies Google-specific configuration
func (h *Handler) applyGoogleConfig(scopes []string, params url.Values) ([]string, url.Values) {
// Google: Remove offline_access if present, add access_type=offline
filteredScopes := make([]string, 0, len(scopes))
for _, scope := range scopes {
if scope != "offline_access" {
filteredScopes = append(filteredScopes, scope)
}
}
params.Set("access_type", "offline")
h.logger.Debugf("Google OIDC provider detected, added access_type=offline")
params.Set("prompt", "consent")
h.logger.Debugf("Google OIDC provider detected, added prompt=consent to ensure refresh tokens")
return filteredScopes, params
}
// applyAzureConfig applies Azure AD-specific configuration
func (h *Handler) applyAzureConfig(scopes []string, params url.Values) ([]string, url.Values) {
params.Set("response_mode", "query")
h.logger.Debugf("Azure AD provider detected, added response_mode=query")
if h.shouldAddOfflineAccess(scopes) {
scopes = append(scopes, "offline_access")
h.logger.Debugf("Azure AD provider: Added offline_access scope (overrideScopes: %t, user scopes count: %d)",
h.overrideScopes, len(h.scopes))
} else {
h.logger.Debugf("Azure AD provider: User is overriding scopes (count: %d), offline_access not automatically added.",
len(h.scopes))
}
return scopes, params
}
// applyStandardProviderConfig applies configuration for standard OIDC providers
func (h *Handler) applyStandardProviderConfig(scopes []string, params url.Values) ([]string, url.Values) {
if h.shouldAddOfflineAccess(scopes) {
scopes = append(scopes, "offline_access")
h.logger.Debugf("Standard provider: Added offline_access scope (overrideScopes: %t, user scopes count: %d)",
h.overrideScopes, len(h.scopes))
} else {
h.logger.Debugf("Standard provider: User is overriding scopes (count: %d), offline_access not automatically added.",
len(h.scopes))
}
return scopes, params
}
// shouldAddOfflineAccess determines if offline_access scope should be added
func (h *Handler) shouldAddOfflineAccess(scopes []string) bool {
if h.overrideScopes && len(h.scopes) > 0 {
return false
}
for _, scope := range scopes {
if scope == "offline_access" {
return false
}
}
return true
}
// buildURLWithParams constructs a URL by combining a base URL with query parameters.
// It handles both relative and absolute URLs, validates URL security,
// and properly encodes query parameters.
func (h *Handler) buildURLWithParams(baseURL string, params url.Values) string {
if baseURL != "" {
if strings.HasPrefix(baseURL, "http://") || strings.HasPrefix(baseURL, "https://") {
if err := h.validateURL(baseURL); err != nil {
h.logger.Errorf("URL validation failed for %s: %v", baseURL, err)
return ""
}
}
}
if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
issuerURLParsed, err := url.Parse(h.issuerURL)
if err != nil {
h.logger.Errorf("Could not parse issuerURL: %s. Error: %v", h.issuerURL, err)
return ""
}
baseURLParsed, err := url.Parse(baseURL)
if err != nil {
h.logger.Errorf("Could not parse baseURL: %s. Error: %v", baseURL, err)
return ""
}
resolvedURL := issuerURLParsed.ResolveReference(baseURLParsed)
if err := h.validateURL(resolvedURL.String()); err != nil {
h.logger.Errorf("Resolved URL validation failed for %s: %v", resolvedURL.String(), err)
return ""
}
resolvedURL.RawQuery = params.Encode()
return resolvedURL.String()
}
u, err := url.Parse(baseURL)
if err != nil {
h.logger.Errorf("Could not parse absolute baseURL: %s. Error: %v", baseURL, err)
return ""
}
if err := h.validateParsedURL(u); err != nil {
h.logger.Errorf("Parsed URL validation failed for %s: %v", baseURL, err)
return ""
}
u.RawQuery = params.Encode()
return u.String()
}
// validateURL performs security validation on URLs to prevent SSRF attacks.
// It checks for allowed schemes, validates hosts, and prevents access to private networks.
func (h *Handler) validateURL(urlStr string) error {
if urlStr == "" {
return fmt.Errorf("empty URL")
}
u, err := url.Parse(urlStr)
if err != nil {
return fmt.Errorf("invalid URL format: %w", err)
}
return h.validateParsedURL(u)
}
// validateParsedURL validates a parsed URL structure for security.
// It checks schemes, hosts, and paths to prevent malicious URLs.
func (h *Handler) validateParsedURL(u *url.URL) error {
allowedSchemes := map[string]bool{
"https": true,
"http": true,
}
if !allowedSchemes[u.Scheme] {
return fmt.Errorf("disallowed URL scheme: %s", u.Scheme)
}
if u.Scheme == "http" {
h.logger.Debugf("Warning: Using HTTP scheme for URL: %s", u.String())
}
if u.Host == "" {
return fmt.Errorf("missing host in URL")
}
if err := h.validateHost(u.Host); err != nil {
return fmt.Errorf("invalid host: %w", err)
}
if strings.Contains(u.Path, "..") {
return fmt.Errorf("path traversal detected in URL path")
}
return nil
}
// validateHost validates a hostname for security and reachability.
// It prevents access to private networks and localhost addresses.
func (h *Handler) validateHost(host string) error {
if host == "" {
return fmt.Errorf("empty host")
}
// Strip port if present
if strings.Contains(host, ":") {
var err error
host, _, err = net.SplitHostPort(host)
if err != nil {
return fmt.Errorf("invalid host:port format: %w", err)
}
}
// Check for localhost variations
localhostVariations := []string{
"localhost", "127.0.0.1", "::1", "0.0.0.0",
}
for _, localhost := range localhostVariations {
if strings.EqualFold(host, localhost) {
return fmt.Errorf("localhost access not allowed: %s", host)
}
}
// Try to parse as IP address
if ip := net.ParseIP(host); ip != nil {
if ip.IsLoopback() {
return fmt.Errorf("loopback IP not allowed: %s", host)
}
if ip.IsPrivate() {
return fmt.Errorf("private IP not allowed: %s", host)
}
if ip.IsLinkLocalUnicast() {
return fmt.Errorf("link-local IP not allowed: %s", host)
}
if ip.IsMulticast() {
return fmt.Errorf("multicast IP not allowed: %s", host)
}
}
return nil
}
// SessionData interface for dependency injection
type SessionData interface {
GetRedirectCount() int
ResetRedirectCount()
IncrementRedirectCount()
SetAuthenticated(bool)
SetEmail(string)
SetAccessToken(string)
SetRefreshToken(string)
SetIDToken(string)
SetNonce(string)
SetCodeVerifier(string)
SetCSRF(string)
SetIncomingPath(string)
MarkDirty()
Save(req *http.Request, rw http.ResponseWriter) error
}
File diff suppressed because it is too large Load Diff
-562
View File
@@ -1,562 +0,0 @@
package auth
import (
"net/url"
"strings"
"testing"
)
// TestAuthHandler_validateURL tests URL validation functionality
func TestAuthHandler_validateURL(t *testing.T) {
logger := &mockLogger{}
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
tests := []struct {
name string
url string
wantErr bool
errMsg string
}{
{
name: "Valid HTTPS URL",
url: "https://example.com/auth",
wantErr: false,
},
{
name: "Valid HTTP URL",
url: "http://example.com/auth",
wantErr: false,
},
{
name: "Empty URL",
url: "",
wantErr: true,
errMsg: "empty URL",
},
{
name: "Invalid URL format",
url: "not-a-url",
wantErr: true,
errMsg: "disallowed URL scheme",
},
{
name: "Disallowed scheme - javascript",
url: "javascript:alert('xss')",
wantErr: true,
errMsg: "disallowed URL scheme",
},
{
name: "Disallowed scheme - data",
url: "data:text/html,<script>alert('xss')</script>",
wantErr: true,
errMsg: "disallowed URL scheme",
},
{
name: "Disallowed scheme - file",
url: "file:///etc/passwd",
wantErr: true,
errMsg: "disallowed URL scheme",
},
{
name: "Disallowed scheme - ftp",
url: "ftp://example.com/file",
wantErr: true,
errMsg: "disallowed URL scheme",
},
{
name: "Missing host",
url: "https:///path",
wantErr: true,
errMsg: "missing host",
},
{
name: "Path traversal attempt",
url: "https://example.com/../../../etc/passwd",
wantErr: true,
errMsg: "path traversal detected",
},
{
name: "Path traversal in middle",
url: "https://example.com/path/../sensitive/file",
wantErr: true,
errMsg: "path traversal detected",
},
{
name: "Localhost attempt",
url: "https://localhost/auth",
wantErr: true,
errMsg: "localhost access not allowed",
},
{
name: "127.0.0.1 attempt",
url: "https://127.0.0.1/auth",
wantErr: true,
errMsg: "localhost access not allowed",
},
{
name: "IPv6 localhost attempt",
url: "https://[::1]/auth",
wantErr: true,
errMsg: "invalid host:port format",
},
{
name: "0.0.0.0 attempt",
url: "https://0.0.0.0/auth",
wantErr: true,
errMsg: "localhost access not allowed",
},
{
name: "Private IP - 192.168.x.x",
url: "https://192.168.1.1/auth",
wantErr: true,
errMsg: "private IP not allowed",
},
{
name: "Private IP - 10.x.x.x",
url: "https://10.0.0.1/auth",
wantErr: true,
errMsg: "private IP not allowed",
},
{
name: "Private IP - 172.16.x.x",
url: "https://172.16.0.1/auth",
wantErr: true,
errMsg: "private IP not allowed",
},
{
name: "Link-local IP",
url: "https://169.254.1.1/auth",
wantErr: true,
errMsg: "link-local IP not allowed",
},
{
name: "Multicast IP",
url: "https://224.0.0.1/auth",
wantErr: true,
errMsg: "multicast IP not allowed",
},
{
name: "Valid public IP",
url: "https://8.8.8.8/auth",
wantErr: false,
},
{
name: "Valid domain with port",
url: "https://example.com:8443/auth",
wantErr: false,
},
{
name: "localhost with case variation",
url: "https://LOCALHOST/auth",
wantErr: true,
errMsg: "localhost access not allowed",
},
{
name: "Invalid host:port format",
url: "https://example.com:notanumber/auth",
wantErr: true,
errMsg: "invalid URL format",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := handler.validateURL(tt.url)
if tt.wantErr {
if err == nil {
t.Errorf("validateURL() expected error but got none")
return
}
if !strings.Contains(err.Error(), tt.errMsg) {
t.Errorf("validateURL() error = %v, expected error containing %v", err, tt.errMsg)
}
} else {
if err != nil {
t.Errorf("validateURL() unexpected error = %v", err)
}
}
})
}
}
// TestAuthHandler_validateHost tests host validation specifically
func TestAuthHandler_validateHost(t *testing.T) {
logger := &mockLogger{}
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
tests := []struct {
name string
host string
wantErr bool
errMsg string
}{
{
name: "Valid hostname",
host: "example.com",
wantErr: false,
},
{
name: "Valid hostname with subdomain",
host: "api.example.com",
wantErr: false,
},
{
name: "Valid hostname with port",
host: "example.com:8080",
wantErr: false,
},
{
name: "Empty host",
host: "",
wantErr: true,
errMsg: "empty host",
},
{
name: "localhost",
host: "localhost",
wantErr: true,
errMsg: "localhost access not allowed",
},
{
name: "LOCALHOST (case insensitive)",
host: "LOCALHOST",
wantErr: true,
errMsg: "localhost access not allowed",
},
{
name: "localhost with port",
host: "localhost:8080",
wantErr: true,
errMsg: "localhost access not allowed",
},
{
name: "127.0.0.1",
host: "127.0.0.1",
wantErr: true,
errMsg: "localhost access not allowed",
},
{
name: "127.0.0.1 with port",
host: "127.0.0.1:8080",
wantErr: true,
errMsg: "localhost access not allowed",
},
{
name: "IPv6 localhost",
host: "::1",
wantErr: true,
errMsg: "invalid host:port format",
},
{
name: "0.0.0.0",
host: "0.0.0.0",
wantErr: true,
errMsg: "localhost access not allowed",
},
{
name: "Private IP 192.168.1.1",
host: "192.168.1.1",
wantErr: true,
errMsg: "private IP not allowed",
},
{
name: "Private IP 10.0.0.1",
host: "10.0.0.1",
wantErr: true,
errMsg: "private IP not allowed",
},
{
name: "Private IP 172.16.0.1",
host: "172.16.0.1",
wantErr: true,
errMsg: "private IP not allowed",
},
{
name: "Public IP 8.8.8.8",
host: "8.8.8.8",
wantErr: false,
},
{
name: "Link-local IP",
host: "169.254.1.1",
wantErr: true,
errMsg: "link-local IP not allowed",
},
{
name: "Multicast IP",
host: "224.0.0.1",
wantErr: true,
errMsg: "multicast IP not allowed",
},
{
name: "Invalid host:port format",
host: "example.com::",
wantErr: true,
errMsg: "invalid host:port format",
},
{
name: "Valid international domain",
host: "example.org",
wantErr: false,
},
{
name: "Valid ccTLD",
host: "example.co.uk",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := handler.validateHost(tt.host)
if tt.wantErr {
if err == nil {
t.Errorf("validateHost() expected error but got none")
return
}
if !strings.Contains(err.Error(), tt.errMsg) {
t.Errorf("validateHost() error = %v, expected error containing %v", err, tt.errMsg)
}
} else {
if err != nil {
t.Errorf("validateHost() unexpected error = %v", err)
}
}
})
}
}
// TestAuthHandler_buildURLWithParams tests URL building with parameters
func TestAuthHandler_buildURLWithParams(t *testing.T) {
logger := &mockLogger{}
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
tests := []struct {
name string
baseURL string
params url.Values
expected string
expectEmpty bool
}{
{
name: "Absolute HTTPS URL",
baseURL: "https://provider.com/auth",
params: url.Values{
"client_id": []string{"test-client"},
"response_type": []string{"code"},
},
expected: "https://provider.com/auth?client_id=test-client&response_type=code",
},
{
name: "Absolute HTTP URL",
baseURL: "http://provider.com/auth",
params: url.Values{
"state": []string{"test-state"},
},
expected: "http://provider.com/auth?state=test-state",
},
{
name: "Relative URL resolved against issuer",
baseURL: "/oauth2/authorize",
params: url.Values{
"scope": []string{"openid"},
},
expected: "https://example.com/oauth2/authorize?scope=openid",
},
{
name: "Root relative URL",
baseURL: "/auth",
params: url.Values{
"nonce": []string{"test-nonce"},
},
expected: "https://example.com/auth?nonce=test-nonce",
},
{
name: "Invalid absolute URL",
baseURL: "https://localhost/auth",
params: url.Values{},
expectEmpty: true, // Should return empty string due to validation failure
},
{
name: "Invalid relative URL when resolved",
baseURL: "/auth",
params: url.Values{},
expected: "", // Should be empty because issuer validation would be tested separately
expectEmpty: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := handler.buildURLWithParams(tt.baseURL, tt.params)
if tt.expectEmpty {
if result != "" {
t.Errorf("buildURLWithParams() expected empty string, got %v", result)
}
return
}
// For relative URLs, we expect them to be resolved against the issuer URL
if !strings.HasPrefix(tt.baseURL, "http") {
// Verify it starts with the issuer URL
if !strings.HasPrefix(result, handler.issuerURL) {
t.Errorf("buildURLWithParams() relative URL not resolved against issuer URL. Got %v", result)
}
}
// Parse the result to verify parameters
parsedURL, err := url.Parse(result)
if err != nil {
t.Fatalf("buildURLWithParams() produced invalid URL: %v", err)
}
// Verify all expected parameters are present
resultParams := parsedURL.Query()
for key, expectedValues := range tt.params {
actualValues := resultParams[key]
if len(actualValues) != len(expectedValues) {
t.Errorf("Parameter %s: expected %d values, got %d", key, len(expectedValues), len(actualValues))
continue
}
for i, expectedValue := range expectedValues {
if actualValues[i] != expectedValue {
t.Errorf("Parameter %s[%d]: expected %v, got %v", key, i, expectedValue, actualValues[i])
}
}
}
})
}
}
// TestAuthHandler_buildURLWithParams_ParameterEncoding tests proper parameter encoding
func TestAuthHandler_buildURLWithParams_ParameterEncoding(t *testing.T) {
logger := &mockLogger{}
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
// Test special characters that need encoding
params := url.Values{
"redirect_uri": []string{"https://example.com/callback?test=value&other=data"},
"state": []string{"state with spaces and & special chars"},
"scope": []string{"openid profile email"},
"special": []string{"value+with+plus&ampersand=equals"},
}
result := handler.buildURLWithParams("https://provider.com/auth", params)
parsedURL, err := url.Parse(result)
if err != nil {
t.Fatalf("Failed to parse result URL: %v", err)
}
// Verify parameters are correctly encoded/decoded
resultParams := parsedURL.Query()
expectedParams := map[string]string{
"redirect_uri": "https://example.com/callback?test=value&other=data",
"state": "state with spaces and & special chars",
"scope": "openid profile email",
"special": "value+with+plus&ampersand=equals",
}
for key, expectedValue := range expectedParams {
actualValue := resultParams.Get(key)
if actualValue != expectedValue {
t.Errorf("Parameter %s: expected %v, got %v", key, expectedValue, actualValue)
}
}
}
// TestAuthHandler_validateParsedURL tests validateParsedURL method
func TestAuthHandler_validateParsedURL(t *testing.T) {
logger := &mockLogger{}
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
tests := []struct {
name string
url string
wantErr bool
errMsg string
}{
{
name: "Valid HTTPS URL",
url: "https://example.com/path",
wantErr: false,
},
{
name: "Valid HTTP URL with warning",
url: "http://example.com/path",
wantErr: false, // Should not error but should log warning
},
{
name: "Invalid scheme",
url: "javascript:alert('xss')",
wantErr: true,
errMsg: "disallowed URL scheme",
},
{
name: "Missing host",
url: "https:///path",
wantErr: true,
errMsg: "missing host",
},
{
name: "Path traversal",
url: "https://example.com/path/../../../etc",
wantErr: true,
errMsg: "path traversal detected",
},
{
name: "Invalid host (private IP)",
url: "https://192.168.1.1/path",
wantErr: true,
errMsg: "invalid host",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parsedURL, err := url.Parse(tt.url)
if err != nil {
t.Fatalf("Failed to parse test URL: %v", err)
}
err = handler.validateParsedURL(parsedURL)
if tt.wantErr {
if err == nil {
t.Errorf("validateParsedURL() expected error but got none")
return
}
if !strings.Contains(err.Error(), tt.errMsg) {
t.Errorf("validateParsedURL() error = %v, expected error containing %v", err, tt.errMsg)
}
} else {
if err != nil {
t.Errorf("validateParsedURL() unexpected error = %v", err)
}
// Check for HTTP warning in debug logs
if parsedURL.Scheme == "http" && len(logger.debugMessages) > 0 {
found := false
for _, msg := range logger.debugMessages {
if strings.Contains(msg, "Warning: Using HTTP scheme") {
found = true
break
}
}
if !found {
t.Error("Expected HTTP scheme warning in debug logs")
}
}
}
})
}
}
-428
View File
@@ -1,428 +0,0 @@
// Package traefikoidc provides OIDC authentication middleware for Traefik.
// This file contains tests for Auth0-specific audience validation scenarios.
package traefikoidc
import (
"net/http/httptest"
"strings"
"testing"
"time"
)
// TestAuth0Scenario1WithCustomAudience tests Auth0 scenario 1:
// - Custom audience configured in plugin
// - Authorize endpoint called WITH audience parameter
// - ID token: aud = client_id
// - Access token: aud = [userinfo, custom_audience]
// Expected: Both tokens validate correctly
func TestAuth0Scenario1WithCustomAudience(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
customAudience := "https://my-api.example.com"
ts.tOidc.audience = customAudience
// Create ID token with aud = client_id (OIDC standard)
idToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id", // ID token always has client_id
"nonce": "test-nonce-scenario1", // ID tokens have nonce per OIDC spec
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
"iat": float64(time.Now().Unix()),
"sub": "test-user",
"email": "test@example.com",
"jti": "id-token-jti",
})
if err != nil {
t.Fatalf("Failed to create ID token: %v", err)
}
// Create access token with aud = [userinfo, custom_audience]
accessToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": []interface{}{
"https://test-issuer.com/userinfo",
customAudience, // Custom API audience
},
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
"iat": float64(time.Now().Unix()),
"sub": "test-user",
"scope": "openid profile email read:data", // Access tokens have scope
"jti": "access-token-jti",
})
if err != nil {
t.Fatalf("Failed to create access token: %v", err)
}
// Verify ID token validates against client_id
cleanupReplayCache()
initReplayCache()
err = ts.tOidc.VerifyToken(idToken)
if err != nil {
t.Errorf("ID token validation failed (should validate against client_id): %v", err)
}
// Verify access token validates against custom audience
cleanupReplayCache()
initReplayCache()
err = ts.tOidc.VerifyToken(accessToken)
if err != nil {
t.Errorf("Access token validation failed (should validate against custom audience): %v", err)
}
// Verify buildAuthURL includes audience parameter (URL-encoded)
authURL := ts.tOidc.buildAuthURL("https://example.com/callback", "state", "nonce", "")
if !strings.Contains(authURL, "audience=") {
t.Errorf("Auth URL should contain audience parameter when custom audience is configured, got: %s", authURL)
}
// Verify the audience is properly URL-encoded (contains %3A for :, %2F for /)
if !strings.Contains(authURL, "audience=https%3A%2F%2Fmy-api.example.com") {
t.Errorf("Auth URL should contain URL-encoded custom audience, got: %s", authURL)
}
}
// TestAuth0Scenario2DefaultAudience tests Auth0 scenario 2:
// - No custom audience configured (defaults to client_id)
// - Authorize endpoint called WITHOUT audience parameter
// - ID token: aud = client_id
// - Access token: aud = [userinfo, default_audience] (no client_id)
// Expected: ID token validates, access token falls back to ID token validation
func TestAuth0Scenario2DefaultAudience(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
// No custom audience - defaults to client_id
ts.tOidc.audience = ts.tOidc.clientID
// Create ID token with aud = client_id
idToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"nonce": "test-nonce-scenario2", // ID tokens have nonce per OIDC spec
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
"iat": float64(time.Now().Unix()),
"sub": "test-user",
"email": "test@example.com",
"jti": "id-token-jti-2",
})
if err != nil {
t.Fatalf("Failed to create ID token: %v", err)
}
// Create access token with aud = [userinfo, some_default_audience]
// This represents Auth0's default audience behavior
accessToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": []interface{}{
"https://test-issuer.com/userinfo",
"https://test-issuer.com/api/v2/", // Default Auth0 Management API
},
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
"iat": float64(time.Now().Unix()),
"sub": "test-user",
"scope": "openid profile email",
"jti": "access-token-jti-2",
})
if err != nil {
t.Fatalf("Failed to create access token: %v", err)
}
// Verify ID token validates
cleanupReplayCache()
initReplayCache()
err = ts.tOidc.VerifyToken(idToken)
if err != nil {
t.Errorf("ID token validation failed: %v", err)
}
// Access token won't have client_id in aud, so it will fail validation
// This is expected for scenario 2 - the session validation relies on ID token
cleanupReplayCache()
initReplayCache()
err = ts.tOidc.VerifyToken(accessToken)
if err == nil {
t.Logf("Access token validation passed (unexpected but OK if client_id is in aud array)")
} else {
// Expected failure - access token doesn't have client_id in aud
t.Logf("Access token validation failed as expected (aud doesn't contain client_id): %v", err)
}
// Verify buildAuthURL does NOT include audience parameter (since audience == client_id)
authURL := ts.tOidc.buildAuthURL("https://example.com/callback", "state", "nonce", "")
if strings.Contains(authURL, "audience=") {
t.Errorf("Auth URL should NOT contain audience parameter when audience equals client_id, got: %s", authURL)
}
}
// TestAuth0Scenario3OpaqueAccessToken tests Auth0 scenario 3:
// - No custom audience configured
// - No default audience in Auth0
// - ID token: aud = client_id
// - Access token: opaque (not JWT)
// Expected: ID token validates, opaque access token is accepted
func TestAuth0Scenario3OpaqueAccessToken(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
// Enable opaque tokens for this scenario (Option C requirement)
ts.tOidc.allowOpaqueTokens = true
// No custom audience
ts.tOidc.audience = ts.tOidc.clientID
// Create ID token
idToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"nonce": "test-nonce-scenario3", // ID tokens have nonce per OIDC spec
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
"iat": float64(time.Now().Unix()),
"sub": "test-user",
"email": "test@example.com",
"jti": "id-token-jti-3",
})
if err != nil {
t.Fatalf("Failed to create ID token: %v", err)
}
// Opaque access token (not a JWT - just a random string)
opaqueAccessToken := "opaque_access_token_random_string_12345"
// Verify ID token validates
cleanupReplayCache()
initReplayCache()
err = ts.tOidc.VerifyToken(idToken)
if err != nil {
t.Errorf("ID token validation failed: %v", err)
}
// Opaque access token should fail JWT validation (expected)
err = ts.tOidc.VerifyToken(opaqueAccessToken)
if err == nil {
t.Error("Opaque access token should fail JWT validation")
} else {
t.Logf("Opaque access token correctly rejected by JWT validator: %v", err)
}
// Test that validateStandardTokens handles opaque tokens correctly
// by falling back to ID token validation
req := httptest.NewRequest("GET", "https://example.com/test", nil)
session, err := ts.tOidc.sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
session.SetAuthenticated(true)
session.SetAccessToken(opaqueAccessToken)
session.SetIDToken(idToken)
authenticated, needsRefresh, expired := ts.tOidc.validateStandardTokens(session)
if !authenticated || needsRefresh || expired {
t.Errorf("Session with opaque access token and valid ID token should be authenticated. Got: auth=%v, refresh=%v, expired=%v",
authenticated, needsRefresh, expired)
}
}
// TestAuth0AudienceArrayValidation tests that audience validation
// correctly handles array audiences (common in Auth0)
func TestAuth0AudienceArrayValidation(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
customAudience := "https://my-api.example.com"
ts.tOidc.audience = customAudience
// Access token with audience as array containing our custom audience
accessToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": []interface{}{
"https://test-issuer.com/userinfo",
customAudience,
"https://another-api.example.com",
},
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
"iat": float64(time.Now().Unix()),
"sub": "test-user",
"scope": "openid profile email read:data write:data",
"jti": "array-aud-token-jti",
})
if err != nil {
t.Fatalf("Failed to create access token: %v", err)
}
// Should validate successfully - custom audience is in the array
cleanupReplayCache()
initReplayCache()
err = ts.tOidc.VerifyToken(accessToken)
if err != nil {
t.Errorf("Access token with audience array should validate when custom audience is present: %v", err)
}
}
// TestAuth0MismatchedAudience tests that tokens with wrong audience fail validation
func TestAuth0MismatchedAudience(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
customAudience := "https://my-api.example.com"
ts.tOidc.audience = customAudience
// Access token with WRONG audience
accessToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": []interface{}{
"https://test-issuer.com/userinfo",
"https://different-api.example.com", // Wrong audience
},
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
"iat": float64(time.Now().Unix()),
"sub": "test-user",
"scope": "openid profile email",
"jti": "wrong-aud-token-jti",
})
if err != nil {
t.Fatalf("Failed to create access token: %v", err)
}
// Should fail validation - audience doesn't match
cleanupReplayCache()
initReplayCache()
err = ts.tOidc.VerifyToken(accessToken)
if err == nil {
t.Error("Access token with wrong audience should fail validation")
} else if !strings.Contains(err.Error(), "invalid audience") {
t.Errorf("Expected 'invalid audience' error, got: %v", err)
}
}
// TestAuth0Scenario2StrictMode tests strict audience validation mode:
// - Scenario 2 (access token with wrong audience) should be REJECTED
// - strictAudienceValidation=true prevents fallback to ID token
// - This addresses Allan's security concerns about audience bypass
func TestAuth0Scenario2StrictMode(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
// Enable strict mode to prevent Scenario 2 bypass (Option C)
ts.tOidc.strictAudienceValidation = true
// Configure custom audience
customAudience := "https://my-api.example.com"
ts.tOidc.audience = customAudience
// Create ID token with aud = client_id (valid)
idToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"nonce": "test-nonce-strict",
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
"iat": float64(time.Now().Unix()),
"sub": "test-user",
"email": "test@example.com",
"jti": "id-token-strict-jti",
})
if err != nil {
t.Fatalf("Failed to create ID token: %v", err)
}
// Create access token with WRONG audience (doesn't include custom audience)
accessToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": []interface{}{
"https://test-issuer.com/userinfo",
"https://wrong-api.example.com", // Wrong audience - not our custom audience
},
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
"iat": float64(time.Now().Unix()),
"sub": "test-user",
"scope": "openid profile email",
"jti": "access-token-strict-jti",
})
if err != nil {
t.Fatalf("Failed to create access token: %v", err)
}
// Test session validation with wrong access token and valid ID token
req := httptest.NewRequest("GET", "https://example.com/test", nil)
session, err := ts.tOidc.sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
session.SetAuthenticated(true)
session.SetAccessToken(accessToken)
session.SetIDToken(idToken)
session.SetRefreshToken("test-refresh-token") // Add refresh token so it can attempt refresh
// In strict mode, this should FAIL (no fallback to ID token)
authenticated, needsRefresh, expired := ts.tOidc.validateStandardTokens(session)
if authenticated {
t.Errorf("Strict mode: Session with wrong access token audience should be rejected, but got authenticated=true")
}
if !needsRefresh {
t.Errorf("Strict mode: Should signal refresh needed after rejection, got needsRefresh=%v", needsRefresh)
}
if expired {
t.Errorf("Strict mode: Should not mark as expired (should try refresh first), got expired=%v", expired)
}
t.Logf("✓ Strict mode correctly rejected Scenario 2 (access token audience mismatch)")
}
// TestIDTokenAlwaysValidatesAgainstClientID verifies that ID tokens
// are ALWAYS validated against client_id, regardless of configured audience
func TestIDTokenAlwaysValidatesAgainstClientID(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
// Configure a custom audience different from client_id
customAudience := "https://my-api.example.com"
ts.tOidc.audience = customAudience
// Create ID token with aud = client_id (per OIDC spec)
idToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id", // ID token MUST have client_id
"nonce": "test-nonce-123", // ID tokens have nonce for replay protection
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
"iat": float64(time.Now().Unix()),
"sub": "test-user",
"email": "test@example.com",
"jti": "id-token-client-id-jti",
})
if err != nil {
t.Fatalf("Failed to create ID token: %v", err)
}
// Should validate successfully - ID tokens are checked against client_id
cleanupReplayCache()
initReplayCache()
err = ts.tOidc.VerifyToken(idToken)
if err != nil {
t.Errorf("ID token should validate against client_id even when custom audience is configured: %v", err)
}
// Create ID token with WRONG audience (should fail)
wrongIDToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": customAudience, // WRONG - should be client_id
"nonce": "test-nonce-wrong-456", // ID token has nonce, so it will be detected as ID token
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
"iat": float64(time.Now().Unix()),
"sub": "test-user",
"email": "test@example.com",
"jti": "wrong-id-token-jti",
})
if err != nil {
t.Fatalf("Failed to create wrong ID token: %v", err)
}
// Should fail - ID tokens must have client_id as audience
cleanupReplayCache()
initReplayCache()
err = ts.tOidc.VerifyToken(wrongIDToken)
if err == nil {
t.Error("ID token with custom audience (not client_id) should fail validation")
}
}
+19 -13
View File
@@ -8,10 +8,6 @@ import (
"github.com/google/uuid"
)
// ============================================================================
// AUTHENTICATION FLOW
// ============================================================================
// validateRedirectCount checks if redirect limit is exceeded and handles the error
func (t *TraefikOidc) validateRedirectCount(session *SessionData, rw http.ResponseWriter, req *http.Request) error {
const maxRedirects = 5
@@ -223,15 +219,25 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
return
}
email, _ := claims["email"].(string)
if email == "" {
t.logger.Errorf("Email claim missing or empty in token during callback")
t.sendErrorResponse(rw, req, "Authentication failed: Email missing in token", http.StatusInternalServerError)
return
// Extract user identifier from the configured claim (defaults to "email" for backward compatibility)
userIdentifier, _ := claims[t.userIdentifierClaim].(string)
if userIdentifier == "" {
// Try "sub" as fallback since it's required by OIDC spec
if t.userIdentifierClaim != "sub" {
userIdentifier, _ = claims["sub"].(string)
}
if userIdentifier == "" {
t.logger.Errorf("User identifier claim '%s' missing or empty in token during callback", t.userIdentifierClaim)
t.sendErrorResponse(rw, req, "Authentication failed: User identifier missing in token", http.StatusInternalServerError)
return
}
t.logger.Debugf("Configured claim '%s' not found, using 'sub' claim as fallback", t.userIdentifierClaim)
}
if !t.isAllowedDomain(email) {
t.logger.Errorf("Disallowed email domain during callback: %s", email)
t.sendErrorResponse(rw, req, "Authentication failed: Email domain not allowed", http.StatusForbidden)
// Validate user authorization
if !t.isAllowedUser(userIdentifier) {
t.logger.Errorf("User not authorized during callback: %s", userIdentifier)
t.sendErrorResponse(rw, req, "Authentication failed: User not authorized", http.StatusForbidden)
return
}
@@ -240,7 +246,7 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
t.sendErrorResponse(rw, req, "Failed to update session", http.StatusInternalServerError)
return
}
session.SetEmail(email)
session.SetEmail(userIdentifier) // SetEmail stores the user identifier (email or other claim)
session.SetIDToken(tokenResponse.IDToken)
session.SetAccessToken(tokenResponse.AccessToken)
session.SetRefreshToken(tokenResponse.RefreshToken)
File diff suppressed because it is too large Load Diff
+1
View File
@@ -787,6 +787,7 @@ func (mm *TaskMemoryMonitor) ForceGC() (before, after TaskMemoryStats, err error
}
if mm.logger != nil {
// #nosec G115 -- heap allocation bytes fit in int64 for practical purposes
freed := int64(before.HeapAlloc) - int64(after.HeapAlloc)
mm.logger.Infof("Forced GC: freed %d bytes (%.2f MB)", freed, float64(freed)/(1024*1024))
}
+241
View File
@@ -0,0 +1,241 @@
package traefikoidc
import (
"fmt"
"sync"
"testing"
"time"
)
// =============================================================================
// UNIVERSAL CACHE BENCHMARKS
// =============================================================================
func BenchmarkCacheSet(b *testing.B) {
cache := NewUniversalCache(createTestCacheConfig())
defer cache.Close()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
cache.Set(fmt.Sprintf("key%d", i), fmt.Sprintf("value%d", i), 1*time.Hour)
i++
}
})
}
func BenchmarkCacheGet(b *testing.B) {
cache := NewUniversalCache(createTestCacheConfig())
defer cache.Close()
for i := 0; i < 1000; i++ {
cache.Set(fmt.Sprintf("key%d", i), fmt.Sprintf("value%d", i), 1*time.Hour)
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
cache.Get(fmt.Sprintf("key%d", i%1000))
i++
}
})
}
func BenchmarkCacheSetGet(b *testing.B) {
cache := NewUniversalCache(createTestCacheConfig())
defer cache.Close()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
key := fmt.Sprintf("key%d", i)
cache.Set(key, fmt.Sprintf("value%d", i), 1*time.Hour)
cache.Get(key)
i++
}
})
}
func BenchmarkCacheLRUEviction(b *testing.B) {
config := createTestCacheConfig()
config.MaxSize = 100
cache := NewUniversalCache(config)
defer cache.Close()
b.ResetTimer()
for i := 0; i < b.N; i++ {
cache.Set(fmt.Sprintf("key%d", i), fmt.Sprintf("value%d", i), 1*time.Hour)
}
}
func BenchmarkCacheConcurrent(b *testing.B) {
cache := NewUniversalCache(createTestCacheConfig())
defer cache.Close()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
switch i % 3 {
case 0:
cache.Set(fmt.Sprintf("key%d", i), fmt.Sprintf("value%d", i), 1*time.Hour)
case 1:
cache.Get(fmt.Sprintf("key%d", i))
case 2:
cache.Delete(fmt.Sprintf("key%d", i))
}
i++
}
})
}
// =============================================================================
// CACHE MANAGER BENCHMARKS
// =============================================================================
func BenchmarkCacheInterfaceWrapper_Set(b *testing.B) {
t := &testing.T{}
cm := getTestCacheManager(t)
cache := cm.GetSharedTokenBlacklist()
b.ResetTimer()
for i := 0; i < b.N; i++ {
cache.Set("benchmark-key", "benchmark-value", time.Hour)
}
}
func BenchmarkCacheInterfaceWrapper_Get(b *testing.B) {
t := &testing.T{}
cm := getTestCacheManager(t)
cache := cm.GetSharedTokenBlacklist()
cache.Set("benchmark-key", "benchmark-value", time.Hour)
b.ResetTimer()
for i := 0; i < b.N; i++ {
cache.Get("benchmark-key")
}
}
func BenchmarkCacheInterfaceWrapper_Delete(b *testing.B) {
t := &testing.T{}
cm := getTestCacheManager(t)
cache := cm.GetSharedTokenBlacklist()
b.ResetTimer()
for i := 0; i < b.N; i++ {
b.StopTimer()
key := fmt.Sprintf("benchmark-key-%d", i)
cache.Set(key, "value", time.Hour)
b.StartTimer()
cache.Delete(key)
}
}
// =============================================================================
// CACHE COMPATIBILITY BENCHMARKS
// =============================================================================
func BenchmarkNewBoundedCache(b *testing.B) {
for i := 0; i < b.N; i++ {
NewBoundedCache(1000)
}
}
func BenchmarkNewOptimizedCache(b *testing.B) {
for i := 0; i < b.N; i++ {
NewOptimizedCache()
}
}
func BenchmarkLRUStrategy_EstimateSize(b *testing.B) {
strategy := NewLRUStrategy(1000)
item := "test-item"
b.ResetTimer()
for i := 0; i < b.N; i++ {
strategy.EstimateSize(item)
}
}
// =============================================================================
// SHARDED CACHE BENCHMARKS
// =============================================================================
func BenchmarkShardedCache(b *testing.B) {
b.Run("Set", func(b *testing.B) {
cache := NewShardedCache(64, 100000)
b.ResetTimer()
for i := 0; i < b.N; i++ {
cache.Set(fmt.Sprintf("key-%d", i), i, 5*time.Minute)
}
})
b.Run("Get", func(b *testing.B) {
cache := NewShardedCache(64, 100000)
for i := 0; i < 10000; i++ {
cache.Set(fmt.Sprintf("key-%d", i), i, 5*time.Minute)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
cache.Get(fmt.Sprintf("key-%d", i%10000))
}
})
b.Run("ParallelSetGet", func(b *testing.B) {
cache := NewShardedCache(64, 100000)
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
key := fmt.Sprintf("key-%d", i)
cache.Set(key, i, 5*time.Minute)
cache.Get(key)
i++
}
})
})
}
// BenchmarkShardedVsGlobalMutex compares sharded cache with global mutex approach
func BenchmarkShardedVsGlobalMutex(b *testing.B) {
b.Run("ShardedCache64", func(b *testing.B) {
cache := NewShardedCache(64, 100000)
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
key := fmt.Sprintf("jti-%d", i%10000)
if !cache.Exists(key) {
cache.Set(key, true, 5*time.Minute)
}
i++
}
})
})
b.Run("GlobalMutexCache", func(b *testing.B) {
var mu sync.RWMutex
data := make(map[string]bool)
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
key := fmt.Sprintf("jti-%d", i%10000)
mu.RLock()
_, exists := data[key]
mu.RUnlock()
if !exists {
mu.Lock()
data[key] = true
mu.Unlock()
}
i++
}
})
})
}
-369
View File
@@ -1,369 +0,0 @@
package traefikoidc
import (
"testing"
"time"
)
// TestNewBoundedCache tests creation of bounded cache
func TestNewBoundedCache(t *testing.T) {
maxSize := 500
cache := NewBoundedCache(maxSize)
if cache == nil {
t.Fatal("Expected cache to be created, got nil")
}
// Verify we can use basic operations
cache.Set("test-key", "test-value", time.Hour)
value, found := cache.Get("test-key")
if !found {
t.Error("Expected key to be found in cache")
}
if value != "test-value" {
t.Errorf("Expected 'test-value', got %v", value)
}
}
// TestDefaultUnifiedCacheConfig tests default configuration
func TestDefaultUnifiedCacheConfig(t *testing.T) {
config := DefaultUnifiedCacheConfig()
if config.Type != CacheTypeGeneral {
t.Errorf("Expected CacheTypeGeneral, got %v", config.Type)
}
if config.MaxSize != 500 {
t.Errorf("Expected MaxSize 500, got %d", config.MaxSize)
}
if config.MaxMemoryBytes != 64*1024*1024 {
t.Errorf("Expected MaxMemoryBytes 64MB, got %d", config.MaxMemoryBytes)
}
if config.CleanupInterval != 2*time.Minute {
t.Errorf("Expected CleanupInterval 2 minutes, got %v", config.CleanupInterval)
}
if config.Logger == nil {
t.Error("Expected Logger to be set")
}
}
// TestNewUnifiedCache tests unified cache creation
func TestNewUnifiedCache(t *testing.T) {
config := DefaultUnifiedCacheConfig()
cache := NewUnifiedCache(config)
if cache == nil {
t.Fatal("Expected cache to be created, got nil")
}
if cache.UniversalCache == nil {
t.Error("Expected UniversalCache to be set")
}
// Test basic operations
cache.Set("test-key", "test-value", time.Hour)
value, found := cache.Get("test-key")
if !found {
t.Error("Expected key to be found in cache")
}
if value != "test-value" {
t.Errorf("Expected 'test-value', got %v", value)
}
}
// TestUnifiedCache_SetMaxSize tests SetMaxSize method
func TestUnifiedCache_SetMaxSize(t *testing.T) {
config := DefaultUnifiedCacheConfig()
cache := NewUnifiedCache(config)
// Test setting max size
newSize := 1000
cache.SetMaxSize(newSize)
// We can't easily verify the size was set without exposing internal fields,
// but we can ensure the method doesn't panic
}
// TestNewCacheAdapter tests cache adapter creation
func TestNewCacheAdapter(t *testing.T) {
tests := []struct {
name string
cache interface{}
expectNil bool
description string
}{
{
name: "UniversalCache",
cache: NewUniversalCache(DefaultUnifiedCacheConfig()),
expectNil: false,
description: "Should create adapter for UniversalCache",
},
{
name: "UnifiedCache",
cache: NewUnifiedCache(DefaultUnifiedCacheConfig()),
expectNil: false,
description: "Should create adapter for UnifiedCache",
},
{
name: "Invalid cache type",
cache: "not-a-cache",
expectNil: true,
description: "Should return nil for invalid cache type",
},
{
name: "Nil cache",
cache: nil,
expectNil: true,
description: "Should return nil for nil cache",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
adapter := NewCacheAdapter(tt.cache)
if tt.expectNil {
if adapter != nil {
t.Errorf("Expected nil adapter, got %v", adapter)
}
} else {
if adapter == nil {
t.Error("Expected non-nil adapter")
}
// Test basic operations
adapter.Set("test", "value", time.Hour)
value, found := adapter.Get("test")
if !found {
t.Error("Expected key to be found")
}
if value != "value" {
t.Errorf("Expected 'value', got %v", value)
}
}
})
}
}
// TestNewOptimizedCache tests optimized cache creation
func TestNewOptimizedCache(t *testing.T) {
cache := NewOptimizedCache()
if cache == nil {
t.Fatal("Expected cache to be created, got nil")
}
// Verify it works with basic operations
cache.Set("test-key", "test-value", time.Hour)
value, found := cache.Get("test-key")
if !found {
t.Error("Expected key to be found in cache")
}
if value != "test-value" {
t.Errorf("Expected 'test-value', got %v", value)
}
}
// TestNewLRUStrategy tests LRU strategy creation
func TestNewLRUStrategy(t *testing.T) {
maxSize := 100
strategy := NewLRUStrategy(maxSize)
if strategy == nil {
t.Fatal("Expected strategy to be created, got nil")
}
lruStrategy, ok := strategy.(*LRUStrategy)
if !ok {
t.Fatal("Expected LRUStrategy type")
}
if lruStrategy.maxSize != maxSize {
t.Errorf("Expected maxSize %d, got %d", maxSize, lruStrategy.maxSize)
}
if lruStrategy.order == nil {
t.Error("Expected order list to be initialized")
}
if lruStrategy.elements == nil {
t.Error("Expected elements map to be initialized")
}
}
// TestLRUStrategy_Name tests strategy name
func TestLRUStrategy_Name(t *testing.T) {
strategy := NewLRUStrategy(100)
name := strategy.Name()
if name != "LRU" {
t.Errorf("Expected 'LRU', got %s", name)
}
}
// TestLRUStrategy_ShouldEvict tests eviction logic
func TestLRUStrategy_ShouldEvict(t *testing.T) {
strategy := NewLRUStrategy(100)
// LRU strategy always returns false for ShouldEvict
result := strategy.ShouldEvict("test-item", time.Now())
if result != false {
t.Error("Expected ShouldEvict to return false")
}
}
// TestLRUStrategy_OnAccess tests access callback
func TestLRUStrategy_OnAccess(t *testing.T) {
strategy := NewLRUStrategy(100)
// OnAccess should not panic
strategy.OnAccess("test-key", "test-value")
}
// TestLRUStrategy_OnRemove tests removal callback
func TestLRUStrategy_OnRemove(t *testing.T) {
strategy := NewLRUStrategy(100)
// OnRemove should not panic
strategy.OnRemove("test-key")
}
// TestLRUStrategy_EstimateSize tests size estimation
func TestLRUStrategy_EstimateSize(t *testing.T) {
strategy := NewLRUStrategy(100)
size := strategy.EstimateSize("test-item")
if size != 64 {
t.Errorf("Expected size 64, got %d", size)
}
}
// TestLRUStrategy_GetEvictionCandidate tests eviction candidate retrieval
func TestLRUStrategy_GetEvictionCandidate(t *testing.T) {
strategy := NewLRUStrategy(100)
key, found := strategy.GetEvictionCandidate()
if found {
t.Error("Expected no eviction candidate to be found")
}
if key != "" {
t.Errorf("Expected empty key, got %s", key)
}
}
// TestNewOptimizedCacheWithConfig tests optimized cache with custom config
func TestNewOptimizedCacheWithConfig(t *testing.T) {
config := UniversalCacheConfig{
Type: CacheTypeGeneral,
MaxSize: 1000,
MaxMemoryBytes: 128 * 1024 * 1024,
EnableMetrics: true,
Logger: GetSingletonNoOpLogger(),
}
cache := NewOptimizedCacheWithConfig(config)
if cache == nil {
t.Fatal("Expected cache to be created, got nil")
}
// Verify it works with basic operations
cache.Set("test-key", "test-value", time.Hour)
value, found := cache.Get("test-key")
if !found {
t.Error("Expected key to be found in cache")
}
if value != "test-value" {
t.Errorf("Expected 'test-value', got %v", value)
}
}
// TestNewFixedMetadataCache tests fixed metadata cache creation
func TestNewFixedMetadataCache(t *testing.T) {
cache := NewFixedMetadataCache()
if cache == nil {
t.Fatal("Expected cache to be created, got nil")
}
// Verify it works with proper metadata operations
metadata := &ProviderMetadata{
Issuer: "https://example.com",
AuthURL: "https://example.com/auth",
TokenURL: "https://example.com/token",
JWKSURL: "https://example.com/jwks",
}
err := cache.Set("test-provider", metadata, time.Hour)
if err != nil {
t.Errorf("Unexpected error setting metadata: %v", err)
}
// Test that the cache was created (basic verification)
// Note: We can't easily test Get without more complex setup
}
// TestNewDoublyLinkedList tests doubly linked list creation
func TestNewDoublyLinkedList(t *testing.T) {
list := NewDoublyLinkedList()
if list == nil {
t.Fatal("Expected list to be created, got nil")
}
// Test it's a proper list structure
if list.Len() != 0 {
t.Error("Expected empty list initially")
}
}
// TestDoublyLinkedList_PopFront tests front element removal
func TestDoublyLinkedList_PopFront(t *testing.T) {
list := NewDoublyLinkedList()
// Test popping from empty list
element := list.PopFront()
if element != nil {
t.Error("Expected nil when popping from empty list")
}
// Add an element and test popping
added := list.PushBack("test-value")
if added == nil {
t.Fatal("Expected element to be added")
}
popped := list.PopFront()
if popped == nil {
t.Error("Expected element to be popped")
}
if list.Len() != 0 {
t.Error("Expected list to be empty after popping")
}
}
// Benchmark tests for performance
func BenchmarkNewBoundedCache(b *testing.B) {
for i := 0; i < b.N; i++ {
NewBoundedCache(1000)
}
}
func BenchmarkNewOptimizedCache(b *testing.B) {
for i := 0; i < b.N; i++ {
NewOptimizedCache()
}
}
func BenchmarkLRUStrategy_EstimateSize(b *testing.B) {
strategy := NewLRUStrategy(1000)
item := "test-item"
b.ResetTimer()
for i := 0; i < b.N; i++ {
strategy.EstimateSize(item)
}
}
-314
View File
@@ -1,314 +0,0 @@
package traefikoidc
import (
"fmt"
"sync"
"testing"
"time"
)
// Helper function to ensure we have a working cache manager for tests
func getTestCacheManager(t *testing.T) *CacheManager {
cm := GetGlobalCacheManager(&sync.WaitGroup{})
if cm == nil {
t.Fatal("Failed to get cache manager")
}
if cm.manager == nil {
t.Fatal("Cache manager has nil internal manager")
}
return cm
}
// TestCacheManager_Close tests cache manager close functionality
func TestCacheManager_Close(t *testing.T) {
// Get a fresh cache manager
wg := &sync.WaitGroup{}
cm := GetGlobalCacheManager(wg)
if cm == nil {
t.Fatal("Expected cache manager to be created")
}
// Test closing the cache manager
err := cm.Close()
if err != nil {
t.Errorf("Unexpected error closing cache manager: %v", err)
}
}
// TestCleanupGlobalCacheManager tests global cleanup
func TestCleanupGlobalCacheManager(t *testing.T) {
// Test cleanup when no instance exists (should not error)
originalInstance := globalCacheManagerInstance
globalCacheManagerInstance = nil
err := CleanupGlobalCacheManager()
if err != nil {
t.Errorf("Unexpected error during cleanup of nil instance: %v", err)
}
// Restore original instance
globalCacheManagerInstance = originalInstance
}
// TestCacheInterfaceWrapper_Delete tests delete functionality
func TestCacheInterfaceWrapper_Delete(t *testing.T) {
cm := getTestCacheManager(t)
cache := cm.GetSharedTokenBlacklist()
// Add an item
cache.Set("test-key", "test-value", time.Hour)
// Verify it exists
value, found := cache.Get("test-key")
if !found {
t.Fatal("Expected key to be found after setting")
}
if value != "test-value" {
t.Errorf("Expected 'test-value', got %v", value)
}
// Delete it
cache.Delete("test-key")
// Verify it's gone
_, found = cache.Get("test-key")
if found {
t.Error("Expected key to be deleted")
}
}
// TestCacheInterfaceWrapper_Size tests size functionality
func TestCacheInterfaceWrapper_Size(t *testing.T) {
cm := getTestCacheManager(t)
cache := cm.GetSharedTokenBlacklist()
// Clear cache first
cache.Clear()
// Check initial size
initialSize := cache.Size()
if initialSize != 0 {
t.Errorf("Expected initial size 0, got %d", initialSize)
}
// Add some items
cache.Set("key1", "value1", time.Hour)
cache.Set("key2", "value2", time.Hour)
// Check size increased
newSize := cache.Size()
if newSize != 2 {
t.Errorf("Expected size 2, got %d", newSize)
}
}
// TestCacheInterfaceWrapper_Clear tests clear functionality
func TestCacheInterfaceWrapper_Clear(t *testing.T) {
cm := getTestCacheManager(t)
cache := cm.GetSharedTokenBlacklist()
// Add some items
cache.Set("key1", "value1", time.Hour)
cache.Set("key2", "value2", time.Hour)
// Verify items exist
size := cache.Size()
if size != 2 {
t.Errorf("Expected 2 items before clear, got %d", size)
}
// Clear all
cache.Clear()
// Verify cache is empty
size = cache.Size()
if size != 0 {
t.Errorf("Expected 0 items after clear, got %d", size)
}
// Verify specific items are gone
_, found := cache.Get("key1")
if found {
t.Error("Expected key1 to be cleared")
}
_, found = cache.Get("key2")
if found {
t.Error("Expected key2 to be cleared")
}
}
// TestCacheInterfaceWrapper_Close tests wrapper close functionality
func TestCacheInterfaceWrapper_Close(t *testing.T) {
cm := getTestCacheManager(t)
cache := cm.GetSharedTokenBlacklist()
// Test close - should not panic
wrapper, ok := cache.(*CacheInterfaceWrapper)
if !ok {
t.Fatal("Expected CacheInterfaceWrapper")
}
wrapper.Close() // Should not panic
// Test close with nil cache
nilWrapper := &CacheInterfaceWrapper{cache: nil}
nilWrapper.Close() // Should not panic
}
// TestCacheInterfaceWrapper_GetStats tests stats functionality
func TestCacheInterfaceWrapper_GetStats(t *testing.T) {
cm := getTestCacheManager(t)
cache := cm.GetSharedTokenBlacklist()
wrapper, ok := cache.(*CacheInterfaceWrapper)
if !ok {
t.Fatal("Expected CacheInterfaceWrapper")
}
// Get stats
stats := wrapper.GetStats()
if stats == nil {
t.Error("Expected non-nil stats")
}
// Stats should be accessible (len() never returns negative values)
// Just verify it's accessible by checking it's not nil (already done above)
}
// TestCacheInterfaceWrapper_Cleanup tests cleanup functionality
func TestCacheInterfaceWrapper_Cleanup(t *testing.T) {
cm := getTestCacheManager(t)
cache := cm.GetSharedTokenBlacklist()
// Add an item that will expire quickly
cache.Set("expire-key", "expire-value", time.Millisecond)
// Wait for expiration
time.Sleep(10 * time.Millisecond)
// Trigger cleanup
cache.Cleanup()
// Item should be cleaned up
_, found := cache.Get("expire-key")
if found {
t.Error("Expected expired key to be cleaned up")
}
}
// TestCacheInterfaceWrapper_SetMaxSize tests max size setting
func TestCacheInterfaceWrapper_SetMaxSize(t *testing.T) {
cm := getTestCacheManager(t)
cache := cm.GetSharedTokenBlacklist()
// Test setting max size (should not panic)
cache.SetMaxSize(1000)
// We can't easily verify the size was set without exposing internals,
// but we can ensure the method doesn't panic
}
// TestGetSharedCaches tests getting shared cache instances
func TestGetSharedCaches(t *testing.T) {
cm := getTestCacheManager(t)
// Test getting shared token blacklist
blacklist := cm.GetSharedTokenBlacklist()
if blacklist == nil {
t.Error("Expected non-nil token blacklist")
}
// Test getting shared token cache
tokenCache := cm.GetSharedTokenCache()
if tokenCache == nil {
t.Error("Expected non-nil token cache")
}
// Test getting shared metadata cache
metadataCache := cm.GetSharedMetadataCache()
if metadataCache == nil {
t.Error("Expected non-nil metadata cache")
}
// Test getting shared JWK cache
jwkCache := cm.GetSharedJWKCache()
if jwkCache == nil {
t.Error("Expected non-nil JWK cache")
}
}
// TestConcurrentCacheAccess tests thread safety
func TestConcurrentCacheAccess(t *testing.T) {
cm := getTestCacheManager(t)
cache := cm.GetSharedTokenBlacklist()
var wg sync.WaitGroup
goroutines := 10
iterations := 10
// Concurrent operations
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
for j := 0; j < iterations; j++ {
key := fmt.Sprintf("key-%d-%d", id, j)
value := fmt.Sprintf("value-%d-%d", id, j)
cache.Set(key, value, time.Hour)
retrieved, found := cache.Get(key)
if found && retrieved != value {
t.Errorf("Concurrent access failed: expected %s, got %v", value, retrieved)
}
cache.Delete(key)
}
}(i)
}
wg.Wait()
}
// Benchmark tests for performance
func BenchmarkCacheInterfaceWrapper_Set(b *testing.B) {
t := &testing.T{}
cm := getTestCacheManager(t)
cache := cm.GetSharedTokenBlacklist()
b.ResetTimer()
for i := 0; i < b.N; i++ {
cache.Set("benchmark-key", "benchmark-value", time.Hour)
}
}
func BenchmarkCacheInterfaceWrapper_Get(b *testing.B) {
t := &testing.T{}
cm := getTestCacheManager(t)
cache := cm.GetSharedTokenBlacklist()
// Pre-populate cache
cache.Set("benchmark-key", "benchmark-value", time.Hour)
b.ResetTimer()
for i := 0; i < b.N; i++ {
cache.Get("benchmark-key")
}
}
func BenchmarkCacheInterfaceWrapper_Delete(b *testing.B) {
t := &testing.T{}
cm := getTestCacheManager(t)
cache := cm.GetSharedTokenBlacklist()
b.ResetTimer()
for i := 0; i < b.N; i++ {
b.StopTimer()
key := fmt.Sprintf("benchmark-key-%d", i)
cache.Set(key, "value", time.Hour)
b.StartTimer()
cache.Delete(key)
}
}
File diff suppressed because it is too large Load Diff
-319
View File
@@ -1,319 +0,0 @@
// Package circuit_breaker provides circuit breaker implementation for resilience
package circuit_breaker
import (
"context"
"fmt"
"sync"
"sync/atomic"
"time"
)
// CircuitBreakerState represents the current state of a circuit breaker.
// The circuit breaker pattern prevents cascading failures by monitoring
// error rates and temporarily blocking requests to failing services.
type CircuitBreakerState int
// Circuit breaker states following the standard pattern:
// Closed: Normal operation, requests flow through
// Open: Circuit is tripped, requests are blocked
// HalfOpen: Testing state, limited requests allowed to test recovery
const (
// CircuitBreakerClosed allows all requests through (normal operation)
CircuitBreakerClosed CircuitBreakerState = iota
// CircuitBreakerOpen blocks all requests (service is failing)
CircuitBreakerOpen
// CircuitBreakerHalfOpen allows limited requests to test service recovery
CircuitBreakerHalfOpen
)
// String returns a string representation of the circuit breaker state
func (s CircuitBreakerState) String() string {
switch s {
case CircuitBreakerClosed:
return "closed"
case CircuitBreakerOpen:
return "open"
case CircuitBreakerHalfOpen:
return "half-open"
default:
return "unknown"
}
}
// Logger interface for dependency injection
type Logger interface {
Infof(format string, args ...interface{})
Errorf(format string, args ...interface{})
Debugf(format string, args ...interface{})
}
// BaseRecoveryMechanism interface for common functionality
type BaseRecoveryMechanism interface {
RecordRequest()
RecordSuccess()
RecordFailure()
GetBaseMetrics() map[string]interface{}
LogInfo(format string, args ...interface{})
LogError(format string, args ...interface{})
LogDebug(format string, args ...interface{})
}
// CircuitBreaker implements the circuit breaker pattern for external service calls.
// It monitors failure rates and automatically opens the circuit when failures
// exceed the threshold, preventing further requests until the service recovers.
type CircuitBreaker struct {
// baseRecovery provides common functionality
baseRecovery BaseRecoveryMechanism
// maxFailures is the threshold for opening the circuit
maxFailures int
// timeout is how long to wait before allowing requests in half-open state
timeout time.Duration
// resetTimeout is how long to wait before transitioning from open to half-open
resetTimeout time.Duration
// state tracks the current circuit breaker state
state CircuitBreakerState
// failures counts consecutive failures
failures int64
// lastFailureTime records when the last failure occurred
lastFailureTime time.Time
// mutex protects shared state
mutex sync.RWMutex
// logger for debugging and monitoring
logger Logger
}
// CircuitBreakerConfig holds configuration parameters for circuit breakers.
// These settings control when the circuit opens and how it recovers.
type CircuitBreakerConfig struct {
// MaxFailures is the number of failures before opening the circuit
MaxFailures int `json:"max_failures"`
// Timeout is how long to wait before trying to recover (open -> half-open)
Timeout time.Duration `json:"timeout"`
// ResetTimeout is how long to wait before fully closing the circuit
ResetTimeout time.Duration `json:"reset_timeout"`
}
// DefaultCircuitBreakerConfig returns sensible default configuration for circuit breakers.
// Configured for typical web service scenarios with moderate tolerance for failures.
func DefaultCircuitBreakerConfig() CircuitBreakerConfig {
return CircuitBreakerConfig{
MaxFailures: 2,
Timeout: 60 * time.Second,
ResetTimeout: 30 * time.Second,
}
}
// NewCircuitBreaker creates a new circuit breaker with the specified configuration.
// The circuit breaker starts in the closed state, allowing all requests through.
func NewCircuitBreaker(config CircuitBreakerConfig, logger Logger, baseRecovery BaseRecoveryMechanism) *CircuitBreaker {
return &CircuitBreaker{
baseRecovery: baseRecovery,
maxFailures: config.MaxFailures,
timeout: config.Timeout,
resetTimeout: config.ResetTimeout,
state: CircuitBreakerClosed,
logger: logger,
}
}
// ExecuteWithContext executes a function through the circuit breaker with context.
// It checks if requests are allowed, executes the function, and updates the circuit state
// based on the result. Implements the ErrorRecoveryMechanism interface.
func (cb *CircuitBreaker) ExecuteWithContext(ctx context.Context, fn func() error) error {
if cb.baseRecovery != nil {
cb.baseRecovery.RecordRequest()
}
if !cb.allowRequest() {
return fmt.Errorf("circuit breaker is open")
}
err := fn()
if err != nil {
cb.recordFailure()
if cb.baseRecovery != nil {
cb.baseRecovery.RecordFailure()
}
return err
}
cb.recordSuccess()
if cb.baseRecovery != nil {
cb.baseRecovery.RecordSuccess()
}
return nil
}
// Execute executes a function through the circuit breaker without context.
// This is provided for backward compatibility with existing code.
func (cb *CircuitBreaker) Execute(fn func() error) error {
return cb.ExecuteWithContext(context.Background(), fn)
}
// allowRequest determines whether to allow a request based on the circuit state.
// Handles state transitions from open to half-open based on timeout.
func (cb *CircuitBreaker) allowRequest() bool {
cb.mutex.Lock()
defer cb.mutex.Unlock()
now := time.Now()
switch cb.state {
case CircuitBreakerClosed:
return true
case CircuitBreakerOpen:
if now.Sub(cb.lastFailureTime) > cb.timeout {
cb.state = CircuitBreakerHalfOpen
if cb.logger != nil {
cb.logger.Infof("Circuit breaker transitioning to half-open state")
}
return true
}
return false
case CircuitBreakerHalfOpen:
return true
default:
return false
}
}
// recordFailure records a failure and potentially opens the circuit.
// Updates failure count and triggers state transitions when thresholds are exceeded.
func (cb *CircuitBreaker) recordFailure() {
cb.mutex.Lock()
defer cb.mutex.Unlock()
cb.failures++
cb.lastFailureTime = time.Now()
switch cb.state {
case CircuitBreakerClosed:
if cb.failures >= int64(cb.maxFailures) {
cb.state = CircuitBreakerOpen
if cb.baseRecovery != nil {
cb.baseRecovery.LogError("Circuit breaker opened after %d failures", cb.failures)
}
}
case CircuitBreakerHalfOpen:
cb.state = CircuitBreakerOpen
if cb.baseRecovery != nil {
cb.baseRecovery.LogError("Circuit breaker returned to open state after failure in half-open")
}
}
}
// recordSuccess records a successful request and potentially closes the circuit.
// Resets failure count and transitions from half-open to closed state on success.
func (cb *CircuitBreaker) recordSuccess() {
cb.mutex.Lock()
defer cb.mutex.Unlock()
switch cb.state {
case CircuitBreakerHalfOpen:
cb.failures = 0
cb.state = CircuitBreakerClosed
if cb.baseRecovery != nil {
cb.baseRecovery.LogInfo("Circuit breaker closed after successful request in half-open state")
}
case CircuitBreakerClosed:
cb.failures = 0
}
}
// GetState returns the current state of the circuit breaker.
// Thread-safe method for monitoring circuit breaker status.
func (cb *CircuitBreaker) GetState() CircuitBreakerState {
cb.mutex.RLock()
defer cb.mutex.RUnlock()
return cb.state
}
// Reset resets the circuit breaker to its initial closed state.
// Clears failure count and state, effectively recovering from any open state.
func (cb *CircuitBreaker) Reset() {
cb.mutex.Lock()
defer cb.mutex.Unlock()
cb.state = CircuitBreakerClosed
atomic.StoreInt64(&cb.failures, 0)
if cb.baseRecovery != nil {
cb.baseRecovery.LogInfo("Circuit breaker has been reset")
}
}
// IsAvailable returns whether the circuit breaker is currently allowing requests.
// This provides a quick way to check if the service is available.
func (cb *CircuitBreaker) IsAvailable() bool {
return cb.allowRequest()
}
// GetMetrics returns comprehensive metrics about the circuit breaker.
// Includes state information, failure counts, configuration, and base metrics.
func (cb *CircuitBreaker) GetMetrics() map[string]interface{} {
cb.mutex.RLock()
state := cb.state
failures := cb.failures
lastFailureTime := cb.lastFailureTime
cb.mutex.RUnlock()
var metrics map[string]interface{}
if cb.baseRecovery != nil {
metrics = cb.baseRecovery.GetBaseMetrics()
} else {
metrics = make(map[string]interface{})
}
metrics["state"] = state.String()
metrics["current_failures"] = failures
metrics["max_failures"] = cb.maxFailures
metrics["timeout"] = cb.timeout.String()
metrics["reset_timeout"] = cb.resetTimeout.String()
if !lastFailureTime.IsZero() {
metrics["last_failure_time"] = lastFailureTime
metrics["time_since_last_failure"] = time.Since(lastFailureTime).String()
}
return metrics
}
// GetFailureCount returns the current failure count
func (cb *CircuitBreaker) GetFailureCount() int64 {
cb.mutex.RLock()
defer cb.mutex.RUnlock()
return cb.failures
}
// GetLastFailureTime returns the time of the last failure
func (cb *CircuitBreaker) GetLastFailureTime() time.Time {
cb.mutex.RLock()
defer cb.mutex.RUnlock()
return cb.lastFailureTime
}
// IsOpen returns true if the circuit breaker is in open state
func (cb *CircuitBreaker) IsOpen() bool {
cb.mutex.RLock()
defer cb.mutex.RUnlock()
return cb.state == CircuitBreakerOpen
}
// IsClosed returns true if the circuit breaker is in closed state
func (cb *CircuitBreaker) IsClosed() bool {
cb.mutex.RLock()
defer cb.mutex.RUnlock()
return cb.state == CircuitBreakerClosed
}
// IsHalfOpen returns true if the circuit breaker is in half-open state
func (cb *CircuitBreaker) IsHalfOpen() bool {
cb.mutex.RLock()
defer cb.mutex.RUnlock()
return cb.state == CircuitBreakerHalfOpen
}
-981
View File
@@ -1,981 +0,0 @@
package circuit_breaker
import (
"context"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
)
// Mock implementations for testing
type mockLogger struct {
infoLogs []string
errorLogs []string
debugLogs []string
mu sync.RWMutex
}
func (m *mockLogger) Infof(format string, args ...interface{}) {
m.mu.Lock()
defer m.mu.Unlock()
m.infoLogs = append(m.infoLogs, fmt.Sprintf(format, args...))
}
func (m *mockLogger) Errorf(format string, args ...interface{}) {
m.mu.Lock()
defer m.mu.Unlock()
m.errorLogs = append(m.errorLogs, fmt.Sprintf(format, args...))
}
func (m *mockLogger) Debugf(format string, args ...interface{}) {
m.mu.Lock()
defer m.mu.Unlock()
m.debugLogs = append(m.debugLogs, fmt.Sprintf(format, args...))
}
func (m *mockLogger) getInfoLogs() []string {
m.mu.RLock()
defer m.mu.RUnlock()
result := make([]string, len(m.infoLogs))
copy(result, m.infoLogs)
return result
}
//lint:ignore U1000 May be needed for future error log verification tests
func (m *mockLogger) getErrorLogs() []string {
m.mu.RLock()
defer m.mu.RUnlock()
result := make([]string, len(m.errorLogs))
copy(result, m.errorLogs)
return result
}
//lint:ignore U1000 May be needed for future test isolation
func (m *mockLogger) reset() {
m.mu.Lock()
defer m.mu.Unlock()
m.infoLogs = nil
m.errorLogs = nil
m.debugLogs = nil
}
type mockBaseRecoveryMechanism struct {
requestCount int64
successCount int64
failureCount int64
infoLogs []string
errorLogs []string
debugLogs []string
baseMetrics map[string]interface{}
mu sync.RWMutex
}
func newMockBaseRecovery() *mockBaseRecoveryMechanism {
return &mockBaseRecoveryMechanism{
baseMetrics: make(map[string]interface{}),
}
}
func (m *mockBaseRecoveryMechanism) RecordRequest() {
atomic.AddInt64(&m.requestCount, 1)
}
func (m *mockBaseRecoveryMechanism) RecordSuccess() {
atomic.AddInt64(&m.successCount, 1)
}
func (m *mockBaseRecoveryMechanism) RecordFailure() {
atomic.AddInt64(&m.failureCount, 1)
}
func (m *mockBaseRecoveryMechanism) GetBaseMetrics() map[string]interface{} {
m.mu.RLock()
defer m.mu.RUnlock()
result := make(map[string]interface{})
for k, v := range m.baseMetrics {
result[k] = v
}
result["total_requests"] = atomic.LoadInt64(&m.requestCount)
result["total_successes"] = atomic.LoadInt64(&m.successCount)
result["total_failures"] = atomic.LoadInt64(&m.failureCount)
return result
}
func (m *mockBaseRecoveryMechanism) LogInfo(format string, args ...interface{}) {
m.mu.Lock()
defer m.mu.Unlock()
m.infoLogs = append(m.infoLogs, fmt.Sprintf(format, args...))
}
func (m *mockBaseRecoveryMechanism) LogError(format string, args ...interface{}) {
m.mu.Lock()
defer m.mu.Unlock()
m.errorLogs = append(m.errorLogs, fmt.Sprintf(format, args...))
}
func (m *mockBaseRecoveryMechanism) LogDebug(format string, args ...interface{}) {
m.mu.Lock()
defer m.mu.Unlock()
m.debugLogs = append(m.debugLogs, fmt.Sprintf(format, args...))
}
func (m *mockBaseRecoveryMechanism) getRequestCount() int64 {
return atomic.LoadInt64(&m.requestCount)
}
func (m *mockBaseRecoveryMechanism) getSuccessCount() int64 {
return atomic.LoadInt64(&m.successCount)
}
func (m *mockBaseRecoveryMechanism) getFailureCount() int64 {
return atomic.LoadInt64(&m.failureCount)
}
func (m *mockBaseRecoveryMechanism) getInfoLogs() []string {
m.mu.RLock()
defer m.mu.RUnlock()
result := make([]string, len(m.infoLogs))
copy(result, m.infoLogs)
return result
}
func (m *mockBaseRecoveryMechanism) getErrorLogs() []string {
m.mu.RLock()
defer m.mu.RUnlock()
result := make([]string, len(m.errorLogs))
copy(result, m.errorLogs)
return result
}
func TestCircuitBreakerState_String(t *testing.T) {
tests := []struct {
state CircuitBreakerState
expected string
}{
{CircuitBreakerClosed, "closed"},
{CircuitBreakerOpen, "open"},
{CircuitBreakerHalfOpen, "half-open"},
{CircuitBreakerState(999), "unknown"},
}
for _, tt := range tests {
t.Run(tt.expected, func(t *testing.T) {
result := tt.state.String()
if result != tt.expected {
t.Errorf("Expected %s, got %s", tt.expected, result)
}
})
}
}
func TestDefaultCircuitBreakerConfig(t *testing.T) {
config := DefaultCircuitBreakerConfig()
if config.MaxFailures != 2 {
t.Errorf("Expected MaxFailures to be 2, got %d", config.MaxFailures)
}
if config.Timeout != 60*time.Second {
t.Errorf("Expected Timeout to be 60s, got %v", config.Timeout)
}
if config.ResetTimeout != 30*time.Second {
t.Errorf("Expected ResetTimeout to be 30s, got %v", config.ResetTimeout)
}
}
func TestNewCircuitBreaker(t *testing.T) {
config := CircuitBreakerConfig{
MaxFailures: 3,
Timeout: 30 * time.Second,
ResetTimeout: 15 * time.Second,
}
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
if cb == nil {
t.Fatal("NewCircuitBreaker returned nil")
}
if cb.maxFailures != 3 {
t.Errorf("Expected maxFailures to be 3, got %d", cb.maxFailures)
}
if cb.timeout != 30*time.Second {
t.Errorf("Expected timeout to be 30s, got %v", cb.timeout)
}
if cb.resetTimeout != 15*time.Second {
t.Errorf("Expected resetTimeout to be 15s, got %v", cb.resetTimeout)
}
if cb.state != CircuitBreakerClosed {
t.Errorf("Expected initial state to be Closed, got %v", cb.state)
}
if cb.logger != logger {
t.Error("Expected logger to be set")
}
if cb.baseRecovery != baseRecovery {
t.Error("Expected baseRecovery to be set")
}
}
func TestCircuitBreaker_ExecuteWithContext_Success(t *testing.T) {
config := CircuitBreakerConfig{
MaxFailures: 2,
Timeout: time.Second,
ResetTimeout: time.Second,
}
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
callCount := 0
testFunc := func() error {
callCount++
return nil
}
ctx := context.Background()
err := cb.ExecuteWithContext(ctx, testFunc)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if callCount != 1 {
t.Errorf("Expected function to be called once, got %d", callCount)
}
if cb.GetState() != CircuitBreakerClosed {
t.Errorf("Expected state to remain Closed, got %v", cb.GetState())
}
if baseRecovery.getRequestCount() != 1 {
t.Errorf("Expected 1 request recorded, got %d", baseRecovery.getRequestCount())
}
if baseRecovery.getSuccessCount() != 1 {
t.Errorf("Expected 1 success recorded, got %d", baseRecovery.getSuccessCount())
}
}
func TestCircuitBreaker_ExecuteWithContext_Failure(t *testing.T) {
config := CircuitBreakerConfig{
MaxFailures: 2,
Timeout: time.Second,
ResetTimeout: time.Second,
}
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
testError := fmt.Errorf("test error")
testFunc := func() error {
return testError
}
ctx := context.Background()
err := cb.ExecuteWithContext(ctx, testFunc)
if err != testError {
t.Errorf("Expected test error, got %v", err)
}
if cb.GetState() != CircuitBreakerClosed {
t.Errorf("Expected state to remain Closed after single failure, got %v", cb.GetState())
}
if baseRecovery.getRequestCount() != 1 {
t.Errorf("Expected 1 request recorded, got %d", baseRecovery.getRequestCount())
}
if baseRecovery.getFailureCount() != 1 {
t.Errorf("Expected 1 failure recorded, got %d", baseRecovery.getFailureCount())
}
}
func TestCircuitBreaker_Execute(t *testing.T) {
config := CircuitBreakerConfig{
MaxFailures: 1,
Timeout: time.Second,
ResetTimeout: time.Second,
}
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
callCount := 0
testFunc := func() error {
callCount++
return nil
}
err := cb.Execute(testFunc)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if callCount != 1 {
t.Errorf("Expected function to be called once, got %d", callCount)
}
}
func TestCircuitBreaker_OpenAfterMaxFailures(t *testing.T) {
config := CircuitBreakerConfig{
MaxFailures: 2,
Timeout: time.Second,
ResetTimeout: time.Second,
}
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
testError := fmt.Errorf("test error")
testFunc := func() error {
return testError
}
ctx := context.Background()
// First failure
err := cb.ExecuteWithContext(ctx, testFunc)
if err != testError {
t.Errorf("Expected test error on first failure, got %v", err)
}
if cb.GetState() != CircuitBreakerClosed {
t.Errorf("Expected state to remain Closed after first failure, got %v", cb.GetState())
}
// Second failure - should open circuit
err = cb.ExecuteWithContext(ctx, testFunc)
if err != testError {
t.Errorf("Expected test error on second failure, got %v", err)
}
if cb.GetState() != CircuitBreakerOpen {
t.Errorf("Expected state to be Open after max failures, got %v", cb.GetState())
}
// Third attempt - should be blocked
callCount := 0
blockedFunc := func() error {
callCount++
return nil
}
err = cb.ExecuteWithContext(ctx, blockedFunc)
if err == nil {
t.Error("Expected error when circuit is open")
}
if callCount != 0 {
t.Errorf("Expected function not to be called when circuit is open, got %d calls", callCount)
}
}
func TestCircuitBreaker_HalfOpenTransition(t *testing.T) {
config := CircuitBreakerConfig{
MaxFailures: 1,
Timeout: 10 * time.Millisecond, // Very short for testing
ResetTimeout: time.Second,
}
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
// Trigger circuit opening
testError := fmt.Errorf("test error")
err := cb.ExecuteWithContext(context.Background(), func() error {
return testError
})
if err != testError {
t.Errorf("Expected test error, got %v", err)
}
if cb.GetState() != CircuitBreakerOpen {
t.Errorf("Expected state to be Open, got %v", cb.GetState())
}
// Wait for timeout
time.Sleep(15 * time.Millisecond)
// Next request should transition to half-open
callCount := 0
testFunc := func() error {
callCount++
return nil
}
err = cb.ExecuteWithContext(context.Background(), testFunc)
if err != nil {
t.Errorf("Expected no error in half-open state, got %v", err)
}
if callCount != 1 {
t.Errorf("Expected function to be called in half-open state, got %d calls", callCount)
}
if cb.GetState() != CircuitBreakerClosed {
t.Errorf("Expected state to be Closed after successful half-open request, got %v", cb.GetState())
}
}
func TestCircuitBreaker_HalfOpenFailureReturnsToOpen(t *testing.T) {
config := CircuitBreakerConfig{
MaxFailures: 1,
Timeout: 10 * time.Millisecond,
ResetTimeout: time.Second,
}
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
// Trigger circuit opening
testError := fmt.Errorf("test error")
_ = cb.ExecuteWithContext(context.Background(), func() error {
return testError
})
if cb.GetState() != CircuitBreakerOpen {
t.Errorf("Expected state to be Open, got %v", cb.GetState())
}
// Wait for timeout to allow half-open transition
time.Sleep(15 * time.Millisecond)
// First call should transition to half-open, but we'll force it by checking allowRequest
if !cb.allowRequest() {
t.Error("Expected allowRequest to return true after timeout")
}
if cb.GetState() != CircuitBreakerHalfOpen {
t.Errorf("Expected state to be HalfOpen, got %v", cb.GetState())
}
// Failure in half-open should return to open
err := cb.ExecuteWithContext(context.Background(), func() error {
return testError
})
if err != testError {
t.Errorf("Expected test error, got %v", err)
}
if cb.GetState() != CircuitBreakerOpen {
t.Errorf("Expected state to return to Open after half-open failure, got %v", cb.GetState())
}
}
func TestCircuitBreaker_Reset(t *testing.T) {
config := CircuitBreakerConfig{
MaxFailures: 1,
Timeout: time.Second,
ResetTimeout: time.Second,
}
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
// Trigger circuit opening
testError := fmt.Errorf("test error")
_ = cb.ExecuteWithContext(context.Background(), func() error {
return testError
})
if cb.GetState() != CircuitBreakerOpen {
t.Errorf("Expected state to be Open, got %v", cb.GetState())
}
// Reset circuit
cb.Reset()
if cb.GetState() != CircuitBreakerClosed {
t.Errorf("Expected state to be Closed after reset, got %v", cb.GetState())
}
if cb.GetFailureCount() != 0 {
t.Errorf("Expected failure count to be 0 after reset, got %d", cb.GetFailureCount())
}
// Should allow requests again
callCount := 0
err := cb.ExecuteWithContext(context.Background(), func() error {
callCount++
return nil
})
if err != nil {
t.Errorf("Expected no error after reset, got %v", err)
}
if callCount != 1 {
t.Errorf("Expected function to be called after reset, got %d calls", callCount)
}
}
func TestCircuitBreaker_IsAvailable(t *testing.T) {
config := CircuitBreakerConfig{
MaxFailures: 1,
Timeout: 10 * time.Millisecond,
ResetTimeout: time.Second,
}
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
// Initially available
if !cb.IsAvailable() {
t.Error("Expected circuit breaker to be available initially")
}
// Trigger opening
testError := fmt.Errorf("test error")
cb.ExecuteWithContext(context.Background(), func() error {
return testError
})
// Should not be available when open
if cb.IsAvailable() {
t.Error("Expected circuit breaker to be unavailable when open")
}
// Wait for timeout
time.Sleep(15 * time.Millisecond)
// Should be available again after timeout (half-open)
if !cb.IsAvailable() {
t.Error("Expected circuit breaker to be available after timeout")
}
}
func TestCircuitBreaker_StateCheckers(t *testing.T) {
config := CircuitBreakerConfig{
MaxFailures: 1,
Timeout: 10 * time.Millisecond,
ResetTimeout: time.Second,
}
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
// Initially closed
if !cb.IsClosed() {
t.Error("Expected circuit breaker to be closed initially")
}
if cb.IsOpen() {
t.Error("Expected circuit breaker not to be open initially")
}
if cb.IsHalfOpen() {
t.Error("Expected circuit breaker not to be half-open initially")
}
// Trigger opening
testError := fmt.Errorf("test error")
cb.ExecuteWithContext(context.Background(), func() error {
return testError
})
// Should be open
if cb.IsClosed() {
t.Error("Expected circuit breaker not to be closed when open")
}
if !cb.IsOpen() {
t.Error("Expected circuit breaker to be open")
}
if cb.IsHalfOpen() {
t.Error("Expected circuit breaker not to be half-open when open")
}
// Wait for timeout and trigger half-open
time.Sleep(15 * time.Millisecond)
cb.allowRequest() // This will transition to half-open
// Should be half-open
if cb.IsClosed() {
t.Error("Expected circuit breaker not to be closed when half-open")
}
if cb.IsOpen() {
t.Error("Expected circuit breaker not to be open when half-open")
}
if !cb.IsHalfOpen() {
t.Error("Expected circuit breaker to be half-open")
}
}
func TestCircuitBreaker_GetMetrics(t *testing.T) {
config := CircuitBreakerConfig{
MaxFailures: 2,
Timeout: 30 * time.Second,
ResetTimeout: 15 * time.Second,
}
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
baseRecovery.baseMetrics["custom_metric"] = "custom_value"
cb := NewCircuitBreaker(config, logger, baseRecovery)
// Record some activity
testError := fmt.Errorf("test error")
cb.ExecuteWithContext(context.Background(), func() error {
return testError
})
metrics := cb.GetMetrics()
// Check circuit breaker specific metrics
if metrics["state"] != "closed" {
t.Errorf("Expected state to be 'closed', got %v", metrics["state"])
}
if metrics["current_failures"] != int64(1) {
t.Errorf("Expected current_failures to be 1, got %v", metrics["current_failures"])
}
if metrics["max_failures"] != 2 {
t.Errorf("Expected max_failures to be 2, got %v", metrics["max_failures"])
}
if metrics["timeout"] != "30s" {
t.Errorf("Expected timeout to be '30s', got %v", metrics["timeout"])
}
if metrics["reset_timeout"] != "15s" {
t.Errorf("Expected reset_timeout to be '15s', got %v", metrics["reset_timeout"])
}
// Check base metrics are included
if metrics["total_requests"] != int64(1) {
t.Errorf("Expected total_requests to be 1, got %v", metrics["total_requests"])
}
if metrics["custom_metric"] != "custom_value" {
t.Errorf("Expected custom_metric to be 'custom_value', got %v", metrics["custom_metric"])
}
// Check failure time metrics
if _, exists := metrics["last_failure_time"]; !exists {
t.Error("Expected last_failure_time to exist")
}
if _, exists := metrics["time_since_last_failure"]; !exists {
t.Error("Expected time_since_last_failure to exist")
}
}
func TestCircuitBreaker_GetMetrics_NoBaseRecovery(t *testing.T) {
config := DefaultCircuitBreakerConfig()
logger := &mockLogger{}
cb := NewCircuitBreaker(config, logger, nil)
metrics := cb.GetMetrics()
// Should still have circuit breaker metrics
if metrics["state"] != "closed" {
t.Errorf("Expected state to be 'closed', got %v", metrics["state"])
}
if metrics["max_failures"] != 2 {
t.Errorf("Expected max_failures to be 2, got %v", metrics["max_failures"])
}
// Should not have base metrics
if _, exists := metrics["total_requests"]; exists {
t.Error("Expected total_requests not to exist without base recovery")
}
}
func TestCircuitBreaker_GetLastFailureTime(t *testing.T) {
config := DefaultCircuitBreakerConfig()
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
// Initially should be zero
if !cb.GetLastFailureTime().IsZero() {
t.Error("Expected last failure time to be zero initially")
}
// Record a failure
before := time.Now()
testError := fmt.Errorf("test error")
cb.ExecuteWithContext(context.Background(), func() error {
return testError
})
after := time.Now()
lastFailure := cb.GetLastFailureTime()
if lastFailure.IsZero() {
t.Error("Expected last failure time to be set after failure")
}
if lastFailure.Before(before) || lastFailure.After(after) {
t.Errorf("Expected last failure time to be between %v and %v, got %v",
before, after, lastFailure)
}
}
func TestCircuitBreaker_ExecuteWithoutBaseRecovery(t *testing.T) {
config := DefaultCircuitBreakerConfig()
logger := &mockLogger{}
cb := NewCircuitBreaker(config, logger, nil)
callCount := 0
testFunc := func() error {
callCount++
return nil
}
err := cb.ExecuteWithContext(context.Background(), testFunc)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if callCount != 1 {
t.Errorf("Expected function to be called once, got %d", callCount)
}
// Should work fine without base recovery
if cb.GetState() != CircuitBreakerClosed {
t.Errorf("Expected state to be Closed, got %v", cb.GetState())
}
}
func TestCircuitBreaker_ConcurrentAccess(t *testing.T) {
config := CircuitBreakerConfig{
MaxFailures: 10, // Higher threshold for concurrent test
Timeout: 100 * time.Millisecond,
ResetTimeout: 50 * time.Millisecond,
}
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
const numGoroutines = 10
const numOperations = 50
var wg sync.WaitGroup
successCount := int64(0)
errorCount := int64(0)
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
for j := 0; j < numOperations; j++ {
err := cb.ExecuteWithContext(context.Background(), func() error {
// Simulate some failures
if j%10 == 9 { // Every 10th operation fails
return fmt.Errorf("simulated error")
}
return nil
})
if err != nil {
atomic.AddInt64(&errorCount, 1)
} else {
atomic.AddInt64(&successCount, 1)
}
// Intermittently check state and metrics
if j%5 == 0 {
cb.GetState()
cb.GetMetrics()
cb.IsAvailable()
}
}
}(i)
}
wg.Wait()
// Verify we got both successes and errors
finalSuccessCount := atomic.LoadInt64(&successCount)
finalErrorCount := atomic.LoadInt64(&errorCount)
if finalSuccessCount == 0 {
t.Error("Expected some successful operations")
}
if finalErrorCount == 0 {
t.Error("Expected some failed operations")
}
totalOperations := finalSuccessCount + finalErrorCount
expectedMax := int64(numGoroutines * numOperations)
if totalOperations > expectedMax {
t.Errorf("Expected at most %d operations, got %d", expectedMax, totalOperations)
}
t.Logf("Concurrent test completed: %d successes, %d errors, final state: %v",
finalSuccessCount, finalErrorCount, cb.GetState())
}
func TestCircuitBreaker_StateTransitionLogging(t *testing.T) {
config := CircuitBreakerConfig{
MaxFailures: 1,
Timeout: 10 * time.Millisecond,
ResetTimeout: time.Second,
}
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
// Trigger circuit opening
testError := fmt.Errorf("test error")
cb.ExecuteWithContext(context.Background(), func() error {
return testError
})
// Check that error was logged when circuit opened
errorLogs := baseRecovery.getErrorLogs()
if len(errorLogs) == 0 {
t.Error("Expected error log when circuit breaker opened")
} else {
if !contains(errorLogs, "Circuit breaker opened after") {
t.Errorf("Expected circuit opening log, got %v", errorLogs)
}
}
// Wait and trigger half-open
time.Sleep(15 * time.Millisecond)
// Successful request should close circuit and log
cb.ExecuteWithContext(context.Background(), func() error {
return nil
})
// Check that success was logged when circuit closed
infoLogs := baseRecovery.getInfoLogs()
if len(infoLogs) == 0 {
t.Error("Expected info log when circuit breaker closed")
} else {
if !contains(infoLogs, "Circuit breaker closed after successful request") {
t.Errorf("Expected circuit closing log, got %v", infoLogs)
}
}
// Reset should also be logged
cb.Reset()
infoLogs = baseRecovery.getInfoLogs()
if !contains(infoLogs, "Circuit breaker has been reset") {
t.Errorf("Expected reset log, got %v", infoLogs)
}
}
func TestCircuitBreaker_LoggerTransitionLogging(t *testing.T) {
config := CircuitBreakerConfig{
MaxFailures: 1,
Timeout: 10 * time.Millisecond,
ResetTimeout: time.Second,
}
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
// Wait for timeout and check half-open transition logging
testError := fmt.Errorf("test error")
cb.ExecuteWithContext(context.Background(), func() error {
return testError
})
// Wait for timeout
time.Sleep(15 * time.Millisecond)
// Next allowRequest call should log transition to half-open
cb.allowRequest()
infoLogs := logger.getInfoLogs()
if len(infoLogs) == 0 {
t.Error("Expected info log for half-open transition")
} else {
if !contains(infoLogs, "Circuit breaker transitioning to half-open state") {
t.Errorf("Expected half-open transition log, got %v", infoLogs)
}
}
}
// Helper function to check if a slice contains a string with substring
func contains(slice []string, substr string) bool {
for _, s := range slice {
if len(s) >= len(substr) && s[:len(substr)] == substr {
return true
}
}
return false
}
// Benchmark tests
func BenchmarkCircuitBreaker_ExecuteWithContext_Success(b *testing.B) {
config := DefaultCircuitBreakerConfig()
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
testFunc := func() error {
return nil
}
ctx := context.Background()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
cb.ExecuteWithContext(ctx, testFunc)
}
})
}
func BenchmarkCircuitBreaker_ExecuteWithContext_Failure(b *testing.B) {
config := CircuitBreakerConfig{
MaxFailures: 1000, // High threshold to avoid opening during benchmark
Timeout: time.Second,
ResetTimeout: time.Second,
}
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
testError := fmt.Errorf("test error")
testFunc := func() error {
return testError
}
ctx := context.Background()
b.ResetTimer()
for i := 0; i < b.N; i++ {
cb.ExecuteWithContext(ctx, testFunc)
}
}
func BenchmarkCircuitBreaker_GetState(b *testing.B) {
config := DefaultCircuitBreakerConfig()
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
cb.GetState()
}
})
}
func BenchmarkCircuitBreaker_GetMetrics(b *testing.B) {
config := DefaultCircuitBreakerConfig()
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
// Add some activity
for i := 0; i < 100; i++ {
if i%2 == 0 {
cb.ExecuteWithContext(context.Background(), func() error { return nil })
} else {
cb.ExecuteWithContext(context.Background(), func() error { return fmt.Errorf("error") })
}
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
cb.GetMetrics()
}
}
-258
View File
@@ -1,258 +0,0 @@
// Package config provides backward compatibility for legacy configuration
package config
import (
"fmt"
"time"
"github.com/lukaszraczylo/traefikoidc/internal/compat"
"github.com/lukaszraczylo/traefikoidc/internal/features"
)
// LegacyAdapter provides backward compatibility for old Config struct
type LegacyAdapter struct {
unified *UnifiedConfig
adapter *compat.ConfigAdapter
}
// NewLegacyAdapter creates a new legacy adapter from unified config
func NewLegacyAdapter(unified *UnifiedConfig) *LegacyAdapter {
adapter := compat.NewConfigAdapter(unified)
// Register getters for commonly used fields
adapter.RegisterGetter("ProviderURL", func() interface{} {
return unified.Provider.IssuerURL
})
adapter.RegisterGetter("ClientID", func() interface{} {
return unified.Provider.ClientID
})
adapter.RegisterGetter("ClientSecret", func() interface{} {
return unified.Provider.ClientSecret
})
adapter.RegisterGetter("CallbackURL", func() interface{} {
return unified.Provider.RedirectURL
})
adapter.RegisterGetter("LogoutURL", func() interface{} {
return unified.Provider.LogoutURL
})
adapter.RegisterGetter("PostLogoutRedirectURI", func() interface{} {
return unified.Provider.PostLogoutRedirectURI
})
adapter.RegisterGetter("SessionEncryptionKey", func() interface{} {
return unified.Session.EncryptionKey
})
adapter.RegisterGetter("ForceHTTPS", func() interface{} {
return unified.Security.ForceHTTPS
})
adapter.RegisterGetter("LogLevel", func() interface{} {
return unified.Logging.Level
})
adapter.RegisterGetter("Scopes", func() interface{} {
return unified.Provider.Scopes
})
adapter.RegisterGetter("OverrideScopes", func() interface{} {
return unified.Provider.OverrideScopes
})
adapter.RegisterGetter("AllowedUsers", func() interface{} {
return unified.Security.AllowedUsers
})
adapter.RegisterGetter("AllowedUserDomains", func() interface{} {
return unified.Security.AllowedUserDomains
})
adapter.RegisterGetter("AllowedRolesAndGroups", func() interface{} {
return unified.Security.AllowedRolesAndGroups
})
adapter.RegisterGetter("ExcludedURLs", func() interface{} {
return unified.Security.ExcludedURLs
})
adapter.RegisterGetter("EnablePKCE", func() interface{} {
return unified.Security.EnablePKCE
})
adapter.RegisterGetter("RateLimit", func() interface{} {
return unified.RateLimit.RequestsPerSecond
})
adapter.RegisterGetter("RefreshGracePeriodSeconds", func() interface{} {
return int(unified.Token.RefreshGracePeriod.Seconds())
})
adapter.RegisterGetter("CookieDomain", func() interface{} {
return unified.Session.Domain
})
adapter.RegisterGetter("SecurityHeaders", func() interface{} {
return unified.Security.Headers
})
return &LegacyAdapter{
unified: unified,
adapter: adapter,
}
}
// ToOldConfig converts unified config to old Config struct format
func (la *LegacyAdapter) ToOldConfig() *Config {
// Use feature flags to determine behavior
if !features.IsUnifiedConfigEnabled() {
// Return existing Config if unified config not enabled
return CreateConfig()
}
cfg := &Config{
ProviderURL: la.unified.Provider.IssuerURL,
ClientID: la.unified.Provider.ClientID,
ClientSecret: la.unified.Provider.ClientSecret,
CallbackURL: la.unified.Provider.RedirectURL,
LogoutURL: la.unified.Provider.LogoutURL,
PostLogoutRedirectURI: la.unified.Provider.PostLogoutRedirectURI,
SessionEncryptionKey: la.unified.Session.EncryptionKey,
ForceHTTPS: la.unified.Security.ForceHTTPS,
LogLevel: la.unified.Logging.Level,
Scopes: la.unified.Provider.Scopes,
OverrideScopes: la.unified.Provider.OverrideScopes,
AllowedUsers: la.unified.Security.AllowedUsers,
AllowedUserDomains: la.unified.Security.AllowedUserDomains,
AllowedRolesAndGroups: la.unified.Security.AllowedRolesAndGroups,
ExcludedURLs: la.unified.Security.ExcludedURLs,
EnablePKCE: la.unified.Security.EnablePKCE,
RateLimit: la.unified.RateLimit.RequestsPerSecond,
RefreshGracePeriodSeconds: int(la.unified.Token.RefreshGracePeriod.Seconds()),
Headers: la.convertHeaders(),
CookieDomain: la.unified.Session.Domain,
SecurityHeaders: la.unified.Security.Headers,
}
return cfg
}
// convertHeaders converts unified header config to old format
func (la *LegacyAdapter) convertHeaders() []HeaderConfig {
headers := make([]HeaderConfig, 0)
for name, value := range la.unified.Middleware.CustomHeaders {
headers = append(headers, HeaderConfig{
Name: name,
Value: value,
})
}
return headers
}
// FromOldConfig creates unified config from old Config struct
func FromOldConfig(old *Config) *UnifiedConfig {
unified := NewUnifiedConfig()
// Map provider settings
unified.Provider.IssuerURL = old.ProviderURL
unified.Provider.ClientID = old.ClientID
unified.Provider.ClientSecret = old.ClientSecret
unified.Provider.RedirectURL = old.CallbackURL
unified.Provider.LogoutURL = old.LogoutURL
unified.Provider.PostLogoutRedirectURI = old.PostLogoutRedirectURI
unified.Provider.Scopes = old.Scopes
unified.Provider.OverrideScopes = old.OverrideScopes
// Map session settings
unified.Session.EncryptionKey = old.SessionEncryptionKey
unified.Session.Domain = old.CookieDomain
// Map security settings
unified.Security.ForceHTTPS = old.ForceHTTPS
unified.Security.EnablePKCE = old.EnablePKCE
unified.Security.AllowedUsers = old.AllowedUsers
unified.Security.AllowedUserDomains = old.AllowedUserDomains
unified.Security.AllowedRolesAndGroups = old.AllowedRolesAndGroups
unified.Security.ExcludedURLs = old.ExcludedURLs
unified.Security.Headers = old.SecurityHeaders
// Map rate limiting
unified.RateLimit.RequestsPerSecond = old.RateLimit
unified.RateLimit.Enabled = old.RateLimit > 0
// Map token settings
unified.Token.RefreshGracePeriod = timeSecondsToDuration(old.RefreshGracePeriodSeconds)
// Map logging
unified.Logging.Level = old.LogLevel
// Map custom headers
if len(old.Headers) > 0 {
unified.Middleware.CustomHeaders = make(map[string]string)
for _, header := range old.Headers {
unified.Middleware.CustomHeaders[header.Name] = header.Value
}
}
// Store original config in legacy field for reference
unified.Legacy["original"] = old
return unified
}
// timeSecondsToDuration converts seconds to time.Duration
func timeSecondsToDuration(seconds int) time.Duration {
return time.Duration(seconds) * time.Second
}
// GetConfigInterface returns appropriate config based on feature flag
func GetConfigInterface() interface{} {
if features.IsUnifiedConfigEnabled() {
return NewUnifiedConfig()
}
return CreateConfig()
}
// ValidateConfig validates config based on feature flag
func ValidateConfig(cfg interface{}) error {
if features.IsUnifiedConfigEnabled() {
if unified, ok := cfg.(*UnifiedConfig); ok {
return unified.Validate()
}
}
// Fall back to old validation if available
if old, ok := cfg.(*Config); ok {
return old.Validate()
}
return nil
}
// Add Validate method to old Config for compatibility
func (c *Config) Validate() error {
var errors ValidationErrors
// Basic validation for old config
if c.ProviderURL == "" {
errors = append(errors, ValidationError{
Field: "ProviderURL",
Message: "provider URL is required",
})
}
if c.ClientID == "" {
errors = append(errors, ValidationError{
Field: "ClientID",
Message: "client ID is required",
})
}
if c.ClientSecret == "" && !c.EnablePKCE {
errors = append(errors, ValidationError{
Field: "ClientSecret",
Message: "client secret is required (or enable PKCE)",
})
}
if c.SessionEncryptionKey != "" && len(c.SessionEncryptionKey) < minEncryptionKeyLength {
errors = append(errors, ValidationError{
Field: "SessionEncryptionKey",
Message: fmt.Sprintf("encryption key must be at least %d characters", minEncryptionKeyLength),
Value: len(c.SessionEncryptionKey),
})
}
if len(errors) > 0 {
return errors
}
return nil
}
-363
View File
@@ -1,363 +0,0 @@
//go:build !yaegi
package config
import (
"testing"
"github.com/lukaszraczylo/traefikoidc/internal/features"
)
// NewLegacyAdapter Tests
func TestNewLegacyAdapter(t *testing.T) {
unified := NewUnifiedConfig()
unified.Provider.IssuerURL = "https://provider.example.com"
unified.Provider.ClientID = "test-client"
unified.Provider.ClientSecret = "test-secret"
adapter := NewLegacyAdapter(unified)
if adapter == nil {
t.Fatal("Expected NewLegacyAdapter to return non-nil")
}
if adapter.unified != unified {
t.Error("Expected adapter to reference the unified config")
}
if adapter.adapter == nil {
t.Error("Expected internal adapter to be initialized")
}
}
// ToOldConfig Tests
func TestLegacyAdapter_ToOldConfig(t *testing.T) {
unified := NewUnifiedConfig()
unified.Provider.IssuerURL = "https://issuer.example.com"
unified.Provider.ClientID = "client-123"
unified.Provider.ClientSecret = "secret-456"
unified.Provider.RedirectURL = "https://app.example.com/callback"
unified.Provider.LogoutURL = "/logout"
unified.Provider.PostLogoutRedirectURI = "https://app.example.com"
unified.Provider.Scopes = []string{"openid", "profile"}
unified.Provider.OverrideScopes = true
unified.Session.EncryptionKey = "test-encryption-key-32-chars!!"
unified.Session.Domain = "example.com"
unified.Security.ForceHTTPS = true
unified.Security.EnablePKCE = true
unified.Security.AllowedUsers = []string{"user@example.com"}
unified.Security.AllowedUserDomains = []string{"example.com"}
unified.Security.AllowedRolesAndGroups = []string{"admin"}
unified.Security.ExcludedURLs = []string{"/health"}
unified.RateLimit.RequestsPerSecond = 100
unified.Logging.Level = "debug"
unified.Middleware.CustomHeaders = map[string]string{
"X-Header-1": "value1",
"X-Header-2": "value2",
}
adapter := NewLegacyAdapter(unified)
oldConfig := adapter.ToOldConfig()
if oldConfig == nil {
t.Fatal("Expected ToOldConfig to return non-nil")
}
// ToOldConfig behavior depends on feature flag
if !features.IsUnifiedConfigEnabled() {
// When feature is disabled, returns default config
if oldConfig.ProviderURL == "" {
t.Log("Feature flag disabled - ToOldConfig returns default config")
}
return
}
// When feature is enabled, verify all fields were correctly mapped
if oldConfig.ProviderURL != unified.Provider.IssuerURL {
t.Errorf("Expected ProviderURL '%s', got '%s'", unified.Provider.IssuerURL, oldConfig.ProviderURL)
}
if oldConfig.ClientID != unified.Provider.ClientID {
t.Errorf("Expected ClientID '%s', got '%s'", unified.Provider.ClientID, oldConfig.ClientID)
}
if oldConfig.ClientSecret != unified.Provider.ClientSecret {
t.Errorf("Expected ClientSecret '%s', got '%s'", unified.Provider.ClientSecret, oldConfig.ClientSecret)
}
if oldConfig.CallbackURL != unified.Provider.RedirectURL {
t.Error("Expected CallbackURL to match RedirectURL")
}
if oldConfig.LogoutURL != unified.Provider.LogoutURL {
t.Error("Expected LogoutURL to match")
}
if oldConfig.ForceHTTPS != unified.Security.ForceHTTPS {
t.Error("Expected ForceHTTPS to match")
}
if oldConfig.EnablePKCE != unified.Security.EnablePKCE {
t.Error("Expected EnablePKCE to match")
}
if oldConfig.RateLimit != unified.RateLimit.RequestsPerSecond {
t.Errorf("Expected RateLimit %d, got %d", unified.RateLimit.RequestsPerSecond, oldConfig.RateLimit)
}
if len(oldConfig.Headers) != 2 {
t.Errorf("Expected 2 headers, got %d", len(oldConfig.Headers))
}
}
// convertHeaders Tests
func TestLegacyAdapter_convertHeaders(t *testing.T) {
unified := NewUnifiedConfig()
unified.Middleware.CustomHeaders = map[string]string{
"X-Custom-Header-1": "value1",
"X-Custom-Header-2": "value2",
"X-Custom-Header-3": "value3",
}
adapter := NewLegacyAdapter(unified)
headers := adapter.convertHeaders()
if len(headers) != 3 {
t.Errorf("Expected 3 headers, got %d", len(headers))
}
// Check that headers were converted
headerMap := make(map[string]string)
for _, h := range headers {
headerMap[h.Name] = h.Value
}
if headerMap["X-Custom-Header-1"] != "value1" {
t.Error("Expected X-Custom-Header-1 to have value 'value1'")
}
if headerMap["X-Custom-Header-2"] != "value2" {
t.Error("Expected X-Custom-Header-2 to have value 'value2'")
}
}
func TestLegacyAdapter_convertHeaders_Empty(t *testing.T) {
unified := NewUnifiedConfig()
// No custom headers
adapter := NewLegacyAdapter(unified)
headers := adapter.convertHeaders()
if len(headers) != 0 {
t.Errorf("Expected 0 headers, got %d", len(headers))
}
}
// GetConfigInterface Tests
func TestGetConfigInterface(t *testing.T) {
cfg := GetConfigInterface()
if cfg == nil {
t.Fatal("Expected GetConfigInterface to return non-nil")
}
// Should return either UnifiedConfig or Config depending on feature flag
_, isUnified := cfg.(*UnifiedConfig)
_, isOld := cfg.(*Config)
if !isUnified && !isOld {
t.Error("Expected either *UnifiedConfig or *Config")
}
// Verify consistency with feature flag
if features.IsUnifiedConfigEnabled() {
if !isUnified {
t.Error("Expected *UnifiedConfig when unified config is enabled")
}
} else {
if !isOld {
t.Error("Expected *Config when unified config is disabled")
}
}
}
// ValidateConfig Tests
func TestValidateConfig_UnifiedConfig(t *testing.T) {
unified := NewUnifiedConfig()
unified.Provider.IssuerURL = "https://provider.example.com"
unified.Provider.ClientID = "client-id"
unified.Provider.ClientSecret = "client-secret"
unified.Session.EncryptionKey = "encryption-key-32-characters!!"
err := ValidateConfig(unified)
// Should succeed regardless of feature flag since we're passing the right type
if err != nil {
t.Errorf("Expected valid unified config to pass validation, got: %v", err)
}
}
func TestValidateConfig_OldConfig(t *testing.T) {
old := CreateConfig()
old.ProviderURL = "https://provider.example.com"
old.ClientID = "client-id"
old.ClientSecret = "client-secret"
old.SessionEncryptionKey = "encryption-key-32-characters!!"
err := ValidateConfig(old)
if err != nil {
t.Errorf("Expected valid old config to pass validation, got: %v", err)
}
}
func TestValidateConfig_InvalidType(t *testing.T) {
// Pass something that's not a config
err := ValidateConfig("not a config")
if err != nil {
t.Errorf("Expected nil for unknown type, got: %v", err)
}
}
// Config.Validate Tests
func TestConfig_Validate_Valid(t *testing.T) {
cfg := CreateConfig()
cfg.ProviderURL = "https://provider.example.com"
cfg.ClientID = "client-id"
cfg.ClientSecret = "client-secret"
cfg.SessionEncryptionKey = "encryption-key-32-characters!!"
err := cfg.Validate()
if err != nil {
t.Errorf("Expected valid config to pass, got: %v", err)
}
}
func TestConfig_Validate_MissingProviderURL(t *testing.T) {
cfg := CreateConfig()
cfg.ClientID = "client-id"
cfg.ClientSecret = "client-secret"
err := cfg.Validate()
if err == nil {
t.Error("Expected error for missing ProviderURL")
}
// Check if it's a ValidationErrors type
if verrs, ok := err.(ValidationErrors); ok {
found := false
for _, verr := range verrs {
if verr.Field == "ProviderURL" {
found = true
break
}
}
if !found {
t.Error("Expected ProviderURL validation error")
}
}
}
func TestConfig_Validate_MissingClientID(t *testing.T) {
cfg := CreateConfig()
cfg.ProviderURL = "https://provider.example.com"
cfg.ClientSecret = "client-secret"
err := cfg.Validate()
if err == nil {
t.Error("Expected error for missing ClientID")
}
if verrs, ok := err.(ValidationErrors); ok {
found := false
for _, verr := range verrs {
if verr.Field == "ClientID" {
found = true
break
}
}
if !found {
t.Error("Expected ClientID validation error")
}
}
}
func TestConfig_Validate_MissingClientSecret_NoPKCE(t *testing.T) {
cfg := CreateConfig()
cfg.ProviderURL = "https://provider.example.com"
cfg.ClientID = "client-id"
cfg.EnablePKCE = false
err := cfg.Validate()
if err == nil {
t.Error("Expected error for missing ClientSecret without PKCE")
}
if verrs, ok := err.(ValidationErrors); ok {
found := false
for _, verr := range verrs {
if verr.Field == "ClientSecret" {
found = true
break
}
}
if !found {
t.Error("Expected ClientSecret validation error")
}
}
}
func TestConfig_Validate_MissingClientSecret_WithPKCE(t *testing.T) {
cfg := CreateConfig()
cfg.ProviderURL = "https://provider.example.com"
cfg.ClientID = "client-id"
cfg.EnablePKCE = true // PKCE enabled, so ClientSecret not required
err := cfg.Validate()
if err != nil {
t.Errorf("Expected no error with PKCE enabled and no ClientSecret, got: %v", err)
}
}
func TestConfig_Validate_ShortEncryptionKey(t *testing.T) {
cfg := CreateConfig()
cfg.ProviderURL = "https://provider.example.com"
cfg.ClientID = "client-id"
cfg.ClientSecret = "client-secret"
cfg.SessionEncryptionKey = "short" // Too short
err := cfg.Validate()
if err == nil {
t.Error("Expected error for short encryption key")
}
if verrs, ok := err.(ValidationErrors); ok {
found := false
for _, verr := range verrs {
if verr.Field == "SessionEncryptionKey" {
found = true
break
}
}
if !found {
t.Error("Expected SessionEncryptionKey validation error")
}
}
}
func TestConfig_Validate_MultipleErrors(t *testing.T) {
cfg := CreateConfig()
// Missing ProviderURL, ClientID, and ClientSecret
err := cfg.Validate()
if err == nil {
t.Fatal("Expected validation errors")
}
verrs, ok := err.(ValidationErrors)
if !ok {
t.Fatal("Expected ValidationErrors type")
}
if len(verrs) < 2 {
t.Errorf("Expected at least 2 validation errors, got %d", len(verrs))
}
}
File diff suppressed because it is too large Load Diff
-276
View File
@@ -1,276 +0,0 @@
// Package config provides default values and initialization for unified configuration
package config
import (
"time"
)
// NewUnifiedConfig creates a new unified configuration with sensible defaults
func NewUnifiedConfig() *UnifiedConfig {
return &UnifiedConfig{
Provider: DefaultProviderConfig(),
Session: DefaultSessionConfig(),
Token: DefaultTokenConfig(),
Redis: *DefaultRedisConfig(), // Using existing DefaultRedisConfig
Security: DefaultSecurityConfig(),
Middleware: DefaultMiddlewareConfig(),
Cache: DefaultCacheConfig(),
RateLimit: DefaultRateLimitConfig(),
Logging: DefaultLoggingConfig(),
Metrics: DefaultMetricsConfig(),
Health: DefaultHealthConfig(),
Transport: DefaultTransportConfig(),
Pool: DefaultPoolConfig(),
Circuit: DefaultCircuitConfig(),
Legacy: make(map[string]interface{}),
}
}
// DefaultProviderConfig returns default provider configuration
func DefaultProviderConfig() ProviderConfig {
return ProviderConfig{
Scopes: []string{"openid", "profile", "email"},
OverrideScopes: false,
CustomClaims: make(map[string]string),
JWKCachePeriod: 24 * time.Hour,
MetadataCacheTTL: 24 * time.Hour,
Discovery: true,
}
}
// DefaultSessionConfig returns default session configuration
func DefaultSessionConfig() SessionConfig {
return SessionConfig{
Name: "oidc_session",
MaxAge: 86400, // 24 hours
ChunkSize: 4000, // Safe size for cookies
MaxChunks: 5,
Path: "/",
Secure: true,
HttpOnly: true,
SameSite: "Lax",
StorageType: "cookie",
CleanupInterval: 1 * time.Hour,
}
}
// DefaultTokenConfig returns default token configuration
func DefaultTokenConfig() TokenConfig {
return TokenConfig{
AccessTokenTTL: 1 * time.Hour,
RefreshTokenTTL: 24 * time.Hour,
RefreshGracePeriod: 60 * time.Second,
ValidationMode: "jwt",
CacheEnabled: true,
CacheTTL: 5 * time.Minute,
CacheNegativeTTL: 30 * time.Second,
ValidateSignature: true,
ValidateExpiry: true,
ValidateAudience: true,
ValidateIssuer: true,
RequiredClaims: []string{"sub", "iat", "exp"},
ClockSkew: 5 * time.Minute,
}
}
// DefaultSecurityConfig returns default security configuration
func DefaultSecurityConfig() SecurityConfig {
return SecurityConfig{
ForceHTTPS: true,
EnablePKCE: true,
AllowedUsers: []string{},
AllowedUserDomains: []string{},
AllowedRolesAndGroups: []string{},
ExcludedURLs: []string{
"/favicon.ico",
"/robots.txt",
"/health",
"/.well-known/",
"/metrics",
"/ping",
"/static/",
"/assets/",
"/js/",
"/css/",
"/images/",
"/fonts/",
},
Headers: createDefaultSecurityConfig(),
CSRFProtection: true,
CSRFTokenName: "csrf_token",
CSRFTokenTTL: 1 * time.Hour,
MaxLoginAttempts: 5,
LockoutDuration: 15 * time.Minute,
RequireMFA: false,
}
}
// DefaultMiddlewareConfig returns default middleware configuration
func DefaultMiddlewareConfig() MiddlewareConfig {
return MiddlewareConfig{
Priority: 1000,
SkipPaths: []string{},
RequirePaths: []string{},
PassthroughMode: false,
MaxRequestSize: 10 * 1024 * 1024, // 10MB
RequestTimeout: 30 * time.Second,
IdleTimeout: 90 * time.Second,
CustomHeaders: make(map[string]string),
RemoveHeaders: []string{},
}
}
// DefaultCacheConfig returns default cache configuration
func DefaultCacheConfig() CacheConfig {
return CacheConfig{
Enabled: true,
Type: "memory",
DefaultTTL: 5 * time.Minute,
MaxEntries: 10000,
MaxEntrySize: 1024 * 1024, // 1MB
EvictionPolicy: "lru",
CleanupInterval: 10 * time.Minute,
Namespace: "traefikoidc",
Compression: false,
Serialization: "json",
}
}
// DefaultRateLimitConfig returns default rate limiting configuration
func DefaultRateLimitConfig() RateLimitConfig {
return RateLimitConfig{
Enabled: false,
RequestsPerSecond: 10,
Burst: 20,
StorageType: "memory",
WindowDuration: 1 * time.Minute,
KeyType: "ip",
CustomKeyFunc: "",
WhitelistIPs: []string{},
WhitelistUsers: []string{},
}
}
// DefaultLoggingConfig returns default logging configuration
func DefaultLoggingConfig() LoggingConfig {
return LoggingConfig{
Level: "info",
Format: "json",
Output: "stdout",
FilePath: "",
FilterSensitive: true,
MaskFields: []string{
"password",
"secret",
"token",
"key",
"authorization",
"cookie",
},
BufferSize: 8192,
FlushInterval: 5 * time.Second,
AuditEnabled: false,
AuditEvents: []string{
"login",
"logout",
"token_refresh",
"auth_failure",
},
}
}
// DefaultMetricsConfig returns default metrics configuration
func DefaultMetricsConfig() MetricsConfig {
return MetricsConfig{
Enabled: false,
Provider: "prometheus",
Endpoint: "/metrics",
Namespace: "traefikoidc",
Subsystem: "middleware",
CollectInterval: 10 * time.Second,
Histograms: true,
Labels: make(map[string]string),
}
}
// DefaultHealthConfig returns default health check configuration
func DefaultHealthConfig() HealthConfig {
return HealthConfig{
Enabled: true,
Path: "/health",
CheckInterval: 30 * time.Second,
Timeout: 5 * time.Second,
CheckProvider: true,
CheckRedis: true,
CheckCache: true,
MaxLatency: 1 * time.Second,
MinMemory: 100 * 1024 * 1024, // 100MB
}
}
// DefaultTransportConfig returns default HTTP transport configuration
func DefaultTransportConfig() TransportConfig {
return TransportConfig{
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
MaxConnsPerHost: 0, // No limit
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
ResponseHeaderTimeout: 10 * time.Second,
DisableKeepAlives: false,
DisableCompression: false,
TLSInsecureSkipVerify: false,
TLSMinVersion: "TLS1.2",
TLSCipherSuites: []string{},
ProxyURL: "",
NoProxy: []string{},
}
}
// DefaultPoolConfig returns default connection pool configuration
func DefaultPoolConfig() PoolConfig {
return PoolConfig{
Enabled: true,
Size: 10,
MinSize: 2,
MaxSize: 50,
MaxAge: 30 * time.Minute,
IdleTimeout: 5 * time.Minute,
WaitTimeout: 5 * time.Second,
HealthCheckInterval: 30 * time.Second,
MaxRetries: 3,
}
}
// DefaultCircuitConfig returns default circuit breaker configuration
func DefaultCircuitConfig() CircuitConfig {
return CircuitConfig{
Enabled: true,
MaxRequests: 100,
Interval: 10 * time.Second,
Timeout: 60 * time.Second,
ConsecutiveFailures: 5,
FailureRatio: 0.5,
OnOpen: "reject",
OnHalfOpen: "passthrough",
MetricsEnabled: true,
LogStateChanges: true,
}
}
// MergeWithDefaults merges a partial configuration with defaults
func MergeWithDefaults(partial *UnifiedConfig) *UnifiedConfig {
if partial == nil {
return NewUnifiedConfig()
}
// Ensure Legacy field is initialized
if partial.Legacy == nil {
partial.Legacy = make(map[string]interface{})
}
// TODO: Implement deep merge logic with defaults
// For now, just return the partial config
return partial
}
-396
View File
@@ -1,396 +0,0 @@
// Package config provides configuration loading and merging logic
package config
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"reflect"
"strings"
"github.com/lukaszraczylo/traefikoidc/internal/features"
"gopkg.in/yaml.v3"
)
// ConfigLoader handles loading configuration from various sources
type ConfigLoader struct {
migrator *ConfigMigrator
envPrefix string
configPaths []string
}
// NewConfigLoader creates a new configuration loader
func NewConfigLoader() *ConfigLoader {
return &ConfigLoader{
migrator: NewConfigMigrator(),
envPrefix: "TRAEFIKOIDC_",
configPaths: getDefaultConfigPaths(),
}
}
// getDefaultConfigPaths returns default configuration file paths to check
func getDefaultConfigPaths() []string {
return []string{
"traefik-oidc.yaml",
"traefik-oidc.yml",
"traefik-oidc.json",
"config.yaml",
"config.yml",
"config.json",
"/etc/traefik-oidc/config.yaml",
"/etc/traefik-oidc/config.json",
}
}
// Load loads configuration from all available sources
func (l *ConfigLoader) Load() (*UnifiedConfig, error) {
// Start with defaults
config := NewUnifiedConfig()
// Try to load from file
if fileConfig, err := l.LoadFromFile(); err == nil && fileConfig != nil {
config = l.mergeConfigs(config, fileConfig)
}
// Load from environment variables
l.LoadFromEnv(config)
// Validate the final configuration
if err := config.Validate(); err != nil {
return nil, fmt.Errorf("configuration validation failed: %w", err)
}
return config, nil
}
// LoadFromFile loads configuration from a file
func (l *ConfigLoader) LoadFromFile(paths ...string) (*UnifiedConfig, error) {
// Use provided paths or default paths
searchPaths := paths
if len(searchPaths) == 0 {
searchPaths = l.configPaths
}
// Check for config file in environment variable
if envPath := os.Getenv(l.envPrefix + "CONFIG_FILE"); envPath != "" {
searchPaths = append([]string{envPath}, searchPaths...)
}
// Try each path
for _, path := range searchPaths {
if _, err := os.Stat(path); err == nil {
return l.loadFile(path)
}
}
// No config file found, not an error (use defaults)
return nil, nil
}
// loadFile loads a specific configuration file
func (l *ConfigLoader) loadFile(path string) (*UnifiedConfig, error) {
// Clean and validate path to prevent traversal attacks
cleanPath := filepath.Clean(path)
// Check for path traversal attempts
if strings.Contains(cleanPath, "..") {
return nil, fmt.Errorf("invalid config path: potential path traversal detected in %s", path)
}
// Ensure the path is within expected directories (current dir or subdirs)
absPath, err := filepath.Abs(cleanPath)
if err != nil {
return nil, fmt.Errorf("failed to resolve absolute path for %s: %w", path, err)
}
// Read the file with validated path
data, err := os.ReadFile(absPath)
if err != nil {
return nil, fmt.Errorf("failed to read config file %s: %w", absPath, err)
}
// Check if unified config is enabled
if features.IsUnifiedConfigEnabled() {
// Use migrator to handle any version
config, warnings, err := l.migrator.Migrate(data)
if err != nil {
return nil, fmt.Errorf("failed to migrate config from %s: %w", path, err)
}
// Log warnings
for _, warning := range warnings {
// In production, use proper logging
fmt.Printf("Config Warning (%s): %s\n", path, warning)
}
return config, nil
}
// Legacy path: load old config and convert
ext := strings.ToLower(filepath.Ext(path))
var oldConfig Config
switch ext {
case ".json":
if err := json.Unmarshal(data, &oldConfig); err != nil {
return nil, fmt.Errorf("failed to parse JSON config: %w", err)
}
case ".yaml", ".yml":
if err := yaml.Unmarshal(data, &oldConfig); err != nil {
return nil, fmt.Errorf("failed to parse YAML config: %w", err)
}
default:
return nil, fmt.Errorf("unsupported config file extension: %s", ext)
}
return FromOldConfig(&oldConfig), nil
}
// LoadFromEnv loads configuration from environment variables
func (l *ConfigLoader) LoadFromEnv(config *UnifiedConfig) {
// Provider configuration
l.loadEnvString(&config.Provider.IssuerURL, "PROVIDER_ISSUER_URL", "PROVIDER_URL")
l.loadEnvString(&config.Provider.ClientID, "PROVIDER_CLIENT_ID", "CLIENT_ID")
l.loadEnvString(&config.Provider.ClientSecret, "PROVIDER_CLIENT_SECRET", "CLIENT_SECRET")
l.loadEnvString(&config.Provider.RedirectURL, "PROVIDER_REDIRECT_URL", "CALLBACK_URL")
l.loadEnvString(&config.Provider.LogoutURL, "PROVIDER_LOGOUT_URL", "LOGOUT_URL")
l.loadEnvString(&config.Provider.PostLogoutRedirectURI, "PROVIDER_POST_LOGOUT_URI", "POST_LOGOUT_REDIRECT_URI")
l.loadEnvStringSlice(&config.Provider.Scopes, "PROVIDER_SCOPES", "SCOPES")
l.loadEnvBool(&config.Provider.OverrideScopes, "PROVIDER_OVERRIDE_SCOPES", "OVERRIDE_SCOPES")
// Session configuration
l.loadEnvString(&config.Session.Name, "SESSION_NAME")
l.loadEnvInt(&config.Session.MaxAge, "SESSION_MAX_AGE")
l.loadEnvString(&config.Session.Secret, "SESSION_SECRET")
l.loadEnvString(&config.Session.EncryptionKey, "SESSION_ENCRYPTION_KEY")
l.loadEnvString(&config.Session.Domain, "SESSION_DOMAIN", "COOKIE_DOMAIN")
l.loadEnvBool(&config.Session.Secure, "SESSION_SECURE")
l.loadEnvBool(&config.Session.HttpOnly, "SESSION_HTTP_ONLY")
l.loadEnvString(&config.Session.SameSite, "SESSION_SAME_SITE")
// Security configuration
l.loadEnvBool(&config.Security.ForceHTTPS, "SECURITY_FORCE_HTTPS", "FORCE_HTTPS")
l.loadEnvBool(&config.Security.EnablePKCE, "SECURITY_ENABLE_PKCE", "ENABLE_PKCE")
l.loadEnvStringSlice(&config.Security.AllowedUsers, "SECURITY_ALLOWED_USERS", "ALLOWED_USERS")
l.loadEnvStringSlice(&config.Security.AllowedUserDomains, "SECURITY_ALLOWED_DOMAINS", "ALLOWED_USER_DOMAINS")
l.loadEnvStringSlice(&config.Security.AllowedRolesAndGroups, "SECURITY_ALLOWED_ROLES", "ALLOWED_ROLES_AND_GROUPS")
l.loadEnvStringSlice(&config.Security.ExcludedURLs, "SECURITY_EXCLUDED_URLS", "EXCLUDED_URLS")
// Cache configuration
l.loadEnvBool(&config.Cache.Enabled, "CACHE_ENABLED")
l.loadEnvString(&config.Cache.Type, "CACHE_TYPE")
l.loadEnvInt(&config.Cache.MaxEntries, "CACHE_MAX_ENTRIES")
// MaxEntrySize is int64, skip for now
// Rate limiting
l.loadEnvBool(&config.RateLimit.Enabled, "RATELIMIT_ENABLED")
l.loadEnvInt(&config.RateLimit.RequestsPerSecond, "RATELIMIT_RPS", "RATE_LIMIT")
l.loadEnvInt(&config.RateLimit.Burst, "RATELIMIT_BURST")
// Logging
l.loadEnvString(&config.Logging.Level, "LOGGING_LEVEL", "LOG_LEVEL")
l.loadEnvString(&config.Logging.Format, "LOGGING_FORMAT")
l.loadEnvString(&config.Logging.Output, "LOGGING_OUTPUT")
// Redis configuration (already handled by its own LoadFromEnv)
config.Redis.LoadFromEnv()
// Feature flags
features.GetManager().LoadFromEnv()
}
// Helper methods for environment variable loading
func (l *ConfigLoader) loadEnvString(target *string, keys ...string) {
for _, key := range keys {
if value := os.Getenv(l.envPrefix + key); value != "" {
*target = value
return
}
// Try without prefix
if value := os.Getenv(key); value != "" {
*target = value
return
}
}
}
func (l *ConfigLoader) loadEnvBool(target *bool, keys ...string) {
for _, key := range keys {
if value := os.Getenv(l.envPrefix + key); value != "" {
*target = strings.ToLower(value) == "true" || value == "1"
return
}
// Try without prefix
if value := os.Getenv(key); value != "" {
*target = strings.ToLower(value) == "true" || value == "1"
return
}
}
}
func (l *ConfigLoader) loadEnvInt(target *int, keys ...string) {
for _, key := range keys {
if value := os.Getenv(l.envPrefix + key); value != "" {
var i int
if _, err := fmt.Sscanf(value, "%d", &i); err == nil {
*target = i
return
}
}
// Try without prefix
if value := os.Getenv(key); value != "" {
var i int
if _, err := fmt.Sscanf(value, "%d", &i); err == nil {
*target = i
return
}
}
}
}
func (l *ConfigLoader) loadEnvStringSlice(target *[]string, keys ...string) {
for _, key := range keys {
if value := os.Getenv(l.envPrefix + key); value != "" {
*target = splitAndTrim(value)
return
}
// Try without prefix
if value := os.Getenv(key); value != "" {
*target = splitAndTrim(value)
return
}
}
}
func splitAndTrim(s string) []string {
parts := strings.Split(s, ",")
result := make([]string, 0, len(parts))
for _, part := range parts {
if trimmed := strings.TrimSpace(part); trimmed != "" {
result = append(result, trimmed)
}
}
return result
}
// mergeConfigs merges two configurations, with source overriding target
func (l *ConfigLoader) mergeConfigs(target, source *UnifiedConfig) *UnifiedConfig {
if source == nil {
return target
}
if target == nil {
return source
}
// Use reflection for deep merge
l.mergeStructs(reflect.ValueOf(target).Elem(), reflect.ValueOf(source).Elem())
return target
}
// mergeStructs recursively merges two structs
func (l *ConfigLoader) mergeStructs(target, source reflect.Value) {
for i := 0; i < source.NumField(); i++ {
sourceField := source.Field(i)
targetField := target.Field(i)
// Skip if source field is zero value
if isZeroValue(sourceField) {
continue
}
switch sourceField.Kind() {
case reflect.Struct:
// Recursively merge structs
l.mergeStructs(targetField, sourceField)
case reflect.Slice:
// Replace slice if source has values
if sourceField.Len() > 0 {
targetField.Set(sourceField)
}
case reflect.Map:
// Merge maps
if !sourceField.IsNil() {
if targetField.IsNil() {
targetField.Set(reflect.MakeMap(sourceField.Type()))
}
for _, key := range sourceField.MapKeys() {
targetField.SetMapIndex(key, sourceField.MapIndex(key))
}
}
default:
// Replace value
targetField.Set(sourceField)
}
}
}
// isZeroValue checks if a reflect.Value is a zero value
func isZeroValue(v reflect.Value) bool {
switch v.Kind() {
case reflect.Ptr, reflect.Interface:
return v.IsNil()
case reflect.Slice, reflect.Map:
return v.IsNil() || v.Len() == 0
case reflect.Struct:
// Check if all fields are zero
for i := 0; i < v.NumField(); i++ {
if !isZeroValue(v.Field(i)) {
return false
}
}
return true
default:
zero := reflect.Zero(v.Type())
return reflect.DeepEqual(v.Interface(), zero.Interface())
}
}
// SaveToFile saves the configuration to a file
func (l *ConfigLoader) SaveToFile(config *UnifiedConfig, path string) error {
// Clean and validate path to prevent traversal attacks
cleanPath := filepath.Clean(path)
// Check for path traversal attempts
if strings.Contains(cleanPath, "..") {
return fmt.Errorf("invalid config path: potential path traversal detected in %s", path)
}
// Ensure the path is within expected directories
absPath, err := filepath.Abs(cleanPath)
if err != nil {
return fmt.Errorf("failed to resolve absolute path for %s: %w", path, err)
}
ext := strings.ToLower(filepath.Ext(absPath))
var data []byte
switch ext {
case ".json":
data, err = json.MarshalIndent(config, "", " ")
case ".yaml", ".yml":
data, err = yaml.Marshal(config)
default:
return fmt.Errorf("unsupported file extension: %s", ext)
}
if err != nil {
return fmt.Errorf("failed to marshal config: %w", err)
}
// Create directory if it doesn't exist with secure permissions
dir := filepath.Dir(absPath)
if err := os.MkdirAll(dir, 0700); err != nil {
return fmt.Errorf("failed to create directory %s: %w", dir, err)
}
// Write file with secure permissions (owner read/write only)
if err := os.WriteFile(absPath, data, 0600); err != nil {
return fmt.Errorf("failed to write config file %s: %w", absPath, err)
}
return nil
}
-832
View File
@@ -1,832 +0,0 @@
//go:build !yaegi
package config
import (
"os"
"path/filepath"
"reflect"
"strings"
"testing"
)
// TestConfigLoader tests the config loader functionality
func TestConfigLoader(t *testing.T) {
loader := NewConfigLoader()
if loader == nil {
t.Fatal("NewConfigLoader should not return nil")
}
if loader.migrator == nil {
t.Error("ConfigLoader should have a migrator")
}
if loader.envPrefix != "TRAEFIKOIDC_" {
t.Errorf("Expected envPrefix to be 'TRAEFIKOIDC_', got %s", loader.envPrefix)
}
if len(loader.configPaths) == 0 {
t.Error("ConfigLoader should have default config paths")
}
}
// TestLoadFromEnv tests loading configuration from environment variables
func TestLoadFromEnv(t *testing.T) {
// Set up test environment variables
testEnvVars := map[string]string{
"TRAEFIKOIDC_PROVIDER_ISSUER_URL": "https://test.example.com",
"TRAEFIKOIDC_PROVIDER_CLIENT_ID": "test-client-id",
"TRAEFIKOIDC_PROVIDER_CLIENT_SECRET": "test-secret",
"TRAEFIKOIDC_SESSION_ENCRYPTION_KEY": "32-character-encryption-key-12345",
"TRAEFIKOIDC_SESSION_CHUNKED": "true",
"TRAEFIKOIDC_REDIS_ENABLED": "true",
"TRAEFIKOIDC_REDIS_ADDR": "redis.example.com:6379",
"TRAEFIKOIDC_SECURITY_FORCE_HTTPS": "true",
"TRAEFIKOIDC_CACHE_ENABLED": "true",
"TRAEFIKOIDC_CACHE_TYPE": "redis",
"TRAEFIKOIDC_RATELIMIT_ENABLED": "true",
"TRAEFIKOIDC_RATELIMIT_RPS": "100",
}
// Set environment variables
for key, value := range testEnvVars {
os.Setenv(key, value)
defer os.Unsetenv(key)
}
loader := NewConfigLoader()
config := &UnifiedConfig{}
loader.LoadFromEnv(config)
// Verify values were loaded
if config.Provider.IssuerURL != "https://test.example.com" {
t.Errorf("Expected IssuerURL to be 'https://test.example.com', got %s", config.Provider.IssuerURL)
}
if config.Provider.ClientID != "test-client-id" {
t.Errorf("Expected ClientID to be 'test-client-id', got %s", config.Provider.ClientID)
}
if config.Provider.ClientSecret != "test-secret" {
t.Errorf("Expected ClientSecret to be 'test-secret', got %s", config.Provider.ClientSecret)
}
if config.Session.EncryptionKey != "32-character-encryption-key-12345" {
t.Errorf("Expected EncryptionKey to be set, got %s", config.Session.EncryptionKey)
}
if !config.Security.ForceHTTPS {
t.Error("Expected ForceHTTPS to be true")
}
if !config.Cache.Enabled {
t.Error("Expected Cache to be enabled")
}
if config.Cache.Type != "redis" {
t.Errorf("Expected Cache.Type to be 'redis', got %s", config.Cache.Type)
}
if !config.RateLimit.Enabled {
t.Error("Expected RateLimit to be enabled")
}
if config.RateLimit.RequestsPerSecond != 100 {
t.Errorf("Expected RequestsPerSecond to be 100, got %d", config.RateLimit.RequestsPerSecond)
}
}
// TestSaveToFile tests saving configuration to files
func TestSaveToFile(t *testing.T) {
// Create a temporary directory for test files
tmpDir, err := os.MkdirTemp("", "config-test-*")
if err != nil {
t.Fatalf("Failed to create temp directory: %v", err)
}
defer os.RemoveAll(tmpDir)
loader := NewConfigLoader()
config := &UnifiedConfig{
Provider: ProviderConfig{
IssuerURL: "https://auth.example.com",
ClientID: "test-client",
ClientSecret: "secret",
},
Session: SessionConfig{
EncryptionKey: "32-character-encryption-key-12345",
},
}
tests := []struct {
name string
filename string
wantErr bool
}{
{
name: "save as JSON",
filename: "config.json",
wantErr: false,
},
{
name: "save as YAML",
filename: "config.yaml",
wantErr: false,
},
{
name: "save as YML",
filename: "config.yml",
wantErr: false,
},
{
name: "unsupported extension",
filename: "config.txt",
wantErr: true,
},
{
name: "path traversal attempt",
filename: "../../../etc/config.json",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
filePath := filepath.Join(tmpDir, tt.filename)
err := loader.SaveToFile(config, filePath)
if tt.wantErr {
if err == nil {
t.Error("Expected error but got none")
}
return
}
if err != nil {
t.Errorf("Unexpected error: %v", err)
return
}
// Verify file was created with correct permissions
info, err := os.Stat(filePath)
if err != nil {
t.Errorf("Failed to stat saved file: %v", err)
return
}
// Check file permissions (should be 0600)
mode := info.Mode().Perm()
if mode != 0600 {
t.Errorf("Expected file permissions 0600, got %o", mode)
}
// Verify content can be read back
data, err := os.ReadFile(filePath)
if err != nil {
t.Errorf("Failed to read saved file: %v", err)
return
}
// Verify secrets are redacted
content := string(data)
if strings.Contains(content, "secret") && !strings.Contains(content, "[REDACTED]") {
t.Error("Secrets should be redacted in saved file")
}
})
}
}
// TestLoadFile tests loading configuration from files
func TestLoadFile(t *testing.T) {
// Create a temporary directory for test files
tmpDir, err := os.MkdirTemp("", "config-test-*")
if err != nil {
t.Fatalf("Failed to create temp directory: %v", err)
}
defer os.RemoveAll(tmpDir)
// Test data - using old config format since unified config is not enabled by default
jsonConfig := `{
"providerURL": "https://auth.example.com",
"clientID": "test-client",
"clientSecret": "secret",
"sessionEncryptionKey": "32-character-encryption-key-12345"
}`
yamlConfig := `
providerurl: https://auth.example.com
clientid: test-client
clientsecret: secret
sessionencryptionkey: 32-character-encryption-key-12345
`
tests := []struct {
name string
filename string
content string
wantErr bool
}{
{
name: "load JSON config",
filename: "config.json",
content: jsonConfig,
wantErr: false,
},
{
name: "load YAML config",
filename: "config.yaml",
content: yamlConfig,
wantErr: false,
},
{
name: "path traversal attempt",
filename: "../../../etc/passwd",
content: "",
wantErr: true,
},
{
name: "non-existent file",
filename: "does-not-exist.json",
content: "",
wantErr: true,
},
}
loader := NewConfigLoader()
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var filePath string
if tt.content != "" {
filePath = filepath.Join(tmpDir, tt.filename)
err := os.WriteFile(filePath, []byte(tt.content), 0600)
if err != nil {
t.Fatalf("Failed to write test file: %v", err)
return
}
} else {
filePath = tt.filename
}
config, err := loader.loadFile(filePath)
if tt.wantErr {
if err == nil {
t.Error("Expected error but got none")
}
return
}
if err != nil {
if !os.IsNotExist(err) && !strings.Contains(err.Error(), "no such file") {
t.Errorf("Unexpected error: %v", err)
}
return
}
// Verify loaded config
if config == nil {
t.Error("Expected config to be loaded")
return
}
if config.Provider.IssuerURL != "https://auth.example.com" {
t.Errorf("Expected IssuerURL to be 'https://auth.example.com', got %s", config.Provider.IssuerURL)
}
if config.Provider.ClientID != "test-client" {
t.Errorf("Expected ClientID to be 'test-client', got %s", config.Provider.ClientID)
}
})
}
}
// ====================================================================================
// Tests for untested functions (0% coverage)
// ====================================================================================
// TestConfigLoader_Load tests the full Load pipeline
func TestConfigLoader_Load(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "config-load-test-*")
if err != nil {
t.Fatalf("Failed to create temp directory: %v", err)
}
defer os.RemoveAll(tmpDir)
// Create a test config file
configPath := filepath.Join(tmpDir, "traefik-oidc.json")
configData := `{
"providerURL": "https://auth.example.com",
"clientID": "test-client",
"clientSecret": "test-secret",
"sessionEncryptionKey": "32-character-encryption-key-12345"
}`
err = os.WriteFile(configPath, []byte(configData), 0600)
if err != nil {
t.Fatalf("Failed to write test config file: %v", err)
}
// Change to temp directory so loader can find the config
oldDir, _ := os.Getwd()
os.Chdir(tmpDir)
defer os.Chdir(oldDir)
// Set some environment variables to test merging
os.Setenv("TRAEFIKOIDC_SECURITY_FORCE_HTTPS", "true")
defer os.Unsetenv("TRAEFIKOIDC_SECURITY_FORCE_HTTPS")
loader := NewConfigLoader()
config, err := loader.Load()
if err != nil {
t.Fatalf("Load() failed: %v", err)
}
if config == nil {
t.Fatal("Load() returned nil config")
}
// Verify file was loaded
if config.Provider.IssuerURL != "https://auth.example.com" {
t.Errorf("Expected IssuerURL from file, got %s", config.Provider.IssuerURL)
}
// Verify env vars were loaded
if !config.Security.ForceHTTPS {
t.Error("Expected ForceHTTPS from env var to be true")
}
}
// TestConfigLoader_LoadFromFile tests the LoadFromFile function
func TestConfigLoader_LoadFromFile(t *testing.T) {
t.Run("NoConfigFile", func(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "config-nofile-test-*")
if err != nil {
t.Fatalf("Failed to create temp directory: %v", err)
}
defer os.RemoveAll(tmpDir)
oldDir, _ := os.Getwd()
os.Chdir(tmpDir)
defer os.Chdir(oldDir)
loader := NewConfigLoader()
config, err := loader.LoadFromFile()
// Should not error when no config file found
if err != nil {
t.Errorf("LoadFromFile() should not error when no file found: %v", err)
}
// Should return nil config
if config != nil {
t.Error("LoadFromFile() should return nil config when no file found")
}
})
t.Run("LoadFromEnvPath", func(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "config-envpath-test-*")
if err != nil {
t.Fatalf("Failed to create temp directory: %v", err)
}
defer os.RemoveAll(tmpDir)
// Create config file
configPath := filepath.Join(tmpDir, "custom-config.json")
configData := `{
"providerURL": "https://custom.example.com",
"clientID": "custom-client"
}`
err = os.WriteFile(configPath, []byte(configData), 0600)
if err != nil {
t.Fatalf("Failed to write test config: %v", err)
}
// Set env variable pointing to config
os.Setenv("TRAEFIKOIDC_CONFIG_FILE", configPath)
defer os.Unsetenv("TRAEFIKOIDC_CONFIG_FILE")
loader := NewConfigLoader()
config, err := loader.LoadFromFile()
if err != nil {
t.Fatalf("LoadFromFile() failed: %v", err)
}
if config == nil {
t.Fatal("LoadFromFile() returned nil config")
}
if config.Provider.IssuerURL != "https://custom.example.com" {
t.Errorf("Expected IssuerURL 'https://custom.example.com', got %s", config.Provider.IssuerURL)
}
})
t.Run("LoadWithProvidedPaths", func(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "config-provided-test-*")
if err != nil {
t.Fatalf("Failed to create temp directory: %v", err)
}
defer os.RemoveAll(tmpDir)
// Create config file
configPath := filepath.Join(tmpDir, "specific.json")
configData := `{
"providerURL": "https://specific.example.com",
"clientID": "specific-client"
}`
err = os.WriteFile(configPath, []byte(configData), 0600)
if err != nil {
t.Fatalf("Failed to write test config: %v", err)
}
loader := NewConfigLoader()
config, err := loader.LoadFromFile(configPath)
if err != nil {
t.Fatalf("LoadFromFile() with path failed: %v", err)
}
if config == nil {
t.Fatal("LoadFromFile() returned nil config")
}
if config.Provider.IssuerURL != "https://specific.example.com" {
t.Errorf("Expected IssuerURL 'https://specific.example.com', got %s", config.Provider.IssuerURL)
}
})
}
// TestSplitAndTrim tests the splitAndTrim helper function
func TestSplitAndTrim(t *testing.T) {
tests := []struct {
name string
input string
expected []string
}{
{
name: "Simple comma-separated",
input: "a,b,c",
expected: []string{"a", "b", "c"},
},
{
name: "With spaces",
input: "a, b , c",
expected: []string{"a", "b", "c"},
},
{
name: "Empty strings filtered out",
input: "a,,b, ,c",
expected: []string{"a", "b", "c"},
},
{
name: "Leading and trailing spaces",
input: " a , b , c ",
expected: []string{"a", "b", "c"},
},
{
name: "Single value",
input: "single",
expected: []string{"single"},
},
{
name: "Empty string",
input: "",
expected: []string{},
},
{
name: "Only commas and spaces",
input: " , , , ",
expected: []string{},
},
{
name: "Complex real-world example",
input: "openid, profile, email, groups",
expected: []string{"openid", "profile", "email", "groups"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := splitAndTrim(tt.input)
if len(result) != len(tt.expected) {
t.Errorf("Expected %d items, got %d: %v", len(tt.expected), len(result), result)
return
}
for i, expected := range tt.expected {
if result[i] != expected {
t.Errorf("At index %d: expected %q, got %q", i, expected, result[i])
}
}
})
}
}
// TestConfigLoader_MergeConfigs tests the mergeConfigs function
func TestConfigLoader_MergeConfigs(t *testing.T) {
loader := NewConfigLoader()
t.Run("MergeNilSource", func(t *testing.T) {
target := &UnifiedConfig{
Provider: ProviderConfig{
IssuerURL: "https://target.example.com",
},
}
result := loader.mergeConfigs(target, nil)
if result != target {
t.Error("mergeConfigs should return target when source is nil")
}
})
t.Run("MergeNilTarget", func(t *testing.T) {
source := &UnifiedConfig{
Provider: ProviderConfig{
IssuerURL: "https://source.example.com",
},
}
result := loader.mergeConfigs(nil, source)
if result != source {
t.Error("mergeConfigs should return source when target is nil")
}
})
t.Run("MergeSimpleFields", func(t *testing.T) {
target := &UnifiedConfig{
Provider: ProviderConfig{
IssuerURL: "https://target.example.com",
ClientID: "",
},
}
source := &UnifiedConfig{
Provider: ProviderConfig{
IssuerURL: "https://source.example.com",
ClientID: "source-client",
},
}
result := loader.mergeConfigs(target, source)
if result.Provider.IssuerURL != "https://source.example.com" {
t.Errorf("Expected IssuerURL to be overridden, got %s", result.Provider.IssuerURL)
}
if result.Provider.ClientID != "source-client" {
t.Errorf("Expected ClientID to be set, got %s", result.Provider.ClientID)
}
})
t.Run("MergeSlices", func(t *testing.T) {
target := &UnifiedConfig{
Provider: ProviderConfig{
Scopes: []string{"openid", "profile"},
},
}
source := &UnifiedConfig{
Provider: ProviderConfig{
Scopes: []string{"email", "groups"},
},
}
result := loader.mergeConfigs(target, source)
// Source slice should replace target slice
if len(result.Provider.Scopes) != 2 {
t.Errorf("Expected 2 scopes, got %d", len(result.Provider.Scopes))
}
if result.Provider.Scopes[0] != "email" {
t.Errorf("Expected first scope 'email', got %s", result.Provider.Scopes[0])
}
})
t.Run("MergeMaps", func(t *testing.T) {
target := &UnifiedConfig{
Middleware: MiddlewareConfig{
CustomHeaders: map[string]string{
"X-Target-Header": "target-value",
},
},
}
source := &UnifiedConfig{
Middleware: MiddlewareConfig{
CustomHeaders: map[string]string{
"X-Source-Header": "source-value",
"X-Target-Header": "overridden-value",
},
},
}
result := loader.mergeConfigs(target, source)
if len(result.Middleware.CustomHeaders) != 2 {
t.Errorf("Expected 2 headers, got %d", len(result.Middleware.CustomHeaders))
}
if result.Middleware.CustomHeaders["X-Target-Header"] != "overridden-value" {
t.Errorf("Expected X-Target-Header to be overridden")
}
if result.Middleware.CustomHeaders["X-Source-Header"] != "source-value" {
t.Errorf("Expected X-Source-Header to be added")
}
})
}
// TestConfigLoader_MergeStructs tests the mergeStructs function indirectly
func TestConfigLoader_MergeStructs(t *testing.T) {
loader := NewConfigLoader()
t.Run("NestedStructMerge", func(t *testing.T) {
target := &UnifiedConfig{
Provider: ProviderConfig{
IssuerURL: "https://target.example.com",
ClientID: "target-client",
},
Session: SessionConfig{
Name: "target-session",
MaxAge: 3600,
},
}
source := &UnifiedConfig{
Provider: ProviderConfig{
ClientID: "source-client",
ClientSecret: "source-secret",
},
Session: SessionConfig{
MaxAge: 7200,
},
}
result := loader.mergeConfigs(target, source)
// Provider.IssuerURL should remain (zero value in source)
if result.Provider.IssuerURL != "https://target.example.com" {
t.Errorf("Expected IssuerURL to remain, got %s", result.Provider.IssuerURL)
}
// Provider.ClientID should be overridden
if result.Provider.ClientID != "source-client" {
t.Errorf("Expected ClientID to be overridden, got %s", result.Provider.ClientID)
}
// Provider.ClientSecret should be added
if result.Provider.ClientSecret != "source-secret" {
t.Errorf("Expected ClientSecret to be added, got %s", result.Provider.ClientSecret)
}
// Session.Name should remain (zero value in source)
if result.Session.Name != "target-session" {
t.Errorf("Expected Session.Name to remain, got %s", result.Session.Name)
}
// Session.MaxAge should be overridden
if result.Session.MaxAge != 7200 {
t.Errorf("Expected Session.MaxAge to be overridden, got %d", result.Session.MaxAge)
}
})
}
// TestIsZeroValue tests the isZeroValue helper function
func TestIsZeroValue(t *testing.T) {
tests := []struct {
name string
value interface{}
expected bool
}{
{
name: "Zero string",
value: "",
expected: true,
},
{
name: "Non-zero string",
value: "hello",
expected: false,
},
{
name: "Zero int",
value: 0,
expected: true,
},
{
name: "Non-zero int",
value: 42,
expected: false,
},
{
name: "Zero bool",
value: false,
expected: true,
},
{
name: "Non-zero bool",
value: true,
expected: false,
},
{
name: "Nil pointer",
value: (*string)(nil),
expected: true,
},
{
name: "Non-nil pointer",
value: stringPtr("test"),
expected: false,
},
{
name: "Nil slice",
value: ([]string)(nil),
expected: true,
},
{
name: "Empty slice",
value: []string{},
expected: true,
},
{
name: "Non-empty slice",
value: []string{"a"},
expected: false,
},
{
name: "Nil map",
value: (map[string]string)(nil),
expected: true,
},
{
name: "Empty map",
value: map[string]string{},
expected: true,
},
{
name: "Non-empty map",
value: map[string]string{"key": "value"},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
v := reflect.ValueOf(tt.value)
result := isZeroValue(v)
if result != tt.expected {
t.Errorf("Expected isZeroValue to be %v, got %v", tt.expected, result)
}
})
}
}
// TestIsZeroValue_Struct tests isZeroValue with struct types
func TestIsZeroValue_Struct(t *testing.T) {
type TestStruct struct {
Field1 string
Field2 int
}
t.Run("Zero struct", func(t *testing.T) {
s := TestStruct{}
v := reflect.ValueOf(s)
result := isZeroValue(v)
if !result {
t.Error("Expected zero struct to return true")
}
})
t.Run("Non-zero struct - Field1 set", func(t *testing.T) {
s := TestStruct{Field1: "test"}
v := reflect.ValueOf(s)
result := isZeroValue(v)
if result {
t.Error("Expected non-zero struct to return false")
}
})
t.Run("Non-zero struct - Field2 set", func(t *testing.T) {
s := TestStruct{Field2: 42}
v := reflect.ValueOf(s)
result := isZeroValue(v)
if result {
t.Error("Expected non-zero struct to return false")
}
})
t.Run("Non-zero struct - Both fields set", func(t *testing.T) {
s := TestStruct{Field1: "test", Field2: 42}
v := reflect.ValueOf(s)
result := isZeroValue(v)
if result {
t.Error("Expected non-zero struct to return false")
}
})
}
// Helper function for pointer tests
func stringPtr(s string) *string {
return &s
}
-169
View File
@@ -1,169 +0,0 @@
// Package config provides unified configuration management for the OIDC middleware
package config
import (
"encoding/json"
)
// REDACTED is the placeholder value for sensitive information
const REDACTED = "[REDACTED]"
// MarshalJSON implements custom JSON marshalling to redact sensitive fields
func (c UnifiedConfig) MarshalJSON() ([]byte, error) {
// Create an alias to avoid recursion
type Alias UnifiedConfig
// Create a copy with redacted sensitive fields
copy := (Alias)(c)
// Redact provider secrets
if copy.Provider.ClientSecret != "" {
copy.Provider.ClientSecret = REDACTED
}
// Redact session secrets
if copy.Session.Secret != "" {
copy.Session.Secret = REDACTED
}
if copy.Session.EncryptionKey != "" {
copy.Session.EncryptionKey = REDACTED
}
if copy.Session.SigningKey != "" {
copy.Session.SigningKey = REDACTED
}
// Redact Redis passwords
if copy.Redis.Password != "" {
copy.Redis.Password = REDACTED
}
if copy.Redis.SentinelPassword != "" {
copy.Redis.SentinelPassword = REDACTED
}
return json.Marshal(copy)
}
// MarshalJSON for ProviderConfig to redact sensitive fields
func (p ProviderConfig) MarshalJSON() ([]byte, error) {
type Alias ProviderConfig
copy := (Alias)(p)
if copy.ClientSecret != "" {
copy.ClientSecret = REDACTED
}
return json.Marshal(copy)
}
// MarshalJSON for SessionConfig to redact sensitive fields
func (s SessionConfig) MarshalJSON() ([]byte, error) {
type Alias SessionConfig
copy := (Alias)(s)
if copy.Secret != "" {
copy.Secret = REDACTED
}
if copy.EncryptionKey != "" {
copy.EncryptionKey = REDACTED
}
if copy.SigningKey != "" {
copy.SigningKey = REDACTED
}
return json.Marshal(copy)
}
// MarshalJSON for RedisConfig to redact sensitive fields
func (r RedisConfig) MarshalJSON() ([]byte, error) {
type Alias RedisConfig
copy := (Alias)(r)
if copy.Password != "" {
copy.Password = REDACTED
}
if copy.SentinelPassword != "" {
copy.SentinelPassword = REDACTED
}
return json.Marshal(copy)
}
// MarshalYAML implements custom YAML marshalling to redact sensitive fields
func (c UnifiedConfig) MarshalYAML() (interface{}, error) {
// Create an alias to avoid recursion
type Alias UnifiedConfig
// Create a copy with redacted sensitive fields
copy := (Alias)(c)
// Redact provider secrets
if copy.Provider.ClientSecret != "" {
copy.Provider.ClientSecret = REDACTED
}
// Redact session secrets
if copy.Session.Secret != "" {
copy.Session.Secret = REDACTED
}
if copy.Session.EncryptionKey != "" {
copy.Session.EncryptionKey = REDACTED
}
if copy.Session.SigningKey != "" {
copy.Session.SigningKey = REDACTED
}
// Redact Redis passwords
if copy.Redis.Password != "" {
copy.Redis.Password = REDACTED
}
if copy.Redis.SentinelPassword != "" {
copy.Redis.SentinelPassword = REDACTED
}
return copy, nil
}
// MarshalYAML for ProviderConfig to redact sensitive fields
func (p ProviderConfig) MarshalYAML() (interface{}, error) {
type Alias ProviderConfig
copy := (Alias)(p)
if copy.ClientSecret != "" {
copy.ClientSecret = REDACTED
}
return copy, nil
}
// MarshalYAML for SessionConfig to redact sensitive fields
func (s SessionConfig) MarshalYAML() (interface{}, error) {
type Alias SessionConfig
copy := (Alias)(s)
if copy.Secret != "" {
copy.Secret = REDACTED
}
if copy.EncryptionKey != "" {
copy.EncryptionKey = REDACTED
}
if copy.SigningKey != "" {
copy.SigningKey = REDACTED
}
return copy, nil
}
// MarshalYAML for RedisConfig to redact sensitive fields
func (r RedisConfig) MarshalYAML() (interface{}, error) {
type Alias RedisConfig
copy := (Alias)(r)
if copy.Password != "" {
copy.Password = REDACTED
}
if copy.SentinelPassword != "" {
copy.SentinelPassword = REDACTED
}
return copy, nil
}
-407
View File
@@ -1,407 +0,0 @@
// Package config provides configuration migration from old to new format
package config
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"time"
"github.com/lukaszraczylo/traefikoidc/internal/compat"
"github.com/lukaszraczylo/traefikoidc/internal/features"
"gopkg.in/yaml.v3"
)
// ConfigVersion represents the version of a configuration format
type ConfigVersion string
const (
// VersionLegacy represents the original config format
VersionLegacy ConfigVersion = "legacy"
// VersionUnified represents the new unified config format
VersionUnified ConfigVersion = "unified"
// CurrentVersion is the current config version
CurrentVersion ConfigVersion = VersionUnified
)
// ConfigMigrator handles migration between config versions
type ConfigMigrator struct {
compatLayer *compat.CompatibilityLayer
migrations map[ConfigVersion]MigrationFunc
}
// MigrationFunc defines a function that migrates configuration
type MigrationFunc func(data map[string]interface{}) (*UnifiedConfig, error)
// NewConfigMigrator creates a new configuration migrator
func NewConfigMigrator() *ConfigMigrator {
m := &ConfigMigrator{
compatLayer: compat.GetLayer(),
migrations: make(map[ConfigVersion]MigrationFunc),
}
// Register migration functions
m.migrations[VersionLegacy] = m.migrateLegacyToUnified
return m
}
// DetectVersion detects the version of a configuration
func (m *ConfigMigrator) DetectVersion(data []byte) ConfigVersion {
var testMap map[string]interface{}
// Try JSON first
if err := json.Unmarshal(data, &testMap); err != nil {
// Try YAML
if err := yaml.Unmarshal(data, &testMap); err != nil {
return VersionLegacy // Default to legacy if can't parse
}
}
// Check for unified config markers
if _, hasProvider := testMap["provider"]; hasProvider {
if _, hasSession := testMap["session"]; hasSession {
return VersionUnified
}
}
// Check for legacy config markers
if _, hasProviderURL := testMap["providerUrl"]; hasProviderURL {
return VersionLegacy
}
if _, hasProviderURL := testMap["ProviderURL"]; hasProviderURL {
return VersionLegacy
}
return VersionLegacy
}
// Migrate migrates configuration data to the current version
func (m *ConfigMigrator) Migrate(data []byte) (*UnifiedConfig, []string, error) {
warnings := []string{}
// Detect version
version := m.DetectVersion(data)
// If already current version, just unmarshal
if version == CurrentVersion {
var config UnifiedConfig
if err := json.Unmarshal(data, &config); err != nil {
// Try YAML
if err := yaml.Unmarshal(data, &config); err != nil {
return nil, warnings, fmt.Errorf("failed to unmarshal unified config: %w", err)
}
}
return &config, warnings, nil
}
// Parse to generic map
var configMap map[string]interface{}
if err := json.Unmarshal(data, &configMap); err != nil {
// Try YAML
if err := yaml.Unmarshal(data, &configMap); err != nil {
return nil, warnings, fmt.Errorf("failed to unmarshal config: %w", err)
}
}
// Apply migration
migrationFunc, exists := m.migrations[version]
if !exists {
return nil, warnings, fmt.Errorf("no migration path from version %s", version)
}
config, err := migrationFunc(configMap)
if err != nil {
return nil, warnings, fmt.Errorf("migration failed: %w", err)
}
// Collect any deprecation warnings
for key := range configMap {
if warning, deprecated := m.compatLayer.CheckDeprecation(key); deprecated {
warnings = append(warnings, warning)
}
}
return config, warnings, nil
}
// migrateLegacyToUnified migrates legacy config to unified format
func (m *ConfigMigrator) migrateLegacyToUnified(data map[string]interface{}) (*UnifiedConfig, error) {
config := NewUnifiedConfig()
// Use compatibility layer for field mapping
migratedMap, warnings := m.compatLayer.MigrateMap(data)
// Log warnings
for _, warning := range warnings {
// In production, these would be logged
_ = warning
}
// Map provider configuration
if provider, ok := getNestedMap(migratedMap, "Provider"); ok {
_ = mapToStruct(provider, &config.Provider)
} else {
// Direct field mapping for legacy format
config.Provider.IssuerURL = getStringValue(data, "providerUrl", "ProviderURL")
config.Provider.ClientID = getStringValue(data, "clientId", "ClientID")
config.Provider.ClientSecret = getStringValue(data, "clientSecret", "ClientSecret")
config.Provider.RedirectURL = getStringValue(data, "callbackUrl", "CallbackURL")
config.Provider.LogoutURL = getStringValue(data, "logoutUrl", "LogoutURL")
config.Provider.PostLogoutRedirectURI = getStringValue(data, "postLogoutRedirectUri", "PostLogoutRedirectURI")
if scopes := getArrayValue(data, "scopes", "Scopes"); scopes != nil {
config.Provider.Scopes = scopes
}
config.Provider.OverrideScopes = getBoolValue(data, "overrideScopes", "OverrideScopes")
}
// Map session configuration
if session, ok := getNestedMap(migratedMap, "Session"); ok {
_ = mapToStruct(session, &config.Session)
} else {
config.Session.EncryptionKey = getStringValue(data, "sessionEncryptionKey", "SessionEncryptionKey")
config.Session.Domain = getStringValue(data, "cookieDomain", "CookieDomain")
}
// Map security configuration
if security, ok := getNestedMap(migratedMap, "Security"); ok {
_ = mapToStruct(security, &config.Security)
} else {
config.Security.ForceHTTPS = getBoolValue(data, "forceHttps", "ForceHTTPS")
config.Security.EnablePKCE = getBoolValue(data, "enablePkce", "EnablePKCE")
if users := getArrayValue(data, "allowedUsers", "AllowedUsers"); users != nil {
config.Security.AllowedUsers = users
}
if domains := getArrayValue(data, "allowedUserDomains", "AllowedUserDomains"); domains != nil {
config.Security.AllowedUserDomains = domains
}
if roles := getArrayValue(data, "allowedRolesAndGroups", "AllowedRolesAndGroups"); roles != nil {
config.Security.AllowedRolesAndGroups = roles
}
if excluded := getArrayValue(data, "excludedUrls", "ExcludedURLs"); excluded != nil {
config.Security.ExcludedURLs = excluded
}
// Handle security headers
if headers := data["securityHeaders"]; headers != nil {
// Security headers might be in old format
_ = mapToStruct(headers, &config.Security.Headers)
}
}
// Map rate limiting
if rateLimit := getIntValue(data, "rateLimit", "RateLimit"); rateLimit > 0 {
config.RateLimit.Enabled = true
config.RateLimit.RequestsPerSecond = rateLimit
config.RateLimit.Burst = rateLimit * 2 // Default burst to 2x rate
}
// Map token configuration
if refreshGrace := getIntValue(data, "refreshGracePeriodSeconds", "RefreshGracePeriodSeconds"); refreshGrace > 0 {
config.Token.RefreshGracePeriod = time.Duration(refreshGrace) * time.Second
}
// Map logging
config.Logging.Level = strings.ToLower(getStringValue(data, "logLevel", "LogLevel"))
if config.Logging.Level == "" {
config.Logging.Level = "info"
}
// Map custom headers
if headers := data["headers"]; headers != nil {
if headerList, ok := headers.([]interface{}); ok {
config.Middleware.CustomHeaders = make(map[string]string)
for _, h := range headerList {
if headerMap, ok := h.(map[string]interface{}); ok {
name := getStringFromInterface(headerMap["name"])
value := getStringFromInterface(headerMap["value"])
if name != "" {
config.Middleware.CustomHeaders[name] = value
}
}
}
}
}
// Store original data for reference
config.Legacy = data
return config, nil
}
// MigrateFile migrates a configuration file
func (m *ConfigMigrator) MigrateFile(filePath string) (*UnifiedConfig, error) {
// Clean and validate path to prevent traversal attacks
cleanPath := filepath.Clean(filePath)
// Check for path traversal attempts
if strings.Contains(cleanPath, "..") {
return nil, fmt.Errorf("invalid config path: potential path traversal detected in %s", filePath)
}
// Ensure the path is within expected directories
absPath, err := filepath.Abs(cleanPath)
if err != nil {
return nil, fmt.Errorf("failed to resolve absolute path for %s: %w", filePath, err)
}
// Read the file with validated path
data, err := os.ReadFile(absPath)
if err != nil {
return nil, fmt.Errorf("failed to read config file: %w", err)
}
config, warnings, err := m.Migrate(data)
if err != nil {
return nil, err
}
// Log warnings
for _, warning := range warnings {
fmt.Printf("Migration Warning: %s\n", warning)
}
return config, nil
}
// AutoMigrate automatically migrates config based on feature flags
func AutoMigrate(data interface{}) (*UnifiedConfig, error) {
if !features.IsUnifiedConfigEnabled() {
// Feature not enabled, return nil
return nil, nil
}
migrator := NewConfigMigrator()
// Handle different input types
switch v := data.(type) {
case []byte:
config, _, err := migrator.Migrate(v)
return config, err
case string:
config, _, err := migrator.Migrate([]byte(v))
return config, err
case *Config:
// Convert old config to unified
return FromOldConfig(v), nil
case *UnifiedConfig:
// Already unified
return v, nil
case map[string]interface{}:
// Convert map to JSON then migrate
jsonData, err := json.Marshal(v)
if err != nil {
return nil, err
}
config, _, err := migrator.Migrate(jsonData)
return config, err
default:
return nil, fmt.Errorf("unsupported config type: %T", v)
}
}
// Helper functions
func getNestedMap(m map[string]interface{}, key string) (map[string]interface{}, bool) {
if val, exists := m[key]; exists {
if mapped, ok := val.(map[string]interface{}); ok {
return mapped, true
}
}
return nil, false
}
func getStringValue(m map[string]interface{}, keys ...string) string {
for _, key := range keys {
if val, exists := m[key]; exists {
return getStringFromInterface(val)
}
}
return ""
}
func getStringFromInterface(val interface{}) string {
if val == nil {
return ""
}
switch v := val.(type) {
case string:
return v
case []byte:
return string(v)
default:
return fmt.Sprintf("%v", v)
}
}
func getBoolValue(m map[string]interface{}, keys ...string) bool {
for _, key := range keys {
if val, exists := m[key]; exists {
if b, ok := val.(bool); ok {
return b
}
// Try string conversion
if s, ok := val.(string); ok {
return strings.ToLower(s) == "true"
}
}
}
return false
}
func getIntValue(m map[string]interface{}, keys ...string) int {
for _, key := range keys {
if val, exists := m[key]; exists {
switch v := val.(type) {
case int:
return v
case int64:
return int(v)
case float64:
return int(v)
case string:
// Try to parse
var i int
if _, err := fmt.Sscanf(v, "%d", &i); err != nil {
// If parsing fails, return default
return 0
}
return i
}
}
}
return 0
}
func getArrayValue(m map[string]interface{}, keys ...string) []string {
for _, key := range keys {
if val, exists := m[key]; exists {
if arr, ok := val.([]interface{}); ok {
result := make([]string, 0, len(arr))
for _, item := range arr {
result = append(result, getStringFromInterface(item))
}
return result
}
if strArr, ok := val.([]string); ok {
return strArr
}
}
}
return nil
}
func mapToStruct(m interface{}, target interface{}) error {
// Simple mapping using JSON as intermediate
data, err := json.Marshal(m)
if err != nil {
return err
}
return json.Unmarshal(data, target)
}
File diff suppressed because it is too large Load Diff
-297
View File
@@ -1,297 +0,0 @@
// Package config provides configuration structures for the Traefik OIDC plugin.
package config
import (
"os"
"strconv"
"strings"
"time"
)
// RedisMode represents the Redis deployment mode
type RedisMode string
const (
// RedisModeStandalone represents a single Redis instance
RedisModeStandalone RedisMode = "standalone"
// RedisModeCluster represents Redis cluster mode
RedisModeCluster RedisMode = "cluster"
// RedisModeSentinel represents Redis sentinel mode
RedisModeSentinel RedisMode = "sentinel"
)
// RedisConfig holds Redis cache backend configuration
type RedisConfig struct {
// Enabled indicates if Redis backend should be used
Enabled bool `json:"enabled,omitempty" yaml:"enabled,omitempty"`
// Mode specifies the Redis deployment mode
Mode RedisMode `json:"mode,omitempty" yaml:"mode,omitempty"`
// === Standalone Configuration ===
// Addr is the Redis server address (host:port)
Addr string `json:"addr,omitempty" yaml:"addr,omitempty"`
// Password for Redis authentication
Password string `json:"password,omitempty" yaml:"password,omitempty"`
// DB is the database number (0-15)
DB int `json:"db,omitempty" yaml:"db,omitempty"`
// === Cluster Configuration ===
// ClusterAddrs is the list of cluster node addresses
ClusterAddrs []string `json:"clusterAddrs,omitempty" yaml:"clusterAddrs,omitempty"`
// === Sentinel Configuration ===
// MasterName is the name of the master instance
MasterName string `json:"masterName,omitempty" yaml:"masterName,omitempty"`
// SentinelAddrs is the list of sentinel addresses
SentinelAddrs []string `json:"sentinelAddrs,omitempty" yaml:"sentinelAddrs,omitempty"`
// SentinelPassword is the password for sentinel authentication
SentinelPassword string `json:"sentinelPassword,omitempty" yaml:"sentinelPassword,omitempty"`
// === Connection Pool Settings ===
// PoolSize is the maximum number of socket connections
PoolSize int `json:"poolSize,omitempty" yaml:"poolSize,omitempty"`
// MinIdleConns is the minimum number of idle connections
MinIdleConns int `json:"minIdleConns,omitempty" yaml:"minIdleConns,omitempty"`
// MaxRetries is the maximum number of retries before giving up
MaxRetries int `json:"maxRetries,omitempty" yaml:"maxRetries,omitempty"`
// === Timeouts ===
// DialTimeout is the timeout for establishing new connections
DialTimeout time.Duration `json:"dialTimeout,omitempty" yaml:"dialTimeout,omitempty"`
// ReadTimeout is the timeout for socket reads
ReadTimeout time.Duration `json:"readTimeout,omitempty" yaml:"readTimeout,omitempty"`
// WriteTimeout is the timeout for socket writes
WriteTimeout time.Duration `json:"writeTimeout,omitempty" yaml:"writeTimeout,omitempty"`
// PoolTimeout is the timeout for connection pool
PoolTimeout time.Duration `json:"poolTimeout,omitempty" yaml:"poolTimeout,omitempty"`
// ConnMaxIdleTime is the maximum amount of time a connection may be idle
ConnMaxIdleTime time.Duration `json:"connMaxIdleTime,omitempty" yaml:"connMaxIdleTime,omitempty"`
// ConnMaxLifetime is the maximum lifetime of a connection
ConnMaxLifetime time.Duration `json:"connMaxLifetime,omitempty" yaml:"connMaxLifetime,omitempty"`
// === Key Management ===
// KeyPrefix is the prefix for all Redis keys
KeyPrefix string `json:"keyPrefix,omitempty" yaml:"keyPrefix,omitempty"`
// === TLS Configuration ===
// TLSEnabled enables TLS for Redis connections
TLSEnabled bool `json:"tlsEnabled,omitempty" yaml:"tlsEnabled,omitempty"`
// TLSInsecureSkipVerify skips TLS certificate verification
TLSInsecureSkipVerify bool `json:"tlsInsecureSkipVerify,omitempty" yaml:"tlsInsecureSkipVerify,omitempty"`
// === Resilience Settings ===
// EnableCircuitBreaker enables circuit breaker for Redis operations
EnableCircuitBreaker bool `json:"enableCircuitBreaker,omitempty" yaml:"enableCircuitBreaker,omitempty"`
// CircuitBreakerMaxFailures is the number of failures before opening circuit
CircuitBreakerMaxFailures int `json:"circuitBreakerMaxFailures,omitempty" yaml:"circuitBreakerMaxFailures,omitempty"`
// CircuitBreakerTimeout is how long the circuit stays open
CircuitBreakerTimeout time.Duration `json:"circuitBreakerTimeout,omitempty" yaml:"circuitBreakerTimeout,omitempty"`
// EnableHealthCheck enables periodic health checks
EnableHealthCheck bool `json:"enableHealthCheck,omitempty" yaml:"enableHealthCheck,omitempty"`
// HealthCheckInterval is how often to check Redis health
HealthCheckInterval time.Duration `json:"healthCheckInterval,omitempty" yaml:"healthCheckInterval,omitempty"`
}
// DefaultRedisConfig returns default Redis configuration
func DefaultRedisConfig() *RedisConfig {
return &RedisConfig{
Enabled: false,
Mode: RedisModeStandalone,
Addr: "localhost:6379",
DB: 0,
PoolSize: 10,
MinIdleConns: 2,
MaxRetries: 3,
DialTimeout: 5 * time.Second,
ReadTimeout: 3 * time.Second,
WriteTimeout: 3 * time.Second,
PoolTimeout: 4 * time.Second,
ConnMaxIdleTime: 5 * time.Minute,
ConnMaxLifetime: 30 * time.Minute,
KeyPrefix: "traefikoidc:",
TLSEnabled: false,
TLSInsecureSkipVerify: false,
EnableCircuitBreaker: true,
CircuitBreakerMaxFailures: 5,
CircuitBreakerTimeout: 30 * time.Second,
EnableHealthCheck: true,
HealthCheckInterval: 30 * time.Second,
}
}
// LoadFromEnv loads Redis configuration from environment variables
func (c *RedisConfig) LoadFromEnv() {
// Enable Redis if environment variable is set
if enabled := os.Getenv("REDIS_ENABLED"); enabled != "" {
c.Enabled = strings.ToLower(enabled) == "true"
}
// Mode
if mode := os.Getenv("REDIS_MODE"); mode != "" {
c.Mode = RedisMode(strings.ToLower(mode))
}
// Standalone configuration
if addr := os.Getenv("REDIS_ADDR"); addr != "" {
c.Addr = addr
}
if password := os.Getenv("REDIS_PASSWORD"); password != "" {
c.Password = password
}
if db := os.Getenv("REDIS_DB"); db != "" {
if dbNum, err := strconv.Atoi(db); err == nil {
c.DB = dbNum
}
}
// Cluster configuration
if clusterAddrs := os.Getenv("REDIS_CLUSTER_ADDRS"); clusterAddrs != "" {
c.ClusterAddrs = strings.Split(clusterAddrs, ",")
for i := range c.ClusterAddrs {
c.ClusterAddrs[i] = strings.TrimSpace(c.ClusterAddrs[i])
}
}
// Sentinel configuration
if masterName := os.Getenv("REDIS_MASTER_NAME"); masterName != "" {
c.MasterName = masterName
}
if sentinelAddrs := os.Getenv("REDIS_SENTINEL_ADDRS"); sentinelAddrs != "" {
c.SentinelAddrs = strings.Split(sentinelAddrs, ",")
for i := range c.SentinelAddrs {
c.SentinelAddrs[i] = strings.TrimSpace(c.SentinelAddrs[i])
}
}
if sentinelPassword := os.Getenv("REDIS_SENTINEL_PASSWORD"); sentinelPassword != "" {
c.SentinelPassword = sentinelPassword
}
// Connection pool settings
if poolSize := os.Getenv("REDIS_POOL_SIZE"); poolSize != "" {
if size, err := strconv.Atoi(poolSize); err == nil {
c.PoolSize = size
}
}
if minIdleConns := os.Getenv("REDIS_MIN_IDLE_CONNS"); minIdleConns != "" {
if conns, err := strconv.Atoi(minIdleConns); err == nil {
c.MinIdleConns = conns
}
}
if maxRetries := os.Getenv("REDIS_MAX_RETRIES"); maxRetries != "" {
if retries, err := strconv.Atoi(maxRetries); err == nil {
c.MaxRetries = retries
}
}
// Timeouts
if dialTimeout := os.Getenv("REDIS_DIAL_TIMEOUT"); dialTimeout != "" {
if timeout, err := time.ParseDuration(dialTimeout); err == nil {
c.DialTimeout = timeout
}
}
if readTimeout := os.Getenv("REDIS_READ_TIMEOUT"); readTimeout != "" {
if timeout, err := time.ParseDuration(readTimeout); err == nil {
c.ReadTimeout = timeout
}
}
if writeTimeout := os.Getenv("REDIS_WRITE_TIMEOUT"); writeTimeout != "" {
if timeout, err := time.ParseDuration(writeTimeout); err == nil {
c.WriteTimeout = timeout
}
}
// Key prefix
if keyPrefix := os.Getenv("REDIS_KEY_PREFIX"); keyPrefix != "" {
c.KeyPrefix = keyPrefix
}
// TLS settings
if tlsEnabled := os.Getenv("REDIS_TLS_ENABLED"); tlsEnabled != "" {
c.TLSEnabled = strings.ToLower(tlsEnabled) == "true"
}
if tlsInsecure := os.Getenv("REDIS_TLS_INSECURE_SKIP_VERIFY"); tlsInsecure != "" {
c.TLSInsecureSkipVerify = strings.ToLower(tlsInsecure) == "true"
}
// Resilience settings
if enableCB := os.Getenv("REDIS_ENABLE_CIRCUIT_BREAKER"); enableCB != "" {
c.EnableCircuitBreaker = strings.ToLower(enableCB) == "true"
}
if cbMaxFailures := os.Getenv("REDIS_CIRCUIT_BREAKER_MAX_FAILURES"); cbMaxFailures != "" {
if failures, err := strconv.Atoi(cbMaxFailures); err == nil {
c.CircuitBreakerMaxFailures = failures
}
}
if cbTimeout := os.Getenv("REDIS_CIRCUIT_BREAKER_TIMEOUT"); cbTimeout != "" {
if timeout, err := time.ParseDuration(cbTimeout); err == nil {
c.CircuitBreakerTimeout = timeout
}
}
if enableHC := os.Getenv("REDIS_ENABLE_HEALTH_CHECK"); enableHC != "" {
c.EnableHealthCheck = strings.ToLower(enableHC) == "true"
}
if hcInterval := os.Getenv("REDIS_HEALTH_CHECK_INTERVAL"); hcInterval != "" {
if interval, err := time.ParseDuration(hcInterval); err == nil {
c.HealthCheckInterval = interval
}
}
}
// Validate checks if the configuration is valid
func (c *RedisConfig) Validate() error {
if !c.Enabled {
return nil
}
switch c.Mode {
case RedisModeStandalone:
if c.Addr == "" {
return &ConfigError{Field: "addr", Message: "Redis address is required for standalone mode"}
}
case RedisModeCluster:
if len(c.ClusterAddrs) == 0 {
return &ConfigError{Field: "clusterAddrs", Message: "At least one cluster address is required"}
}
case RedisModeSentinel:
if c.MasterName == "" {
return &ConfigError{Field: "masterName", Message: "Master name is required for sentinel mode"}
}
if len(c.SentinelAddrs) == 0 {
return &ConfigError{Field: "sentinelAddrs", Message: "At least one sentinel address is required"}
}
default:
return &ConfigError{Field: "mode", Message: "Invalid Redis mode"}
}
return nil
}
// ConfigError represents a configuration validation error
type ConfigError struct {
Field string
Message string
}
// Error implements the error interface
func (e *ConfigError) Error() string {
return "redis config error: " + e.Field + ": " + e.Message
}
-511
View File
@@ -1,511 +0,0 @@
// Package config provides configuration management for the OIDC middleware
package config
import (
"fmt"
"net/http"
"strconv"
"strings"
)
const (
minEncryptionKeyLength = 16
ConstSessionTimeout = 86400
)
//lint:ignore U1000 May be referenced for default exclusion patterns
var defaultExcludedURLs = map[string]struct{}{
"/favicon.ico": {},
"/robots.txt": {},
"/health": {},
"/.well-known/": {},
"/metrics": {},
"/ping": {},
"/api/": {},
"/static/": {},
"/assets/": {},
"/js/": {},
"/css/": {},
"/images/": {},
"/fonts/": {},
}
// Settings manages configuration and initialization for the OIDC middleware
type Settings struct {
logger Logger
}
// Logger interface for dependency injection
type Logger interface {
Debug(msg string)
Debugf(format string, args ...interface{})
Info(msg string)
Infof(format string, args ...interface{})
Error(msg string)
Errorf(format string, args ...interface{})
}
// Config represents the configuration for the OIDC middleware
type Config struct {
ProviderURL string `json:"providerUrl"`
ClientID string `json:"clientId"`
ClientSecret string `json:"clientSecret"`
CallbackURL string `json:"callbackUrl"`
LogoutURL string `json:"logoutUrl"`
PostLogoutRedirectURI string `json:"postLogoutRedirectUri"`
SessionEncryptionKey string `json:"sessionEncryptionKey"`
ForceHTTPS bool `json:"forceHttps"`
LogLevel string `json:"logLevel"`
Scopes []string `json:"scopes"`
OverrideScopes bool `json:"overrideScopes"`
AllowedUsers []string `json:"allowedUsers"`
AllowedUserDomains []string `json:"allowedUserDomains"`
AllowedRolesAndGroups []string `json:"allowedRolesAndGroups"`
ExcludedURLs []string `json:"excludedUrls"`
EnablePKCE bool `json:"enablePkce"`
RateLimit int `json:"rateLimit"`
RefreshGracePeriodSeconds int `json:"refreshGracePeriodSeconds"`
Headers []HeaderConfig `json:"headers"`
HTTPClient *http.Client `json:"-"`
CookieDomain string `json:"cookieDomain"`
SecurityHeaders *SecurityHeadersConfig `json:"securityHeaders,omitempty"`
// Dynamic Client Registration (RFC 7591) configuration
DynamicClientRegistration *DynamicClientRegistrationConfig `json:"dynamicClientRegistration,omitempty"`
}
// DynamicClientRegistrationConfig configures OIDC Dynamic Client Registration (RFC 7591)
type DynamicClientRegistrationConfig struct {
// Enabled enables automatic client registration with the OIDC provider
Enabled bool `json:"enabled"`
// InitialAccessToken is an optional bearer token for protected registration endpoints
// Some providers require this token to authorize new client registrations
InitialAccessToken string `json:"initialAccessToken,omitempty"`
// RegistrationEndpoint overrides the endpoint discovered from provider metadata
// If empty, uses the registration_endpoint from .well-known/openid-configuration
RegistrationEndpoint string `json:"registrationEndpoint,omitempty"`
// ClientMetadata contains the client metadata to register
ClientMetadata *ClientRegistrationMetadata `json:"clientMetadata,omitempty"`
// PersistCredentials determines whether to save registered credentials to a file
// This allows reusing the same client_id/client_secret across restarts
PersistCredentials bool `json:"persistCredentials"`
// CredentialsFile is the path to store/load registered client credentials
// Defaults to "/tmp/oidc-client-credentials.json" if not specified
CredentialsFile string `json:"credentialsFile,omitempty"`
}
// ClientRegistrationMetadata contains client metadata for dynamic registration (RFC 7591)
type ClientRegistrationMetadata struct {
// RedirectURIs is REQUIRED - array of redirect URIs for authorization
RedirectURIs []string `json:"redirect_uris"`
// ResponseTypes specifies OAuth 2.0 response types (default: ["code"])
ResponseTypes []string `json:"response_types,omitempty"`
// GrantTypes specifies OAuth 2.0 grant types (default: ["authorization_code"])
GrantTypes []string `json:"grant_types,omitempty"`
// ApplicationType is either "web" (default) or "native"
ApplicationType string `json:"application_type,omitempty"`
// Contacts is an array of email addresses for responsible parties
Contacts []string `json:"contacts,omitempty"`
// ClientName is a human-readable name for the client
ClientName string `json:"client_name,omitempty"`
// LogoURI is a URL pointing to a logo for the client
LogoURI string `json:"logo_uri,omitempty"`
// ClientURI is a URL of the home page of the client
ClientURI string `json:"client_uri,omitempty"`
// PolicyURI is a URL pointing to the client's privacy policy
PolicyURI string `json:"policy_uri,omitempty"`
// TOSURI is a URL pointing to the client's terms of service
TOSURI string `json:"tos_uri,omitempty"`
// JWKSURI is a URL for the client's JSON Web Key Set
JWKSURI string `json:"jwks_uri,omitempty"`
// SubjectType is "pairwise" or "public" (provider-specific)
SubjectType string `json:"subject_type,omitempty"`
// TokenEndpointAuthMethod specifies how the client authenticates at token endpoint
// Values: "client_secret_basic", "client_secret_post", "client_secret_jwt", "private_key_jwt", "none"
TokenEndpointAuthMethod string `json:"token_endpoint_auth_method,omitempty"`
// DefaultMaxAge is the default maximum authentication age in seconds
DefaultMaxAge int `json:"default_max_age,omitempty"`
// RequireAuthTime specifies whether auth_time claim is required in ID token
RequireAuthTime bool `json:"require_auth_time,omitempty"`
// DefaultACRValues specifies default ACR values
DefaultACRValues []string `json:"default_acr_values,omitempty"`
// Scope is a space-separated list of scopes (alternative to config.Scopes)
Scope string `json:"scope,omitempty"`
}
// HeaderConfig represents header template configuration
type HeaderConfig struct {
Name string `json:"name"`
Value string `json:"value"`
}
// SecurityHeadersConfig configures security headers for the plugin
type SecurityHeadersConfig struct {
// Enable security headers (default: true)
Enabled bool `json:"enabled"`
// Security profile: "default", "strict", "development", "api", or "custom"
Profile string `json:"profile"`
// Content Security Policy
ContentSecurityPolicy string `json:"contentSecurityPolicy,omitempty"`
// HSTS settings
StrictTransportSecurity bool `json:"strictTransportSecurity"`
StrictTransportSecurityMaxAge int `json:"strictTransportSecurityMaxAge"` // seconds
StrictTransportSecuritySubdomains bool `json:"strictTransportSecuritySubdomains"`
StrictTransportSecurityPreload bool `json:"strictTransportSecurityPreload"`
// Frame options: "DENY", "SAMEORIGIN", or "ALLOW-FROM uri"
FrameOptions string `json:"frameOptions,omitempty"`
// Content type options (default: "nosniff")
ContentTypeOptions string `json:"contentTypeOptions,omitempty"`
// XSS protection (default: "1; mode=block")
XSSProtection string `json:"xssProtection,omitempty"`
// Referrer policy
ReferrerPolicy string `json:"referrerPolicy,omitempty"`
// Permissions policy
PermissionsPolicy string `json:"permissionsPolicy,omitempty"`
// Cross-origin settings
CrossOriginEmbedderPolicy string `json:"crossOriginEmbedderPolicy,omitempty"`
CrossOriginOpenerPolicy string `json:"crossOriginOpenerPolicy,omitempty"`
CrossOriginResourcePolicy string `json:"crossOriginResourcePolicy,omitempty"`
// CORS settings
CORSEnabled bool `json:"corsEnabled"`
CORSAllowedOrigins []string `json:"corsAllowedOrigins,omitempty"`
CORSAllowedMethods []string `json:"corsAllowedMethods,omitempty"`
CORSAllowedHeaders []string `json:"corsAllowedHeaders,omitempty"`
CORSAllowCredentials bool `json:"corsAllowCredentials"`
CORSMaxAge int `json:"corsMaxAge"` // seconds
// Custom headers (in addition to standard security headers)
CustomHeaders map[string]string `json:"customHeaders,omitempty"`
// Security features
DisableServerHeader bool `json:"disableServerHeader"`
DisablePoweredByHeader bool `json:"disablePoweredByHeader"`
}
// NewSettings creates a new Settings instance
func NewSettings(logger Logger) *Settings {
return &Settings{
logger: logger,
}
}
// CreateConfig creates a default configuration
func CreateConfig() *Config {
return &Config{
LogLevel: "INFO",
ForceHTTPS: true,
EnablePKCE: true,
RateLimit: 10,
RefreshGracePeriodSeconds: 60,
Scopes: []string{"openid", "profile", "email"},
Headers: []HeaderConfig{},
SecurityHeaders: createDefaultSecurityConfig(),
}
}
// createDefaultSecurityConfig creates a default security headers configuration
func createDefaultSecurityConfig() *SecurityHeadersConfig {
return &SecurityHeadersConfig{
Enabled: true,
Profile: "default",
// Default security headers
StrictTransportSecurity: true,
StrictTransportSecurityMaxAge: 31536000, // 1 year
StrictTransportSecuritySubdomains: true,
StrictTransportSecurityPreload: true,
FrameOptions: "DENY",
ContentTypeOptions: "nosniff",
XSSProtection: "1; mode=block",
ReferrerPolicy: "strict-origin-when-cross-origin",
// CORS disabled by default
CORSEnabled: false,
CORSAllowedMethods: []string{"GET", "POST", "OPTIONS"},
CORSAllowedHeaders: []string{"Authorization", "Content-Type"},
CORSAllowCredentials: false,
CORSMaxAge: 86400, // 24 hours
// Security features
DisableServerHeader: true,
DisablePoweredByHeader: true,
}
}
// ToInternalSecurityConfig converts plugin SecurityHeadersConfig to internal security config
func (c *SecurityHeadersConfig) ToInternalSecurityConfig() interface{} {
if c == nil || !c.Enabled {
return nil
}
// Create the internal security config structure
config := map[string]interface{}{
"DevelopmentMode": false,
}
// Apply profile-based defaults
switch strings.ToLower(c.Profile) {
case "strict":
applyStrictProfile(config)
case "development":
applyDevelopmentProfile(config)
case "api":
applyAPIProfile(config)
case "custom":
// No defaults, use only what's explicitly configured
default: // "default"
applyDefaultProfile(config)
}
// Override with explicit configuration
if c.ContentSecurityPolicy != "" {
config["ContentSecurityPolicy"] = c.ContentSecurityPolicy
}
// HSTS configuration
if c.StrictTransportSecurity {
config["StrictTransportSecurityMaxAge"] = c.StrictTransportSecurityMaxAge
config["StrictTransportSecuritySubdomains"] = c.StrictTransportSecuritySubdomains
config["StrictTransportSecurityPreload"] = c.StrictTransportSecurityPreload
}
// Frame options
if c.FrameOptions != "" {
config["FrameOptions"] = c.FrameOptions
}
// Content type and XSS protection
if c.ContentTypeOptions != "" {
config["ContentTypeOptions"] = c.ContentTypeOptions
}
if c.XSSProtection != "" {
config["XSSProtection"] = c.XSSProtection
}
// Referrer and permissions policies
if c.ReferrerPolicy != "" {
config["ReferrerPolicy"] = c.ReferrerPolicy
}
if c.PermissionsPolicy != "" {
config["PermissionsPolicy"] = c.PermissionsPolicy
}
// Cross-origin policies
if c.CrossOriginEmbedderPolicy != "" {
config["CrossOriginEmbedderPolicy"] = c.CrossOriginEmbedderPolicy
}
if c.CrossOriginOpenerPolicy != "" {
config["CrossOriginOpenerPolicy"] = c.CrossOriginOpenerPolicy
}
if c.CrossOriginResourcePolicy != "" {
config["CrossOriginResourcePolicy"] = c.CrossOriginResourcePolicy
}
// CORS configuration
config["CORSEnabled"] = c.CORSEnabled
if len(c.CORSAllowedOrigins) > 0 {
config["CORSAllowedOrigins"] = c.CORSAllowedOrigins
}
if len(c.CORSAllowedMethods) > 0 {
config["CORSAllowedMethods"] = c.CORSAllowedMethods
}
if len(c.CORSAllowedHeaders) > 0 {
config["CORSAllowedHeaders"] = c.CORSAllowedHeaders
}
config["CORSAllowCredentials"] = c.CORSAllowCredentials
if c.CORSMaxAge > 0 {
config["CORSMaxAge"] = c.CORSMaxAge
}
// Custom headers
if len(c.CustomHeaders) > 0 {
config["CustomHeaders"] = c.CustomHeaders
}
// Security features
config["DisableServerHeader"] = c.DisableServerHeader
config["DisablePoweredByHeader"] = c.DisablePoweredByHeader
return config
}
// applyDefaultProfile applies default security settings
func applyDefaultProfile(config map[string]interface{}) {
config["ContentSecurityPolicy"] = "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'; img-src 'self' data: https:; font-src 'self' data:; connect-src 'self'; frame-ancestors 'none';"
config["FrameOptions"] = "DENY"
config["ContentTypeOptions"] = "nosniff"
config["XSSProtection"] = "1; mode=block"
config["ReferrerPolicy"] = "strict-origin-when-cross-origin"
config["PermissionsPolicy"] = "geolocation=(), microphone=(), camera=(), payment=(), usb=()"
config["CrossOriginEmbedderPolicy"] = "require-corp"
config["CrossOriginOpenerPolicy"] = "same-origin"
config["CrossOriginResourcePolicy"] = "same-origin"
}
// applyStrictProfile applies strict security settings
func applyStrictProfile(config map[string]interface{}) {
config["ContentSecurityPolicy"] = "default-src 'none'; script-src 'self'; style-src 'self'; img-src 'self'; font-src 'self'; connect-src 'self'; frame-ancestors 'none'; base-uri 'self'; form-action 'self';"
config["FrameOptions"] = "DENY"
config["ContentTypeOptions"] = "nosniff"
config["XSSProtection"] = "1; mode=block"
config["ReferrerPolicy"] = "strict-origin-when-cross-origin"
config["PermissionsPolicy"] = "geolocation=(), microphone=(), camera=(), payment=(), usb=(), magnetometer=(), gyroscope=(), speaker=()"
config["CrossOriginEmbedderPolicy"] = "require-corp"
config["CrossOriginOpenerPolicy"] = "same-origin"
config["CrossOriginResourcePolicy"] = "same-site"
}
// applyDevelopmentProfile applies development-friendly settings
func applyDevelopmentProfile(config map[string]interface{}) {
config["ContentSecurityPolicy"] = "default-src 'self' 'unsafe-inline' 'unsafe-eval'; img-src 'self' data: https: http:; connect-src 'self' ws: wss:;"
config["FrameOptions"] = "SAMEORIGIN"
config["ContentTypeOptions"] = "nosniff"
config["XSSProtection"] = "1; mode=block"
config["ReferrerPolicy"] = "strict-origin-when-cross-origin"
config["CrossOriginOpenerPolicy"] = "unsafe-none"
config["CrossOriginResourcePolicy"] = "cross-origin"
config["DevelopmentMode"] = true
}
// applyAPIProfile applies API-friendly settings
func applyAPIProfile(config map[string]interface{}) {
config["ContentSecurityPolicy"] = "default-src 'none'; frame-ancestors 'none';"
config["FrameOptions"] = "DENY"
config["ContentTypeOptions"] = "nosniff"
config["XSSProtection"] = "1; mode=block"
config["ReferrerPolicy"] = "strict-origin-when-cross-origin"
config["CrossOriginResourcePolicy"] = "cross-origin"
}
// GetSecurityHeadersApplier returns a function that applies security headers
func (c *Config) GetSecurityHeadersApplier() func(http.ResponseWriter, *http.Request) {
if c.SecurityHeaders == nil || !c.SecurityHeaders.Enabled {
return nil
}
// This would need to import the internal security package
// For now, return a basic implementation
return func(rw http.ResponseWriter, req *http.Request) {
headers := rw.Header()
// Apply basic security headers based on configuration
if c.SecurityHeaders.FrameOptions != "" {
headers.Set("X-Frame-Options", c.SecurityHeaders.FrameOptions)
}
if c.SecurityHeaders.ContentTypeOptions != "" {
headers.Set("X-Content-Type-Options", c.SecurityHeaders.ContentTypeOptions)
}
if c.SecurityHeaders.XSSProtection != "" {
headers.Set("X-XSS-Protection", c.SecurityHeaders.XSSProtection)
}
if c.SecurityHeaders.ReferrerPolicy != "" {
headers.Set("Referrer-Policy", c.SecurityHeaders.ReferrerPolicy)
}
if c.SecurityHeaders.ContentSecurityPolicy != "" {
headers.Set("Content-Security-Policy", c.SecurityHeaders.ContentSecurityPolicy)
}
// HSTS for HTTPS
if (req.TLS != nil || req.Header.Get("X-Forwarded-Proto") == "https") && c.SecurityHeaders.StrictTransportSecurity {
hstsValue := fmt.Sprintf("max-age=%d", c.SecurityHeaders.StrictTransportSecurityMaxAge)
if c.SecurityHeaders.StrictTransportSecuritySubdomains {
hstsValue += "; includeSubDomains"
}
if c.SecurityHeaders.StrictTransportSecurityPreload {
hstsValue += "; preload"
}
headers.Set("Strict-Transport-Security", hstsValue)
}
// CORS headers
if c.SecurityHeaders.CORSEnabled {
origin := req.Header.Get("Origin")
if origin != "" && isOriginAllowed(origin, c.SecurityHeaders.CORSAllowedOrigins) {
headers.Set("Access-Control-Allow-Origin", origin)
}
if len(c.SecurityHeaders.CORSAllowedMethods) > 0 {
headers.Set("Access-Control-Allow-Methods", strings.Join(c.SecurityHeaders.CORSAllowedMethods, ", "))
}
if len(c.SecurityHeaders.CORSAllowedHeaders) > 0 {
headers.Set("Access-Control-Allow-Headers", strings.Join(c.SecurityHeaders.CORSAllowedHeaders, ", "))
}
if c.SecurityHeaders.CORSAllowCredentials {
headers.Set("Access-Control-Allow-Credentials", "true")
}
if c.SecurityHeaders.CORSMaxAge > 0 {
headers.Set("Access-Control-Max-Age", strconv.Itoa(c.SecurityHeaders.CORSMaxAge))
}
}
// Custom headers
for name, value := range c.SecurityHeaders.CustomHeaders {
headers.Set(name, value)
}
// Remove server headers
if c.SecurityHeaders.DisableServerHeader {
headers.Del("Server")
}
if c.SecurityHeaders.DisablePoweredByHeader {
headers.Del("X-Powered-By")
}
}
}
// isOriginAllowed checks if an origin is in the allowed list
func isOriginAllowed(origin string, allowedOrigins []string) bool {
for _, allowed := range allowedOrigins {
if origin == allowed || allowed == "*" {
return true
}
// Simple wildcard matching for subdomains
if strings.Contains(allowed, "*") {
if strings.HasPrefix(allowed, "https://*.") {
domain := strings.TrimPrefix(allowed, "https://*.")
if strings.HasSuffix(origin, "."+domain) || origin == "https://"+domain {
return true
}
}
if strings.HasPrefix(allowed, "http://*.") {
domain := strings.TrimPrefix(allowed, "http://*.")
if strings.HasSuffix(origin, "."+domain) || origin == "http://"+domain {
return true
}
}
}
}
return false
}
-287
View File
@@ -1,287 +0,0 @@
// Package config provides unified configuration management for the OIDC middleware
package config
import (
"time"
)
// UnifiedConfig is the master configuration structure consolidating all config aspects
// This replaces 45 duplicate config structs across the codebase
type UnifiedConfig struct {
// Core Configuration
Provider ProviderConfig `json:"provider" yaml:"provider"`
Session SessionConfig `json:"session" yaml:"session"`
Token TokenConfig `json:"token" yaml:"token"`
Redis RedisConfig `json:"redis" yaml:"redis"`
Security SecurityConfig `json:"security" yaml:"security"`
// Middleware Configuration
Middleware MiddlewareConfig `json:"middleware" yaml:"middleware"`
Cache CacheConfig `json:"cache" yaml:"cache"`
RateLimit RateLimitConfig `json:"rateLimit" yaml:"rateLimit"`
// Operational Configuration
Logging LoggingConfig `json:"logging" yaml:"logging"`
Metrics MetricsConfig `json:"metrics" yaml:"metrics"`
Health HealthConfig `json:"health" yaml:"health"`
// Advanced Configuration
Transport TransportConfig `json:"transport" yaml:"transport"`
Pool PoolConfig `json:"pool" yaml:"pool"`
Circuit CircuitConfig `json:"circuit" yaml:"circuit"`
// Compatibility field for migration
Legacy map[string]interface{} `json:"-" yaml:"-"`
}
// ProviderConfig contains OIDC provider settings
type ProviderConfig struct {
IssuerURL string `json:"issuerURL" yaml:"issuerURL"`
ClientID string `json:"clientID" yaml:"clientID"`
ClientSecret string `json:"clientSecret" yaml:"clientSecret"`
RedirectURL string `json:"redirectURL" yaml:"redirectURL"`
LogoutURL string `json:"logoutURL" yaml:"logoutURL"`
PostLogoutRedirectURI string `json:"postLogoutRedirectURI" yaml:"postLogoutRedirectURI"`
Scopes []string `json:"scopes" yaml:"scopes"`
OverrideScopes bool `json:"overrideScopes" yaml:"overrideScopes"`
CustomClaims map[string]string `json:"customClaims" yaml:"customClaims"`
JWKCachePeriod time.Duration `json:"jwkCachePeriod" yaml:"jwkCachePeriod"`
MetadataCacheTTL time.Duration `json:"metadataCacheTTL" yaml:"metadataCacheTTL"`
Discovery bool `json:"discovery" yaml:"discovery"`
// Provider-specific endpoints
AuthorizationEndpoint string `json:"authorizationEndpoint,omitempty" yaml:"authorizationEndpoint,omitempty"`
TokenEndpoint string `json:"tokenEndpoint,omitempty" yaml:"tokenEndpoint,omitempty"`
UserInfoEndpoint string `json:"userInfoEndpoint,omitempty" yaml:"userInfoEndpoint,omitempty"`
JWKSEndpoint string `json:"jwksEndpoint,omitempty" yaml:"jwksEndpoint,omitempty"`
IntrospectEndpoint string `json:"introspectEndpoint,omitempty" yaml:"introspectEndpoint,omitempty"`
RevocationEndpoint string `json:"revocationEndpoint,omitempty" yaml:"revocationEndpoint,omitempty"`
}
// SessionConfig contains session management settings
type SessionConfig struct {
Name string `json:"name" yaml:"name"`
MaxAge int `json:"maxAge" yaml:"maxAge"`
Secret string `json:"secret" yaml:"secret"`
EncryptionKey string `json:"encryptionKey" yaml:"encryptionKey"`
SigningKey string `json:"signingKey" yaml:"signingKey"`
ChunkSize int `json:"chunkSize" yaml:"chunkSize"`
MaxChunks int `json:"maxChunks" yaml:"maxChunks"`
// Cookie settings
Domain string `json:"domain" yaml:"domain"`
Path string `json:"path" yaml:"path"`
Secure bool `json:"secure" yaml:"secure"`
HttpOnly bool `json:"httpOnly" yaml:"httpOnly"`
SameSite string `json:"sameSite" yaml:"sameSite"`
CookiePrefix string `json:"cookiePrefix" yaml:"cookiePrefix"` // Prefix for cookie names (e.g., "_oidc_myapp_")
// Storage settings
StorageType string `json:"storageType" yaml:"storageType"` // "memory", "redis", "cookie"
CleanupInterval time.Duration `json:"cleanupInterval" yaml:"cleanupInterval"`
}
// TokenConfig contains token handling settings
type TokenConfig struct {
AccessTokenTTL time.Duration `json:"accessTokenTTL" yaml:"accessTokenTTL"`
RefreshTokenTTL time.Duration `json:"refreshTokenTTL" yaml:"refreshTokenTTL"`
RefreshGracePeriod time.Duration `json:"refreshGracePeriod" yaml:"refreshGracePeriod"`
ValidationMode string `json:"validationMode" yaml:"validationMode"` // "jwt", "introspect", "hybrid"
IntrospectURL string `json:"introspectURL" yaml:"introspectURL"`
// Token caching
CacheEnabled bool `json:"cacheEnabled" yaml:"cacheEnabled"`
CacheTTL time.Duration `json:"cacheTTL" yaml:"cacheTTL"`
CacheNegativeTTL time.Duration `json:"cacheNegativeTTL" yaml:"cacheNegativeTTL"`
// Token validation
ValidateSignature bool `json:"validateSignature" yaml:"validateSignature"`
ValidateExpiry bool `json:"validateExpiry" yaml:"validateExpiry"`
ValidateAudience bool `json:"validateAudience" yaml:"validateAudience"`
ValidateIssuer bool `json:"validateIssuer" yaml:"validateIssuer"`
RequiredClaims []string `json:"requiredClaims" yaml:"requiredClaims"`
ClockSkew time.Duration `json:"clockSkew" yaml:"clockSkew"`
}
// SecurityConfig contains security-related settings
type SecurityConfig struct {
ForceHTTPS bool `json:"forceHTTPS" yaml:"forceHTTPS"`
EnablePKCE bool `json:"enablePKCE" yaml:"enablePKCE"`
AllowedUsers []string `json:"allowedUsers" yaml:"allowedUsers"`
AllowedUserDomains []string `json:"allowedUserDomains" yaml:"allowedUserDomains"`
AllowedRolesAndGroups []string `json:"allowedRolesAndGroups" yaml:"allowedRolesAndGroups"`
ExcludedURLs []string `json:"excludedURLs" yaml:"excludedURLs"`
Headers *SecurityHeadersConfig `json:"headers" yaml:"headers"`
// CSRF protection
CSRFProtection bool `json:"csrfProtection" yaml:"csrfProtection"`
CSRFTokenName string `json:"csrfTokenName" yaml:"csrfTokenName"`
CSRFTokenTTL time.Duration `json:"csrfTokenTTL" yaml:"csrfTokenTTL"`
// Additional security
MaxLoginAttempts int `json:"maxLoginAttempts" yaml:"maxLoginAttempts"`
LockoutDuration time.Duration `json:"lockoutDuration" yaml:"lockoutDuration"`
RequireMFA bool `json:"requireMFA" yaml:"requireMFA"`
}
// MiddlewareConfig contains middleware-specific settings
type MiddlewareConfig struct {
Priority int `json:"priority" yaml:"priority"`
SkipPaths []string `json:"skipPaths" yaml:"skipPaths"`
RequirePaths []string `json:"requirePaths" yaml:"requirePaths"`
PassthroughMode bool `json:"passthroughMode" yaml:"passthroughMode"`
// Request handling
MaxRequestSize int64 `json:"maxRequestSize" yaml:"maxRequestSize"`
RequestTimeout time.Duration `json:"requestTimeout" yaml:"requestTimeout"`
IdleTimeout time.Duration `json:"idleTimeout" yaml:"idleTimeout"`
// Response handling
CustomHeaders map[string]string `json:"customHeaders" yaml:"customHeaders"`
RemoveHeaders []string `json:"removeHeaders" yaml:"removeHeaders"`
}
// CacheConfig contains cache configuration
type CacheConfig struct {
Enabled bool `json:"enabled" yaml:"enabled"`
Type string `json:"type" yaml:"type"` // "memory", "redis", "hybrid"
DefaultTTL time.Duration `json:"defaultTTL" yaml:"defaultTTL"`
MaxEntries int `json:"maxEntries" yaml:"maxEntries"`
MaxEntrySize int64 `json:"maxEntrySize" yaml:"maxEntrySize"`
EvictionPolicy string `json:"evictionPolicy" yaml:"evictionPolicy"` // "lru", "lfu", "fifo"
// Memory cache settings
CleanupInterval time.Duration `json:"cleanupInterval" yaml:"cleanupInterval"`
// Distributed cache settings
Namespace string `json:"namespace" yaml:"namespace"`
Compression bool `json:"compression" yaml:"compression"`
Serialization string `json:"serialization" yaml:"serialization"` // "json", "msgpack", "protobuf"
}
// RateLimitConfig contains rate limiting configuration
type RateLimitConfig struct {
Enabled bool `json:"enabled" yaml:"enabled"`
RequestsPerSecond int `json:"requestsPerSecond" yaml:"requestsPerSecond"`
Burst int `json:"burst" yaml:"burst"`
// Rate limit storage
StorageType string `json:"storageType" yaml:"storageType"` // "memory", "redis"
WindowDuration time.Duration `json:"windowDuration" yaml:"windowDuration"`
// Rate limit keys
KeyType string `json:"keyType" yaml:"keyType"` // "ip", "user", "token", "custom"
CustomKeyFunc string `json:"customKeyFunc" yaml:"customKeyFunc"`
// Whitelisting
WhitelistIPs []string `json:"whitelistIPs" yaml:"whitelistIPs"`
WhitelistUsers []string `json:"whitelistUsers" yaml:"whitelistUsers"`
}
// LoggingConfig contains logging configuration
type LoggingConfig struct {
Level string `json:"level" yaml:"level"` // "debug", "info", "warn", "error"
Format string `json:"format" yaml:"format"` // "json", "text", "structured"
Output string `json:"output" yaml:"output"` // "stdout", "stderr", "file"
FilePath string `json:"filePath" yaml:"filePath"`
// Log filtering
FilterSensitive bool `json:"filterSensitive" yaml:"filterSensitive"`
MaskFields []string `json:"maskFields" yaml:"maskFields"`
// Performance
BufferSize int `json:"bufferSize" yaml:"bufferSize"`
FlushInterval time.Duration `json:"flushInterval" yaml:"flushInterval"`
// Audit logging
AuditEnabled bool `json:"auditEnabled" yaml:"auditEnabled"`
AuditEvents []string `json:"auditEvents" yaml:"auditEvents"`
}
// MetricsConfig contains metrics collection configuration
type MetricsConfig struct {
Enabled bool `json:"enabled" yaml:"enabled"`
Provider string `json:"provider" yaml:"provider"` // "prometheus", "statsd", "otlp"
Endpoint string `json:"endpoint" yaml:"endpoint"`
Namespace string `json:"namespace" yaml:"namespace"`
Subsystem string `json:"subsystem" yaml:"subsystem"`
// Collection settings
CollectInterval time.Duration `json:"collectInterval" yaml:"collectInterval"`
Histograms bool `json:"histograms" yaml:"histograms"`
// Custom labels
Labels map[string]string `json:"labels" yaml:"labels"`
}
// HealthConfig contains health check configuration
type HealthConfig struct {
Enabled bool `json:"enabled" yaml:"enabled"`
Path string `json:"path" yaml:"path"`
CheckInterval time.Duration `json:"checkInterval" yaml:"checkInterval"`
Timeout time.Duration `json:"timeout" yaml:"timeout"`
// Checks to perform
CheckProvider bool `json:"checkProvider" yaml:"checkProvider"`
CheckRedis bool `json:"checkRedis" yaml:"checkRedis"`
CheckCache bool `json:"checkCache" yaml:"checkCache"`
// Thresholds
MaxLatency time.Duration `json:"maxLatency" yaml:"maxLatency"`
MinMemory int64 `json:"minMemory" yaml:"minMemory"`
}
// TransportConfig contains HTTP transport configuration
type TransportConfig struct {
MaxIdleConns int `json:"maxIdleConns" yaml:"maxIdleConns"`
MaxIdleConnsPerHost int `json:"maxIdleConnsPerHost" yaml:"maxIdleConnsPerHost"`
MaxConnsPerHost int `json:"maxConnsPerHost" yaml:"maxConnsPerHost"`
IdleConnTimeout time.Duration `json:"idleConnTimeout" yaml:"idleConnTimeout"`
TLSHandshakeTimeout time.Duration `json:"tlsHandshakeTimeout" yaml:"tlsHandshakeTimeout"`
ExpectContinueTimeout time.Duration `json:"expectContinueTimeout" yaml:"expectContinueTimeout"`
ResponseHeaderTimeout time.Duration `json:"responseHeaderTimeout" yaml:"responseHeaderTimeout"`
DisableKeepAlives bool `json:"disableKeepAlives" yaml:"disableKeepAlives"`
DisableCompression bool `json:"disableCompression" yaml:"disableCompression"`
// TLS configuration
TLSInsecureSkipVerify bool `json:"tlsInsecureSkipVerify" yaml:"tlsInsecureSkipVerify"`
TLSMinVersion string `json:"tlsMinVersion" yaml:"tlsMinVersion"`
TLSCipherSuites []string `json:"tlsCipherSuites" yaml:"tlsCipherSuites"`
// Proxy settings
ProxyURL string `json:"proxyURL" yaml:"proxyURL"`
NoProxy []string `json:"noProxy" yaml:"noProxy"`
}
// PoolConfig contains connection pool configuration
type PoolConfig struct {
Enabled bool `json:"enabled" yaml:"enabled"`
Size int `json:"size" yaml:"size"`
MinSize int `json:"minSize" yaml:"minSize"`
MaxSize int `json:"maxSize" yaml:"maxSize"`
MaxAge time.Duration `json:"maxAge" yaml:"maxAge"`
IdleTimeout time.Duration `json:"idleTimeout" yaml:"idleTimeout"`
WaitTimeout time.Duration `json:"waitTimeout" yaml:"waitTimeout"`
// Health checking
HealthCheckInterval time.Duration `json:"healthCheckInterval" yaml:"healthCheckInterval"`
MaxRetries int `json:"maxRetries" yaml:"maxRetries"`
}
// CircuitConfig contains circuit breaker configuration
type CircuitConfig struct {
Enabled bool `json:"enabled" yaml:"enabled"`
MaxRequests uint32 `json:"maxRequests" yaml:"maxRequests"`
Interval time.Duration `json:"interval" yaml:"interval"`
Timeout time.Duration `json:"timeout" yaml:"timeout"`
ConsecutiveFailures uint32 `json:"consecutiveFailures" yaml:"consecutiveFailures"`
FailureRatio float64 `json:"failureRatio" yaml:"failureRatio"`
// Circuit states
OnOpen string `json:"onOpen" yaml:"onOpen"` // "reject", "fallback", "passthrough"
OnHalfOpen string `json:"onHalfOpen" yaml:"onHalfOpen"`
// Monitoring
MetricsEnabled bool `json:"metricsEnabled" yaml:"metricsEnabled"`
LogStateChanges bool `json:"logStateChanges" yaml:"logStateChanges"`
}
-263
View File
@@ -1,263 +0,0 @@
//go:build !yaegi
package config
import (
"encoding/json"
"strings"
"testing"
"gopkg.in/yaml.v3"
)
// TestUnifiedConfigJSONMarshalling tests JSON marshalling with secret redaction
func TestUnifiedConfigJSONMarshalling(t *testing.T) {
config := &UnifiedConfig{
Provider: ProviderConfig{
IssuerURL: "https://auth.example.com",
ClientID: "test-client",
ClientSecret: "super-secret-value",
},
Session: SessionConfig{
Secret: "session-secret",
EncryptionKey: "32-character-encryption-key-here",
SigningKey: "signing-key-secret",
},
Redis: RedisConfig{
Password: "redis-password",
SentinelPassword: "sentinel-password",
},
}
// Marshal to JSON
jsonBytes, err := json.Marshal(config)
if err != nil {
t.Fatalf("Failed to marshal config to JSON: %v", err)
}
jsonStr := string(jsonBytes)
// Verify secrets are redacted
if !contains(jsonStr, `"clientSecret":"[REDACTED]"`) {
t.Error("ClientSecret should be redacted in JSON output")
}
if !contains(jsonStr, `"secret":"[REDACTED]"`) {
t.Error("Session.Secret should be redacted in JSON output")
}
if !contains(jsonStr, `"encryptionKey":"[REDACTED]"`) {
t.Error("Session.EncryptionKey should be redacted in JSON output")
}
if !contains(jsonStr, `"signingKey":"[REDACTED]"`) {
t.Error("Session.SigningKey should be redacted in JSON output")
}
if !contains(jsonStr, `"password":"[REDACTED]"`) {
t.Error("Redis.Password should be redacted in JSON output")
}
if !contains(jsonStr, `"sentinelPassword":"[REDACTED]"`) {
t.Error("Redis.SentinelPassword should be redacted in JSON output")
}
// Verify non-secret fields are preserved
if !contains(jsonStr, `"issuerURL":"https://auth.example.com"`) {
t.Error("IssuerURL should be preserved in JSON output")
}
if !contains(jsonStr, `"clientID":"test-client"`) {
t.Error("ClientID should be preserved in JSON output")
}
}
// TestUnifiedConfigYAMLMarshalling tests YAML marshalling with secret redaction
func TestUnifiedConfigYAMLMarshalling(t *testing.T) {
config := &UnifiedConfig{
Provider: ProviderConfig{
IssuerURL: "https://auth.example.com",
ClientID: "test-client",
ClientSecret: "super-secret-value",
},
Session: SessionConfig{
Secret: "session-secret",
EncryptionKey: "32-character-encryption-key-here",
SigningKey: "signing-key-secret",
},
Redis: RedisConfig{
Password: "redis-password",
SentinelPassword: "sentinel-password",
},
}
// Marshal to YAML
yamlBytes, err := yaml.Marshal(config)
if err != nil {
t.Fatalf("Failed to marshal config to YAML: %v", err)
}
yamlStr := string(yamlBytes)
// Verify secrets are redacted
if !contains(yamlStr, "clientSecret: '[REDACTED]'") {
t.Error("ClientSecret should be redacted in YAML output")
}
if !contains(yamlStr, "secret: '[REDACTED]'") {
t.Error("Session.Secret should be redacted in YAML output")
}
if !contains(yamlStr, "encryptionKey: '[REDACTED]'") {
t.Error("Session.EncryptionKey should be redacted in YAML output")
}
if !contains(yamlStr, "signingKey: '[REDACTED]'") {
t.Error("Session.SigningKey should be redacted in YAML output")
}
if !contains(yamlStr, "password: '[REDACTED]'") {
t.Error("Redis.Password should be redacted in YAML output")
}
if !contains(yamlStr, "sentinelPassword: '[REDACTED]'") {
t.Error("Redis.SentinelPassword should be redacted in YAML output")
}
// Verify non-secret fields are preserved
if !contains(yamlStr, "issuerURL: https://auth.example.com") {
t.Error("IssuerURL should be preserved in YAML output")
}
if !contains(yamlStr, "clientID: test-client") {
t.Error("ClientID should be preserved in YAML output")
}
}
// TestProviderConfigMarshalling tests individual struct marshalling
func TestProviderConfigMarshalling(t *testing.T) {
provider := ProviderConfig{
IssuerURL: "https://auth.example.com",
ClientID: "test-client",
ClientSecret: "super-secret-value",
}
// Test JSON marshalling
jsonBytes, err := json.Marshal(provider)
if err != nil {
t.Fatalf("Failed to marshal ProviderConfig to JSON: %v", err)
}
jsonStr := string(jsonBytes)
if !contains(jsonStr, `"clientSecret":"[REDACTED]"`) {
t.Error("ClientSecret should be redacted in JSON output")
}
if !contains(jsonStr, `"clientID":"test-client"`) {
t.Error("ClientID should be preserved in JSON output")
}
// Test YAML marshalling
yamlBytes, err := yaml.Marshal(provider)
if err != nil {
t.Fatalf("Failed to marshal ProviderConfig to YAML: %v", err)
}
yamlStr := string(yamlBytes)
if !contains(yamlStr, "clientSecret: '[REDACTED]'") {
t.Error("ClientSecret should be redacted in YAML output")
}
if !contains(yamlStr, "clientID: test-client") {
t.Error("ClientID should be preserved in YAML output")
}
}
// TestSessionConfigMarshalling tests session config marshalling
func TestSessionConfigMarshalling(t *testing.T) {
session := SessionConfig{
Name: "session-cookie",
Secret: "session-secret",
EncryptionKey: "32-character-encryption-key-here",
SigningKey: "signing-key-secret",
Domain: "example.com",
Secure: true,
}
// Test JSON marshalling
jsonBytes, err := json.Marshal(session)
if err != nil {
t.Fatalf("Failed to marshal SessionConfig to JSON: %v", err)
}
jsonStr := string(jsonBytes)
if !contains(jsonStr, `"secret":"[REDACTED]"`) {
t.Error("Secret should be redacted in JSON output")
}
if !contains(jsonStr, `"encryptionKey":"[REDACTED]"`) {
t.Error("EncryptionKey should be redacted in JSON output")
}
if !contains(jsonStr, `"signingKey":"[REDACTED]"`) {
t.Error("SigningKey should be redacted in JSON output")
}
if !contains(jsonStr, `"name":"session-cookie"`) {
t.Error("Name should be preserved in JSON output")
}
if !contains(jsonStr, `"domain":"example.com"`) {
t.Error("Domain should be preserved in JSON output")
}
}
// TestRedisConfigMarshalling tests Redis config marshalling
func TestRedisConfigMarshalling(t *testing.T) {
redis := RedisConfig{
Enabled: true,
Mode: RedisModeCluster,
Password: "redis-password",
SentinelPassword: "sentinel-password",
Addr: "localhost:6379",
DB: 1,
}
// Test JSON marshalling
jsonBytes, err := json.Marshal(redis)
if err != nil {
t.Fatalf("Failed to marshal RedisConfig to JSON: %v", err)
}
jsonStr := string(jsonBytes)
if !contains(jsonStr, `"password":"[REDACTED]"`) {
t.Error("Password should be redacted in JSON output")
}
if !contains(jsonStr, `"sentinelPassword":"[REDACTED]"`) {
t.Error("SentinelPassword should be redacted in JSON output")
}
if !contains(jsonStr, `"addr":"localhost:6379"`) {
t.Error("Addr should be preserved in JSON output")
}
if !contains(jsonStr, `"db":1`) {
t.Error("DB should be preserved in JSON output")
}
}
// TestEmptySecretsNotRedacted tests that empty secrets are not shown as redacted
func TestEmptySecretsNotRedacted(t *testing.T) {
config := &UnifiedConfig{
Provider: ProviderConfig{
IssuerURL: "https://auth.example.com",
ClientID: "test-client",
ClientSecret: "", // Empty secret
},
Session: SessionConfig{
Secret: "", // Empty secret
EncryptionKey: "", // Empty secret
},
Redis: RedisConfig{
Password: "", // Empty secret
},
}
// Marshal to JSON
jsonBytes, err := json.Marshal(config)
if err != nil {
t.Fatalf("Failed to marshal config to JSON: %v", err)
}
jsonStr := string(jsonBytes)
// Verify empty secrets are not shown as redacted
if contains(jsonStr, "[REDACTED]") {
t.Error("Empty secrets should not be shown as [REDACTED]")
}
}
// Helper function to check if string contains substring
func contains(s, substr string) bool {
return strings.Contains(s, substr)
}
-652
View File
@@ -1,652 +0,0 @@
// Package config provides validation for unified configuration
package config
import (
"fmt"
"net/url"
"regexp"
"strings"
"time"
)
// ValidationError represents a configuration validation error
type ValidationError struct {
Field string
Message string
Value interface{}
}
// Error implements the error interface
func (e *ValidationError) Error() string {
if e.Value != nil {
return fmt.Sprintf("config validation error: %s: %s (value: %v)", e.Field, e.Message, e.Value)
}
return fmt.Sprintf("config validation error: %s: %s", e.Field, e.Message)
}
// ValidationErrors represents multiple validation errors
type ValidationErrors []ValidationError
// Error implements the error interface
func (e ValidationErrors) Error() string {
if len(e) == 0 {
return ""
}
var messages []string
for _, err := range e {
messages = append(messages, err.Error())
}
return strings.Join(messages, "; ")
}
// Validate performs comprehensive validation on the unified configuration
func (c *UnifiedConfig) Validate() error {
var errors ValidationErrors
// Validate Provider configuration
if err := c.validateProvider(); err != nil {
errors = append(errors, err...)
}
// Validate Session configuration
if err := c.validateSession(); err != nil {
errors = append(errors, err...)
}
// Validate Token configuration
if err := c.validateToken(); err != nil {
errors = append(errors, err...)
}
// Validate Redis configuration (uses existing validation)
if err := c.Redis.Validate(); err != nil {
errors = append(errors, ValidationError{
Field: "Redis",
Message: err.Error(),
})
}
// Validate Security configuration
if err := c.validateSecurity(); err != nil {
errors = append(errors, err...)
}
// Validate Middleware configuration
if err := c.validateMiddleware(); err != nil {
errors = append(errors, err...)
}
// Validate Cache configuration
if err := c.validateCache(); err != nil {
errors = append(errors, err...)
}
// Validate RateLimit configuration
if err := c.validateRateLimit(); err != nil {
errors = append(errors, err...)
}
// Validate Logging configuration
if err := c.validateLogging(); err != nil {
errors = append(errors, err...)
}
// Validate Metrics configuration
if err := c.validateMetrics(); err != nil {
errors = append(errors, err...)
}
// Validate Transport configuration
if err := c.validateTransport(); err != nil {
errors = append(errors, err...)
}
// Validate Circuit configuration
if err := c.validateCircuit(); err != nil {
errors = append(errors, err...)
}
if len(errors) > 0 {
return errors
}
return nil
}
// validateProvider validates provider configuration
func (c *UnifiedConfig) validateProvider() ValidationErrors {
var errors ValidationErrors
// IssuerURL is required and must be a valid URL
if c.Provider.IssuerURL == "" {
errors = append(errors, ValidationError{
Field: "Provider.IssuerURL",
Message: "issuer URL is required",
})
} else if _, err := url.Parse(c.Provider.IssuerURL); err != nil {
errors = append(errors, ValidationError{
Field: "Provider.IssuerURL",
Message: "invalid issuer URL",
Value: c.Provider.IssuerURL,
})
}
// ClientID is required
if c.Provider.ClientID == "" {
errors = append(errors, ValidationError{
Field: "Provider.ClientID",
Message: "client ID is required",
})
}
// ClientSecret is required (except for public clients with PKCE)
if c.Provider.ClientSecret == "" && !c.Security.EnablePKCE {
errors = append(errors, ValidationError{
Field: "Provider.ClientSecret",
Message: "client secret is required (or enable PKCE for public clients)",
})
}
// RedirectURL must be valid if provided
if c.Provider.RedirectURL != "" {
if _, err := url.Parse(c.Provider.RedirectURL); err != nil {
errors = append(errors, ValidationError{
Field: "Provider.RedirectURL",
Message: "invalid redirect URL",
Value: c.Provider.RedirectURL,
})
}
}
// Scopes must include 'openid' for OIDC
hasOpenID := false
for _, scope := range c.Provider.Scopes {
if scope == "openid" {
hasOpenID = true
break
}
}
if !hasOpenID && !c.Provider.OverrideScopes {
errors = append(errors, ValidationError{
Field: "Provider.Scopes",
Message: "scopes must include 'openid' for OIDC",
Value: c.Provider.Scopes,
})
}
// JWK cache period must be positive
if c.Provider.JWKCachePeriod < 0 {
errors = append(errors, ValidationError{
Field: "Provider.JWKCachePeriod",
Message: "JWK cache period must be positive",
Value: c.Provider.JWKCachePeriod,
})
}
return errors
}
// validateSession validates session configuration
func (c *UnifiedConfig) validateSession() ValidationErrors {
var errors ValidationErrors
// Session name must not be empty
if c.Session.Name == "" {
errors = append(errors, ValidationError{
Field: "Session.Name",
Message: "session name is required",
})
}
// Session secret or encryption key is required
if c.Session.Secret == "" && c.Session.EncryptionKey == "" {
errors = append(errors, ValidationError{
Field: "Session",
Message: "either session secret or encryption key is required",
})
}
// Encryption key must be at least 32 bytes for security
if c.Session.EncryptionKey != "" && len(c.Session.EncryptionKey) < 32 {
errors = append(errors, ValidationError{
Field: "Session.EncryptionKey",
Message: "encryption key must be at least 32 characters for proper security",
Value: len(c.Session.EncryptionKey),
})
}
// ChunkSize must be reasonable (between 1KB and 10KB)
if c.Session.ChunkSize < 1000 || c.Session.ChunkSize > 10000 {
errors = append(errors, ValidationError{
Field: "Session.ChunkSize",
Message: "chunk size must be between 1000 and 10000 bytes",
Value: c.Session.ChunkSize,
})
}
// MaxChunks must be reasonable (between 1 and 100)
if c.Session.MaxChunks < 1 || c.Session.MaxChunks > 100 {
errors = append(errors, ValidationError{
Field: "Session.MaxChunks",
Message: "max chunks must be between 1 and 100",
Value: c.Session.MaxChunks,
})
}
// SameSite must be valid
validSameSite := map[string]bool{
"": true,
"Lax": true,
"Strict": true,
"None": true,
}
if !validSameSite[c.Session.SameSite] {
errors = append(errors, ValidationError{
Field: "Session.SameSite",
Message: "invalid SameSite value (must be Lax, Strict, or None)",
Value: c.Session.SameSite,
})
}
// StorageType must be valid
validStorage := map[string]bool{
"memory": true,
"redis": true,
"cookie": true,
}
if !validStorage[c.Session.StorageType] {
errors = append(errors, ValidationError{
Field: "Session.StorageType",
Message: "invalid storage type (must be memory, redis, or cookie)",
Value: c.Session.StorageType,
})
}
return errors
}
// validateToken validates token configuration
func (c *UnifiedConfig) validateToken() ValidationErrors {
var errors ValidationErrors
// Token TTLs must be positive
if c.Token.AccessTokenTTL <= 0 {
errors = append(errors, ValidationError{
Field: "Token.AccessTokenTTL",
Message: "access token TTL must be positive",
Value: c.Token.AccessTokenTTL,
})
}
if c.Token.RefreshTokenTTL <= 0 {
errors = append(errors, ValidationError{
Field: "Token.RefreshTokenTTL",
Message: "refresh token TTL must be positive",
Value: c.Token.RefreshTokenTTL,
})
}
// Validation mode must be valid
validModes := map[string]bool{
"jwt": true,
"introspect": true,
"hybrid": true,
}
if !validModes[c.Token.ValidationMode] {
errors = append(errors, ValidationError{
Field: "Token.ValidationMode",
Message: "invalid validation mode (must be jwt, introspect, or hybrid)",
Value: c.Token.ValidationMode,
})
}
// Introspect URL required for introspect or hybrid mode
if (c.Token.ValidationMode == "introspect" || c.Token.ValidationMode == "hybrid") && c.Token.IntrospectURL == "" {
errors = append(errors, ValidationError{
Field: "Token.IntrospectURL",
Message: "introspect URL is required for introspect or hybrid validation mode",
})
}
// Clock skew must be reasonable (0 to 10 minutes)
if c.Token.ClockSkew < 0 || c.Token.ClockSkew > 10*time.Minute {
errors = append(errors, ValidationError{
Field: "Token.ClockSkew",
Message: "clock skew must be between 0 and 10 minutes",
Value: c.Token.ClockSkew,
})
}
return errors
}
// validateSecurity validates security configuration
func (c *UnifiedConfig) validateSecurity() ValidationErrors {
var errors ValidationErrors
// Validate allowed user domains are valid domains
domainRegex := regexp.MustCompile(`^([a-zA-Z0-9-]+\.)*[a-zA-Z0-9-]+\.[a-zA-Z]{2,}$`)
for _, domain := range c.Security.AllowedUserDomains {
if !domainRegex.MatchString(domain) {
errors = append(errors, ValidationError{
Field: "Security.AllowedUserDomains",
Message: "invalid domain format",
Value: domain,
})
}
}
// Max login attempts must be reasonable
if c.Security.MaxLoginAttempts < 0 || c.Security.MaxLoginAttempts > 100 {
errors = append(errors, ValidationError{
Field: "Security.MaxLoginAttempts",
Message: "max login attempts must be between 0 and 100",
Value: c.Security.MaxLoginAttempts,
})
}
// Lockout duration must be reasonable
if c.Security.LockoutDuration < 0 || c.Security.LockoutDuration > 24*time.Hour {
errors = append(errors, ValidationError{
Field: "Security.LockoutDuration",
Message: "lockout duration must be between 0 and 24 hours",
Value: c.Security.LockoutDuration,
})
}
return errors
}
// validateMiddleware validates middleware configuration
func (c *UnifiedConfig) validateMiddleware() ValidationErrors {
var errors ValidationErrors
// Max request size must be reasonable (1KB to 100MB)
if c.Middleware.MaxRequestSize < 1024 || c.Middleware.MaxRequestSize > 100*1024*1024 {
errors = append(errors, ValidationError{
Field: "Middleware.MaxRequestSize",
Message: "max request size must be between 1KB and 100MB",
Value: c.Middleware.MaxRequestSize,
})
}
// Request timeout must be reasonable
if c.Middleware.RequestTimeout < time.Second || c.Middleware.RequestTimeout > 5*time.Minute {
errors = append(errors, ValidationError{
Field: "Middleware.RequestTimeout",
Message: "request timeout must be between 1 second and 5 minutes",
Value: c.Middleware.RequestTimeout,
})
}
return errors
}
// validateCache validates cache configuration
func (c *UnifiedConfig) validateCache() ValidationErrors {
var errors ValidationErrors
if !c.Cache.Enabled {
return errors
}
// Cache type must be valid
validTypes := map[string]bool{
"memory": true,
"redis": true,
"hybrid": true,
}
if !validTypes[c.Cache.Type] {
errors = append(errors, ValidationError{
Field: "Cache.Type",
Message: "invalid cache type (must be memory, redis, or hybrid)",
Value: c.Cache.Type,
})
}
// Max entries must be reasonable
if c.Cache.MaxEntries < 10 || c.Cache.MaxEntries > 1000000 {
errors = append(errors, ValidationError{
Field: "Cache.MaxEntries",
Message: "max entries must be between 10 and 1000000",
Value: c.Cache.MaxEntries,
})
}
// Eviction policy must be valid
validEviction := map[string]bool{
"lru": true,
"lfu": true,
"fifo": true,
}
if !validEviction[c.Cache.EvictionPolicy] {
errors = append(errors, ValidationError{
Field: "Cache.EvictionPolicy",
Message: "invalid eviction policy (must be lru, lfu, or fifo)",
Value: c.Cache.EvictionPolicy,
})
}
return errors
}
// validateRateLimit validates rate limiting configuration
func (c *UnifiedConfig) validateRateLimit() ValidationErrors {
var errors ValidationErrors
if !c.RateLimit.Enabled {
return errors
}
// Requests per second must be reasonable
if c.RateLimit.RequestsPerSecond < 1 || c.RateLimit.RequestsPerSecond > 10000 {
errors = append(errors, ValidationError{
Field: "RateLimit.RequestsPerSecond",
Message: "requests per second must be between 1 and 10000",
Value: c.RateLimit.RequestsPerSecond,
})
}
// Burst must be at least as large as requests per second
if c.RateLimit.Burst < c.RateLimit.RequestsPerSecond {
errors = append(errors, ValidationError{
Field: "RateLimit.Burst",
Message: "burst must be at least as large as requests per second",
Value: c.RateLimit.Burst,
})
}
// Key type must be valid
validKeyTypes := map[string]bool{
"ip": true,
"user": true,
"token": true,
"custom": true,
}
if !validKeyTypes[c.RateLimit.KeyType] {
errors = append(errors, ValidationError{
Field: "RateLimit.KeyType",
Message: "invalid key type (must be ip, user, token, or custom)",
Value: c.RateLimit.KeyType,
})
}
return errors
}
// validateLogging validates logging configuration
func (c *UnifiedConfig) validateLogging() ValidationErrors {
var errors ValidationErrors
// Log level must be valid
validLevels := map[string]bool{
"debug": true,
"info": true,
"warn": true,
"error": true,
}
if !validLevels[c.Logging.Level] {
errors = append(errors, ValidationError{
Field: "Logging.Level",
Message: "invalid log level (must be debug, info, warn, or error)",
Value: c.Logging.Level,
})
}
// Format must be valid
validFormats := map[string]bool{
"json": true,
"text": true,
"structured": true,
}
if !validFormats[c.Logging.Format] {
errors = append(errors, ValidationError{
Field: "Logging.Format",
Message: "invalid log format (must be json, text, or structured)",
Value: c.Logging.Format,
})
}
// Output must be valid
validOutputs := map[string]bool{
"stdout": true,
"stderr": true,
"file": true,
}
if !validOutputs[c.Logging.Output] {
errors = append(errors, ValidationError{
Field: "Logging.Output",
Message: "invalid log output (must be stdout, stderr, or file)",
Value: c.Logging.Output,
})
}
// File path required if output is file
if c.Logging.Output == "file" && c.Logging.FilePath == "" {
errors = append(errors, ValidationError{
Field: "Logging.FilePath",
Message: "file path is required when output is 'file'",
})
}
return errors
}
// validateMetrics validates metrics configuration
func (c *UnifiedConfig) validateMetrics() ValidationErrors {
var errors ValidationErrors
if !c.Metrics.Enabled {
return errors
}
// Provider must be valid
validProviders := map[string]bool{
"prometheus": true,
"statsd": true,
"otlp": true,
}
if !validProviders[c.Metrics.Provider] {
errors = append(errors, ValidationError{
Field: "Metrics.Provider",
Message: "invalid metrics provider (must be prometheus, statsd, or otlp)",
Value: c.Metrics.Provider,
})
}
// Endpoint required for some providers
if (c.Metrics.Provider == "statsd" || c.Metrics.Provider == "otlp") && c.Metrics.Endpoint == "" {
errors = append(errors, ValidationError{
Field: "Metrics.Endpoint",
Message: fmt.Sprintf("endpoint is required for %s provider", c.Metrics.Provider),
})
}
return errors
}
// validateTransport validates transport configuration
func (c *UnifiedConfig) validateTransport() ValidationErrors {
var errors ValidationErrors
// Max connections must be reasonable
if c.Transport.MaxIdleConns < 0 || c.Transport.MaxIdleConns > 10000 {
errors = append(errors, ValidationError{
Field: "Transport.MaxIdleConns",
Message: "max idle connections must be between 0 and 10000",
Value: c.Transport.MaxIdleConns,
})
}
// TLS min version must be valid
validTLSVersions := map[string]bool{
"TLS1.0": true,
"TLS1.1": true,
"TLS1.2": true,
"TLS1.3": true,
}
if c.Transport.TLSMinVersion != "" && !validTLSVersions[c.Transport.TLSMinVersion] {
errors = append(errors, ValidationError{
Field: "Transport.TLSMinVersion",
Message: "invalid TLS min version (must be TLS1.0, TLS1.1, TLS1.2, or TLS1.3)",
Value: c.Transport.TLSMinVersion,
})
}
// Proxy URL must be valid if provided
if c.Transport.ProxyURL != "" {
if _, err := url.Parse(c.Transport.ProxyURL); err != nil {
errors = append(errors, ValidationError{
Field: "Transport.ProxyURL",
Message: "invalid proxy URL",
Value: c.Transport.ProxyURL,
})
}
}
return errors
}
// validateCircuit validates circuit breaker configuration
func (c *UnifiedConfig) validateCircuit() ValidationErrors {
var errors ValidationErrors
if !c.Circuit.Enabled {
return errors
}
// Consecutive failures must be reasonable
if c.Circuit.ConsecutiveFailures < 1 || c.Circuit.ConsecutiveFailures > 100 {
errors = append(errors, ValidationError{
Field: "Circuit.ConsecutiveFailures",
Message: "consecutive failures must be between 1 and 100",
Value: c.Circuit.ConsecutiveFailures,
})
}
// Failure ratio must be between 0 and 1
if c.Circuit.FailureRatio < 0 || c.Circuit.FailureRatio > 1 {
errors = append(errors, ValidationError{
Field: "Circuit.FailureRatio",
Message: "failure ratio must be between 0 and 1",
Value: c.Circuit.FailureRatio,
})
}
// OnOpen action must be valid
validActions := map[string]bool{
"reject": true,
"fallback": true,
"passthrough": true,
}
if !validActions[c.Circuit.OnOpen] {
errors = append(errors, ValidationError{
Field: "Circuit.OnOpen",
Message: "invalid OnOpen action (must be reject, fallback, or passthrough)",
Value: c.Circuit.OnOpen,
})
}
return errors
}
-588
View File
@@ -1,588 +0,0 @@
//go:build !yaegi
package config
import (
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestValidateUnifiedConfig tests the validation of UnifiedConfig
func TestValidateUnifiedConfig(t *testing.T) {
tests := []struct {
name string
config *UnifiedConfig
expectError bool
errorField string
}{
{
name: "valid config with minimum requirements",
config: &UnifiedConfig{
Provider: ProviderConfig{
IssuerURL: "https://auth.example.com",
ClientID: "test-client",
ClientSecret: "secret",
Scopes: []string{"openid", "profile", "email"},
},
Session: SessionConfig{
Name: "oidc_session",
EncryptionKey: "this-is-a-32-character-key-12345",
ChunkSize: 4000,
MaxChunks: 5,
StorageType: "cookie",
},
Token: TokenConfig{
AccessTokenTTL: time.Hour,
RefreshTokenTTL: 24 * time.Hour,
ValidationMode: "jwt",
},
Middleware: MiddlewareConfig{
MaxRequestSize: 10 * 1024 * 1024,
RequestTimeout: 30 * time.Second,
},
Logging: LoggingConfig{
Level: "info",
Format: "json",
Output: "stdout",
},
},
expectError: false,
},
{
name: "missing provider URL",
config: &UnifiedConfig{
Provider: ProviderConfig{
ClientID: "test-client",
ClientSecret: "secret",
},
Session: SessionConfig{
EncryptionKey: "this-is-a-32-character-key-12345",
},
},
expectError: true,
errorField: "Provider.IssuerURL",
},
{
name: "missing client ID",
config: &UnifiedConfig{
Provider: ProviderConfig{
IssuerURL: "https://auth.example.com",
ClientSecret: "secret",
},
Session: SessionConfig{
EncryptionKey: "this-is-a-32-character-key-12345",
},
},
expectError: true,
errorField: "Provider.ClientID",
},
{
name: "encryption key too short",
config: &UnifiedConfig{
Provider: ProviderConfig{
IssuerURL: "https://auth.example.com",
ClientID: "test-client",
ClientSecret: "secret",
},
Session: SessionConfig{
EncryptionKey: "too-short",
},
},
expectError: true,
errorField: "Session.EncryptionKey",
},
{
name: "invalid chunk size",
config: &UnifiedConfig{
Provider: ProviderConfig{
IssuerURL: "https://auth.example.com",
ClientID: "test-client",
ClientSecret: "secret",
},
Session: SessionConfig{
EncryptionKey: "this-is-a-32-character-key-12345",
ChunkSize: 500, // Too small
},
},
expectError: true,
errorField: "Session.ChunkSize",
},
{
name: "invalid max chunks",
config: &UnifiedConfig{
Provider: ProviderConfig{
IssuerURL: "https://auth.example.com",
ClientID: "test-client",
ClientSecret: "secret",
},
Session: SessionConfig{
EncryptionKey: "this-is-a-32-character-key-12345",
ChunkSize: 4000,
MaxChunks: 0, // Too small
},
},
expectError: true,
errorField: "Session.MaxChunks",
},
{
name: "invalid TLS min version",
config: &UnifiedConfig{
Provider: ProviderConfig{
IssuerURL: "https://auth.example.com",
ClientID: "test-client",
ClientSecret: "secret",
},
Session: SessionConfig{
EncryptionKey: "this-is-a-32-character-key-12345",
},
Transport: TransportConfig{
TLSMinVersion: "1.0", // Too old
},
},
expectError: true,
errorField: "Transport.TLSMinVersion",
},
{
name: "invalid circuit breaker failure ratio",
config: &UnifiedConfig{
Provider: ProviderConfig{
IssuerURL: "https://auth.example.com",
ClientID: "test-client",
ClientSecret: "secret",
},
Session: SessionConfig{
EncryptionKey: "this-is-a-32-character-key-12345",
},
Circuit: CircuitConfig{
Enabled: true,
FailureRatio: 1.5, // Too high
},
},
expectError: true,
errorField: "Circuit.FailureRatio",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.config.Validate()
if tt.expectError {
if err == nil {
t.Errorf("Expected validation error for field %s, but got none", tt.errorField)
} else if validationErrs, ok := err.(ValidationErrors); ok {
found := false
for _, e := range validationErrs {
if e.Field == tt.errorField {
found = true
break
}
}
if !found {
t.Errorf("Expected validation error for field %s, but got errors for: %v",
tt.errorField, validationErrs)
}
}
} else {
if err != nil {
t.Errorf("Expected no validation error, but got: %v", err)
}
}
})
}
}
// TestValidationErrorMessage tests validation error formatting
func TestValidationErrorMessage(t *testing.T) {
errs := ValidationErrors{
{
Field: "Provider.IssuerURL",
Message: "is required",
Value: nil,
},
{
Field: "Session.EncryptionKey",
Message: "must be at least 32 characters",
Value: 16,
},
}
errMsg := errs.Error()
if !strings.Contains(errMsg, "Provider.IssuerURL") {
t.Error("Error message should contain field name Provider.IssuerURL")
}
if !strings.Contains(errMsg, "is required") {
t.Error("Error message should contain 'is required'")
}
if !strings.Contains(errMsg, "Session.EncryptionKey") {
t.Error("Error message should contain field name Session.EncryptionKey")
}
if !strings.Contains(errMsg, "must be at least 32 characters") {
t.Error("Error message should contain 'must be at least 32 characters'")
}
}
// TestValidateRedisConfig tests Redis configuration validation
func TestValidateRedisConfig(t *testing.T) {
tests := []struct {
name string
config *RedisConfig
expectError bool
errorMsg string
}{
{
name: "valid standalone config",
config: &RedisConfig{
Enabled: true,
Mode: RedisModeStandalone,
Addr: "localhost:6379",
},
expectError: false,
},
{
name: "missing address for standalone",
config: &RedisConfig{
Enabled: true,
Mode: RedisModeStandalone,
Addr: "",
},
expectError: true,
errorMsg: "Redis address is required",
},
{
name: "valid cluster config",
config: &RedisConfig{
Enabled: true,
Mode: RedisModeCluster,
ClusterAddrs: []string{"localhost:7000", "localhost:7001"},
},
expectError: false,
},
{
name: "missing cluster addresses",
config: &RedisConfig{
Enabled: true,
Mode: RedisModeCluster,
ClusterAddrs: []string{},
},
expectError: true,
errorMsg: "cluster address is required",
},
{
name: "valid sentinel config",
config: &RedisConfig{
Enabled: true,
Mode: RedisModeSentinel,
MasterName: "mymaster",
SentinelAddrs: []string{"localhost:26379"},
},
expectError: false,
},
{
name: "missing master name for sentinel",
config: &RedisConfig{
Enabled: true,
Mode: RedisModeSentinel,
MasterName: "",
SentinelAddrs: []string{"localhost:26379"},
},
expectError: true,
errorMsg: "Master name is required",
},
{
name: "missing sentinel addresses",
config: &RedisConfig{
Enabled: true,
Mode: RedisModeSentinel,
MasterName: "mymaster",
SentinelAddrs: []string{},
},
expectError: true,
errorMsg: "sentinel address is required",
},
{
name: "disabled redis needs no validation",
config: &RedisConfig{
Enabled: false,
},
expectError: false,
},
{
name: "invalid redis mode",
config: &RedisConfig{
Enabled: true,
Mode: "invalid-mode",
},
expectError: true,
errorMsg: "Invalid Redis mode",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.config.Validate()
if tt.expectError {
if err == nil {
t.Errorf("Expected validation error containing '%s', but got none", tt.errorMsg)
} else if !strings.Contains(err.Error(), tt.errorMsg) {
t.Errorf("Expected error message to contain '%s', but got: %v", tt.errorMsg, err)
}
} else {
if err != nil {
t.Errorf("Expected no validation error, but got: %v", err)
}
}
})
}
}
// ============================================================================
// validateRateLimit Tests
// ============================================================================
func TestValidateRateLimit_Disabled(t *testing.T) {
config := NewUnifiedConfig()
config.RateLimit.Enabled = false
errors := config.validateRateLimit()
assert.Empty(t, errors, "Should have no errors when rate limiting is disabled")
}
func TestValidateRateLimit_ValidConfig(t *testing.T) {
config := NewUnifiedConfig()
config.RateLimit.Enabled = true
config.RateLimit.RequestsPerSecond = 100
config.RateLimit.Burst = 200
config.RateLimit.KeyType = "ip"
errors := config.validateRateLimit()
assert.Empty(t, errors, "Should have no errors for valid rate limit config")
}
func TestValidateRateLimit_RequestsPerSecondTooLow(t *testing.T) {
config := NewUnifiedConfig()
config.RateLimit.Enabled = true
config.RateLimit.RequestsPerSecond = 0
config.RateLimit.Burst = 100
config.RateLimit.KeyType = "ip"
errors := config.validateRateLimit()
require.Len(t, errors, 1)
assert.Equal(t, "RateLimit.RequestsPerSecond", errors[0].Field)
assert.Contains(t, errors[0].Message, "between 1 and 10000")
}
func TestValidateRateLimit_RequestsPerSecondTooHigh(t *testing.T) {
config := NewUnifiedConfig()
config.RateLimit.Enabled = true
config.RateLimit.RequestsPerSecond = 15000
config.RateLimit.Burst = 20000
config.RateLimit.KeyType = "ip"
errors := config.validateRateLimit()
require.Len(t, errors, 1)
assert.Equal(t, "RateLimit.RequestsPerSecond", errors[0].Field)
assert.Contains(t, errors[0].Message, "between 1 and 10000")
}
func TestValidateRateLimit_BurstTooSmall(t *testing.T) {
config := NewUnifiedConfig()
config.RateLimit.Enabled = true
config.RateLimit.RequestsPerSecond = 100
config.RateLimit.Burst = 50 // Less than RequestsPerSecond
config.RateLimit.KeyType = "ip"
errors := config.validateRateLimit()
require.Len(t, errors, 1)
assert.Equal(t, "RateLimit.Burst", errors[0].Field)
assert.Contains(t, errors[0].Message, "at least as large as requests per second")
}
func TestValidateRateLimit_InvalidKeyType(t *testing.T) {
tests := []struct {
name string
keyType string
}{
{"empty key type", ""},
{"invalid key type", "invalid"},
{"random string", "foobar"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
config := NewUnifiedConfig()
config.RateLimit.Enabled = true
config.RateLimit.RequestsPerSecond = 100
config.RateLimit.Burst = 200
config.RateLimit.KeyType = tt.keyType
errors := config.validateRateLimit()
require.Len(t, errors, 1)
assert.Equal(t, "RateLimit.KeyType", errors[0].Field)
assert.Contains(t, errors[0].Message, "invalid key type")
})
}
}
func TestValidateRateLimit_ValidKeyTypes(t *testing.T) {
validKeyTypes := []string{"ip", "user", "token", "custom"}
for _, keyType := range validKeyTypes {
t.Run(keyType, func(t *testing.T) {
config := NewUnifiedConfig()
config.RateLimit.Enabled = true
config.RateLimit.RequestsPerSecond = 100
config.RateLimit.Burst = 200
config.RateLimit.KeyType = keyType
errors := config.validateRateLimit()
assert.Empty(t, errors, "Should have no errors for valid key type: %s", keyType)
})
}
}
func TestValidateRateLimit_MultipleErrors(t *testing.T) {
config := NewUnifiedConfig()
config.RateLimit.Enabled = true
config.RateLimit.RequestsPerSecond = 0 // Too low
config.RateLimit.Burst = 50 // Will pass (0 < 50)
config.RateLimit.KeyType = "invalid" // Invalid
errors := config.validateRateLimit()
// Should have 2 errors (rps and keyType)
assert.Len(t, errors, 2)
// Check each error is present
fields := make(map[string]bool)
for _, err := range errors {
fields[err.Field] = true
}
assert.True(t, fields["RateLimit.RequestsPerSecond"])
assert.True(t, fields["RateLimit.KeyType"])
}
// ============================================================================
// validateMetrics Tests
// ============================================================================
func TestValidateMetrics_Disabled(t *testing.T) {
config := NewUnifiedConfig()
config.Metrics.Enabled = false
errors := config.validateMetrics()
assert.Empty(t, errors, "Should have no errors when metrics are disabled")
}
func TestValidateMetrics_ValidPrometheus(t *testing.T) {
config := NewUnifiedConfig()
config.Metrics.Enabled = true
config.Metrics.Provider = "prometheus"
config.Metrics.Endpoint = "" // Prometheus doesn't require endpoint
errors := config.validateMetrics()
assert.Empty(t, errors, "Should have no errors for valid prometheus config")
}
func TestValidateMetrics_ValidStatsd(t *testing.T) {
config := NewUnifiedConfig()
config.Metrics.Enabled = true
config.Metrics.Provider = "statsd"
config.Metrics.Endpoint = "localhost:8125"
errors := config.validateMetrics()
assert.Empty(t, errors, "Should have no errors for valid statsd config")
}
func TestValidateMetrics_ValidOTLP(t *testing.T) {
config := NewUnifiedConfig()
config.Metrics.Enabled = true
config.Metrics.Provider = "otlp"
config.Metrics.Endpoint = "localhost:4317"
errors := config.validateMetrics()
assert.Empty(t, errors, "Should have no errors for valid otlp config")
}
func TestValidateMetrics_InvalidProvider(t *testing.T) {
tests := []struct {
name string
provider string
}{
{"empty provider", ""},
{"invalid provider", "invalid"},
{"datadog", "datadog"},
{"influx", "influx"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
config := NewUnifiedConfig()
config.Metrics.Enabled = true
config.Metrics.Provider = tt.provider
config.Metrics.Endpoint = "localhost:8080"
errors := config.validateMetrics()
require.Len(t, errors, 1)
assert.Equal(t, "Metrics.Provider", errors[0].Field)
assert.Contains(t, errors[0].Message, "invalid metrics provider")
})
}
}
func TestValidateMetrics_StatsdMissingEndpoint(t *testing.T) {
config := NewUnifiedConfig()
config.Metrics.Enabled = true
config.Metrics.Provider = "statsd"
config.Metrics.Endpoint = "" // Missing required endpoint
errors := config.validateMetrics()
require.Len(t, errors, 1)
assert.Equal(t, "Metrics.Endpoint", errors[0].Field)
assert.Contains(t, errors[0].Message, "endpoint is required for statsd provider")
}
func TestValidateMetrics_OTLPMissingEndpoint(t *testing.T) {
config := NewUnifiedConfig()
config.Metrics.Enabled = true
config.Metrics.Provider = "otlp"
config.Metrics.Endpoint = "" // Missing required endpoint
errors := config.validateMetrics()
require.Len(t, errors, 1)
assert.Equal(t, "Metrics.Endpoint", errors[0].Field)
assert.Contains(t, errors[0].Message, "endpoint is required for otlp provider")
}
func TestValidateMetrics_MultipleErrors(t *testing.T) {
config := NewUnifiedConfig()
config.Metrics.Enabled = true
config.Metrics.Provider = "invalid" // Invalid provider
config.Metrics.Endpoint = "" // Would be missing if provider was statsd/otlp
errors := config.validateMetrics()
// Should have at least 1 error for invalid provider
assert.NotEmpty(t, errors)
assert.Equal(t, "Metrics.Provider", errors[0].Field)
}
File diff suppressed because it is too large Load Diff
+456
View File
@@ -0,0 +1,456 @@
# Configuration Reference
Complete reference for all Traefik OIDC middleware configuration options.
## Table of Contents
- [Required Parameters](#required-parameters)
- [Optional Parameters](#optional-parameters)
- [Security Options](#security-options)
- [Session Management](#session-management)
- [Access Control](#access-control)
- [Headers Configuration](#headers-configuration)
- [Security Headers](#security-headers)
- [Scope Configuration](#scope-configuration)
- [Advanced Options](#advanced-options)
---
## Required Parameters
| Parameter | Type | Description | Example |
|-----------|------|-------------|---------|
| `providerURL` | string | Base URL of the OIDC provider | `https://accounts.google.com` |
| `clientID` | string | OAuth 2.0 client identifier | `1234567890.apps.googleusercontent.com` |
| `clientSecret` | string | OAuth 2.0 client secret | `your-client-secret` |
| `sessionEncryptionKey` | string | Key for encrypting session data (min 32 bytes) | `your-32-byte-encryption-key-here` |
| `callbackURL` | string | Path where provider redirects after authentication | `/oauth2/callback` |
### Basic Configuration Example
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-auth
spec:
plugin:
traefikoidc:
providerURL: https://accounts.google.com
clientID: your-client-id.apps.googleusercontent.com
clientSecret: your-client-secret
sessionEncryptionKey: your-32-byte-encryption-key-here
callbackURL: /oauth2/callback
```
---
## Optional Parameters
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `logoutURL` | string | `callbackURL + "/logout"` | Path for logout requests |
| `postLogoutRedirectURI` | string | `/` | Redirect URL after logout |
| `logLevel` | string | `info` | Logging verbosity (`debug`, `info`, `error`) |
| `forceHTTPS` | bool | `false` | Force HTTPS for redirect URIs |
| `rateLimit` | int | `100` | Maximum requests per second |
| `excludedURLs` | []string | none | Paths that bypass authentication |
| `revocationURL` | string | auto-discovered | Token revocation endpoint |
| `oidcEndSessionURL` | string | auto-discovered | Provider's end session endpoint |
| `enablePKCE` | bool | `false` | Enable PKCE for authorization code flow |
| `minimalHeaders` | bool | `false` | Reduce forwarded headers |
### TLS Termination at Load Balancer
If running Traefik behind a load balancer (AWS ALB, Google Cloud LB, Azure App Gateway) that terminates TLS:
```yaml
forceHTTPS: true # Required for correct redirect URIs
```
Without this setting, redirect URIs will use `http://` instead of `https://`, causing OAuth callback failures.
---
## Security Options
### Audience Validation
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `audience` | string | `clientID` | Expected audience for access token validation |
| `strictAudienceValidation` | bool | `false` | Reject sessions with audience mismatch |
| `allowOpaqueTokens` | bool | `false` | Enable opaque token support via RFC 7662 |
| `requireTokenIntrospection` | bool | `false` | Require introspection for opaque tokens |
#### Production Security Configuration
```yaml
audience: "https://my-api.example.com"
strictAudienceValidation: true
```
#### Opaque Token Support
```yaml
allowOpaqueTokens: true
requireTokenIntrospection: true
strictAudienceValidation: true
```
### Other Security Options
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `disableReplayDetection` | bool | `false` | Disable JTI-based replay attack detection |
| `allowPrivateIPAddresses` | bool | `false` | Allow private IPs in provider URLs |
---
## Session Management
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `sessionMaxAge` | int | `86400` (24h) | Maximum session age in seconds |
| `refreshGracePeriodSeconds` | int | `60` | Seconds before expiry to attempt refresh |
| `cookieDomain` | string | auto-detected | Domain for session cookies |
| `cookiePrefix` | string | `_oidc_raczylo_` | Prefix for cookie names |
### Multi-Subdomain Setup
```yaml
cookieDomain: .example.com # Share cookies across subdomains
```
### Multiple Middleware Instances
When running multiple middleware instances with different authorization requirements, use unique prefixes:
```yaml
# User authentication middleware
---
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-userauth
spec:
plugin:
traefikoidc:
cookiePrefix: "_oidc_userauth_"
sessionEncryptionKey: user-encryption-key-min-32-bytes
# ... other config
---
# Admin authentication middleware
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-adminauth
spec:
plugin:
traefikoidc:
cookiePrefix: "_oidc_adminauth_"
sessionEncryptionKey: admin-encryption-key-min-32-bytes
allowedUsers:
- admin@example.com
# ... other config
```
### Extended Session Duration
```yaml
sessionMaxAge: 604800 # 7 days
# Common values:
# 3600 - 1 hour (high security)
# 86400 - 1 day (default)
# 259200 - 3 days
# 604800 - 7 days
# 2592000 - 30 days
```
---
## Access Control
### User Restrictions
| Parameter | Type | Description |
|-----------|------|-------------|
| `allowedUserDomains` | []string | Restrict to specific email domains |
| `allowedUsers` | []string | Specific email addresses allowed |
| `allowedRolesAndGroups` | []string | Required roles or groups |
| `roleClaimName` | string | JWT claim for roles (default: `roles`) |
| `groupClaimName` | string | JWT claim for groups (default: `groups`) |
| `userIdentifierClaim` | string | Claim for user ID (default: `email`) |
### Domain Restriction
```yaml
allowedUserDomains:
- company.com
- subsidiary.com
```
### Specific User Access
```yaml
allowedUsers:
- user@example.com
- contractor@external.org
```
### Role-Based Access Control
```yaml
allowedRolesAndGroups:
- admin
- developer
roleClaimName: "https://myapp.com/roles" # For namespaced claims (Auth0)
```
### Access Control Logic
- If only `allowedUsers` is set: Only specified emails can access
- If only `allowedUserDomains` is set: Only specified domains can access
- If both are set: Access granted if email is in `allowedUsers` OR domain is in `allowedUserDomains`
- If neither is set: Any authenticated user can access
### Users Without Email (Azure AD)
For Azure AD service accounts or users without email:
```yaml
userIdentifierClaim: sub # Options: sub, oid, upn, preferred_username
allowedUsers:
- "abc12345-6789-0abc-def0-123456789abc" # User object ID
```
---
## Headers Configuration
### Default Headers
The middleware sets these headers for downstream services:
| Header | Description |
|--------|-------------|
| `X-Forwarded-User` | User's email address |
| `X-User-Groups` | Comma-separated user groups |
| `X-User-Roles` | Comma-separated user roles |
| `X-Auth-Request-Redirect` | Original request URI |
| `X-Auth-Request-User` | User's email address |
| `X-Auth-Request-Token` | User's ID token |
### Minimal Headers Mode
For "431 Request Header Fields Too Large" errors:
```yaml
minimalHeaders: true # Only forwards X-Forwarded-User
```
### Custom Templated Headers
```yaml
headers:
- name: "X-User-Email"
value: "{{{{.Claims.email}}}}"
- name: "X-User-ID"
value: "{{{{.Claims.sub}}}}"
- name: "Authorization"
value: "Bearer {{{{.AccessToken}}}}"
- name: "X-User-Roles"
value: "{{{{range $i, $e := .Claims.roles}}}}{{{{if $i}}}},{{{{end}}}}{{{{$e}}}}{{{{end}}}}"
```
**Template Variables:**
- `{{.Claims.field}}` - ID token claims
- `{{.AccessToken}}` - Raw access token
- `{{.IdToken}}` - Raw ID token
- `{{.RefreshToken}}` - Raw refresh token
**Important:** Use double curly braces (`{{{{` and `}}}}`) to escape templates in YAML.
---
## Security Headers
### Security Profiles
| Profile | Use Case | Security Level |
|---------|----------|----------------|
| `default` | Standard web apps | High |
| `strict` | Maximum security | Very High |
| `development` | Local development | Medium |
| `api` | API endpoints | High |
| `custom` | Custom requirements | Configurable |
### Basic Configuration
```yaml
securityHeaders:
enabled: true
profile: "default"
```
### API with CORS
```yaml
securityHeaders:
enabled: true
profile: "api"
corsEnabled: true
corsAllowedOrigins:
- "https://your-frontend.com"
- "https://*.example.com"
corsAllowCredentials: true
```
### Custom Security Configuration
```yaml
securityHeaders:
enabled: true
profile: "custom"
# Content Security Policy
contentSecurityPolicy: "default-src 'self'; script-src 'self'"
# HSTS
strictTransportSecurity: true
strictTransportSecurityMaxAge: 31536000
strictTransportSecuritySubdomains: true
strictTransportSecurityPreload: true
# Frame and Content Protection
frameOptions: "DENY"
contentTypeOptions: "nosniff"
xssProtection: "1; mode=block"
referrerPolicy: "strict-origin-when-cross-origin"
# CORS
corsEnabled: true
corsAllowedOrigins: ["https://app.example.com"]
corsAllowedMethods: ["GET", "POST", "PUT", "DELETE", "OPTIONS"]
corsAllowedHeaders: ["Authorization", "Content-Type"]
corsAllowCredentials: true
corsMaxAge: 86400
# Custom Headers
customHeaders:
X-Custom-Header: "value"
# Server Identification
disableServerHeader: true
disablePoweredByHeader: true
```
### CORS Origin Patterns
```yaml
corsAllowedOrigins:
- "https://example.com" # Exact match
- "https://*.example.com" # Subdomain wildcard
- "http://localhost:*" # Port wildcard (development)
```
---
## Scope Configuration
### Default Behavior (Append Mode)
```yaml
scopes:
- roles
- custom_scope
# Result: ["openid", "profile", "email", "roles", "custom_scope"]
```
### Override Mode
```yaml
overrideScopes: true
scopes:
- openid
- profile
- custom_scope
# Result: ["openid", "profile", "custom_scope"]
```
---
## Advanced Options
### Dynamic Client Registration (RFC 7591)
```yaml
dynamicClientRegistration:
enabled: true
initialAccessToken: "your-token" # Optional
persistCredentials: true
credentialsFile: "/tmp/oidc-credentials.json"
clientMetadata:
redirect_uris:
- "https://your-app.com/oauth2/callback"
client_name: "My Application"
application_type: "web"
grant_types:
- "authorization_code"
- "refresh_token"
```
### Multi-Replica Deployment
Without Redis, disable replay detection:
```yaml
disableReplayDetection: true
```
With Redis (recommended):
```yaml
redis:
enabled: true
address: "redis:6379"
cacheMode: "hybrid"
```
See [REDIS.md](REDIS.md) for complete Redis configuration.
---
## Kubernetes Secrets
Reference secrets instead of hardcoding sensitive values:
```yaml
providerURL: urn:k8s:secret:oidc-secret:ISSUER
clientID: urn:k8s:secret:oidc-secret:CLIENT_ID
clientSecret: urn:k8s:secret:oidc-secret:SECRET
```
Create the secret:
```bash
kubectl create secret generic oidc-secret \
--from-literal=ISSUER=https://accounts.google.com \
--from-literal=CLIENT_ID=your-client-id \
--from-literal=SECRET=your-client-secret \
-n traefik
```
---
## Environment Variable Naming
**Important:** Avoid using "API" as a substring in environment variable names when using `${VAR}` syntax in Traefik configuration. Traefik reserves `TRAEFIK_API_*` variables and the substring may cause conflicts.
```yaml
# Bad - may cause issues
sessionEncryptionKey: ${OIDC_SECRET_API}
# Good
sessionEncryptionKey: ${OIDC_SECRET_SVC}
```
+455
View File
@@ -0,0 +1,455 @@
# Development Guide
Guide for local development, testing, and contributing to the Traefik OIDC middleware.
## Table of Contents
- [Prerequisites](#prerequisites)
- [Local Development Setup](#local-development-setup)
- [Running Tests](#running-tests)
- [Test Categories](#test-categories)
- [CI/CD Pipeline](#cicd-pipeline)
- [Code Quality](#code-quality)
- [Contributing](#contributing)
---
## Prerequisites
- **Go 1.23+** for plugin compilation
- **Docker & Docker Compose** for local testing
- **OIDC Provider** credentials (Google, Azure, etc.)
### Required Development Tools
```bash
# golangci-lint (comprehensive linting)
go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
# staticcheck (static analysis)
go install honnef.co/go/tools/cmd/staticcheck@latest
# gosec (security scanning)
go install github.com/securego/gosec/v2/cmd/gosec@latest
# govulncheck (vulnerability scanning)
go install golang.org/x/vuln/cmd/govulncheck@latest
```
---
## Local Development Setup
### Docker Compose Environment
The repository includes a Docker Compose setup for testing the plugin locally.
#### 1. Host Configuration
Add to `/etc/hosts`:
```bash
127.0.0.1 hello.localhost
127.0.0.1 traefik.localhost
```
#### 2. Plugin Configuration
The plugin is loaded using Traefik's **local plugins mode**:
- Plugin source: Parent directory (`../`)
- Mount path: `/plugins-local/src/github.com/lukaszraczylo/traefikoidc`
- Configuration: `experimental.localPlugins` in `traefik.yml`
#### 3. OIDC Provider Setup
Edit `docker/dynamic.yml` with your provider details:
**Google:**
```yaml
http:
middlewares:
oidc-auth:
plugin:
traefikoidc:
providerURL: "https://accounts.google.com"
clientID: "your-client-id.apps.googleusercontent.com"
clientSecret: "your-google-client-secret"
sessionEncryptionKey: "your-32-character-encryption-key"
callbackURL: "/oauth2/callback"
logoutURL: "/oauth2/logout"
scopes:
- "openid"
- "email"
- "profile"
```
**Azure AD:**
```yaml
http:
middlewares:
oidc-auth:
plugin:
traefikoidc:
providerURL: "https://login.microsoftonline.com/your-tenant-id/v2.0"
clientID: "your-azure-client-id"
clientSecret: "your-azure-client-secret"
sessionEncryptionKey: "your-32-character-encryption-key"
callbackURL: "/oauth2/callback"
scopes:
- "openid"
- "email"
- "profile"
```
#### 4. Start Environment
```bash
cd docker
docker-compose up -d
```
#### 5. Test Plugin
- **Protected App**: http://hello.localhost (redirects to OIDC)
- **Traefik Dashboard**: http://traefik.localhost:8080
### Development Workflow
1. **Edit plugin code** in the project root
2. **Build and test** (optional syntax check):
```bash
go mod tidy
go build .
go test ./...
```
3. **Restart Traefik** to reload plugin:
```bash
docker-compose restart traefik
```
4. **Test changes** at http://hello.localhost
### Debugging
**View plugin logs:**
```bash
docker-compose logs -f traefik | grep traefikoidc
```
**Check plugin loading:**
```bash
docker-compose logs traefik | grep -i plugin
```
**Verify plugin directory:**
```bash
docker-compose exec traefik ls -la /plugins-local/src/github.com/lukaszraczylo/traefikoidc/
```
---
## Running Tests
### Quick Start
```bash
# Fast development testing (< 30 seconds)
go test ./... -short
# Standard tests with race detector
go test -race -timeout=15m ./...
# With coverage report
go test -coverprofile=coverage.out ./...
go tool cover -func=coverage.out
```
### Test Modes
| Mode | Command | Duration | Use Case |
|------|---------|----------|----------|
| Quick | `go test ./... -short` | < 30s | During development |
| Extended | `RUN_EXTENDED_TESTS=1 go test ./...` | 2-5 min | Before commits |
| Long | `RUN_LONG_TESTS=1 go test ./...` | 5-15 min | Release validation |
| Stress | `RUN_STRESS_TESTS=1 go test ./...` | 10-30 min | Performance testing |
### Environment Variables
```bash
# Enable specific test types
export RUN_EXTENDED_TESTS=1
export RUN_LONG_TESTS=1
export RUN_STRESS_TESTS=1
# Disable specific features
export DISABLE_LEAK_DETECTION=1
# Customize test parameters
export TEST_MAX_CONCURRENCY=10
export TEST_MAX_ITERATIONS=50
export TEST_MEMORY_THRESHOLD_MB=25.5
```
---
## Test Categories
### Quick Tests (Default)
- Basic functionality verification
- Limited iterations (1-3)
- Small data sets
- Essential memory leak checks
**Configuration:**
- Max Iterations: 3
- Max Concurrency: 5
- Memory Threshold: 2.0 MB
- Timeout: 10 seconds
### Extended Tests
- Comprehensive testing before commits
- More iterations (5-10)
- Enhanced memory leak detection
**Configuration:**
- Max Iterations: 10
- Max Concurrency: 20
- Memory Threshold: 10.0 MB
- Timeout: 30 seconds
### Long Tests
- Performance validation
- High iteration counts (50-100)
- Large data sets
**Configuration:**
- Max Iterations: 100
- Max Concurrency: 50
- Memory Threshold: 50.0 MB
- Timeout: 60 seconds
### Stress Tests
- Maximum load testing
- Edge case validation
- Extreme parameters
**Configuration:**
- Max Iterations: 500
- Max Concurrency: 100
- Memory Threshold: 100.0 MB
- Timeout: 120 seconds
### Running Specific Test Suites
```bash
# Memory leak tests
go test -v -run='.*Leak.*' ./...
# Integration tests
go test -v -run='.*Integration.*' ./...
# Regression tests
go test -v -run='.*Regression.*' ./...
# Provider-specific tests
go test -v -run='.*Azure.*' ./...
go test -v -run='.*Google.*' ./...
```
### Benchmarks
```bash
# Quick benchmarks
go test -bench=. -short
# Extended benchmarks
RUN_EXTENDED_TESTS=1 go test -bench=.
# Memory profiling
go test -bench=. -memprofile=mem.prof
go tool pprof mem.prof
```
---
## CI/CD Pipeline
The repository uses GitHub Actions for comprehensive validation with 20+ parallel checks.
### Triggered On
- Pull requests to `main` branch
- Pushes to `main` branch
### Parallel Jobs
#### Code Quality (3 checks)
- **Format & Basic Checks** - gofmt, go vet, go mod
- **golangci-lint** - 30+ linters
- **Staticcheck** - Advanced static analysis
#### Security (3 checks)
- **Gosec** - Security vulnerability scanning
- **Govulncheck** - Go vulnerability database
- **CodeQL** - GitHub's semantic code analysis
#### Testing (9 suites)
- Race Detector
- Coverage (75% threshold)
- Memory Leaks
- Integration Tests
- Regression Tests
- Security Edge Cases
- Session Tests
- Token Tests
- CSRF Tests
#### Provider Testing (9 providers)
Tests run in parallel for:
- Google
- Azure AD
- Auth0
- Okta
- Keycloak
- AWS Cognito
- GitLab
- GitHub
- Generic OIDC
#### Performance & Build (3 checks)
- Benchmarks
- Multi-platform Build (linux/darwin x amd64/arm64)
- Go Version Compatibility (Go 1.23 & 1.24)
### Quality Gates
All PRs must pass:
- All parallel checks
- 75% test coverage minimum
- Zero security vulnerabilities
- No race conditions
- No memory leaks
- All providers tested
- Builds on all platforms
---
## Code Quality
### Pre-Commit Checklist
```bash
# Run before every commit
gofmt -s -w . && \
go mod tidy && \
golangci-lint run && \
go test -race -short ./... && \
echo "Ready to commit!"
```
### Local Validation
```bash
# Format code
gofmt -s -w .
# Run linter
golangci-lint run
# Static analysis
staticcheck ./...
# Security scan
gosec ./...
# Vulnerability check
govulncheck ./...
# Tests with race detector
go test -race -timeout=15m -count=1 ./...
# Coverage report
go test -coverprofile=coverage.out ./...
go tool cover -func=coverage.out
# View coverage in browser
go tool cover -html=coverage.out
```
### Troubleshooting
**Coverage Below Threshold:**
```bash
go test -coverprofile=coverage.out ./...
go tool cover -html=coverage.out # See uncovered lines
```
**Race Condition Found:**
```bash
go test -race -v -run=TestName ./...
```
**Linter Errors:**
```bash
golangci-lint run -v
golangci-lint run --fix # Auto-fix some issues
```
**Provider Test Fails:**
```bash
go test -v -run='.*Azure.*' ./...
```
---
## Contributing
### Development Guidelines
1. **Memory Management**: Ensure all goroutines can be cancelled and resources are bounded
2. **Testing**: Add tests for new features, including memory leak tests where appropriate
3. **Race Conditions**: Run tests with `-race` flag to detect race conditions
4. **Documentation**: Update README and configuration files for new options
### Pull Request Template
PRs should include:
- Description of changes
- Type of change (bug fix, feature, breaking change, etc.)
- Related issues
- Provider impact (which providers are affected)
- Testing performed
- Security considerations
- Performance impact
- Breaking changes (if any)
### Checklist
Before submitting:
- [ ] Code follows project style
- [ ] Self-review completed
- [ ] Tests added for new functionality
- [ ] All tests pass locally
- [ ] Documentation updated
- [ ] No new warnings generated
### Code Owners
The repository uses CODEOWNERS for automatic PR reviewer assignment based on file paths.
### Dependabot
Automated dependency updates run weekly (Mondays 9 AM) with security updates prioritized.
---
## Additional Resources
- [golangci-lint Rules](.golangci.yml)
- [PR Template](.github/PULL_REQUEST_TEMPLATE.md)
- [Workflow Documentation](.github/workflows/README.md)
- [GitHub Actions Documentation](https://docs.github.com/en/actions)
+580
View File
@@ -0,0 +1,580 @@
# OIDC Provider Configuration Guide
Configuration reference for each supported OIDC provider.
## Table of Contents
- [Provider Support Matrix](#provider-support-matrix)
- [Google](#google)
- [Microsoft Azure AD](#microsoft-azure-ad)
- [Auth0](#auth0)
- [Okta](#okta)
- [Keycloak](#keycloak)
- [AWS Cognito](#aws-cognito)
- [GitLab](#gitlab)
- [GitHub](#github)
- [Generic OIDC](#generic-oidc)
- [Automatic Scope Filtering](#automatic-scope-filtering)
---
## Provider Support Matrix
| Provider | OIDC Support | Refresh Tokens | Auto-Detection | ID Tokens |
|----------|-------------|----------------|----------------|-----------|
| Google | Full | Yes | `accounts.google.com` | Yes |
| Azure AD | Full | Yes | `login.microsoftonline.com` | Yes |
| Auth0 | Full | Yes | `*.auth0.com` | Yes |
| Okta | Full | Yes | `*.okta.com` | Yes |
| Keycloak | Full | Yes | `/auth/realms/` path | Yes |
| AWS Cognito | Full | Yes | `cognito-idp.*.amazonaws.com` | Yes |
| GitLab | Full | Yes | `gitlab.com` | Yes |
| GitHub | OAuth 2.0 Only | No | `github.com` | No |
| Generic | Full | Yes | Any OIDC endpoint | Yes |
---
## Google
### Provider URL
```yaml
providerURL: "https://accounts.google.com"
```
### Configuration
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-google
spec:
plugin:
traefikoidc:
providerURL: "https://accounts.google.com"
clientID: "your-id.apps.googleusercontent.com"
clientSecret: "your-client-secret"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "your-32-char-encryption-key-here"
scopes:
- openid
- email
- profile
allowedUserDomains:
- "your-gsuite-domain.com" # Optional: Workspace restriction
forceHttps: true
enablePkce: true
```
### Google-Specific Features
- **Automatic offline access**: Middleware adds `access_type=offline` and `prompt=consent`
- **Scope filtering**: Automatically removes unsupported `offline_access` scope
- **Workspace domains**: Restrict to specific Google Workspace domains via `hd` claim
### Google Cloud Console Setup
1. Go to [Google Cloud Console](https://console.cloud.google.com/)
2. Create or select a project
3. Navigate to APIs & Services > Credentials
4. Create OAuth 2.0 Client ID (Web application)
5. Add authorized redirect URI: `https://your-domain.com/oauth2/callback`
6. Configure OAuth consent screen (must be "Published" for production)
---
## Microsoft Azure AD
### Provider URL
```yaml
# Single tenant
providerURL: "https://login.microsoftonline.com/{tenant-id}/v2.0"
# Multi-tenant
providerURL: "https://login.microsoftonline.com/common/v2.0"
```
### Basic Configuration
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-azure
spec:
plugin:
traefikoidc:
providerURL: "https://login.microsoftonline.com/common/v2.0"
clientID: "your-azure-client-id"
clientSecret: "your-azure-client-secret"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "your-32-char-encryption-key-here"
scopes:
- openid
- profile
- email
- offline_access
allowedRolesAndGroups:
- "App.Users"
- "Admin.Group"
forceHttps: true
```
### With Application ID URI (API Access)
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-azure-api
spec:
plugin:
traefikoidc:
providerURL: "https://login.microsoftonline.com/common/v2.0"
clientID: "your-azure-client-id"
clientSecret: "your-azure-client-secret"
audience: "api://your-azure-client-id" # Application ID URI
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "your-32-char-encryption-key-here"
forceHttps: true
```
### Users Without Email
```yaml
userIdentifierClaim: sub # Options: sub, oid, upn, preferred_username
allowedUsers:
- "user-object-id-1"
- "user-object-id-2"
```
### Azure AD Setup
1. Go to [Azure Portal](https://portal.azure.com/)
2. Navigate to Azure Active Directory > App registrations
3. Create new registration
4. Add redirect URI: `https://your-domain.com/oauth2/callback`
5. Create client secret in Certificates & secrets
6. Configure Token Configuration for group claims
---
## Auth0
### Provider URL
```yaml
providerURL: "https://your-domain.auth0.com"
```
### Basic Configuration
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-auth0
spec:
plugin:
traefikoidc:
providerURL: "https://your-domain.auth0.com"
clientID: "your-auth0-client-id"
clientSecret: "your-auth0-client-secret"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "your-32-char-encryption-key-here"
scopes:
- openid
- profile
- email
- offline_access
postLogoutRedirectUri: "https://your-app.com"
forceHttps: true
enablePkce: true
```
### With Custom API Audience
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-auth0-api
spec:
plugin:
traefikoidc:
providerURL: "https://your-domain.auth0.com"
clientID: "your-auth0-client-id"
clientSecret: "your-auth0-client-secret"
audience: "https://api.your-domain.com" # API identifier
strictAudienceValidation: true
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "your-32-char-encryption-key-here"
roleClaimName: "https://your-app.com/roles" # Namespaced claim
groupClaimName: "https://your-app.com/groups"
allowedRolesAndGroups:
- admin
- editor
```
### Auth0 Action for Custom Claims
```javascript
exports.onExecutePostLogin = async (event, api) => {
const namespace = 'https://your-app.com/';
if (event.authorization) {
api.idToken.setCustomClaim(namespace + 'roles', event.authorization.roles);
api.idToken.setCustomClaim('email', event.user.email);
}
};
```
### Auth0 Setup
1. Go to [Auth0 Dashboard](https://manage.auth0.com/)
2. Create Regular Web Application
3. Configure Allowed Callback URLs: `https://your-domain.com/oauth2/callback`
4. Configure Allowed Logout URLs: `https://your-domain.com/oauth2/logout`
5. Enable OIDC Conformant in Advanced Settings
6. Create API in APIs section for custom audiences
See [AUTH0_AUDIENCE_GUIDE.md](AUTH0_AUDIENCE_GUIDE.md) for detailed audience configuration.
---
## Okta
### Provider URL
```yaml
providerURL: "https://your-domain.okta.com"
# Or with custom authorization server:
providerURL: "https://your-domain.okta.com/oauth2/default"
```
### Configuration
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-okta
spec:
plugin:
traefikoidc:
providerURL: "https://your-domain.okta.com"
clientID: "your-okta-client-id"
clientSecret: "your-okta-client-secret"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "your-32-char-encryption-key-here"
scopes:
- openid
- profile
- email
- groups
- offline_access
allowedRolesAndGroups:
- admin
- "Everyone"
forceHttps: true
enablePkce: true
```
### Okta Setup
1. Access Okta Admin Console
2. Go to Applications > Create App Integration
3. Select OIDC - OpenID Connect > Web Application
4. Configure Sign-in redirect URIs: `https://your-domain.com/oauth2/callback`
5. Configure Sign-out redirect URIs: `https://your-domain.com/oauth2/logout`
6. Enable Authorization Code and Refresh Token grant types
7. Configure Groups claim in authorization server
---
## Keycloak
### Provider URL
```yaml
providerURL: "https://keycloak.your-domain.com/realms/{realm-name}"
```
### Configuration
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-keycloak
spec:
plugin:
traefikoidc:
providerURL: "https://keycloak.company.com/realms/your-realm"
clientID: "your-keycloak-client-id"
clientSecret: "your-keycloak-client-secret"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "your-32-char-encryption-key-here"
scopes:
- openid
- profile
- email
- roles
- groups
- offline_access
allowedRolesAndGroups:
- admin
- editor
forceHttps: true
enablePkce: true
```
### Internal Network Deployment
For private IP addresses (Docker networks, Kubernetes):
```yaml
providerURL: "https://192.168.1.100:8443/realms/your-realm"
allowPrivateIPAddresses: true # Required for private IPs
```
### Keycloak Client Setup
1. Access Keycloak Admin Console
2. Select your realm
3. Go to Clients > Create client
4. Set Client Protocol: openid-connect
5. Set Access Type: confidential
6. Add Valid Redirect URIs: `https://your-domain.com/oauth2/callback`
7. Generate client secret in Credentials tab
8. Configure mappers to add claims to ID Token:
- Email: User Property mapper with "Add to ID token" enabled
- Roles: User Client Role mapper with "Add to ID token" enabled
- Groups: Group Membership mapper with "Add to ID token" enabled
---
## AWS Cognito
### Provider URL
```yaml
providerURL: "https://cognito-idp.{region}.amazonaws.com/{user-pool-id}"
```
### Configuration
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-cognito
spec:
plugin:
traefikoidc:
providerURL: "https://cognito-idp.us-east-1.amazonaws.com/us-east-1_ABCDEF123"
clientID: "your-cognito-client-id"
clientSecret: "your-cognito-client-secret"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "your-32-char-encryption-key-here"
scopes:
- openid
- profile
- email
- aws.cognito.signin.user.admin
allowedRolesAndGroups:
- admin
- users
forceHttps: true
```
### AWS Cognito Setup
1. Create Cognito User Pool
2. Create App Client with OIDC scopes
3. Configure App Client settings:
- Callback URLs: `https://your-domain.com/oauth2/callback`
- Sign out URLs: `https://your-domain.com/oauth2/logout`
- OAuth flows: Authorization code grant
4. Configure hosted UI domain (optional)
5. Set up groups for role-based access
---
## GitLab
### Provider URL
```yaml
# GitLab.com
providerURL: "https://gitlab.com"
# Self-hosted
providerURL: "https://gitlab.your-company.com"
```
### Configuration
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-gitlab
spec:
plugin:
traefikoidc:
providerURL: "https://gitlab.com"
clientID: "your-gitlab-application-id"
clientSecret: "your-gitlab-application-secret"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "your-32-char-encryption-key-here"
scopes:
- openid
- profile
- email
# Note: GitLab doesn't require offline_access scope
# Refresh tokens are issued automatically with openid
allowedRolesAndGroups:
- developers
- maintainers
forceHttps: true
enablePkce: true
```
### GitLab Setup
1. Go to GitLab Settings > Applications
2. Create new application
3. Add scopes: `openid`, `profile`, `email`
4. Set redirect URI: `https://your-domain.com/oauth2/callback`
5. Save and note Application ID and Secret
---
## GitHub
### Provider URL
```yaml
providerURL: "https://github.com"
```
### Configuration
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oauth-github
spec:
plugin:
traefikoidc:
providerURL: "https://github.com/login/oauth"
clientID: "your-github-client-id"
clientSecret: "your-github-client-secret"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "your-32-char-encryption-key-here"
scopes:
- user:email
- read:user
allowedUsers:
- "github-username"
forceHttps: true
```
### Limitations
- **OAuth 2.0 only** - Not OpenID Connect
- **No ID tokens** - Only access tokens for API calls
- **No refresh tokens** - Users must re-authenticate on expiry
- **No standard claims** - User info requires API calls
Use GitHub only for API access, not for user authentication with claims.
### GitHub Setup
1. Go to GitHub Settings > Developer settings > OAuth Apps
2. Create new OAuth App
3. Set Authorization callback URL: `https://your-domain.com/oauth2/callback`
4. Note Client ID and generate Client Secret
---
## Generic OIDC
For any OIDC-compliant provider not listed above.
### Configuration
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-generic
spec:
plugin:
traefikoidc:
providerURL: "https://oidc.your-provider.com"
clientID: "your-client-id"
clientSecret: "your-client-secret"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "your-32-char-encryption-key-here"
scopes:
- openid
- profile
- email
forceHttps: true
enablePkce: true
```
### Requirements
- Provider must expose `.well-known/openid-configuration` endpoint
- Must support authorization code flow
- ID tokens must contain required claims (email, sub, etc.)
---
## Automatic Scope Filtering
The middleware automatically filters OAuth scopes based on the provider's declared capabilities.
### How It Works
1. Fetches provider's `.well-known/openid-configuration`
2. Extracts `scopes_supported` field
3. Filters requested scopes to only include supported ones
4. Falls back to all requested scopes if provider doesn't declare supported scopes
### Example: Self-Hosted GitLab
Self-hosted GitLab may reject `offline_access` scope:
```yaml
scopes:
- openid
- profile
- email
- offline_access # Will be automatically filtered out if unsupported
```
The middleware will:
1. Read GitLab's discovery document
2. Detect `offline_access` is NOT in `scopes_supported`
3. Filter it out automatically
4. Authentication succeeds
### Logging
```
INFO: ScopeFilter: Filtered unsupported scopes: [offline_access]
DEBUG: ScopeFilter: Final filtered scopes: [openid profile email]
```
### Troubleshooting
If a provider rejects scopes even after filtering:
1. Check the provider's discovery document: `curl https://provider/.well-known/openid-configuration`
2. Use `overrideScopes: true` with only supported scopes
3. Review middleware debug logs for filtering decisions
-955
View File
@@ -1,955 +0,0 @@
# Provider-Specific Configuration Guide
This guide covers the configuration requirements and best practices for each supported OIDC provider.
## Table of Contents
- [Google](#google)
- [Microsoft Azure AD](#microsoft-azure-ad)
- [Auth0](#auth0)
- [GitHub](#github)
- [GitLab](#gitlab)
- [AWS Cognito](#aws-cognito)
- [Keycloak](#keycloak)
- [Okta](#okta)
- [Generic OIDC](#generic-oidc)
---
## Google
### Provider URL
```yaml
providerUrl: "https://accounts.google.com"
```
### Required Configuration
```yaml
clientId: "your-google-client-id.apps.googleusercontent.com"
clientSecret: "your-google-client-secret"
callbackUrl: "https://your-domain.com/auth/callback"
scopes: ["openid", "profile", "email"]
```
### Google-Specific Features
- **Automatic offline access**: Google provider automatically adds `access_type=offline` and `prompt=consent`
- **Scope filtering**: Automatically removes `offline_access` scope (not used by Google)
- **Refresh token support**: Fully supported
- **Domain restrictions**: Can restrict by Google Workspace domains
### Example Configuration
```yaml
# Traefik dynamic configuration
http:
middlewares:
google-oidc:
plugin:
traefik-oidc:
providerUrl: "https://accounts.google.com"
clientId: "123456789-abcdef.apps.googleusercontent.com"
clientSecret: "GOCSPX-your-client-secret"
callbackUrl: "https://app.example.com/auth/callback"
logoutUrl: "https://app.example.com/auth/logout"
scopes: ["openid", "profile", "email"]
allowedUserDomains: ["example.com", "company.org"]
forceHttps: true
enablePkce: true
```
### Google OAuth Console Setup
1. Go to [Google Cloud Console](https://console.cloud.google.com/)
2. Create or select a project
3. Enable Google+ API
4. Create OAuth 2.0 credentials
5. Add authorized redirect URIs: `https://your-domain.com/auth/callback`
---
## Microsoft Azure AD
### Provider URL
```yaml
# For Azure AD (single tenant)
providerUrl: "https://login.microsoftonline.com/{tenant-id}/v2.0"
# For Azure AD (multi-tenant)
providerUrl: "https://login.microsoftonline.com/common/v2.0"
```
### Required Configuration
```yaml
clientId: "your-azure-application-id"
clientSecret: "your-azure-client-secret"
callbackUrl: "https://your-domain.com/auth/callback"
scopes: ["openid", "profile", "email", "offline_access"]
```
### Azure-Specific Features
- **Response mode**: Automatically adds `response_mode=query`
- **Offline access**: Requires `offline_access` scope for refresh tokens
- **Access token validation**: Supports both JWT and opaque access tokens
- **Tenant isolation**: Can restrict to specific Azure AD tenants
- **Application ID URI**: Supports custom audience for protected APIs
### Example Configuration (Basic)
```yaml
http:
middlewares:
azure-oidc:
plugin:
traefik-oidc:
providerUrl: "https://login.microsoftonline.com/common/v2.0"
clientId: "12345678-1234-1234-1234-123456789abc"
clientSecret: "your-azure-client-secret"
callbackUrl: "https://app.example.com/auth/callback"
logoutUrl: "https://app.example.com/auth/logout"
postLogoutRedirectUri: "https://app.example.com"
scopes: ["openid", "profile", "email", "offline_access"]
allowedRolesAndGroups: ["App.Users", "Admin.Group"]
forceHttps: true
```
### Azure AD API Configuration (Application ID URI)
When exposing your application as an API with a custom Application ID URI, you need to specify the `audience` parameter. Azure AD includes the Application ID URI in the JWT `aud` claim.
```yaml
http:
middlewares:
azure-api-oidc:
plugin:
traefik-oidc:
providerUrl: "https://login.microsoftonline.com/common/v2.0"
clientId: "12345678-1234-1234-1234-123456789abc"
clientSecret: "your-azure-client-secret"
# Specify the Application ID URI as audience
audience: "api://12345678-1234-1234-1234-123456789abc"
callbackUrl: "https://app.example.com/auth/callback"
logoutUrl: "https://app.example.com/auth/logout"
scopes: ["openid", "profile", "email", "offline_access"]
forceHttps: true
```
**Important**:
- The `audience` parameter should match your Application ID URI (typically `api://{app-id}`)
- Find your Application ID URI in Azure Portal → App Registration → Expose an API → Application ID URI
- Without the `audience` parameter, access tokens with custom audiences will be rejected
- For ID token validation only (no API access), you can omit the `audience` parameter
### Azure App Registration Setup
1. Go to [Azure Portal](https://portal.azure.com/)
2. Navigate to "Azure Active Directory" > "App registrations"
3. Create new registration
4. Add redirect URI: `https://your-domain.com/auth/callback`
5. Create client secret in "Certificates & secrets"
6. Configure API permissions for required scopes
### Azure AD API Exposure Setup (for custom audiences)
1. In your App Registration, go to "Expose an API"
2. Set the Application ID URI (e.g., `api://12345678-1234-1234-1234-123456789abc`)
3. Add any custom scopes your API exposes
4. Update the middleware configuration to include the `audience` parameter with this URI
---
## Auth0
### Provider URL
```yaml
providerUrl: "https://your-domain.auth0.com"
```
### Required Configuration
```yaml
clientId: "your-auth0-client-id"
clientSecret: "your-auth0-client-secret"
callbackUrl: "https://your-domain.com/auth/callback"
scopes: ["openid", "profile", "email", "offline_access"]
```
### Auth0-Specific Features
- **Custom domains**: Supports Auth0 custom domains
- **Rules and hooks**: Leverages Auth0's extensibility
- **Social connections**: Works with Auth0's social identity providers
- **Offline access**: Requires `offline_access` scope
- **API audiences**: Supports custom audience for API access tokens
### Example Configuration (Basic)
```yaml
http:
middlewares:
auth0-oidc:
plugin:
traefik-oidc:
providerUrl: "https://company.auth0.com"
clientId: "abcdef123456789"
clientSecret: "your-auth0-client-secret"
callbackUrl: "https://app.example.com/auth/callback"
logoutUrl: "https://app.example.com/auth/logout"
postLogoutRedirectUri: "https://app.example.com"
scopes: ["openid", "profile", "email", "offline_access"]
allowedUsers: ["user@example.com", "admin@company.com"]
forceHttps: true
enablePkce: true
```
### Auth0 API Configuration (Custom Audience)
When using Auth0 APIs with custom audience parameters, you need to specify the `audience` field. Auth0 includes the API identifier in the JWT `aud` claim instead of the `clientId`.
```yaml
http:
middlewares:
auth0-api-oidc:
plugin:
traefik-oidc:
providerUrl: "https://company.auth0.com"
clientId: "abcdef123456789"
clientSecret: "your-auth0-client-secret"
# Specify the Auth0 API identifier as audience
audience: "https://api.company.com"
callbackUrl: "https://app.example.com/auth/callback"
logoutUrl: "https://app.example.com/auth/logout"
scopes: ["openid", "profile", "email", "offline_access"]
forceHttps: true
enablePkce: true
```
**Important**:
- The `audience` parameter should match your Auth0 API identifier (not the client ID)
- Find your API identifier in Auth0 Dashboard → APIs → Your API → Settings → Identifier
- Without the `audience` parameter, access tokens with custom audiences will be rejected with "invalid audience" error
- For ID token validation only (no APIs), you can omit the `audience` parameter
### Auth0 Application Setup
1. Go to [Auth0 Dashboard](https://manage.auth0.com/)
2. Create new application (Regular Web Application)
3. Configure allowed callback URLs: `https://your-domain.com/auth/callback`
4. Configure allowed logout URLs: `https://your-domain.com/auth/logout`
5. Enable OIDC Conformant in Advanced Settings
### Auth0 API Setup (for custom audiences)
1. Go to Auth0 Dashboard → APIs
2. Create a new API or select existing API
3. Note the "Identifier" field (e.g., `https://api.company.com`) - this is your `audience` value
4. In API Settings → Machine to Machine Applications, authorize your application
5. Configure API permissions/scopes as needed
6. Use the API identifier as the `audience` parameter in your configuration
---
## GitHub
### Provider URL
```yaml
providerUrl: "https://github.com"
```
### Required Configuration
```yaml
clientId: "your-github-client-id"
clientSecret: "your-github-client-secret"
callbackUrl: "https://your-domain.com/auth/callback"
scopes: ["read:user", "user:email"]
```
### GitHub-Specific Features
- **Organization membership**: Can restrict by GitHub organization
- **Team membership**: Can restrict by specific teams
- **Limited OIDC**: GitHub has limited OIDC support
- **Email verification**: Requires verified email addresses
### Example Configuration
```yaml
http:
middlewares:
github-oidc:
plugin:
traefik-oidc:
providerUrl: "https://github.com"
clientId: "Iv1.abcdef123456"
clientSecret: "your-github-client-secret"
callbackUrl: "https://app.example.com/auth/callback"
logoutUrl: "https://app.example.com/auth/logout"
scopes: ["read:user", "user:email"]
allowedUsers: ["octocat", "github-user"]
forceHttps: true
```
### GitHub OAuth App Setup
1. Go to GitHub Settings > Developer settings > OAuth Apps
2. Create new OAuth App
3. Set Authorization callback URL: `https://your-domain.com/auth/callback`
4. Note the Client ID and generate Client Secret
---
## GitLab
### Provider URL
```yaml
# GitLab.com
providerUrl: "https://gitlab.com"
# Self-hosted GitLab
providerUrl: "https://gitlab.your-company.com"
```
### Required Configuration
```yaml
clientId: "your-gitlab-application-id"
clientSecret: "your-gitlab-application-secret"
callbackUrl: "https://your-domain.com/auth/callback"
scopes: ["openid", "profile", "email"]
```
### GitLab-Specific Features
- **Self-hosted support**: Works with self-hosted GitLab instances
- **Group membership**: Can restrict by GitLab groups
- **Project access**: Can validate project permissions
- **Offline access**: Supports refresh tokens without requiring `offline_access` scope
### Example Configuration
```yaml
http:
middlewares:
gitlab-oidc:
plugin:
traefik-oidc:
providerUrl: "https://gitlab.com"
clientId: "abcdef123456789"
clientSecret: "your-gitlab-application-secret"
callbackUrl: "https://app.example.com/auth/callback"
logoutUrl: "https://app.example.com/auth/logout"
scopes: ["openid", "profile", "email"]
# Note: GitLab doesn't support the offline_access scope.
# Refresh tokens are issued automatically for the openid scope.
allowedRolesAndGroups: ["developers", "maintainers"]
forceHttps: true
enablePkce: true
```
### GitLab Application Setup
1. Go to GitLab Settings > Applications
2. Create new application
3. Add scopes: `openid`, `profile`, `email`
4. Set redirect URI: `https://your-domain.com/auth/callback`
5. Save and note the Application ID and Secret
---
## AWS Cognito
### Provider URL
```yaml
providerUrl: "https://cognito-idp.{region}.amazonaws.com/{user-pool-id}"
```
### Required Configuration
```yaml
clientId: "your-cognito-app-client-id"
clientSecret: "your-cognito-app-client-secret" # If app client has secret
callbackUrl: "https://your-domain.com/auth/callback"
scopes: ["openid", "profile", "email"]
```
### Cognito-Specific Features
- **User pools**: Integrates with Cognito User Pools
- **Custom attributes**: Supports custom user attributes
- **Groups**: Can validate Cognito user group membership
- **Regional endpoints**: Requires region-specific URLs
### Example Configuration
```yaml
http:
middlewares:
cognito-oidc:
plugin:
traefik-oidc:
providerUrl: "https://cognito-idp.us-east-1.amazonaws.com/us-east-1_ABCDEF123"
clientId: "1234567890abcdefghij"
clientSecret: "your-cognito-client-secret"
callbackUrl: "https://app.example.com/auth/callback"
logoutUrl: "https://app.example.com/auth/logout"
scopes: ["openid", "profile", "email"]
allowedRolesAndGroups: ["admin", "users"]
forceHttps: true
```
### AWS Cognito Setup
1. Create Cognito User Pool
2. Create App Client with OIDC scopes
3. Configure App Client settings:
- Callback URLs: `https://your-domain.com/auth/callback`
- Sign out URLs: `https://your-domain.com/auth/logout`
- OAuth flows: Authorization code grant
4. Configure hosted UI domain (optional)
---
## Keycloak
### Provider URL
```yaml
providerUrl: "https://keycloak.your-company.com/realms/{realm-name}"
```
### Required Configuration
```yaml
clientId: "your-keycloak-client-id"
clientSecret: "your-keycloak-client-secret"
callbackUrl: "https://your-domain.com/auth/callback"
scopes: ["openid", "profile", "email"]
```
### Keycloak-Specific Features
- **Realm support**: Multi-realm deployments
- **Custom mappers**: Rich claim mapping capabilities
- **Role-based access**: Fine-grained role management
- **Offline access**: Full refresh token support
### Example Configuration
```yaml
http:
middlewares:
keycloak-oidc:
plugin:
traefik-oidc:
providerUrl: "https://keycloak.company.com/realms/employees"
clientId: "traefik-app"
clientSecret: "your-keycloak-client-secret"
callbackUrl: "https://app.example.com/auth/callback"
logoutUrl: "https://app.example.com/auth/logout"
postLogoutRedirectUri: "https://app.example.com"
scopes: ["openid", "profile", "email", "offline_access"]
allowedRolesAndGroups: ["app-users", "administrators"]
forceHttps: true
enablePkce: true
```
### Keycloak Client Setup
1. Access Keycloak Admin Console
2. Select appropriate realm
3. Create new client:
- Client Protocol: openid-connect
- Access Type: confidential
- Valid Redirect URIs: `https://your-domain.com/auth/callback`
4. Configure client scopes and mappers
5. Generate client secret in Credentials tab
---
## Okta
### Provider URL
```yaml
providerUrl: "https://your-domain.okta.com"
```
### Required Configuration
```yaml
clientId: "your-okta-client-id"
clientSecret: "your-okta-client-secret"
callbackUrl: "https://your-domain.com/auth/callback"
scopes: ["openid", "profile", "email", "offline_access"]
```
### Okta-Specific Features
- **Custom authorization servers**: Supports custom auth servers
- **Group claims**: Rich group membership information
- **Universal Directory**: Integrates with Okta's user store
- **Offline access**: Requires `offline_access` scope
### Example Configuration
```yaml
http:
middlewares:
okta-oidc:
plugin:
traefik-oidc:
providerUrl: "https://company.okta.com"
clientId: "0oa123456789abcdef"
clientSecret: "your-okta-client-secret"
callbackUrl: "https://app.example.com/auth/callback"
logoutUrl: "https://app.example.com/auth/logout"
postLogoutRedirectUri: "https://app.example.com"
scopes: ["openid", "profile", "email", "offline_access"]
allowedRolesAndGroups: ["Everyone", "Administrators"]
forceHttps: true
enablePkce: true
```
### Okta Application Setup
1. Access Okta Admin Console
2. Go to Applications > Create App Integration
3. Select OIDC - OpenID Connect
4. Choose Web Application
5. Configure:
- Sign-in redirect URIs: `https://your-domain.com/auth/callback`
- Sign-out redirect URIs: `https://your-domain.com/auth/logout`
- Grant types: Authorization Code, Refresh Token
6. Assign users or groups
---
## Generic OIDC
### Provider URL
```yaml
providerUrl: "https://your-oidc-provider.com"
```
### Required Configuration
```yaml
clientId: "your-client-id"
clientSecret: "your-client-secret"
callbackUrl: "https://your-domain.com/auth/callback"
scopes: ["openid", "profile", "email"]
```
### Generic Features
- **Standards compliance**: Works with any OIDC-compliant provider
- **Auto-discovery**: Uses `.well-known/openid-configuration` endpoint
- **Flexible scopes**: Supports custom scope requirements
- **Custom claims**: Works with provider-specific claims
### Example Configuration
```yaml
http:
middlewares:
generic-oidc:
plugin:
traefik-oidc:
providerUrl: "https://oidc.your-provider.com"
clientId: "your-client-id"
clientSecret: "your-client-secret"
callbackUrl: "https://app.example.com/auth/callback"
logoutUrl: "https://app.example.com/auth/logout"
scopes: ["openid", "profile", "email"]
forceHttps: true
enablePkce: true
```
---
## Automatic Scope Filtering
### Overview
The middleware automatically filters OAuth scopes based on the provider's capabilities declared in their OIDC discovery document (`.well-known/openid-configuration`). This prevents authentication failures when providers reject unsupported scopes.
### How It Works
1. **Discovery Document Parsing**: The middleware fetches the provider's discovery document and extracts the `scopes_supported` field
2. **Intelligent Filtering**: Requested scopes are filtered to only include those the provider supports
3. **Fallback Behavior**: If the provider doesn't declare `scopes_supported`, all requested scopes are used (backward compatible)
4. **Provider-Specific Handling**: Special logic for Google and Azure is preserved and applied after filtering
### Example Scenarios
#### Self-Hosted GitLab
**Problem**: Self-hosted GitLab instances reject the `offline_access` scope with error:
```
The requested scope is invalid, unknown, or malformed.
```
**Solution**: The middleware automatically detects this by:
1. Reading GitLab's discovery document at `https://gitlab.example.com/.well-known/openid-configuration`
2. Observing that `offline_access` is NOT in the `scopes_supported` list
3. Filtering out `offline_access` from the request
4. Authentication succeeds
**Configuration**:
```yaml
http:
middlewares:
gitlab-oidc:
plugin:
traefik-oidc:
providerUrl: "https://gitlab.example.com"
clientId: "your-gitlab-application-id"
clientSecret: "your-gitlab-application-secret"
callbackUrl: "https://app.example.com/auth/callback"
scopes: ["openid", "profile", "email", "offline_access"]
# Even though offline_access is listed, it will be automatically
# filtered out if GitLab doesn't support it
```
#### Auth0 or Keycloak
These providers typically support `offline_access` and it will be included:
```yaml
# Auth0 scopes_supported: ["openid", "profile", "email", "offline_access", ...]
# Result: All requested scopes are sent
```
### Benefits
1. **Self-Hosted Support**: Works seamlessly with self-hosted provider instances
2. **No Manual Configuration**: No need to know which scopes each provider supports
3. **Error Prevention**: Eliminates "invalid scope" authentication failures
4. **Standards Compliant**: Uses official OIDC discovery specification (RFC 8414)
5. **Backward Compatible**: Existing configurations continue to work
### Logging
The middleware provides detailed logging for scope filtering:
```
INFO: ScopeFilter: Filtered unsupported scopes for https://gitlab.example.com: [offline_access]
DEBUG: ScopeFilter: Provider https://gitlab.example.com supported scopes: [openid profile email read_user read_api]
DEBUG: ScopeFilter: Final filtered scopes: [openid profile email]
```
### Troubleshooting
**Issue**: Provider rejects scope even after filtering
**Possible Causes**:
1. Provider's discovery document is outdated
2. Provider doesn't properly implement `scopes_supported`
3. Custom authorization server with non-standard behavior
**Solutions**:
1. Use `overrideScopes: true` and explicitly list only supported scopes
2. Check the provider's discovery document manually: `curl https://your-provider/.well-known/openid-configuration`
3. Review middleware debug logs for filtering decisions
---
## Common Configuration Options
### Audience Configuration
The `audience` parameter specifies the expected JWT audience claim value. This is particularly important when using Auth0 APIs, Azure AD Application ID URIs, or other providers with custom audience requirements.
```yaml
# Optional: Custom audience for JWT validation
# If not set, defaults to clientID for backward compatibility
audience: "https://api.example.com" # Auth0 API identifier
# OR
audience: "api://12345-guid" # Azure AD Application ID URI
```
**When to use**:
- **Auth0**: When using Auth0 APIs with custom audience parameters
- **Azure AD**: When exposing your app as an API with Application ID URI
- **Keycloak**: When using audience-restricted tokens
- **Okta**: When using custom authorization servers with API audiences
**When to omit**:
- For standard ID token validation (default behavior)
- When the provider sets `aud` claim to your `clientID`
- For backward compatibility with existing configurations
**Security Note**: The `audience` parameter prevents token confusion attacks by ensuring tokens issued for one service cannot be used at another service.
### Security Settings
```yaml
# Force HTTPS (recommended for production)
forceHttps: true
# Enable PKCE (recommended for security)
enablePkce: true
# Session encryption key (32+ characters)
sessionEncryptionKey: "your-very-long-encryption-key-here"
```
### Access Control
```yaml
# Restrict by email addresses
allowedUsers: ["user1@example.com", "user2@example.com"]
# Restrict by email domains
allowedUserDomains: ["company.com", "partner.org"]
# Restrict by roles/groups (provider-specific)
allowedRolesAndGroups: ["admin", "users", "developers"]
```
### URLs and Endpoints
```yaml
# OAuth callback URL (must match provider config)
callbackUrl: "https://your-domain.com/auth/callback"
# Logout endpoint
logoutUrl: "https://your-domain.com/auth/logout"
# Post-logout redirect (optional)
postLogoutRedirectUri: "https://your-domain.com"
# URLs to exclude from authentication
excludedUrls: ["/health", "/metrics", "/public"]
```
### Advanced Settings
```yaml
# Override default scopes
overrideScopes: true
scopes: ["openid", "custom_scope"]
# Rate limiting (requests per second)
rateLimit: 10
# Token refresh grace period (seconds)
refreshGracePeriodSeconds: 60
# Cookie domain (for subdomain sharing)
cookieDomain: ".example.com"
# Custom headers to inject
headers:
- name: "X-User-Email"
value: "{{.Claims.email}}"
- name: "X-User-Name"
value: "{{.Claims.name}}"
```
---
## Troubleshooting
### Common Issues
1. **Invalid redirect URI**
- Ensure callback URL exactly matches provider configuration
- Check for HTTP vs HTTPS mismatches
2. **Scope errors**
- Verify required scopes are configured in provider
- Some providers require specific scopes for refresh tokens
3. **Token validation failures**
- Check provider URL format and accessibility
- Verify `.well-known/openid-configuration` endpoint is reachable
4. **Session issues**
- Ensure session encryption key is properly configured
- Check cookie domain settings for subdomain scenarios
### Debug Mode
Enable debug logging to troubleshoot configuration issues:
```yaml
logLevel: "debug"
```
This will provide detailed logs of the authentication flow and help identify configuration problems.
---
## Security Headers Configuration
The plugin includes comprehensive security headers support to protect your applications against common web vulnerabilities.
### Default Security Headers
By default, the plugin applies these security headers:
- `X-Frame-Options: DENY` - Prevents clickjacking
- `X-Content-Type-Options: nosniff` - Prevents MIME sniffing
- `X-XSS-Protection: 1; mode=block` - Enables XSS protection
- `Referrer-Policy: strict-origin-when-cross-origin` - Controls referrer information
- `Strict-Transport-Security` - Forces HTTPS (when HTTPS is detected)
### Security Profiles
Choose from predefined security profiles or create custom configurations:
#### Default Profile (Recommended)
```yaml
securityHeaders:
enabled: true
profile: "default"
```
#### Strict Profile (Maximum Security)
```yaml
securityHeaders:
enabled: true
profile: "strict"
# Additional strict CSP and cross-origin policies
```
#### Development Profile (Local Development)
```yaml
securityHeaders:
enabled: true
profile: "development"
# Relaxed policies for local development
```
#### API Profile (API Endpoints)
```yaml
securityHeaders:
enabled: true
profile: "api"
corsEnabled: true
corsAllowedOrigins: ["https://your-frontend.com"]
```
### Custom Security Configuration
For complete control, use the custom profile:
```yaml
securityHeaders:
enabled: true
profile: "custom"
# Content Security Policy
contentSecurityPolicy: "default-src 'self'; script-src 'self' 'unsafe-inline'"
# HSTS Configuration
strictTransportSecurity: true
strictTransportSecurityMaxAge: 31536000 # 1 year
strictTransportSecuritySubdomains: true
strictTransportSecurityPreload: true
# Frame and content protection
frameOptions: "DENY" # or "SAMEORIGIN", "ALLOW-FROM uri"
contentTypeOptions: "nosniff"
xssProtection: "1; mode=block"
referrerPolicy: "strict-origin-when-cross-origin"
# Permissions policy (feature policy)
permissionsPolicy: "geolocation=(), microphone=(), camera=()"
# Cross-origin policies
crossOriginEmbedderPolicy: "require-corp"
crossOriginOpenerPolicy: "same-origin"
crossOriginResourcePolicy: "same-origin"
# CORS configuration
corsEnabled: true
corsAllowedOrigins:
- "https://app.example.com"
- "https://*.api.example.com"
corsAllowedMethods: ["GET", "POST", "PUT", "DELETE", "OPTIONS"]
corsAllowedHeaders: ["Authorization", "Content-Type", "X-Requested-With"]
corsAllowCredentials: true
corsMaxAge: 86400 # 24 hours
# Custom headers
customHeaders:
X-Custom-Header: "custom-value"
X-API-Version: "v1"
# Server identification
disableServerHeader: true
disablePoweredByHeader: true
```
### Complete Example with Security Headers
Here's a complete configuration example for Google OIDC with custom security headers:
```yaml
# Traefik dynamic configuration
http:
middlewares:
secure-google-oidc:
plugin:
traefik-oidc:
# OIDC Configuration
providerUrl: "https://accounts.google.com"
clientId: "123456789-abcdef.apps.googleusercontent.com"
clientSecret: "GOCSPX-your-client-secret"
callbackUrl: "https://your-domain.com/auth/callback"
sessionEncryptionKey: "your-32-character-encryption-key-here"
# Domain restrictions
allowedUserDomains: ["your-company.com"]
# Security Headers
securityHeaders:
enabled: true
profile: "strict"
corsEnabled: true
corsAllowedOrigins:
- "https://your-frontend.com"
- "https://*.your-domain.com"
corsAllowCredentials: true
customHeaders:
X-Company: "YourCompany"
X-Environment: "production"
routers:
secure-app:
rule: "Host(`your-domain.com`)"
middlewares:
- secure-google-oidc
service: your-app-service
tls:
certResolver: letsencrypt
```
### CORS Configuration Details
For applications with frontend-backend separation, configure CORS properly:
#### Simple CORS (Single Origin)
```yaml
securityHeaders:
corsEnabled: true
corsAllowedOrigins: ["https://app.example.com"]
corsAllowCredentials: true
```
#### Wildcard Subdomains
```yaml
securityHeaders:
corsEnabled: true
corsAllowedOrigins: ["https://*.example.com"]
corsAllowCredentials: true
```
#### Development with Multiple Ports
```yaml
securityHeaders:
profile: "development"
corsEnabled: true
corsAllowedOrigins:
- "http://localhost:*"
- "http://127.0.0.1:*"
```
### Security Best Practices
1. **Always use HTTPS in production**
- Set `forceHttps: true`
- Configure proper TLS certificates
2. **Implement proper CSP**
- Start with strict policy
- Add exceptions only when necessary
- Test thoroughly
3. **Configure CORS restrictively**
- Only allow necessary origins
- Use specific domains instead of wildcards when possible
4. **Enable HSTS**
- Use long max-age values (1 year minimum)
- Include subdomains when appropriate
5. **Monitor security headers**
- Use browser developer tools to verify headers
- Test with security scanning tools
- Regularly review and update policies
### Testing Security Headers
Use browser developer tools or online tools to verify your security headers:
1. **Browser DevTools**: Check Network tab → Response Headers
2. **Online scanners**: Use securityheaders.com or observatory.mozilla.org
3. **Command line**: Use `curl -I https://your-domain.com`
Example verification:
```bash
curl -I https://your-domain.com
# Should show security headers in response
```
+546
View File
@@ -0,0 +1,546 @@
# Redis Cache for Distributed Deployments
Redis cache support for multi-replica Traefik deployments with shared state.
## Table of Contents
- [Overview](#overview)
- [Why Use Redis Cache?](#why-use-redis-cache)
- [Configuration](#configuration)
- [Cache Modes](#cache-modes)
- [Deployment Examples](#deployment-examples)
- [Performance Tuning](#performance-tuning)
- [Monitoring](#monitoring)
- [Troubleshooting](#troubleshooting)
- [Migration Guide](#migration-guide)
---
## Overview
The Redis cache feature provides distributed caching for the Traefik OIDC plugin, enabling seamless operation across multiple Traefik instances.
### Key Features
- **Distributed JTI Replay Detection**: Prevents token replay attacks across all instances
- **Shared Session Management**: Consistent user sessions across replicas
- **Circuit Breaker**: Automatic fallback to memory cache during Redis outages
- **Health Checking**: Continuous monitoring of Redis connectivity
- **Flexible Cache Modes**: Memory, Redis, or hybrid caching strategies
- **Pure-Go Implementation**: Yaegi-compatible, works with dynamic plugin loading
### Architecture
```
┌──────────────┐ ┌──────────────┐ ┌──────────────┐
│ Traefik #1 │ │ Traefik #2 │ │ Traefik #3
│ (Plugin) │ │ (Plugin) │ │ (Plugin) │
└──────┬───────┘ └──────┬───────┘ └──────┬───────┘
│ │ │
└────────────────────┼────────────────────┘
┌──────▼──────┐
│ Redis │
│ (Shared │
│ Cache) │
└─────────────┘
```
---
## Why Use Redis Cache?
### The Problem
When running multiple Traefik instances without shared cache:
1. **False Positive Replay Detection**
- User authenticates → Token stored in Instance A's JTI cache
- Next request → Load balancer routes to Instance B
- Instance B doesn't have the JTI → Falsely detects replay attack
2. **Session Inconsistency**
- User session created on Instance A
- Subsequent request routed to Instance B
- Instance B has no knowledge of the session
3. **Token Metadata Fragmentation**
- Token refresh happens on Instance A
- Other instances continue using old tokens
### The Solution
Redis provides centralized cache that all instances share, ensuring:
- **Consistent Authentication**: All instances share authentication state
- **True Replay Detection**: JTI cache shared across all instances
- **Seamless Scaling**: Add/remove instances without affecting sessions
- **High Availability**: Circuit breaker with automatic fallback
---
## Configuration
### Basic Configuration
```yaml
redis:
enabled: true
address: "redis:6379"
password: "your-password" # Optional
db: 0
keyPrefix: "traefikoidc:"
cacheMode: "hybrid"
```
### All Configuration Options
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `enabled` | bool | `false` | Enable Redis caching |
| `address` | string | - | Redis server address (`host:port`) |
| `password` | string | - | Redis password (optional) |
| `db` | int | `0` | Redis database number (0-15) |
| `keyPrefix` | string | `traefikoidc:` | Prefix for all Redis keys |
| `cacheMode` | string | `redis` | Cache mode: `memory`, `redis`, `hybrid` |
| `poolSize` | int | `10` | Connection pool size |
| `connectTimeout` | int | `5` | Connection timeout (seconds) |
| `readTimeout` | int | `3` | Read timeout (seconds) |
| `writeTimeout` | int | `3` | Write timeout (seconds) |
| `enableTLS` | bool | `false` | Enable TLS for connections |
| `tlsSkipVerify` | bool | `false` | Skip TLS certificate verification |
| `enableCircuitBreaker` | bool | `true` | Enable circuit breaker |
| `circuitBreakerThreshold` | int | `5` | Failures before circuit opens |
| `circuitBreakerTimeout` | int | `60` | Circuit reset timeout (seconds) |
| `enableHealthCheck` | bool | `true` | Enable periodic health checks |
| `healthCheckInterval` | int | `30` | Health check interval (seconds) |
| `hybridL1Size` | int | `500` | Max items in L1 cache (hybrid mode) |
| `hybridL1MemoryMB` | int64 | `10` | Max memory for L1 cache in MB |
### Environment Variables (Fallback)
If not configured through Traefik, these environment variables are used:
```bash
REDIS_ENABLED=true
REDIS_ADDRESS=redis:6379
REDIS_PASSWORD=your-password
REDIS_DB=0
REDIS_KEY_PREFIX=traefikoidc:
REDIS_CACHE_MODE=hybrid
REDIS_POOL_SIZE=10
REDIS_CONNECT_TIMEOUT=5
REDIS_READ_TIMEOUT=3
REDIS_WRITE_TIMEOUT=3
REDIS_ENABLE_TLS=false
REDIS_TLS_SKIP_VERIFY=false
```
---
## Cache Modes
### Memory Mode (Default without Redis)
```yaml
redis:
cacheMode: "memory"
```
- Uses only in-memory cache
- Suitable for single-instance deployments
- No Redis dependency
- Fastest performance
### Redis Mode
```yaml
redis:
enabled: true
address: "redis:6379"
cacheMode: "redis"
```
- All operations go directly to Redis
- Ensures consistency across replicas
- Slightly higher latency
### Hybrid Mode (Recommended)
```yaml
redis:
enabled: true
address: "redis:6379"
cacheMode: "hybrid"
```
Two-tier caching strategy:
```
┌─────────────────────────────────────────┐
│ Client Request │
└────────────────┬────────────────────────┘
┌────────────────┐
│ Local Cache │ ← L1 Cache (Fast)
│ (Memory) │
└────────┬───────┘
│ Miss
┌────────────────┐
│ Remote Cache │ ← L2 Cache (Shared)
│ (Redis) │
└────────────────┘
```
**Read Path:**
1. Check local memory cache (L1)
2. On miss, check Redis (L2)
3. On hit in Redis, populate L1
4. Return value
**Write Path:**
1. Write to Redis (L2) for durability
2. Write to local cache (L1) for speed
### Performance Comparison
| Operation | Memory Mode | Redis Mode | Hybrid Mode |
|-----------|------------|------------|-------------|
| Read (p50) | 0.1ms | 2ms | 0.2ms |
| Read (p99) | 0.5ms | 10ms | 5ms |
| Write (p50) | 0.2ms | 3ms | 3ms |
| Throughput | 100k/s | 20k/s | 80k/s |
---
## Deployment Examples
### Docker Compose
```yaml
version: '3.8'
services:
redis:
image: redis:7-alpine
command: redis-server --requirepass ${REDIS_PASSWORD}
volumes:
- redis-data:/data
healthcheck:
test: ["CMD", "redis-cli", "--raw", "incr", "ping"]
interval: 30s
timeout: 3s
retries: 3
traefik:
image: traefik:v3.2
deploy:
replicas: 3
labels:
- "traefik.http.middlewares.oidc.plugin.traefikoidc.redis.enabled=true"
- "traefik.http.middlewares.oidc.plugin.traefikoidc.redis.address=redis:6379"
- "traefik.http.middlewares.oidc.plugin.traefikoidc.redis.password=${REDIS_PASSWORD}"
- "traefik.http.middlewares.oidc.plugin.traefikoidc.redis.cacheMode=hybrid"
depends_on:
redis:
condition: service_healthy
volumes:
redis-data:
```
### Kubernetes
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-with-redis
spec:
plugin:
traefikoidc:
providerURL: https://accounts.google.com
clientID: your-client-id
clientSecret: your-client-secret
sessionEncryptionKey: your-encryption-key
callbackURL: /oauth2/callback
redis:
enabled: true
address: "redis-service.redis-namespace:6379"
password: "urn:k8s:secret:redis-secret:password"
db: 0
keyPrefix: "traefikoidc:"
cacheMode: "hybrid"
poolSize: 20
enableCircuitBreaker: true
circuitBreakerThreshold: 5
```
### AWS ElastiCache
```yaml
redis:
enabled: true
address: "your-cache.abc123.cache.amazonaws.com:6379"
cacheMode: "hybrid"
enableTLS: true
password: "your-elasticache-auth-token"
```
---
## Performance Tuning
### Connection Pool Sizing
```yaml
redis:
poolSize: 20 # Formula: 2 * CPU cores * replicas
# For 4 cores, 3 replicas: poolSize = 24
```
### TTL Strategy
The plugin automatically sets TTLs based on token lifetimes:
- **JTI Cache**: Matches token lifetime (typically 1 hour)
- **Session**: Matches `sessionMaxAge` configuration
- **Token Metadata**: 5 minutes (short-lived)
### Redis Server Configuration
```bash
# Recommended Redis settings for cache
maxmemory 512mb
maxmemory-policy allkeys-lru # Evict least recently used
# For cache data, disable persistence for better performance
save ""
appendonly no
```
### Hybrid Mode Tuning
```yaml
redis:
cacheMode: "hybrid"
hybridL1Size: 500 # Max items in local cache
hybridL1MemoryMB: 10 # Max memory for local cache
```
---
## Monitoring
### Key Metrics
- **Cache hit rate** (target: >90% for hybrid mode)
- **Redis latency** (target: <10ms p99)
- **Circuit breaker state**
- **Connection pool utilization
### Redis Commands for Monitoring
```bash
# Monitor commands in real-time
redis-cli MONITOR
# Check slow queries
redis-cli SLOWLOG GET 10
# Memory usage
redis-cli INFO memory
# Key statistics
redis-cli DBSIZE
# List keys with prefix
redis-cli --scan --pattern "traefikoidc:*"
# Check key TTL
redis-cli TTL "traefikoidc:session:abc123"
```
### Health Check Endpoint
The plugin provides health information including:
```json
{
"status": "healthy",
"cache": {
"mode": "hybrid",
"redis": {
"connected": true,
"latency": "2ms"
},
"circuit_breaker": {
"state": "closed",
"failures": 0
}
}
}
```
---
## Troubleshooting
### Connection Refused
**Symptoms:** `dial tcp: connection refused`
**Solutions:**
1. Verify Redis is running: `redis-cli ping`
2. Check network connectivity: `telnet redis-host 6379`
3. Verify address configuration
### Authentication Failure
**Symptoms:** `NOAUTH Authentication required`
**Solutions:**
1. Set Redis password in configuration
2. Verify password is correct
### Circuit Breaker Open
**Symptoms:** `Circuit breaker is open`, falling back to memory
**Solutions:**
1. Check Redis health: `redis-cli INFO server`
2. Review network latency: `redis-cli --latency`
3. Adjust circuit breaker thresholds if needed
### High Memory Usage
**Symptoms:** Redis memory constantly growing, OOM errors
**Solutions:**
1. Configure eviction policy:
```bash
CONFIG SET maxmemory 512mb
CONFIG SET maxmemory-policy allkeys-lru
```
2. Review key count: `redis-cli DBSIZE`
3. Check for large keys: `redis-cli --bigkeys`
### Inconsistent Cache State
**Symptoms:** Different responses from different replicas
**Solutions:**
1. Verify all instances use the same Redis address
2. Check cache mode consistency across instances
3. Verify time synchronization on all hosts
---
## Migration Guide
### From Memory-Only to Redis
#### Phase 1: Preparation
1. Deploy Redis infrastructure
2. Test Redis connectivity
3. Configure monitoring
#### Phase 2: Gradual Rollout
1. Enable Redis on one instance:
```yaml
redis:
enabled: true
address: "redis:6379"
cacheMode: "hybrid"
```
2. Monitor for errors
3. Gradually enable on more instances
#### Phase 3: Full Migration
1. Enable Redis on all instances
2. Remove `disableReplayDetection: true` if set
3. Monitor for issues
### Rollback Plan
If issues occur:
1. Set `redis.enabled: false`
2. Plugin falls back to memory cache automatically
3. Investigate and resolve issues
### Migration Checklist
- [ ] Redis deployed and accessible
- [ ] Redis password configured
- [ ] Network connectivity verified
- [ ] Monitoring configured
- [ ] Backup plan prepared
- [ ] Test environment validated
- [ ] Gradual rollout planned
---
## Best Practices
### Security
- Always use Redis password authentication
- Enable TLS for production deployments
- Use network segmentation (private subnets)
- Rotate Redis passwords regularly
### High Availability
- Use Redis Sentinel or Cluster for HA
- Configure appropriate circuit breaker thresholds
- Implement proper health checks
- Use connection pooling
### Performance
- Use hybrid cache mode for best performance
- Monitor cache hit rates
- Size Redis memory appropriately
- Disable persistence for cache-only usage
### Operations
- Implement comprehensive monitoring
- Set up alerting for circuit breaker state
- Document Redis configuration
- Test failover scenarios
---
## FAQ
### Is Redis required?
No, Redis is optional. The plugin works with in-memory cache for single-instance deployments.
### What happens if Redis goes down?
The circuit breaker opens after threshold failures, and the plugin falls back to in-memory cache. It periodically attempts to reconnect.
### Which cache mode should I use?
For production multi-replica deployments, use `hybrid` mode for best performance and consistency.
### How much memory does Redis need?
Depends on active sessions and token sizes:
- Small (1-1000 users): 128MB
- Medium (1000-10000 users): 256-512MB
- Large (10000+ users): 1GB+
### Can I use managed Redis services?
Yes, the plugin works with AWS ElastiCache, Azure Cache for Redis, Google Cloud Memorystore, and Redis Enterprise Cloud.
### Is data encrypted in Redis?
Session data is encrypted before storing using `sessionEncryptionKey`. Additionally, you can enable TLS for Redis connections.
-1125
View File
File diff suppressed because it is too large Load Diff
-413
View File
@@ -1,413 +0,0 @@
# Redis Cache Backend Test Suite
## Overview
This document describes the comprehensive test suite created for the Redis cache backend feature in the Traefik OIDC plugin. The test suite ensures reliability, performance, and correctness of the caching infrastructure.
## Test Structure
### Directory Organization
```
internal/cache/
├── backend/
│ ├── interface.go # CacheBackend interface definition
│ ├── interface_test.go # Contract tests for all backends
│ ├── memory.go # In-memory backend implementation
│ ├── memory_test.go # Memory backend unit tests
│ ├── redis.go # Redis backend implementation
│ ├── redis_test.go # Redis backend unit tests
│ ├── errors.go # Error definitions
│ └── test_helpers_test.go # Test infrastructure and helpers
└── resilience/
├── circuit_breaker.go # Circuit breaker implementation
├── circuit_breaker_test.go # Circuit breaker tests
├── health_check.go # Health checker implementation
└── health_check_test.go # Health check tests
redis_integration_test.go # End-to-end integration tests
```
## Test Categories
### 1. Interface Contract Tests (`interface_test.go`)
**Purpose:** Ensure all backend implementations (Memory, Redis, Hybrid) comply with the CacheBackend interface contract.
**Test Cases:**
- `TestCacheBackendContract` - Runs all contract tests against each backend type
- `testBasicSetGet` - Verifies basic set/get operations
- `testGetNonExistent` - Tests behavior for non-existent keys
- `testUpdateExisting` - Validates updating existing keys
- `testDelete` - Tests delete operations
- `testDeleteNonExistent` - Delete non-existent keys
- `testExists` - Key existence checking
- `testTTLExpiration` - TTL and expiration behavior
- `testClear` - Clear all keys operation
- `testPing` - Health check functionality
- `testStats` - Statistics tracking
- `testConcurrentAccess` - Thread safety with 10+ goroutines
- `testLargeValues` - Handling of 1MB+ values
- `testEmptyValues` - Empty byte array handling
- `testSpecialCharactersInKeys` - Special characters in key names
**Coverage:** ~95% of interface methods
### 2. Memory Backend Tests (`memory_test.go`)
**Purpose:** Test the in-memory LRU cache backend with comprehensive edge cases.
**Test Cases:**
#### Basic Operations (6 tests)
- `TestMemoryBackend_BasicOperations` - CRUD operations
- SetAndGet
- GetNonExistent
- Delete
- DeleteNonExistent
- Exists
- Clear
#### TTL and Expiration (3 tests)
- `TestMemoryBackend_TTLExpiration`
- ShortTTL (100ms)
- TTLDecrement over time
- CleanupExpiredItems
#### LRU Eviction (2 tests)
- `TestMemoryBackend_LRUEviction` - Verifies LRU algorithm
- `TestMemoryBackend_MemoryLimit` - Memory-based eviction
#### Concurrency (1 test)
- `TestMemoryBackend_ConcurrentAccess` - 20 goroutines, 50 iterations each
#### Edge Cases (6 tests)
- `TestMemoryBackend_UpdateExisting` - Overwriting values
- `TestMemoryBackend_Stats` - Metrics tracking (hits, misses, hit rate)
- `TestMemoryBackend_EmptyValues` - Zero-length byte arrays
- `TestMemoryBackend_LargeValues` - 1MB values
- `TestMemoryBackend_Close` - Proper cleanup
- `TestMemoryBackend_Ping` - Health checks
- `TestMemoryBackend_ValueIsolation` - Returns copies, not references
**Coverage:** ~92% of memory backend code
### 3. Redis Backend Tests (`redis_test.go`)
**Purpose:** Test Redis backend using miniredis (in-memory Redis mock).
**Test Cases:**
#### Basic Operations (4 tests)
- `TestRedisBackend_BasicOperations`
- SetAndGet
- GetNonExistent
- Delete
- Exists
#### Redis-Specific Features (6 tests)
- `TestRedisBackend_KeyPrefixing` - Namespace isolation
- `TestRedisBackend_TTLExpiration` - Redis TTL handling
- `TestRedisBackend_Clear` - Bulk delete with SCAN
- `TestRedisBackend_NoPrefix` - Operation without prefix
#### Error Handling (2 tests)
- `TestRedisBackend_ConnectionFailure` - Connection errors
- `TestRedisBackend_RedisErrors` - Simulated Redis failures
#### Concurrency (1 test)
- `TestRedisBackend_ConcurrentAccess` - 20 goroutines, 50 operations
#### Advanced Features (3 tests)
- `TestRedisBackend_PipelineOperations`
- SetMany (batch writes)
- GetMany (batch reads)
- GetManyWithNonExistent
#### Edge Cases (5 tests)
- `TestRedisBackend_Stats` - Statistics tracking
- `TestRedisBackend_Ping` - Connection health
- `TestRedisBackend_Close` - Resource cleanup
- `TestRedisBackend_UpdateExisting` - Overwrite handling
- `TestRedisBackend_LargeValues` - 1MB values
- `TestRedisBackend_EmptyValues` - Empty arrays
**Coverage:** ~88% of Redis backend code
**Key Testing Tool:** `miniredis` - In-memory Redis mock that supports:
- All basic Redis commands
- TTL and expiration
- Time manipulation (FastForward)
- Error simulation
- No external Redis server required
### 4. Circuit Breaker Tests (`circuit_breaker_test.go`)
**Purpose:** Verify circuit breaker pattern implementation for fault tolerance.
**Test Cases:**
#### State Transitions (5 tests)
- `TestCircuitBreaker_StateTransitions`
- Initial state (Closed)
- Closed → Open (after max failures)
- Open → HalfOpen (after timeout)
- HalfOpen → Closed (after successful requests)
- HalfOpen → Open (on failure)
#### Behavior Tests (5 tests)
- `TestCircuitBreaker_OpenCircuitBlocks` - Blocks requests when open
- `TestCircuitBreaker_HalfOpenMaxRequests` - Limits requests in half-open
- `TestCircuitBreaker_SuccessResetsFailures` - Failure counter reset
- `TestCircuitBreaker_ConcurrentAccess` - Thread safety
- `TestCircuitBreaker_Stats` - Statistics tracking
#### Advanced Tests (7 tests)
- `TestCircuitBreaker_Reset` - Manual reset
- `TestCircuitBreaker_StateChangeCallback` - Notifications
- `TestCircuitBreaker_IsAvailable` - Availability check
- `TestCircuitBreaker_RapidFailures` - Fast consecutive failures
- `TestCircuitBreaker_TimeoutAccuracy` - Timeout precision
- `TestCircuitBreaker_DefaultConfig` - Default configuration
- `TestCircuitBreaker_StateString` - String representation
**Benchmarks:**
- `BenchmarkCircuitBreaker_Execute` - Successful operations
- `BenchmarkCircuitBreaker_ExecuteWithFailures` - Mixed success/failure
**Coverage:** ~95% of circuit breaker code
### 5. Health Check Tests (`health_check_test.go`)
**Purpose:** Validate periodic health checking and status management.
**Test Cases:**
#### Status Transitions (4 tests)
- `TestHealthChecker_StatusTransitions` - Healthy → Degraded → Unhealthy → Healthy
- `TestHealthChecker_InitialState` - Default healthy state
- `TestHealthChecker_ForceCheck` - Manual health check trigger
- `TestHealthChecker_StatusChangeCallback` - Change notifications
#### Behavior Tests (6 tests)
- `TestHealthChecker_Stats` - Statistics tracking
- `TestHealthChecker_Timeout` - Check timeout handling
- `TestHealthChecker_ConcurrentAccess` - Thread safety
- `TestHealthChecker_StopAndStart` - Lifecycle management
- `TestHealthChecker_DegradedState` - Degraded status detection
- `TestHealthChecker_DefaultConfig` - Default settings
#### Advanced Tests (2 tests)
- `TestHealthChecker_StatusString` - String representation
- `TestHealthChecker_RecoveryPattern` - Typical failure/recovery cycle
**Benchmarks:**
- `BenchmarkHealthChecker_ForceCheck` - Check performance
- `BenchmarkHealthChecker_Status` - Status read performance
**Coverage:** ~90% of health checker code
### 6. Integration Tests (`redis_integration_test.go`)
**Purpose:** End-to-end testing of real-world scenarios.
**Test Cases:**
#### Multi-Instance Tests (3 tests)
- `TestRedisIntegration_MultipleInstances`
- ShareTokenBlacklist - JTI sharing across Traefik replicas
- ShareTokenCache - Token cache sharing
- ShareMetadataCache - Provider metadata sharing
#### Replay Detection (2 tests)
- `TestRedisIntegration_JTIReplayDetection`
- PreventReplayAcrossInstances - Block used JTIs
- ConcurrentJTIChecks - Race condition handling
#### Resilience (1 test)
- `TestRedisIntegration_Failover`
- RedisTemporaryFailure - Recovery from temporary failures
#### Performance (1 test)
- `TestRedisIntegration_HighLoad`
- HighConcurrency - 50 goroutines × 100 operations
#### Consistency (2 tests)
- `TestRedisIntegration_TTLConsistency` - TTL accuracy
- `TestRedisIntegration_MemoryUsage` - 10,000 item dataset
- `TestRedisIntegration_Cleanup` - Bulk cleanup operations
**Coverage:** Integration scenarios covering 80%+ of realistic use cases
## Test Helpers and Infrastructure
### Test Helpers (`test_helpers_test.go`)
**Utilities:**
- `TestLogger` - Logging for tests
- `MiniredisServer` - Miniredis setup/teardown
- `TestConfig` - Default test configurations
- `GenerateTestData` - Test data generation
- `GenerateLargeValue` - Large value creation
- `AssertCacheStats` - Statistics validation
- `WaitForCondition` - Async condition waiting
- `AssertEventuallyExpires` - TTL expiration verification
## Running the Tests
### Run All Tests
```bash
go test ./internal/cache/backend/... -v
go test ./internal/cache/resilience/... -v
go test -run TestRedisIntegration -v
```
### Run Specific Test Suites
```bash
# Memory backend only
go test ./internal/cache/backend -run TestMemoryBackend -v
# Redis backend only
go test ./internal/cache/backend -run TestRedisBackend -v
# Circuit breaker only
go test ./internal/cache/resilience -run TestCircuitBreaker -v
# Integration tests only
go test -run TestRedisIntegration -v
```
### Run with Coverage
```bash
go test ./internal/cache/backend/... -coverprofile=coverage.out
go test ./internal/cache/resilience/... -coverprofile=coverage_resilience.out
go tool cover -html=coverage.out
```
### Run Benchmarks
```bash
go test ./internal/cache/backend -bench=. -benchmem
go test ./internal/cache/resilience -bench=. -benchmem
```
### Run with Race Detector
```bash
go test ./internal/cache/... -race -v
```
## Test Patterns Used
### 1. Table-Driven Tests
Used for testing multiple scenarios with similar structure.
### 2. Subtests (t.Run)
Organized test cases into logical groups with clear names.
### 3. Parallel Tests
Tests marked with `t.Parallel()` for faster execution.
### 4. Test Fixtures
Reusable setup functions for common test data.
### 5. Mocking
- `miniredis` for Redis operations
- Mock functions for callbacks and health checks
### 6. Assertion Helpers
Using `testify/assert` and `testify/require` for clear assertions.
## Test Coverage Summary
| Component | Coverage | Tests | Lines of Code |
|-----------|----------|-------|---------------|
| Interface Contract | 95% | 14 | ~200 |
| Memory Backend | 92% | 18 | ~350 |
| Redis Backend | 88% | 21 | ~400 |
| Circuit Breaker | 95% | 17 | ~250 |
| Health Checker | 90% | 12 | ~200 |
| Integration Tests | 80% | 9 | ~300 |
| **Total** | **90%** | **91** | **~1,700** |
## Edge Cases Tested
1. **Empty values** - Zero-length byte arrays
2. **Large values** - 1MB+ data
3. **Special characters** - Keys with :, /, -, _, ., |
4. **Concurrent access** - 10-50 goroutines
5. **TTL edge cases** - Very short (<100ms) and long (24h+) TTLs
6. **Connection failures** - Network errors, timeouts
7. **Redis errors** - Simulated Redis failures
8. **Memory limits** - Eviction under memory pressure
9. **Race conditions** - Concurrent JTI checks
10. **State transitions** - All circuit breaker and health check states
## Performance Benchmarks
Benchmarks included for:
- Cache operations (Set, Get, Delete)
- Circuit breaker execution
- Health check operations
- Concurrent access patterns
- Large datasets (10,000+ items)
## Dependencies
### Testing Libraries
- `github.com/stretchr/testify` - Assertions and test utilities
- `github.com/alicebob/miniredis/v2` - In-memory Redis mock
- `github.com/redis/go-redis/v9` - Redis client
### Why Miniredis?
- **No external dependencies** - No Redis server required
- **Fast** - In-memory, perfect for unit tests
- **Full Redis API** - Supports all operations we need
- **Time manipulation** - FastForward for TTL testing
- **Error simulation** - Test failure scenarios
## Future Enhancements
### Planned Tests
1. Hybrid backend tests (L1/L2 cache)
2. Network partition scenarios
3. Redis cluster support
4. Persistence and recovery tests
5. Metrics and monitoring integration
### Test Infrastructure Improvements
1. Test containers for real Redis integration
2. Performance regression tracking
3. Chaos engineering tests
4. Load testing framework
## Continuous Integration
### Recommended CI Configuration
```yaml
test:
script:
- go test ./internal/cache/... -race -cover -v
- go test -run TestRedisIntegration -v
- go test ./internal/cache/... -bench=. -benchmem
```
## Maintenance Guidelines
1. **Add tests for new features** - Maintain >85% coverage
2. **Update contract tests** - When interface changes
3. **Test edge cases** - Always test error paths
4. **Document test purpose** - Clear comments explaining what each test validates
5. **Keep tests fast** - Use t.Parallel() where possible
6. **Mock external dependencies** - Use miniredis, not real Redis
## Conclusion
This comprehensive test suite provides:
- **High confidence** in cache backend correctness
- **Fast feedback** - Tests run in seconds
- **Good coverage** - 90% overall
- **Clear documentation** - Each test is well-documented
- **Maintainability** - Clear structure and patterns
The test suite ensures that the Redis cache backend feature is production-ready and reliable for multi-replica Traefik deployments with shared caching requirements.
+390
View File
@@ -0,0 +1,390 @@
# Testing Guide
Comprehensive testing infrastructure for traefikoidc.
## Overview
| Metric | Value |
|--------|-------|
| Test files | 99 |
| Lines of test code | ~65,500 |
| Code coverage | 71.0% |
| Race conditions | None (all pass with `-race`) |
## Running Tests
```bash
# Run all tests
go test ./...
# Run with race detection
go test -race ./...
# Run with coverage
go test -cover ./...
# Run specific test suite
go test -v -run "TokenValidationSuite" .
# Run edge case tests
go test -v -run "ClockSkewEdgeCasesSuite|UnicodeClaimsSuite" .
```
## Test Infrastructure
### Directory Structure
```
internal/testutil/
├── compat.go # Re-exports for main package access
├── mocks/
│ ├── interfaces.go # JWKCache, TokenExchanger, TokenVerifier, etc.
│ ├── session.go # SessionManager, SessionData
│ ├── cache.go # Cache, TokenCache, Blacklist
│ └── interfaces_test.go # Mock verification tests
├── fixtures/
│ └── tokens.go # JWT token generation fixtures
└── servers/
├── oidc.go # Mock OIDC server factory
└── oidc_test.go # Server tests
```
### Test Suites
| Suite | File | Description |
|-------|------|-------------|
| TokenValidationSuite | `token_validation_suite_test.go` | Token validation happy path and error cases |
| JWKCacheTestSuite | `token_validation_suite_test.go` | JWK cache behavior tests |
| TokenExchangerTestSuite | `token_validation_suite_test.go` | Token exchange scenarios |
| ClockSkewEdgeCasesSuite | `edge_cases_suite_test.go` | Expiry boundary testing |
| UnicodeClaimsSuite | `edge_cases_suite_test.go` | Unicode/emoji handling in claims |
| LargeClaimsSuite | `edge_cases_suite_test.go` | Large data handling (100s of claims) |
| URLPathEdgeCasesSuite | `edge_cases_suite_test.go` | URL parsing edge cases |
| ConcurrencyEdgeCasesSuite | `edge_cases_suite_test.go` | Concurrent token validation |
| ExampleTestSuite | `testutil_example_test.go` | Example demonstrating patterns |
| AuthFlowBehaviourSuite | `auth_flow_behaviour_test.go` | Authentication flow behavior tests |
| SessionBehaviourSuite | `session_behaviour_test.go` | Session management behavior tests |
| EnhancedMocksSuite | `enhanced_mocks_suite_test.go` | Enhanced mock usage demonstration |
## Mock Types
The project provides two mocking patterns:
### State-Based Mocks (Basic)
Located in `main_test.go`, `mocks_test.go`. Simple mocks that store data in struct fields.
| Mock | Interface | Description |
|------|-----------|-------------|
| `MockJWKCache` | `JWKCacheInterface` | Simple state-based mock with JWKS/Err fields |
| `MockTokenVerifier` | `TokenVerifier` | Function-based mock for token verification |
| `MockTokenExchanger` | `TokenExchanger` | Function-based mock for token exchange |
| `MockOAuthProvider` | `http.Handler` | Full HTTP handler mock for OAuth provider simulation |
| `MockSessionManager` | `SessionManager` | State-based mock for session management |
| `MockHTTPClient` | N/A | Mock HTTP client with customizable responses |
**Usage:**
```go
mock := &MockJWKCache{
JWKS: &JWKSet{Keys: []JWK{jwk}},
Err: nil,
}
tOidc := &TraefikOidc{
jwkCache: mock,
// ...
}
```
### Enhanced State-Based Mocks (with Call Tracking)
Located in `enhanced_mocks_test.go`. State-based mocks with built-in call tracking and assertion helpers.
| Mock | Interface | Description |
|------|-----------|-------------|
| `EnhancedMockJWKCache` | `JWKCacheInterface` | State-based with call tracking |
| `EnhancedMockTokenVerifier` | `TokenVerifier` | State-based with call tracking |
| `EnhancedMockTokenExchanger` | `TokenExchanger` | State-based with call tracking |
| `EnhancedMockCacheInterface` | `CacheInterface` | Functional cache with call tracking |
**Usage:**
```go
mock := &EnhancedMockJWKCache{
JWKS: &JWKSet{Keys: []JWK{jwk}},
}
// Make calls
result, err := mock.GetJWKS(ctx, "https://example.com/jwks", nil)
// Verify calls were made
mock.AssertGetJWKSCalled(t)
mock.AssertGetJWKSCalledWith(t, "https://example.com/jwks")
mock.AssertGetJWKSCallCount(t, 1)
// Access call details
s.Equal(1, mock.GetJWKSCallCount())
```
**Features:**
- Track all calls with parameters and timestamps
- Built-in assertion helpers using testify
- Thread-safe for concurrent tests
- `Reset()` method to clear state between tests
- `LastCall()` to inspect most recent call
### Testify-Based Mocks
Located in `testify_mocks_test.go`. Mocks using testify's `.On()/.Return()` pattern for behavior verification.
| Mock | Interface | Description |
|------|-----------|-------------|
| `TestifyJWKCache` | `JWKCacheInterface` | Testify mock with `.On()/.Return()` |
| `TestifyTokenVerifier` | `TokenVerifier` | Testify mock for token verification |
| `TestifyTokenExchanger` | `TokenExchanger` | Testify mock for token exchange |
| `TestifyCacheInterface` | `CacheInterface` | Testify mock for cache operations |
| `TestifyHTTPClient` | N/A | Testify mock for HTTP client |
| `TestifyRoundTripper` | `http.RoundTripper` | Testify mock for HTTP transport |
**Usage:**
```go
mock := &TestifyJWKCache{}
mock.On("GetJWKS", mock.Anything, "https://example.com/jwks", mock.Anything).
Return(&JWKSet{Keys: []JWK{jwk}}, nil)
// After test
mock.AssertExpectations(t)
```
### Testutil Package Mocks
Located in `internal/testutil/mocks/`. Generic mocks for testing the test infrastructure itself.
```go
import "github.com/lukaszraczylo/traefikoidc/internal/testutil"
mock := testutil.NewJWKCacheMock()
mock.On("GetJWKS", mock.Anything, mock.Anything, mock.Anything).
Return(&mocks.JWKSet{Keys: []mocks.JWK{{Kty: "RSA"}}}, nil)
```
### Choosing the Right Mock
| Use Case | Recommended Mock |
|----------|-----------------|
| Simple return values only | Basic state-based (`MockJWKCache`) |
| Return values + verify calls made | Enhanced state-based (`EnhancedMockJWKCache`) |
| Complex call expectations | Testify-based (`TestifyJWKCache`) |
| Verify call order/sequence | Testify-based |
| HTTP endpoint simulation | `MockOAuthProvider` |
| New testify suite tests | Enhanced or Testify-based |
**Decision Guide:**
1. **Basic State-Based**: Use when you only need to control return values and don't care about verifying interactions.
2. **Enhanced State-Based**: Use when you want to verify calls were made with specific parameters, but prefer simpler setup than testify's `.On()/.Return()` pattern.
3. **Testify-Based**: Use when you need complex behavior like different returns per call, strict call ordering, or detailed expectation matching.
## Token Fixtures
The `testutil.TokenFixture` generates JWT tokens for testing:
```go
fixture, err := testutil.NewTokenFixture()
// Valid token with default claims
token, _ := fixture.ValidToken(nil)
// Token with custom claims
token, _ := fixture.ValidToken(map[string]interface{}{
"email": "test@example.com",
"roles": []string{"admin"},
})
// Expired token
token, _ := fixture.ExpiredToken()
// Token with specific roles/groups
token, _ := fixture.TokenWithRoles([]string{"admin", "user"})
token, _ := fixture.TokenWithGroups([]string{"developers"})
// Token with clock skew
token, _ := fixture.TokenWithSkew(-2 * time.Minute) // expired 2 min ago
token, _ := fixture.TokenWithSkew(5 * time.Minute) // expires in 5 min
// Token missing specific claims
token, _ := fixture.TokenMissingClaim("email", "sub")
// Malformed token
token := fixture.MalformedToken() // "not.a.valid.jwt"
// Get JWKS for verification
jwks := fixture.GetJWKS()
```
## Mock OIDC Server
The `testutil.OIDCServer` provides a fully functional mock OIDC provider:
```go
// Default configuration
server := testutil.NewOIDCServer(nil)
defer server.Close()
// Custom configuration
config := testutil.DefaultServerConfig()
config.Issuer = "https://custom-issuer.com"
config.TokenError = &testutil.OIDCError{
Error: "invalid_grant",
Description: "Authorization code expired",
}
server := testutil.NewOIDCServer(config)
// Provider-specific configurations
googleConfig := testutil.GoogleServerConfig()
azureConfig := testutil.AzureServerConfig()
auth0Config := testutil.Auth0ServerConfig()
keycloakConfig := testutil.KeycloakServerConfig()
// Behavior configurations
slowConfig := testutil.SlowServerConfig(100 * time.Millisecond)
rateLimitedConfig := testutil.RateLimitedServerConfig(5) // Limit after 5 requests
```
### Server Endpoints
| Endpoint | Description |
|----------|-------------|
| `/.well-known/openid-configuration` | OIDC discovery document |
| `/authorize` | Authorization endpoint |
| `/token` | Token exchange endpoint |
| `/jwks` | JSON Web Key Set |
| `/userinfo` | User information endpoint |
| `/introspect` | Token introspection |
| `/revoke` | Token revocation |
| `/logout` | End session endpoint |
### Request Tracking
```go
server := testutil.NewOIDCServer(nil)
// Make requests...
count := server.GetRequestCount()
requests := server.GetRequests()
server.Reset() // Clear tracking
```
## Writing Test Suites
### Basic Suite Structure
```go
type MyTestSuite struct {
suite.Suite
fixture *testutil.TokenFixture
tOidc *TraefikOidc
}
func (s *MyTestSuite) SetupSuite() {
var err error
s.fixture, err = testutil.NewTokenFixture()
s.Require().NoError(err)
}
func (s *MyTestSuite) SetupTest() {
// Per-test setup
s.tOidc = &TraefikOidc{
issuerURL: s.fixture.Issuer,
// ...
}
}
func (s *MyTestSuite) TearDownTest() {
// Per-test cleanup
}
func (s *MyTestSuite) TestSomething() {
token, err := s.fixture.ValidToken(nil)
s.Require().NoError(err)
err = s.tOidc.VerifyToken(token)
s.NoError(err)
}
func TestMyTestSuite(t *testing.T) {
suite.Run(t, new(MyTestSuite))
}
```
### Table-Driven Tests
```go
func (s *MyTestSuite) TestClockSkewEdgeCases() {
testCases := []struct {
name string
skew time.Duration
shouldPass bool
}{
{"valid_token", 5 * time.Minute, true},
{"expired_within_tolerance", -1 * time.Minute, true},
{"expired_beyond_tolerance", -10 * time.Minute, false},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
token, err := s.fixture.TokenWithSkew(tc.skew)
s.Require().NoError(err)
err = s.tOidc.VerifyToken(token)
if tc.shouldPass {
s.NoError(err)
} else {
s.Error(err)
}
})
}
}
```
## Test Categories
### Happy Path Tests
Test the expected successful scenarios:
- Valid token verification
- Successful token exchange
- Session creation and retrieval
- Cache operations
### Error Case Tests
Test failure scenarios:
- Expired tokens
- Invalid signatures
- Wrong issuer/audience
- Network failures
- Rate limiting
### Edge Case Tests
Test boundary conditions:
- Clock skew tolerance boundaries
- Unicode/emoji in claims
- Very large claim values
- Concurrent access
- Special characters in URLs
## Best Practices
1. **Use fixtures for token generation** - Don't manually construct JWTs
2. **Use mock servers for integration tests** - Test against realistic OIDC behavior
3. **Always run with `-race`** - Catch concurrency issues early
4. **Use testify assertions** - Better error messages and cleaner code
5. **Clean up resources** - Use `t.Cleanup()` or `TearDownTest()`
6. **Test edge cases systematically** - Use table-driven tests
-308
View File
@@ -1,308 +0,0 @@
# Test Execution Guide
This guide explains how to run tests efficiently with the new test categorization and optimization system.
## Quick Start
### Fast Development Testing (Default - Target: < 30 seconds)
```bash
# Run quick smoke tests only
go test ./...
# Or explicitly run in short mode
go test ./... -short
```
### Extended Testing (Target: 2-5 minutes)
```bash
# Enable extended tests with more iterations and concurrency
RUN_EXTENDED_TESTS=1 go test ./...
# Or use the flag equivalent (if using test runner that supports it)
go test ./... -extended
```
### Long-Running Performance Tests (Target: 5-15 minutes)
```bash
# Enable comprehensive performance and stress tests
RUN_LONG_TESTS=1 go test ./...
```
### Full Stress Testing (Target: 10-30 minutes)
```bash
# Enable all stress tests with maximum parameters
RUN_STRESS_TESTS=1 go test ./...
```
## Test Categories
### 1. Quick Tests (Default)
- **Purpose**: Fast feedback during development
- **Duration**: < 30 seconds total
- **Features**:
- Basic functionality verification
- Limited iterations (1-3)
- Small data sets
- Minimal concurrency
- Essential memory leak checks
**Configuration**:
- Max Iterations: 3
- Max Concurrency: 5
- Memory Threshold: 2.0 MB
- Cache Size: 50
- Timeout: 10 seconds
### 2. Extended Tests
- **Purpose**: Comprehensive testing before commits
- **Duration**: 2-5 minutes
- **Features**:
- Increased test coverage
- More iterations (5-10)
- Medium concurrency tests
- Enhanced memory leak detection
**Configuration**:
- Max Iterations: 10
- Max Concurrency: 20
- Memory Threshold: 10.0 MB
- Cache Size: 200
- Timeout: 30 seconds
### 3. Long Tests
- **Purpose**: Performance validation and stress testing
- **Duration**: 5-15 minutes
- **Features**:
- High iteration counts (50-100)
- High concurrency scenarios
- Large data sets
- Comprehensive memory testing
**Configuration**:
- Max Iterations: 100
- Max Concurrency: 50
- Memory Threshold: 50.0 MB
- Cache Size: 1000
- Timeout: 60 seconds
### 4. Stress Tests
- **Purpose**: Maximum load testing and edge case validation
- **Duration**: 10-30 minutes
- **Features**:
- Extreme iteration counts (100-500)
- Maximum concurrency (100+)
- Large memory allocations
- Edge case combinations
**Configuration**:
- Max Iterations: 500
- Max Concurrency: 100
- Memory Threshold: 100.0 MB
- Cache Size: 2000
- Timeout: 120 seconds
## Environment Variables
### Test Execution Control
```bash
# Enable specific test types
export RUN_EXTENDED_TESTS=1 # Enable extended tests
export RUN_LONG_TESTS=1 # Enable long-running tests
export RUN_STRESS_TESTS=1 # Enable stress tests
# Disable specific features
export DISABLE_LEAK_DETECTION=1 # Skip memory leak detection
```
### Parameter Customization
```bash
# Customize concurrency limits
export TEST_MAX_CONCURRENCY=10 # Override max concurrent operations
# Customize iteration limits
export TEST_MAX_ITERATIONS=50 # Override max test iterations
# Customize memory thresholds
export TEST_MEMORY_THRESHOLD_MB=25.5 # Override memory growth limit (in MB)
```
## Test-Specific Behavior
### Memory Leak Tests
- **Quick Mode**: 1-3 iterations, small data sets, strict memory limits
- **Extended Mode**: 5-10 iterations, medium data sets, relaxed limits
- **Long Mode**: 50-100 iterations, large data sets, performance focus
- **Stress Mode**: 100-500 iterations, maximum data sets, stress focus
### Concurrency Tests
- **Quick Mode**: 2-5 concurrent operations, basic race detection
- **Extended Mode**: 10-20 concurrent operations, moderate stress
- **Long Mode**: 20-50 concurrent operations, high contention
- **Stress Mode**: 50-100+ concurrent operations, maximum stress
### Cache Tests
- **Quick Mode**: Small caches (50 items), basic operations
- **Extended Mode**: Medium caches (200 items), varied operations
- **Long Mode**: Large caches (1000 items), performance testing
- **Stress Mode**: Very large caches (2000+ items), stress testing
## Integration with CI/CD
### GitHub Actions Example
```yaml
# Quick tests for every push/PR
- name: Quick Tests
run: go test ./... -short
# Extended tests for main branch
- name: Extended Tests
if: github.ref == 'refs/heads/main'
run: RUN_EXTENDED_TESTS=1 go test ./...
# Nightly comprehensive testing
- name: Nightly Stress Tests
if: github.event_name == 'schedule'
run: RUN_STRESS_TESTS=1 go test ./...
```
### Local Development Workflow
```bash
# During active development
go test ./... -short
# Before committing
RUN_EXTENDED_TESTS=1 go test ./...
# Before major releases
RUN_LONG_TESTS=1 go test ./...
# Performance validation
RUN_STRESS_TESTS=1 go test ./...
```
## Performance Optimization Features
### Dynamic Test Scaling
The test system automatically adjusts parameters based on:
- Test mode (quick/extended/long/stress)
- Available resources
- Environment variables
- Previous test performance
### Memory Management
- **Garbage Collection**: Forced GC between test iterations
- **Memory Monitoring**: Real-time memory growth tracking
- **Leak Detection**: Goroutine and memory leak prevention
- **Resource Cleanup**: Automatic cleanup of test resources
### Timeout Management
- **Adaptive Timeouts**: Timeouts scale with test complexity
- **Graceful Degradation**: Tests adapt to slower environments
- **Early Termination**: Failed tests terminate quickly
## Troubleshooting
### Tests Taking Too Long
```bash
# Check if running in extended mode accidentally
echo $RUN_EXTENDED_TESTS $RUN_LONG_TESTS
# Force quick mode
unset RUN_EXTENDED_TESTS RUN_LONG_TESTS RUN_STRESS_TESTS
go test ./... -short
```
### Memory Issues
```bash
# Reduce memory limits for constrained environments
export TEST_MEMORY_THRESHOLD_MB=5.0
export TEST_MAX_CONCURRENCY=2
go test ./...
```
### Concurrency Issues
```bash
# Reduce concurrency for slower systems
export TEST_MAX_CONCURRENCY=5
export TEST_MAX_ITERATIONS=10
go test ./...
```
### Skip Specific Test Types
```bash
# Skip memory leak detection if problematic
export DISABLE_LEAK_DETECTION=1
go test ./...
```
## Benchmarking
### Running Benchmarks
```bash
# Quick benchmarks
go test -bench=. -short
# Extended benchmarks
RUN_EXTENDED_TESTS=1 go test -bench=.
# Memory profiling
go test -bench=. -memprofile=mem.prof
go tool pprof mem.prof
```
### Benchmark Categories
- **Basic Operations**: Set/Get performance
- **Concurrency**: Multi-threaded performance
- **Memory**: Allocation and cleanup performance
- **Cache**: Eviction and cleanup performance
## Best Practices
### For Developers
1. Always run quick tests during development (`go test ./... -short`)
2. Run extended tests before committing (`RUN_EXTENDED_TESTS=1 go test ./...`)
3. Use appropriate test categories for your use case
4. Monitor test execution time and adjust if needed
### For CI/CD
1. Use quick tests for fast feedback on PRs
2. Use extended tests for main branch validation
3. Use long tests for release validation
4. Use stress tests for nightly/weekly validation
### For Performance Testing
1. Use consistent environment variables
2. Run tests multiple times for statistical significance
3. Monitor both execution time and resource usage
4. Use profiling tools for detailed analysis
## Examples
### Daily Development
```bash
# Fast tests while coding
go test ./... -short
# Before git commit
RUN_EXTENDED_TESTS=1 go test ./...
```
### Release Testing
```bash
# Comprehensive validation
RUN_LONG_TESTS=1 go test ./...
# Stress testing
RUN_STRESS_TESTS=1 go test ./...
```
### Custom Configuration
```bash
# Custom limits for specific environment
export TEST_MAX_CONCURRENCY=8
export TEST_MAX_ITERATIONS=25
export TEST_MEMORY_THRESHOLD_MB=15.0
RUN_EXTENDED_TESTS=1 go test ./...
```
This test system provides flexible, scalable test execution that adapts to your development workflow and infrastructure constraints while maintaining comprehensive test coverage.
-163
View File
@@ -1,163 +0,0 @@
# Google OAuth Integration Fix
## Problem Overview
The Traefik OIDC plugin encountered an authentication issue when using Google as an OAuth provider. Authentication would fail with the following error:
```
Some requested scopes were invalid. {valid=[openid, https://www.googleapis.com/auth/userinfo.email, https://www.googleapis.com/auth/userinfo.profile], invalid=[offline_access]}
```
This occurred because Google's OAuth implementation differs from the standard OIDC specification in how it handles refresh tokens and offline access.
## Technical Details of the Issue
### Standard OIDC Provider Behavior
Most OpenID Connect (OIDC) providers follow the standard specification, where:
- To obtain a refresh token, clients include the `offline_access` scope in their authorization request
- This allows authenticated sessions to persist beyond the initial access token expiration
### Google's Non-Standard Approach
Google's OAuth implementation deviates from the standard by:
1. Not supporting the `offline_access` scope, instead rejecting it as an invalid scope
2. Requiring the `access_type=offline` query parameter for requesting refresh tokens
3. Needing the `prompt=consent` parameter to consistently issue refresh tokens (especially for repeat authentications)
This difference caused the plugin to fail when configured for Google OAuth, as it was using a standard approach that didn't work with Google's implementation.
## Solution Implementation
The fix involved modifying the authentication flow to specifically handle Google providers:
1. **Google Provider Detection**: Added code to detect if the OIDC provider is Google based on the issuer URL:
```go
// Check if we're dealing with a Google OIDC provider
isGoogleProvider := strings.Contains(t.issuerURL, "google") ||
strings.Contains(t.issuerURL, "accounts.google.com")
```
2. **Provider-Specific Auth URL Building**: Modified the `buildAuthURL` function to handle Google and non-Google providers differently:
```go
// Handle offline access differently for Google vs other providers
if isGoogleProvider {
// For Google, use access_type=offline parameter instead of offline_access scope
params.Set("access_type", "offline")
t.logger.Debug("Google OIDC provider detected, added access_type=offline for refresh tokens")
// Add prompt=consent for Google to ensure refresh token is issued
params.Set("prompt", "consent")
t.logger.Debug("Google OIDC provider detected, added prompt=consent to ensure refresh tokens")
} else {
// For non-Google providers, use the offline_access scope
hasOfflineAccess := false
for _, scope := range scopes {
if scope == "offline_access" {
hasOfflineAccess = true
break
}
}
if !hasOfflineAccess {
scopes = append(scopes, "offline_access")
}
}
```
3. **Token Refresh Enhancement**: Improved the token refresh logic to better handle Google's behavior, particularly when refresh tokens aren't returned in refresh responses (as Google often uses the same refresh token for multiple requests).
## Why This Approach Works
This solution aligns with Google's OAuth 2.0 documentation which specifies:
1. **Access Type Parameter**: Google's [OAuth 2.0 documentation](https://developers.google.com/identity/protocols/oauth2/web-server#offline) states that to request a refresh token, applications must include `access_type=offline` in the authorization request.
2. **Prompt Parameter**: The [`prompt=consent`](https://developers.google.com/identity/protocols/oauth2/web-server#forceapprovalprompt) parameter forces the consent screen to appear, ensuring a refresh token is issued even if the user has previously granted access.
3. **Scope Validation**: Google strictly validates scopes and rejects non-standard ones like `offline_access`, instead relying on the `access_type` parameter to indicate whether a refresh token should be issued.
By adapting to these Google-specific requirements, the OIDC plugin can now seamlessly work with both standard OIDC providers and Google's OAuth implementation.
## Testing and Verification
Comprehensive tests were implemented to verify the solution:
1. **Provider Detection Test**: Ensures the code correctly identifies Google providers and applies the appropriate parameters.
2. **Auth URL Parameter Tests**: Verifies that:
- For Google providers: `access_type=offline` and `prompt=consent` are included; `offline_access` scope is NOT included
- For non-Google providers: `offline_access` scope IS included; `access_type` parameter is NOT added
3. **Token Refresh Tests**: Validates that Google's token refresh process works correctly, including the preservation of refresh tokens when Google doesn't return a new one.
4. **Integration Test**: Tests the complete authentication flow with a mocked Google provider to ensure all components work together seamlessly.
Sample test case (simplified):
```go
t.Run("Google provider detection adds required parameters", func(t *testing.T) {
// Test buildAuthURL to ensure it adds access_type=offline and prompt=consent for Google
authURL := tOidc.buildAuthURL("https://example.com/callback", "state123", "nonce123", "")
// Check that access_type=offline was added (not offline_access scope for Google)
if !strings.Contains(authURL, "access_type=offline") {
t.Errorf("access_type=offline not added to Google auth URL: %s", authURL)
}
// Verify offline_access scope is NOT included for Google providers
if strings.Contains(authURL, "offline_access") {
t.Errorf("offline_access scope incorrectly added to Google auth URL: %s", authURL)
}
// Check that prompt=consent was added
if !strings.Contains(authURL, "prompt=consent") {
t.Errorf("prompt=consent not added to Google auth URL: %s", authURL)
}
})
```
## Usage Guidance for Developers
When configuring the Traefik OIDC middleware for Google:
1. **Provider URL**: Use `https://accounts.google.com` as the `providerURL` value
2. **Client Configuration**: Create OAuth 2.0 credentials in the Google Cloud Console:
- Configure the authorized redirect URI to match your `callbackURL` setting
- Ensure your OAuth consent screen is properly configured (especially if you want long-lived refresh tokens)
3. **Configuration Example**:
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-google
namespace: traefik
spec:
plugin:
traefikoidc:
providerURL: https://accounts.google.com
clientID: your-google-client-id.apps.googleusercontent.com
clientSecret: your-google-client-secret
sessionEncryptionKey: your-secure-encryption-key-min-32-chars
callbackURL: /oauth2/callback
scopes:
- openid
- email
- profile
# Note: DO NOT manually add offline_access scope for Google
# The middleware handles this automatically and correctly
```
4. **Troubleshooting**: If sessions still expire prematurely with Google (typically after 1 hour):
- Ensure your Google Cloud OAuth consent screen is set to "External" and "Production" mode (not "Testing" mode, which limits refresh token validity)
- Review your application logs with `logLevel: debug` to check for refresh token errors
- Verify you're using a version of the middleware that includes this fix
## Conclusion
This fix ensures that the Traefik OIDC plugin works seamlessly with Google's OAuth implementation without requiring users to make provider-specific configuration changes. The middleware now intelligently adapts to the provider's requirements, making it more robust and user-friendly while maintaining compatibility with the standard OIDC specification for other providers.
+1
View File
@@ -348,6 +348,7 @@ func (r *DynamicClientRegistrar) saveCredentials(resp *ClientRegistrationRespons
func (r *DynamicClientRegistrar) loadCredentials() (*ClientRegistrationResponse, error) {
filePath := r.credentialsFilePath()
// #nosec G304 -- path is constructed from trusted config values via credentialsFilePath()
data, err := os.ReadFile(filePath)
if err != nil {
if os.IsNotExist(err) {
+620
View File
@@ -0,0 +1,620 @@
package traefikoidc
import (
"context"
"encoding/base64"
"math/big"
"net/http"
"net/http/httptest"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/lukaszraczylo/traefikoidc/internal/testutil"
"github.com/stretchr/testify/suite"
"golang.org/x/time/rate"
)
// ClockSkewEdgeCasesSuite tests clock skew tolerance scenarios
type ClockSkewEdgeCasesSuite struct {
suite.Suite
fixture *testutil.TokenFixture
tOidc *TraefikOidc
}
func (s *ClockSkewEdgeCasesSuite) SetupSuite() {
var err error
s.fixture, err = testutil.NewTokenFixture()
s.Require().NoError(err)
}
func (s *ClockSkewEdgeCasesSuite) SetupTest() {
// Create JWK for the test key
jwk := JWK{
Kty: "RSA",
Kid: s.fixture.KeyID,
Alg: "RS256",
N: base64.RawURLEncoding.EncodeToString(s.fixture.RSAPublicKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString(bigIntToBytes(big.NewInt(int64(s.fixture.RSAPublicKey.E)))),
}
jwkCache := &MockJWKCache{
JWKS: &JWKSet{Keys: []JWK{jwk}},
Err: nil,
}
tokenBlacklist := NewCache()
tokenCacheInternal := NewCache()
tokenCache := &TokenCache{}
if tokenCache.cache == nil {
if wrapper, ok := tokenCacheInternal.(*CacheInterfaceWrapper); ok {
tokenCache.cache = wrapper.cache
}
}
logger := NewLogger("error") // Reduce noise
s.tOidc = &TraefikOidc{
issuerURL: s.fixture.Issuer,
clientID: s.fixture.Audience,
audience: s.fixture.Audience,
clientSecret: "test-client-secret",
roleClaimName: "roles",
groupClaimName: "groups",
userIdentifierClaim: "email",
jwkCache: jwkCache,
jwksURL: "https://test-jwks-url.com",
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
tokenBlacklist: tokenBlacklist,
tokenCache: tokenCache,
logger: logger,
httpClient: &http.Client{Timeout: 10 * time.Second},
extractClaimsFunc: extractClaims,
initComplete: make(chan struct{}),
goroutineWG: &sync.WaitGroup{},
ctx: context.Background(),
}
close(s.tOidc.initComplete)
s.tOidc.tokenVerifier = s.tOidc
s.tOidc.jwtVerifier = s.tOidc
s.T().Cleanup(func() {
if s.tOidc.tokenBlacklist != nil {
s.tOidc.tokenBlacklist.Close()
}
if s.tOidc.tokenCache != nil && s.tOidc.tokenCache.cache != nil {
s.tOidc.tokenCache.cache.Close()
}
})
}
func (s *ClockSkewEdgeCasesSuite) TestExactlyAtExpiry() {
token, err := s.fixture.TokenWithSkew(0)
s.Require().NoError(err)
// Token at exact expiry - behavior is implementation-defined
err = s.tOidc.VerifyToken(token)
s.T().Logf("Exact expiry result: %v", err)
}
func (s *ClockSkewEdgeCasesSuite) TestOneSecondBeforeExpiry() {
token, err := s.fixture.TokenWithSkew(1 * time.Second)
s.Require().NoError(err)
err = s.tOidc.VerifyToken(token)
s.NoError(err, "Token should be valid 1 second before expiry")
}
func (s *ClockSkewEdgeCasesSuite) TestOneSecondAfterExpiry() {
token, err := s.fixture.TokenWithSkew(-1 * time.Second)
s.Require().NoError(err)
err = s.tOidc.VerifyToken(token)
// With default 2-minute clock skew tolerance, 1 second past expiry should still be valid
s.NoError(err, "Token 1 second past expiry should be valid within clock skew tolerance")
}
func (s *ClockSkewEdgeCasesSuite) TestWithinSkewTolerance() {
// Most implementations allow 5-minute clock skew
token, err := s.fixture.TokenWithSkew(-4 * time.Minute)
s.Require().NoError(err)
err = s.tOidc.VerifyToken(token)
// May pass or fail depending on implementation
s.T().Logf("4-minute expired token result: %v", err)
}
func (s *ClockSkewEdgeCasesSuite) TestBeyondSkewTolerance() {
token, err := s.fixture.TokenWithSkew(-10 * time.Minute)
s.Require().NoError(err)
err = s.tOidc.VerifyToken(token)
s.Error(err, "Token should be invalid 10 minutes after expiry")
}
func TestClockSkewEdgeCasesSuite(t *testing.T) {
suite.Run(t, new(ClockSkewEdgeCasesSuite))
}
// UnicodeClaimsSuite tests Unicode handling in JWT claims
type UnicodeClaimsSuite struct {
suite.Suite
fixture *testutil.TokenFixture
tOidc *TraefikOidc
}
func (s *UnicodeClaimsSuite) SetupSuite() {
var err error
s.fixture, err = testutil.NewTokenFixture()
s.Require().NoError(err)
}
func (s *UnicodeClaimsSuite) SetupTest() {
jwk := JWK{
Kty: "RSA",
Kid: s.fixture.KeyID,
Alg: "RS256",
N: base64.RawURLEncoding.EncodeToString(s.fixture.RSAPublicKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString(bigIntToBytes(big.NewInt(int64(s.fixture.RSAPublicKey.E)))),
}
jwkCache := &MockJWKCache{
JWKS: &JWKSet{Keys: []JWK{jwk}},
Err: nil,
}
tokenBlacklist := NewCache()
tokenCacheInternal := NewCache()
tokenCache := &TokenCache{}
if tokenCache.cache == nil {
if wrapper, ok := tokenCacheInternal.(*CacheInterfaceWrapper); ok {
tokenCache.cache = wrapper.cache
}
}
logger := NewLogger("error")
s.tOidc = &TraefikOidc{
issuerURL: s.fixture.Issuer,
clientID: s.fixture.Audience,
audience: s.fixture.Audience,
clientSecret: "test-client-secret",
roleClaimName: "roles",
groupClaimName: "groups",
userIdentifierClaim: "email",
jwkCache: jwkCache,
jwksURL: "https://test-jwks-url.com",
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
tokenBlacklist: tokenBlacklist,
tokenCache: tokenCache,
logger: logger,
httpClient: &http.Client{Timeout: 10 * time.Second},
extractClaimsFunc: extractClaims,
initComplete: make(chan struct{}),
goroutineWG: &sync.WaitGroup{},
ctx: context.Background(),
}
close(s.tOidc.initComplete)
s.tOidc.tokenVerifier = s.tOidc
s.tOidc.jwtVerifier = s.tOidc
s.T().Cleanup(func() {
if s.tOidc.tokenBlacklist != nil {
s.tOidc.tokenBlacklist.Close()
}
if s.tOidc.tokenCache != nil && s.tOidc.tokenCache.cache != nil {
s.tOidc.tokenCache.cache.Close()
}
})
}
func (s *UnicodeClaimsSuite) TestUnicodeEmail() {
token, err := s.fixture.TokenWithEmail("用户@example.com")
s.Require().NoError(err)
err = s.tOidc.VerifyToken(token)
s.NoError(err, "Unicode email should be handled correctly")
}
func (s *UnicodeClaimsSuite) TestUnicodeName() {
token, err := s.fixture.TokenWithCustomClaims(map[string]interface{}{
"name": "田中太郎",
})
s.Require().NoError(err)
err = s.tOidc.VerifyToken(token)
s.NoError(err, "Unicode name should be handled correctly")
}
func (s *UnicodeClaimsSuite) TestEmojiInClaims() {
token, err := s.fixture.TokenWithCustomClaims(map[string]interface{}{
"name": "Test User 😀",
})
s.Require().NoError(err)
err = s.tOidc.VerifyToken(token)
s.NoError(err, "Emoji in claims should be handled correctly")
}
func (s *UnicodeClaimsSuite) TestRTLText() {
token, err := s.fixture.TokenWithCustomClaims(map[string]interface{}{
"name": "مستخدم اختبار",
})
s.Require().NoError(err)
err = s.tOidc.VerifyToken(token)
s.NoError(err, "RTL text should be handled correctly")
}
func (s *UnicodeClaimsSuite) TestMixedScripts() {
token, err := s.fixture.TokenWithCustomClaims(map[string]interface{}{
"name": "Test 测试 テスト",
"roles": []string{"admin", "管理者", "管理员"},
})
s.Require().NoError(err)
err = s.tOidc.VerifyToken(token)
s.NoError(err, "Mixed scripts should be handled correctly")
}
func TestUnicodeClaimsSuite(t *testing.T) {
suite.Run(t, new(UnicodeClaimsSuite))
}
// LargeClaimsSuite tests large claim values
type LargeClaimsSuite struct {
suite.Suite
fixture *testutil.TokenFixture
tOidc *TraefikOidc
}
func (s *LargeClaimsSuite) SetupSuite() {
var err error
s.fixture, err = testutil.NewTokenFixture()
s.Require().NoError(err)
}
func (s *LargeClaimsSuite) SetupTest() {
jwk := JWK{
Kty: "RSA",
Kid: s.fixture.KeyID,
Alg: "RS256",
N: base64.RawURLEncoding.EncodeToString(s.fixture.RSAPublicKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString(bigIntToBytes(big.NewInt(int64(s.fixture.RSAPublicKey.E)))),
}
jwkCache := &MockJWKCache{
JWKS: &JWKSet{Keys: []JWK{jwk}},
Err: nil,
}
tokenBlacklist := NewCache()
tokenCacheInternal := NewCache()
tokenCache := &TokenCache{}
if tokenCache.cache == nil {
if wrapper, ok := tokenCacheInternal.(*CacheInterfaceWrapper); ok {
tokenCache.cache = wrapper.cache
}
}
logger := NewLogger("error")
s.tOidc = &TraefikOidc{
issuerURL: s.fixture.Issuer,
clientID: s.fixture.Audience,
audience: s.fixture.Audience,
clientSecret: "test-client-secret",
roleClaimName: "roles",
groupClaimName: "groups",
userIdentifierClaim: "email",
jwkCache: jwkCache,
jwksURL: "https://test-jwks-url.com",
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
tokenBlacklist: tokenBlacklist,
tokenCache: tokenCache,
logger: logger,
httpClient: &http.Client{Timeout: 10 * time.Second},
extractClaimsFunc: extractClaims,
initComplete: make(chan struct{}),
goroutineWG: &sync.WaitGroup{},
ctx: context.Background(),
}
close(s.tOidc.initComplete)
s.tOidc.tokenVerifier = s.tOidc
s.tOidc.jwtVerifier = s.tOidc
s.T().Cleanup(func() {
if s.tOidc.tokenBlacklist != nil {
s.tOidc.tokenBlacklist.Close()
}
if s.tOidc.tokenCache != nil && s.tOidc.tokenCache.cache != nil {
s.tOidc.tokenCache.cache.Close()
}
})
}
func (s *LargeClaimsSuite) TestManyRoles() {
roles := make([]string, 100)
for i := 0; i < 100; i++ {
roles[i] = strings.Repeat("role", 10) + string(rune('A'+i%26))
}
token, err := s.fixture.TokenWithRoles(roles)
s.Require().NoError(err)
err = s.tOidc.VerifyToken(token)
s.NoError(err, "Token with 100 roles should be handled")
}
func (s *LargeClaimsSuite) TestManyGroups() {
groups := make([]string, 50)
for i := 0; i < 50; i++ {
groups[i] = strings.Repeat("group", 5) + string(rune('A'+i%26))
}
token, err := s.fixture.TokenWithGroups(groups)
s.Require().NoError(err)
err = s.tOidc.VerifyToken(token)
s.NoError(err, "Token with 50 groups should be handled")
}
func (s *LargeClaimsSuite) TestLongEmail() {
// RFC 5321 allows up to 254 characters
localPart := strings.Repeat("a", 64)
domain := strings.Repeat("b", 63) + ".com"
email := localPart + "@" + domain
token, err := s.fixture.TokenWithEmail(email)
s.Require().NoError(err)
err = s.tOidc.VerifyToken(token)
s.NoError(err, "Token with long email should be handled")
}
func (s *LargeClaimsSuite) TestLongSubject() {
longSub := strings.Repeat("subject", 100)
token, err := s.fixture.TokenWithCustomClaims(map[string]interface{}{
"sub": longSub,
})
s.Require().NoError(err)
err = s.tOidc.VerifyToken(token)
s.NoError(err, "Token with long subject should be handled")
}
func TestLargeClaimsSuite(t *testing.T) {
suite.Run(t, new(LargeClaimsSuite))
}
// URLPathEdgeCasesSuite tests URL handling edge cases
type URLPathEdgeCasesSuite struct {
suite.Suite
}
func (s *URLPathEdgeCasesSuite) TestVeryLongPath() {
longPath := "/" + strings.Repeat("segment/", 100)
req := httptest.NewRequest("GET", longPath, nil)
s.NotNil(req)
s.Contains(req.URL.Path, "segment")
}
func (s *URLPathEdgeCasesSuite) TestSpecialCharactersInPath() {
paths := []string{
"/path%20with%20spaces",
"/path/with/日本語",
"/path?query=value&another=test",
"/path#fragment",
"/path/../traversal",
"/path/./current",
}
for _, path := range paths {
s.Run(path, func() {
req := httptest.NewRequest("GET", path, nil)
s.NotNil(req)
})
}
}
func (s *URLPathEdgeCasesSuite) TestEmptyPath() {
req := httptest.NewRequest("GET", "/", nil)
s.Equal("/", req.URL.Path)
}
func (s *URLPathEdgeCasesSuite) TestDoubleSlashes() {
req := httptest.NewRequest("GET", "//double//slashes//", nil)
s.NotNil(req)
}
func TestURLPathEdgeCasesSuite(t *testing.T) {
suite.Run(t, new(URLPathEdgeCasesSuite))
}
// ConcurrencyEdgeCasesSuite tests concurrency scenarios
type ConcurrencyEdgeCasesSuite struct {
suite.Suite
fixture *testutil.TokenFixture
tOidc *TraefikOidc
}
func (s *ConcurrencyEdgeCasesSuite) SetupSuite() {
var err error
s.fixture, err = testutil.NewTokenFixture()
s.Require().NoError(err)
}
func (s *ConcurrencyEdgeCasesSuite) SetupTest() {
jwk := JWK{
Kty: "RSA",
Kid: s.fixture.KeyID,
Alg: "RS256",
N: base64.RawURLEncoding.EncodeToString(s.fixture.RSAPublicKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString(bigIntToBytes(big.NewInt(int64(s.fixture.RSAPublicKey.E)))),
}
jwkCache := &MockJWKCache{
JWKS: &JWKSet{Keys: []JWK{jwk}},
Err: nil,
}
tokenBlacklist := NewCache()
tokenCacheInternal := NewCache()
tokenCache := &TokenCache{}
if tokenCache.cache == nil {
if wrapper, ok := tokenCacheInternal.(*CacheInterfaceWrapper); ok {
tokenCache.cache = wrapper.cache
}
}
logger := NewLogger("error")
s.tOidc = &TraefikOidc{
issuerURL: s.fixture.Issuer,
clientID: s.fixture.Audience,
audience: s.fixture.Audience,
clientSecret: "test-client-secret",
roleClaimName: "roles",
groupClaimName: "groups",
userIdentifierClaim: "email",
jwkCache: jwkCache,
jwksURL: "https://test-jwks-url.com",
limiter: rate.NewLimiter(rate.Every(time.Second), 100), // Higher limit for concurrency tests
tokenBlacklist: tokenBlacklist,
tokenCache: tokenCache,
logger: logger,
httpClient: &http.Client{Timeout: 10 * time.Second},
extractClaimsFunc: extractClaims,
initComplete: make(chan struct{}),
goroutineWG: &sync.WaitGroup{},
ctx: context.Background(),
}
close(s.tOidc.initComplete)
s.tOidc.tokenVerifier = s.tOidc
s.tOidc.jwtVerifier = s.tOidc
s.T().Cleanup(func() {
if s.tOidc.tokenBlacklist != nil {
s.tOidc.tokenBlacklist.Close()
}
if s.tOidc.tokenCache != nil && s.tOidc.tokenCache.cache != nil {
s.tOidc.tokenCache.cache.Close()
}
})
}
func (s *ConcurrencyEdgeCasesSuite) TestConcurrentTokenValidation() {
token, err := s.fixture.ValidToken(nil)
s.Require().NoError(err)
const goroutines = 50
var wg sync.WaitGroup
errors := make(chan error, goroutines)
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
if err := s.tOidc.VerifyToken(token); err != nil {
errors <- err
}
}()
}
wg.Wait()
close(errors)
var errCount int
for err := range errors {
s.T().Logf("Concurrent error: %v", err)
errCount++
}
s.Equal(0, errCount, "All concurrent validations should succeed")
}
func (s *ConcurrencyEdgeCasesSuite) TestConcurrentDifferentTokens() {
const goroutines = 20
var wg sync.WaitGroup
errors := make(chan error, goroutines)
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
token, err := s.fixture.TokenWithCustomClaims(map[string]interface{}{
"custom": idx,
})
if err != nil {
errors <- err
return
}
if err := s.tOidc.VerifyToken(token); err != nil {
errors <- err
}
}(i)
}
wg.Wait()
close(errors)
var errCount int
for err := range errors {
s.T().Logf("Concurrent different token error: %v", err)
errCount++
}
s.Equal(0, errCount, "All concurrent different token validations should succeed")
}
func (s *ConcurrencyEdgeCasesSuite) TestConcurrentMixedValidInvalid() {
validToken, err := s.fixture.ValidToken(nil)
s.Require().NoError(err)
expiredToken, err := s.fixture.ExpiredToken()
s.Require().NoError(err)
const goroutines = 40
var wg sync.WaitGroup
validCount := int32(0)
expiredCount := int32(0)
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
var token string
if idx%2 == 0 {
token = validToken
} else {
token = expiredToken
}
err := s.tOidc.VerifyToken(token)
if idx%2 == 0 {
if err == nil {
atomic.AddInt32(&validCount, 1)
}
} else {
if err != nil {
atomic.AddInt32(&expiredCount, 1)
}
}
}(i)
}
wg.Wait()
s.T().Logf("Valid passed: %d, Expired rejected: %d", validCount, expiredCount)
}
func TestConcurrencyEdgeCasesSuite(t *testing.T) {
suite.Run(t, new(ConcurrencyEdgeCasesSuite))
}
+258
View File
@@ -0,0 +1,258 @@
package traefikoidc
import (
"context"
"errors"
"testing"
"time"
"github.com/stretchr/testify/suite"
)
// EnhancedMocksSuite demonstrates improved state-based mocks with call tracking
type EnhancedMocksSuite struct {
suite.Suite
}
func (s *EnhancedMocksSuite) TestEnhancedJWKCacheCallTracking() {
mock := &EnhancedMockJWKCache{
JWKS: &JWKSet{Keys: []JWK{{Kid: "test-key"}}},
}
// Make some calls
result, err := mock.GetJWKS(context.Background(), "https://example.com/jwks", nil)
s.NoError(err)
s.NotNil(result)
// Another call with different URL
_, _ = mock.GetJWKS(context.Background(), "https://other.com/jwks", nil)
// Verify calls were tracked
s.Equal(2, mock.GetJWKSCallCount())
mock.AssertGetJWKSCalled(s.T())
mock.AssertGetJWKSCalledWith(s.T(), "https://example.com/jwks")
mock.AssertGetJWKSCallCount(s.T(), 2)
}
func (s *EnhancedMocksSuite) TestEnhancedJWKCacheWithError() {
expectedErr := errors.New("network error")
mock := &EnhancedMockJWKCache{
Err: expectedErr,
}
result, err := mock.GetJWKS(context.Background(), "https://example.com/jwks", nil)
s.Nil(result)
s.Equal(expectedErr, err)
mock.AssertGetJWKSCalled(s.T())
}
func (s *EnhancedMocksSuite) TestEnhancedJWKCacheReset() {
mock := &EnhancedMockJWKCache{
JWKS: &JWKSet{Keys: []JWK{{Kid: "test-key"}}},
}
_, _ = mock.GetJWKS(context.Background(), "https://example.com/jwks", nil)
s.Equal(1, mock.GetJWKSCallCount())
mock.Reset()
s.Equal(0, mock.GetJWKSCallCount())
s.Nil(mock.JWKS)
}
func (s *EnhancedMocksSuite) TestEnhancedTokenVerifierCallTracking() {
mock := &EnhancedMockTokenVerifier{
Err: nil, // Valid tokens
}
// Verify a token
err := mock.VerifyToken("test-token-1")
s.NoError(err)
// Verify another token
err = mock.VerifyToken("test-token-2")
s.NoError(err)
// Check tracking
s.Equal(2, mock.GetVerifyTokenCallCount())
mock.AssertVerifyTokenCalled(s.T())
mock.AssertVerifyTokenCalledWith(s.T(), "test-token-1")
// Check last call
lastCall := mock.LastCall()
s.NotNil(lastCall)
s.Equal("test-token-2", lastCall.Token)
}
func (s *EnhancedMocksSuite) TestEnhancedTokenVerifierWithDynamicFunc() {
callCount := 0
mock := &EnhancedMockTokenVerifier{
VerifyFunc: func(token string) error {
callCount++
if token == "invalid" {
return errors.New("invalid token")
}
return nil
},
}
// Valid token
err := mock.VerifyToken("valid-token")
s.NoError(err)
// Invalid token
err = mock.VerifyToken("invalid")
s.Error(err)
s.Equal(2, callCount)
s.Equal(2, mock.GetVerifyTokenCallCount())
}
func (s *EnhancedMocksSuite) TestEnhancedTokenExchangerCallTracking() {
mock := &EnhancedMockTokenExchanger{
ExchangeResponse: &TokenResponse{
AccessToken: "access-token",
RefreshToken: "refresh-token",
ExpiresIn: 3600,
},
RefreshResponse: &TokenResponse{
AccessToken: "new-access-token",
ExpiresIn: 3600,
},
}
// Exchange code
resp, err := mock.ExchangeCodeForToken(context.Background(), "authorization_code", "auth-code", "https://redirect.com", "verifier")
s.NoError(err)
s.Equal("access-token", resp.AccessToken)
// Refresh token
resp, err = mock.GetNewTokenWithRefreshToken("refresh-token")
s.NoError(err)
s.Equal("new-access-token", resp.AccessToken)
// Revoke token
err = mock.RevokeTokenWithProvider("access-token", "access_token")
s.NoError(err)
// Check tracking
mock.AssertExchangeCalled(s.T())
mock.AssertExchangeCalledWith(s.T(), "authorization_code")
mock.AssertRefreshCalled(s.T())
mock.AssertRevokeCalled(s.T())
s.Equal(1, mock.GetExchangeCallCount())
s.Equal(1, mock.GetRefreshCallCount())
s.Equal(1, mock.GetRevokeCallCount())
// Check last exchange call details
lastExchange := mock.LastExchangeCall()
s.NotNil(lastExchange)
s.Equal("authorization_code", lastExchange.GrantType)
s.Equal("auth-code", lastExchange.CodeOrToken)
s.Equal("https://redirect.com", lastExchange.RedirectURL)
}
func (s *EnhancedMocksSuite) TestEnhancedTokenExchangerWithErrors() {
mock := &EnhancedMockTokenExchanger{
ExchangeErr: errors.New("invalid_grant"),
RefreshErr: errors.New("refresh_expired"),
RevokeErr: errors.New("revoke_failed"),
}
_, err := mock.ExchangeCodeForToken(context.Background(), "authorization_code", "code", "", "")
s.Error(err)
s.Contains(err.Error(), "invalid_grant")
_, err = mock.GetNewTokenWithRefreshToken("token")
s.Error(err)
s.Contains(err.Error(), "refresh_expired")
err = mock.RevokeTokenWithProvider("token", "access_token")
s.Error(err)
s.Contains(err.Error(), "revoke_failed")
}
func (s *EnhancedMocksSuite) TestEnhancedCacheCallTracking() {
mock := NewEnhancedMockCache()
// Set some values
mock.Set("key1", "value1", 5*time.Minute)
mock.Set("key2", "value2", 10*time.Minute)
// Get values
val, found := mock.Get("key1")
s.True(found)
s.Equal("value1", val)
_, found = mock.Get("nonexistent")
s.False(found)
// Delete
mock.Delete("key1")
// Verify tracking
mock.AssertSetCalled(s.T(), "key1")
mock.AssertSetCalled(s.T(), "key2")
mock.AssertGetCalled(s.T(), "key1")
mock.AssertGetCalled(s.T(), "nonexistent")
mock.AssertDeleteCalled(s.T(), "key1")
s.Equal(2, mock.SetCallCount())
s.Equal(2, mock.GetCallCount())
}
func (s *EnhancedMocksSuite) TestEnhancedCacheActualStorage() {
mock := NewEnhancedMockCache()
// The enhanced mock actually stores data
mock.Set("key", "value", time.Hour)
s.Equal(1, mock.Size())
val, found := mock.Get("key")
s.True(found)
s.Equal("value", val)
mock.Delete("key")
s.Equal(0, mock.Size())
_, found = mock.Get("key")
s.False(found)
}
func (s *EnhancedMocksSuite) TestEnhancedCacheClear() {
mock := NewEnhancedMockCache()
mock.Set("key1", "value1", time.Hour)
mock.Set("key2", "value2", time.Hour)
s.Equal(2, mock.Size())
mock.Clear()
s.Equal(0, mock.Size())
}
func (s *EnhancedMocksSuite) TestConcurrentAccess() {
mock := &EnhancedMockJWKCache{
JWKS: &JWKSet{Keys: []JWK{{Kid: "test-key"}}},
}
// Concurrent calls should be safe
done := make(chan bool)
for i := 0; i < 10; i++ {
go func() {
_, _ = mock.GetJWKS(context.Background(), "https://example.com/jwks", nil)
done <- true
}()
}
for i := 0; i < 10; i++ {
<-done
}
s.Equal(10, mock.GetJWKSCallCount())
}
func TestEnhancedMocksSuite(t *testing.T) {
suite.Run(t, new(EnhancedMocksSuite))
}
+595
View File
@@ -0,0 +1,595 @@
package traefikoidc
import (
"context"
"net/http"
"sync"
"sync/atomic"
"time"
"github.com/stretchr/testify/assert"
)
// EnhancedMockJWKCache is an improved state-based mock with call tracking
type EnhancedMockJWKCache struct {
mu sync.RWMutex
// State (what to return)
JWKS *JWKSet
Err error
// Call tracking
GetJWKSCalls []JWKSCall
CleanupCalls int32
CloseCalls int32
getJWKSCallsMu sync.Mutex
}
// JWKSCall records parameters from a GetJWKS call
type JWKSCall struct {
URL string
Timestamp time.Time
}
func (m *EnhancedMockJWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error) {
m.getJWKSCallsMu.Lock()
m.GetJWKSCalls = append(m.GetJWKSCalls, JWKSCall{
URL: jwksURL,
Timestamp: time.Now(),
})
m.getJWKSCallsMu.Unlock()
m.mu.RLock()
defer m.mu.RUnlock()
return m.JWKS, m.Err
}
func (m *EnhancedMockJWKCache) Cleanup() {
atomic.AddInt32(&m.CleanupCalls, 1)
m.mu.Lock()
defer m.mu.Unlock()
m.JWKS = nil
m.Err = nil
}
func (m *EnhancedMockJWKCache) Close() {
atomic.AddInt32(&m.CloseCalls, 1)
}
// Assertion helpers
// AssertGetJWKSCalled verifies GetJWKS was called
func (m *EnhancedMockJWKCache) AssertGetJWKSCalled(t assert.TestingT) bool {
m.getJWKSCallsMu.Lock()
defer m.getJWKSCallsMu.Unlock()
return assert.NotEmpty(t, m.GetJWKSCalls, "GetJWKS should have been called")
}
// AssertGetJWKSCalledWith verifies GetJWKS was called with specific URL
func (m *EnhancedMockJWKCache) AssertGetJWKSCalledWith(t assert.TestingT, expectedURL string) bool {
m.getJWKSCallsMu.Lock()
defer m.getJWKSCallsMu.Unlock()
for _, call := range m.GetJWKSCalls {
if call.URL == expectedURL {
return true
}
}
return assert.Fail(t, "GetJWKS was not called with URL: "+expectedURL)
}
// AssertGetJWKSCallCount verifies the number of GetJWKS calls
func (m *EnhancedMockJWKCache) AssertGetJWKSCallCount(t assert.TestingT, expected int) bool {
m.getJWKSCallsMu.Lock()
defer m.getJWKSCallsMu.Unlock()
return assert.Equal(t, expected, len(m.GetJWKSCalls), "GetJWKS call count mismatch")
}
// GetJWKSCallCount returns the number of GetJWKS calls
func (m *EnhancedMockJWKCache) GetJWKSCallCount() int {
m.getJWKSCallsMu.Lock()
defer m.getJWKSCallsMu.Unlock()
return len(m.GetJWKSCalls)
}
// Reset clears all state and call tracking
func (m *EnhancedMockJWKCache) Reset() {
m.mu.Lock()
m.JWKS = nil
m.Err = nil
m.mu.Unlock()
m.getJWKSCallsMu.Lock()
m.GetJWKSCalls = nil
m.getJWKSCallsMu.Unlock()
atomic.StoreInt32(&m.CleanupCalls, 0)
atomic.StoreInt32(&m.CloseCalls, 0)
}
// EnhancedMockTokenVerifier is an improved state-based mock with call tracking
type EnhancedMockTokenVerifier struct {
mu sync.RWMutex
// State (what to return) - can be a fixed error or a function
Err error
VerifyFunc func(token string) error
// Call tracking
VerifyCalls []TokenVerifyCall
verifyCallsMu sync.Mutex
}
// TokenVerifyCall records parameters from a VerifyToken call
type TokenVerifyCall struct {
Token string
Timestamp time.Time
Result error
}
func (m *EnhancedMockTokenVerifier) VerifyToken(token string) error {
var result error
m.mu.RLock()
if m.VerifyFunc != nil {
result = m.VerifyFunc(token)
} else {
result = m.Err
}
m.mu.RUnlock()
m.verifyCallsMu.Lock()
m.VerifyCalls = append(m.VerifyCalls, TokenVerifyCall{
Token: token,
Timestamp: time.Now(),
Result: result,
})
m.verifyCallsMu.Unlock()
return result
}
// Assertion helpers
// AssertVerifyTokenCalled verifies VerifyToken was called
func (m *EnhancedMockTokenVerifier) AssertVerifyTokenCalled(t assert.TestingT) bool {
m.verifyCallsMu.Lock()
defer m.verifyCallsMu.Unlock()
return assert.NotEmpty(t, m.VerifyCalls, "VerifyToken should have been called")
}
// AssertVerifyTokenCalledWith verifies VerifyToken was called with specific token
func (m *EnhancedMockTokenVerifier) AssertVerifyTokenCalledWith(t assert.TestingT, expectedToken string) bool {
m.verifyCallsMu.Lock()
defer m.verifyCallsMu.Unlock()
for _, call := range m.VerifyCalls {
if call.Token == expectedToken {
return true
}
}
return assert.Fail(t, "VerifyToken was not called with expected token")
}
// AssertVerifyTokenCallCount verifies the number of VerifyToken calls
func (m *EnhancedMockTokenVerifier) AssertVerifyTokenCallCount(t assert.TestingT, expected int) bool {
m.verifyCallsMu.Lock()
defer m.verifyCallsMu.Unlock()
return assert.Equal(t, expected, len(m.VerifyCalls), "VerifyToken call count mismatch")
}
// GetVerifyTokenCallCount returns the number of VerifyToken calls
func (m *EnhancedMockTokenVerifier) GetVerifyTokenCallCount() int {
m.verifyCallsMu.Lock()
defer m.verifyCallsMu.Unlock()
return len(m.VerifyCalls)
}
// LastCall returns the most recent VerifyToken call
func (m *EnhancedMockTokenVerifier) LastCall() *TokenVerifyCall {
m.verifyCallsMu.Lock()
defer m.verifyCallsMu.Unlock()
if len(m.VerifyCalls) == 0 {
return nil
}
return &m.VerifyCalls[len(m.VerifyCalls)-1]
}
// Reset clears all state and call tracking
func (m *EnhancedMockTokenVerifier) Reset() {
m.mu.Lock()
m.Err = nil
m.VerifyFunc = nil
m.mu.Unlock()
m.verifyCallsMu.Lock()
m.VerifyCalls = nil
m.verifyCallsMu.Unlock()
}
// EnhancedMockTokenExchanger is an improved state-based mock with call tracking
type EnhancedMockTokenExchanger struct {
mu sync.RWMutex
// State (what to return)
ExchangeResponse *TokenResponse
ExchangeErr error
RefreshResponse *TokenResponse
RefreshErr error
RevokeErr error
// Optional functions for dynamic behavior
ExchangeCodeFunc func(ctx context.Context, grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error)
RefreshTokenFunc func(refreshToken string) (*TokenResponse, error)
RevokeTokenFunc func(token, tokenType string) error
// Call tracking
ExchangeCalls []ExchangeCall
RefreshCalls []RefreshCall
RevokeCalls []RevokeCall
exchangeCallsMu sync.Mutex
refreshCallsMu sync.Mutex
revokeCallsMu sync.Mutex
}
// ExchangeCall records parameters from an ExchangeCodeForToken call
type ExchangeCall struct {
GrantType string
CodeOrToken string
RedirectURL string
CodeVerifier string
Timestamp time.Time
}
// RefreshCall records parameters from a GetNewTokenWithRefreshToken call
type RefreshCall struct {
RefreshToken string
Timestamp time.Time
}
// RevokeCall records parameters from a RevokeTokenWithProvider call
type RevokeCall struct {
Token string
TokenType string
Timestamp time.Time
}
func (m *EnhancedMockTokenExchanger) ExchangeCodeForToken(ctx context.Context, grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error) {
m.exchangeCallsMu.Lock()
m.ExchangeCalls = append(m.ExchangeCalls, ExchangeCall{
GrantType: grantType,
CodeOrToken: codeOrToken,
RedirectURL: redirectURL,
CodeVerifier: codeVerifier,
Timestamp: time.Now(),
})
m.exchangeCallsMu.Unlock()
m.mu.RLock()
defer m.mu.RUnlock()
if m.ExchangeCodeFunc != nil {
return m.ExchangeCodeFunc(ctx, grantType, codeOrToken, redirectURL, codeVerifier)
}
return m.ExchangeResponse, m.ExchangeErr
}
func (m *EnhancedMockTokenExchanger) GetNewTokenWithRefreshToken(refreshToken string) (*TokenResponse, error) {
m.refreshCallsMu.Lock()
m.RefreshCalls = append(m.RefreshCalls, RefreshCall{
RefreshToken: refreshToken,
Timestamp: time.Now(),
})
m.refreshCallsMu.Unlock()
m.mu.RLock()
defer m.mu.RUnlock()
if m.RefreshTokenFunc != nil {
return m.RefreshTokenFunc(refreshToken)
}
return m.RefreshResponse, m.RefreshErr
}
func (m *EnhancedMockTokenExchanger) RevokeTokenWithProvider(token, tokenType string) error {
m.revokeCallsMu.Lock()
m.RevokeCalls = append(m.RevokeCalls, RevokeCall{
Token: token,
TokenType: tokenType,
Timestamp: time.Now(),
})
m.revokeCallsMu.Unlock()
m.mu.RLock()
defer m.mu.RUnlock()
if m.RevokeTokenFunc != nil {
return m.RevokeTokenFunc(token, tokenType)
}
return m.RevokeErr
}
// Assertion helpers
// AssertExchangeCalled verifies ExchangeCodeForToken was called
func (m *EnhancedMockTokenExchanger) AssertExchangeCalled(t assert.TestingT) bool {
m.exchangeCallsMu.Lock()
defer m.exchangeCallsMu.Unlock()
return assert.NotEmpty(t, m.ExchangeCalls, "ExchangeCodeForToken should have been called")
}
// AssertExchangeCalledWith verifies ExchangeCodeForToken was called with specific grant type
func (m *EnhancedMockTokenExchanger) AssertExchangeCalledWith(t assert.TestingT, grantType string) bool {
m.exchangeCallsMu.Lock()
defer m.exchangeCallsMu.Unlock()
for _, call := range m.ExchangeCalls {
if call.GrantType == grantType {
return true
}
}
return assert.Fail(t, "ExchangeCodeForToken was not called with grant type: "+grantType)
}
// AssertRefreshCalled verifies GetNewTokenWithRefreshToken was called
func (m *EnhancedMockTokenExchanger) AssertRefreshCalled(t assert.TestingT) bool {
m.refreshCallsMu.Lock()
defer m.refreshCallsMu.Unlock()
return assert.NotEmpty(t, m.RefreshCalls, "GetNewTokenWithRefreshToken should have been called")
}
// AssertRevokeCalled verifies RevokeTokenWithProvider was called
func (m *EnhancedMockTokenExchanger) AssertRevokeCalled(t assert.TestingT) bool {
m.revokeCallsMu.Lock()
defer m.revokeCallsMu.Unlock()
return assert.NotEmpty(t, m.RevokeCalls, "RevokeTokenWithProvider should have been called")
}
// GetExchangeCallCount returns the number of ExchangeCodeForToken calls
func (m *EnhancedMockTokenExchanger) GetExchangeCallCount() int {
m.exchangeCallsMu.Lock()
defer m.exchangeCallsMu.Unlock()
return len(m.ExchangeCalls)
}
// GetRefreshCallCount returns the number of GetNewTokenWithRefreshToken calls
func (m *EnhancedMockTokenExchanger) GetRefreshCallCount() int {
m.refreshCallsMu.Lock()
defer m.refreshCallsMu.Unlock()
return len(m.RefreshCalls)
}
// GetRevokeCallCount returns the number of RevokeTokenWithProvider calls
func (m *EnhancedMockTokenExchanger) GetRevokeCallCount() int {
m.revokeCallsMu.Lock()
defer m.revokeCallsMu.Unlock()
return len(m.RevokeCalls)
}
// LastExchangeCall returns the most recent ExchangeCodeForToken call
func (m *EnhancedMockTokenExchanger) LastExchangeCall() *ExchangeCall {
m.exchangeCallsMu.Lock()
defer m.exchangeCallsMu.Unlock()
if len(m.ExchangeCalls) == 0 {
return nil
}
return &m.ExchangeCalls[len(m.ExchangeCalls)-1]
}
// Reset clears all state and call tracking
func (m *EnhancedMockTokenExchanger) Reset() {
m.mu.Lock()
m.ExchangeResponse = nil
m.ExchangeErr = nil
m.RefreshResponse = nil
m.RefreshErr = nil
m.RevokeErr = nil
m.ExchangeCodeFunc = nil
m.RefreshTokenFunc = nil
m.RevokeTokenFunc = nil
m.mu.Unlock()
m.exchangeCallsMu.Lock()
m.ExchangeCalls = nil
m.exchangeCallsMu.Unlock()
m.refreshCallsMu.Lock()
m.RefreshCalls = nil
m.refreshCallsMu.Unlock()
m.revokeCallsMu.Lock()
m.RevokeCalls = nil
m.revokeCallsMu.Unlock()
}
// EnhancedMockCacheInterface is an improved state-based mock for CacheInterface
type EnhancedMockCacheInterface struct {
mu sync.RWMutex
// Internal storage
data map[string]cacheEntry
maxSize int
// Call tracking
GetCalls []CacheGetCall
SetCalls []CacheSetCall
DeleteCalls []string
getCalls sync.Mutex
setCalls sync.Mutex
deleteCalls sync.Mutex
}
type cacheEntry struct {
value any
ttl time.Duration
}
// CacheGetCall records parameters from a Get call
type CacheGetCall struct {
Key string
Found bool
Timestamp time.Time
}
// CacheSetCall records parameters from a Set call
type CacheSetCall struct {
Key string
Value any
TTL time.Duration
Timestamp time.Time
}
// NewEnhancedMockCache creates a new enhanced cache mock
func NewEnhancedMockCache() *EnhancedMockCacheInterface {
return &EnhancedMockCacheInterface{
data: make(map[string]cacheEntry),
maxSize: 1000,
}
}
func (m *EnhancedMockCacheInterface) Set(key string, value any, ttl time.Duration) {
m.setCalls.Lock()
m.SetCalls = append(m.SetCalls, CacheSetCall{
Key: key,
Value: value,
TTL: ttl,
Timestamp: time.Now(),
})
m.setCalls.Unlock()
m.mu.Lock()
m.data[key] = cacheEntry{value: value, ttl: ttl}
m.mu.Unlock()
}
func (m *EnhancedMockCacheInterface) Get(key string) (any, bool) {
m.mu.RLock()
entry, found := m.data[key]
m.mu.RUnlock()
m.getCalls.Lock()
m.GetCalls = append(m.GetCalls, CacheGetCall{
Key: key,
Found: found,
Timestamp: time.Now(),
})
m.getCalls.Unlock()
if found {
return entry.value, true
}
return nil, false
}
func (m *EnhancedMockCacheInterface) Delete(key string) {
m.deleteCalls.Lock()
m.DeleteCalls = append(m.DeleteCalls, key)
m.deleteCalls.Unlock()
m.mu.Lock()
delete(m.data, key)
m.mu.Unlock()
}
func (m *EnhancedMockCacheInterface) SetMaxSize(size int) {
m.mu.Lock()
m.maxSize = size
m.mu.Unlock()
}
func (m *EnhancedMockCacheInterface) Size() int {
m.mu.RLock()
defer m.mu.RUnlock()
return len(m.data)
}
func (m *EnhancedMockCacheInterface) Clear() {
m.mu.Lock()
m.data = make(map[string]cacheEntry)
m.mu.Unlock()
}
func (m *EnhancedMockCacheInterface) Cleanup() {
// No-op for mock
}
func (m *EnhancedMockCacheInterface) Close() {
// No-op for mock
}
func (m *EnhancedMockCacheInterface) GetStats() map[string]any {
m.mu.RLock()
defer m.mu.RUnlock()
return map[string]any{
"size": len(m.data),
"max_size": m.maxSize,
}
}
// Assertion helpers
// AssertGetCalled verifies Get was called with specific key
func (m *EnhancedMockCacheInterface) AssertGetCalled(t assert.TestingT, key string) bool {
m.getCalls.Lock()
defer m.getCalls.Unlock()
for _, call := range m.GetCalls {
if call.Key == key {
return true
}
}
return assert.Fail(t, "Get was not called with key: "+key)
}
// AssertSetCalled verifies Set was called with specific key
func (m *EnhancedMockCacheInterface) AssertSetCalled(t assert.TestingT, key string) bool {
m.setCalls.Lock()
defer m.setCalls.Unlock()
for _, call := range m.SetCalls {
if call.Key == key {
return true
}
}
return assert.Fail(t, "Set was not called with key: "+key)
}
// AssertDeleteCalled verifies Delete was called with specific key
func (m *EnhancedMockCacheInterface) AssertDeleteCalled(t assert.TestingT, key string) bool {
m.deleteCalls.Lock()
defer m.deleteCalls.Unlock()
for _, k := range m.DeleteCalls {
if k == key {
return true
}
}
return assert.Fail(t, "Delete was not called with key: "+key)
}
// GetCallCount returns the number of Get calls
func (m *EnhancedMockCacheInterface) GetCallCount() int {
m.getCalls.Lock()
defer m.getCalls.Unlock()
return len(m.GetCalls)
}
// SetCallCount returns the number of Set calls
func (m *EnhancedMockCacheInterface) SetCallCount() int {
m.setCalls.Lock()
defer m.setCalls.Unlock()
return len(m.SetCalls)
}
// Reset clears all state and call tracking
func (m *EnhancedMockCacheInterface) Reset() {
m.mu.Lock()
m.data = make(map[string]cacheEntry)
m.mu.Unlock()
m.getCalls.Lock()
m.GetCalls = nil
m.getCalls.Unlock()
m.setCalls.Lock()
m.SetCalls = nil
m.setCalls.Unlock()
m.deleteCalls.Lock()
m.DeleteCalls = nil
m.deleteCalls.Unlock()
}
+131
View File
@@ -2,10 +2,14 @@ package traefikoidc
import (
"context"
"crypto/x509"
"errors"
"fmt"
"io"
"math"
"math/rand/v2"
"net"
"strings"
"sync"
"sync/atomic"
"time"
@@ -411,6 +415,31 @@ func DefaultRetryConfig() RetryConfig {
}
}
// MetadataFetchRetryConfig returns retry configuration optimized for OIDC metadata
// fetching during startup. Uses more aggressive retry settings to handle the race
// condition where Traefik initializes the plugin before routes are fully established,
// or before TLS certificates are properly loaded.
// See: https://github.com/lukaszraczylo/traefikoidc/issues/90
func MetadataFetchRetryConfig() RetryConfig {
return RetryConfig{
MaxAttempts: 10, // More attempts for startup scenarios
InitialDelay: 1 * time.Second, // 1 second between attempts as suggested
MaxDelay: 10 * time.Second, // Cap at 10 seconds
BackoffFactor: 1.5, // Gentler backoff for startup
EnableJitter: true, // Prevent thundering herd
RetryableErrors: []string{
"connection refused",
"timeout",
"temporary failure",
"network unreachable",
"EOF",
"certificate",
"x509",
"tls",
},
}
}
// RetryExecutor implements retry logic with exponential backoff and jitter.
// It automatically retries failed operations based on configurable error patterns
// and uses exponential backoff to avoid overwhelming failing services.
@@ -487,11 +516,29 @@ func (re *RetryExecutor) Execute(ctx context.Context, fn func() error) error {
// isRetryableError checks if an error should trigger a retry
// isRetryableError determines if an error should trigger a retry attempt.
// Checks error message against configured retryable error patterns.
// Also handles startup-specific errors like Traefik default certificate errors
// and EOF errors that occur during service initialization.
func (re *RetryExecutor) isRetryableError(err error) bool {
if err == nil {
return false
}
// Check for Traefik default certificate error (startup race condition)
// See: https://github.com/lukaszraczylo/traefikoidc/issues/90
if isTraefikDefaultCertError(err) {
return true
}
// Check for EOF errors (common during startup when services aren't ready)
if isEOFError(err) {
return true
}
// Check for certificate errors (transient during startup)
if isCertificateError(err) {
return true
}
errStr := err.Error()
for _, retryableErr := range re.config.RetryableErrors {
@@ -538,6 +585,7 @@ func (re *RetryExecutor) calculateDelay(attempt int) time.Duration {
delay = float64(re.config.MaxDelay)
}
// #nosec G404 -- math/rand is acceptable for jitter timing, not security-sensitive
if re.config.EnableJitter {
jitter := delay * 0.1 * (2.0*rand.Float64() - 1.0)
delay += jitter
@@ -1087,3 +1135,86 @@ func containsSubstring(s, substr string) bool {
}
return false
}
// isTraefikDefaultCertError detects when Traefik is serving its default self-signed
// certificate during cold-start, before the real certificates are loaded.
// This manifests as an x509.HostnameError where one of the certificate's DNS names
// ends with "traefik.default" (the default Traefik certificate pattern).
// See: https://github.com/lukaszraczylo/traefikoidc/issues/90
func isTraefikDefaultCertError(err error) bool {
if err == nil {
return false
}
var hostnameErr x509.HostnameError
if errors.As(err, &hostnameErr) {
if hostnameErr.Certificate != nil {
for _, name := range hostnameErr.Certificate.DNSNames {
if strings.HasSuffix(name, "traefik.default") {
return true
}
}
}
}
return false
}
// isEOFError checks if an error is an EOF error, which can occur during
// connection establishment when the remote end closes unexpectedly.
// This is common during service startup when endpoints aren't fully ready.
func isEOFError(err error) bool {
if err == nil {
return false
}
// Check for direct EOF
if errors.Is(err, io.EOF) {
return true
}
// Check for unexpected EOF
if errors.Is(err, io.ErrUnexpectedEOF) {
return true
}
// Check error message for EOF patterns (wrapped errors)
errStr := err.Error()
return strings.Contains(errStr, "EOF") || strings.Contains(errStr, "unexpected EOF")
}
// isCertificateError checks if an error is related to TLS certificate validation.
// These errors are often transient during startup when services are still initializing.
func isCertificateError(err error) bool {
if err == nil {
return false
}
// Check for x509 certificate errors
var certInvalidErr x509.CertificateInvalidError
var hostnameErr x509.HostnameError
var unknownAuthErr x509.UnknownAuthorityError
if errors.As(err, &certInvalidErr) ||
errors.As(err, &hostnameErr) ||
errors.As(err, &unknownAuthErr) {
return true
}
// Check error message for certificate patterns
errStr := strings.ToLower(err.Error())
certPatterns := []string{
"certificate",
"x509",
"tls",
"ssl",
}
for _, pattern := range certPatterns {
if strings.Contains(errStr, pattern) {
return true
}
}
return false
}
-242
View File
@@ -1,242 +0,0 @@
package traefikoidc
import (
"testing"
"time"
)
// TestDefaultCircuitBreakerConfig tests the default configuration function
func TestDefaultCircuitBreakerConfig(t *testing.T) {
config := DefaultCircuitBreakerConfig()
// Test default values
if config.MaxFailures != 2 {
t.Errorf("Expected MaxFailures 2, got %d", config.MaxFailures)
}
if config.Timeout != 60*time.Second {
t.Errorf("Expected Timeout 60s, got %v", config.Timeout)
}
if config.ResetTimeout != 30*time.Second {
t.Errorf("Expected ResetTimeout 30s, got %v", config.ResetTimeout)
}
}
// TestBaseRecoveryMechanism_GetBaseMetrics tests getting base metrics
func TestBaseRecoveryMechanism_GetBaseMetrics(t *testing.T) {
logger := GetSingletonNoOpLogger()
base := NewBaseRecoveryMechanism("test-mechanism", logger)
metrics := base.GetBaseMetrics()
if metrics == nil {
t.Fatal("Expected non-nil metrics")
}
// Check expected metric fields
expectedFields := []string{
"total_requests",
"total_failures",
"total_successes",
"uptime_seconds",
"name",
}
for _, field := range expectedFields {
if _, exists := metrics[field]; !exists {
t.Errorf("Expected metric field %s to exist", field)
}
}
}
// TestBaseRecoveryMechanism_RecordRequest tests request recording
func TestBaseRecoveryMechanism_RecordRequest(t *testing.T) {
logger := GetSingletonNoOpLogger()
base := NewBaseRecoveryMechanism("test-mechanism", logger)
// Record some requests
base.RecordRequest()
base.RecordRequest()
base.RecordRequest()
// Get metrics to verify
metrics := base.GetBaseMetrics()
totalRequests := metrics["total_requests"].(int64)
if totalRequests != 3 {
t.Errorf("Expected 3 total requests, got %d", totalRequests)
}
}
// TestBaseRecoveryMechanism_RecordSuccess tests success recording
func TestBaseRecoveryMechanism_RecordSuccess(t *testing.T) {
logger := GetSingletonNoOpLogger()
base := NewBaseRecoveryMechanism("test-mechanism", logger)
// Record some successes
base.RecordSuccess()
base.RecordSuccess()
// Get metrics to verify
metrics := base.GetBaseMetrics()
totalSuccesses := metrics["total_successes"].(int64)
if totalSuccesses != 2 {
t.Errorf("Expected 2 successful requests, got %d", totalSuccesses)
}
}
// TestBaseRecoveryMechanism_RecordFailure tests failure recording
func TestBaseRecoveryMechanism_RecordFailure(t *testing.T) {
logger := GetSingletonNoOpLogger()
base := NewBaseRecoveryMechanism("test-mechanism", logger)
// Record some failures
base.RecordFailure()
base.RecordFailure()
base.RecordFailure()
// Get metrics to verify
metrics := base.GetBaseMetrics()
totalFailures := metrics["total_failures"].(int64)
if totalFailures != 3 {
t.Errorf("Expected 3 failed requests, got %d", totalFailures)
}
}
// TestBaseRecoveryMechanism_LogInfo tests info logging
func TestBaseRecoveryMechanism_LogInfo(t *testing.T) {
logger := GetSingletonNoOpLogger()
base := NewBaseRecoveryMechanism("test-mechanism", logger)
// Test logging doesn't panic
base.LogInfo("test message")
base.LogInfo("test message with args: %s %d", "arg1", 42)
// Test with nil logger
baseNoLogger := NewBaseRecoveryMechanism("test", nil)
baseNoLogger.LogInfo("test message") // Should not panic
}
// TestBaseRecoveryMechanism_LogError tests error logging
func TestBaseRecoveryMechanism_LogError(t *testing.T) {
logger := GetSingletonNoOpLogger()
base := NewBaseRecoveryMechanism("test-mechanism", logger)
// Test logging doesn't panic
base.LogError("error message")
base.LogError("error message with args: %s %d", "error", 500)
// Test with nil logger
baseNoLogger := NewBaseRecoveryMechanism("test", nil)
baseNoLogger.LogError("error message") // Should not panic
}
// TestBaseRecoveryMechanism_LogDebug tests debug logging
func TestBaseRecoveryMechanism_LogDebug(t *testing.T) {
logger := GetSingletonNoOpLogger()
base := NewBaseRecoveryMechanism("test-mechanism", logger)
// Test logging doesn't panic
base.LogDebug("debug message")
base.LogDebug("debug message with args: %s %d", "debug", 123)
// Test with nil logger
baseNoLogger := NewBaseRecoveryMechanism("test", nil)
baseNoLogger.LogDebug("debug message") // Should not panic
}
// TestCircuitBreaker_GetState tests getting circuit breaker state
func TestCircuitBreaker_GetState(t *testing.T) {
config := DefaultCircuitBreakerConfig()
logger := GetSingletonNoOpLogger()
cb := NewCircuitBreaker(config, logger)
// Initial state should be closed
state := cb.GetState()
if state != CircuitBreakerClosed {
t.Errorf("Expected initial state to be closed, got %d", state)
}
}
// TestCircuitBreaker_Reset tests resetting circuit breaker
func TestCircuitBreaker_Reset(t *testing.T) {
config := DefaultCircuitBreakerConfig()
logger := GetSingletonNoOpLogger()
cb := NewCircuitBreaker(config, logger)
// Reset should not panic
cb.Reset()
// State should be closed after reset
state := cb.GetState()
if state != CircuitBreakerClosed {
t.Errorf("Expected state to be closed after reset, got %d", state)
}
}
// TestCircuitBreaker_IsAvailable tests availability check
func TestCircuitBreaker_IsAvailable(t *testing.T) {
config := DefaultCircuitBreakerConfig()
logger := GetSingletonNoOpLogger()
cb := NewCircuitBreaker(config, logger)
// Initially should be available
available := cb.IsAvailable()
if !available {
t.Error("Expected circuit breaker to be available initially")
}
}
// TestCircuitBreaker_GetMetrics tests getting circuit breaker metrics
func TestCircuitBreaker_GetMetrics(t *testing.T) {
config := DefaultCircuitBreakerConfig()
logger := GetSingletonNoOpLogger()
cb := NewCircuitBreaker(config, logger)
metrics := cb.GetMetrics()
if metrics == nil {
t.Fatal("Expected non-nil metrics")
}
// Should include base metrics
if _, exists := metrics["total_requests"]; !exists {
t.Error("Expected total_requests in metrics")
}
// Should include circuit breaker specific metrics
if _, exists := metrics["state"]; !exists {
t.Error("Expected state in metrics")
}
}
// Retry mechanism tests removed due to complex dependencies
// Benchmark tests
func BenchmarkDefaultCircuitBreakerConfig(b *testing.B) {
for i := 0; i < b.N; i++ {
DefaultCircuitBreakerConfig()
}
}
func BenchmarkBaseRecoveryMechanism_GetBaseMetrics(b *testing.B) {
logger := GetSingletonNoOpLogger()
base := NewBaseRecoveryMechanism("test-mechanism", logger)
b.ResetTimer()
for i := 0; i < b.N; i++ {
base.GetBaseMetrics()
}
}
func BenchmarkBaseRecoveryMechanism_RecordRequest(b *testing.B) {
logger := GetSingletonNoOpLogger()
base := NewBaseRecoveryMechanism("test-mechanism", logger)
b.ResetTimer()
for i := 0; i < b.N; i++ {
base.RecordRequest()
}
}
-560
View File
@@ -1,560 +0,0 @@
package traefikoidc
import (
"context"
"errors"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestRetryExecutorReset tests the Reset method
func TestRetryExecutorReset(t *testing.T) {
logger := GetSingletonNoOpLogger()
executor := NewRetryExecutor(DefaultRetryConfig(), logger)
require.NotNil(t, executor)
// Should not panic
assert.NotPanics(t, func() {
executor.Reset()
})
// Multiple resets should be safe
executor.Reset()
executor.Reset()
}
// TestRetryExecutorIsAvailable tests the IsAvailable method
func TestRetryExecutorIsAvailable(t *testing.T) {
logger := GetSingletonNoOpLogger()
executor := NewRetryExecutor(DefaultRetryConfig(), logger)
// Retry executor should always be available
assert.True(t, executor.IsAvailable())
// Should remain available after operations
ctx := context.Background()
executor.ExecuteWithContext(ctx, func() error {
return nil
})
assert.True(t, executor.IsAvailable())
}
// TestSessionErrorUnwrap tests SessionError.Unwrap
func TestSessionErrorUnwrap(t *testing.T) {
t.Run("unwrap with cause", func(t *testing.T) {
rootErr := errors.New("root cause")
sessionErr := NewSessionError("save", "failed to save session", rootErr)
unwrapped := sessionErr.Unwrap()
assert.Equal(t, rootErr, unwrapped)
})
t.Run("unwrap without cause", func(t *testing.T) {
sessionErr := NewSessionError("load", "failed to load session", nil)
unwrapped := sessionErr.Unwrap()
assert.Nil(t, unwrapped)
})
t.Run("error chain", func(t *testing.T) {
rootErr := errors.New("database error")
sessionErr := NewSessionError("delete", "failed to delete session", rootErr)
// Verify error chain works
assert.True(t, errors.Is(sessionErr, rootErr))
})
}
// TestTokenErrorUnwrap tests TokenError.Unwrap
func TestTokenErrorUnwrap(t *testing.T) {
t.Run("unwrap with cause", func(t *testing.T) {
rootErr := errors.New("signature verification failed")
tokenErr := NewTokenError("id_token", "invalid", "token is invalid", rootErr)
unwrapped := tokenErr.Unwrap()
assert.Equal(t, rootErr, unwrapped)
})
t.Run("unwrap without cause", func(t *testing.T) {
tokenErr := NewTokenError("access_token", "expired", "token has expired", nil)
unwrapped := tokenErr.Unwrap()
assert.Nil(t, unwrapped)
})
t.Run("error chain", func(t *testing.T) {
rootErr := errors.New("crypto error")
tokenErr := NewTokenError("refresh_token", "malformed", "token is malformed", rootErr)
// Verify error chain works
assert.True(t, errors.Is(tokenErr, rootErr))
})
}
// TestGracefulDegradationRegisterFallback tests fallback registration
func TestGracefulDegradationRegisterFallback(t *testing.T) {
logger := GetSingletonNoOpLogger()
config := DefaultGracefulDegradationConfig()
gd := NewGracefulDegradation(config, logger)
defer gd.Close()
t.Run("register single fallback", func(t *testing.T) {
fallback := func() (interface{}, error) {
return "fallback result", nil
}
gd.RegisterFallback("service1", fallback)
// Verify fallback was registered (indirectly)
result, err := gd.ExecuteWithFallback("service1", func() (interface{}, error) {
return nil, errors.New("service failed")
})
assert.NoError(t, err)
assert.Equal(t, "fallback result", result)
})
t.Run("register multiple fallbacks", func(t *testing.T) {
gd.RegisterFallback("service2", func() (interface{}, error) {
return "fallback2", nil
})
gd.RegisterFallback("service3", func() (interface{}, error) {
return "fallback3", nil
})
result2, _ := gd.ExecuteWithFallback("service2", func() (interface{}, error) {
return nil, errors.New("fail")
})
result3, _ := gd.ExecuteWithFallback("service3", func() (interface{}, error) {
return nil, errors.New("fail")
})
assert.Equal(t, "fallback2", result2)
assert.Equal(t, "fallback3", result3)
})
t.Run("override existing fallback", func(t *testing.T) {
gd.RegisterFallback("service4", func() (interface{}, error) {
return "old fallback", nil
})
gd.RegisterFallback("service4", func() (interface{}, error) {
return "new fallback", nil
})
result, _ := gd.ExecuteWithFallback("service4", func() (interface{}, error) {
return nil, errors.New("fail")
})
assert.Equal(t, "new fallback", result)
})
}
// TestGracefulDegradationRegisterHealthCheck tests health check registration
func TestGracefulDegradationRegisterHealthCheck(t *testing.T) {
logger := GetSingletonNoOpLogger()
config := DefaultGracefulDegradationConfig()
config.HealthCheckInterval = 50 * time.Millisecond
gd := NewGracefulDegradation(config, logger)
defer gd.Close()
t.Run("register health check", func(t *testing.T) {
healthy := true
healthCheck := func() bool {
return healthy
}
gd.RegisterHealthCheck("service1", healthCheck)
// Mark service as degraded
gd.markServiceDegraded("service1")
assert.True(t, gd.isServiceDegraded("service1"))
// Set healthy and wait for health check to run
healthy = true
time.Sleep(100 * time.Millisecond)
// Service should be recovered
// (may still be degraded due to timing, but health check was registered)
})
t.Run("multiple health checks", func(t *testing.T) {
gd.RegisterHealthCheck("service2", func() bool { return true })
gd.RegisterHealthCheck("service3", func() bool { return false })
// Health checks are registered and will be called periodically
})
}
// TestGracefulDegradationExecuteWithContext tests ExecuteWithContext
func TestGracefulDegradationExecuteWithContext(t *testing.T) {
logger := GetSingletonNoOpLogger()
config := DefaultGracefulDegradationConfig()
gd := NewGracefulDegradation(config, logger)
defer gd.Close()
t.Run("successful execution", func(t *testing.T) {
ctx := context.Background()
err := gd.ExecuteWithContext(ctx, func() error {
return nil
})
assert.NoError(t, err)
})
t.Run("failed execution", func(t *testing.T) {
ctx := context.Background()
testErr := errors.New("operation failed")
err := gd.ExecuteWithContext(ctx, func() error {
return testErr
})
assert.Error(t, err)
})
t.Run("uses fallback on failure", func(t *testing.T) {
gd.RegisterFallback("default", func() (interface{}, error) {
return nil, nil // Success fallback
})
ctx := context.Background()
err := gd.ExecuteWithContext(ctx, func() error {
return errors.New("primary failed")
})
// With fallback succeeding, overall operation succeeds
assert.NoError(t, err)
})
}
// TestGracefulDegradationExecuteWithFallback tests ExecuteWithFallback
func TestGracefulDegradationExecuteWithFallback(t *testing.T) {
logger := GetSingletonNoOpLogger()
config := DefaultGracefulDegradationConfig()
gd := NewGracefulDegradation(config, logger)
defer gd.Close()
t.Run("primary succeeds", func(t *testing.T) {
result, err := gd.ExecuteWithFallback("service1", func() (interface{}, error) {
return "primary result", nil
})
assert.NoError(t, err)
assert.Equal(t, "primary result", result)
})
t.Run("fallback succeeds when primary fails", func(t *testing.T) {
gd.RegisterFallback("service2", func() (interface{}, error) {
return "fallback result", nil
})
result, err := gd.ExecuteWithFallback("service2", func() (interface{}, error) {
return nil, errors.New("primary failed")
})
assert.NoError(t, err)
assert.Equal(t, "fallback result", result)
})
t.Run("error when no fallback available", func(t *testing.T) {
config.EnableFallbacks = false
gdNoFallback := NewGracefulDegradation(config, logger)
defer gdNoFallback.Close()
result, err := gdNoFallback.ExecuteWithFallback("service3", func() (interface{}, error) {
return nil, errors.New("primary failed")
})
assert.Error(t, err)
assert.Nil(t, result)
})
t.Run("fallback also fails", func(t *testing.T) {
gd.RegisterFallback("service4", func() (interface{}, error) {
return nil, errors.New("fallback also failed")
})
result, err := gd.ExecuteWithFallback("service4", func() (interface{}, error) {
return nil, errors.New("primary failed")
})
assert.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "fallback also failed")
})
}
// TestGracefulDegradationIsServiceDegraded tests service degradation status
func TestGracefulDegradationIsServiceDegraded(t *testing.T) {
logger := GetSingletonNoOpLogger()
config := DefaultGracefulDegradationConfig()
config.RecoveryTimeout = 100 * time.Millisecond
gd := NewGracefulDegradation(config, logger)
defer gd.Close()
t.Run("service not degraded initially", func(t *testing.T) {
assert.False(t, gd.isServiceDegraded("new-service"))
})
t.Run("service degraded after marking", func(t *testing.T) {
gd.markServiceDegraded("service1")
assert.True(t, gd.isServiceDegraded("service1"))
})
t.Run("service recovers after timeout", func(t *testing.T) {
gd.markServiceDegraded("service2")
assert.True(t, gd.isServiceDegraded("service2"))
// Wait for recovery timeout
time.Sleep(150 * time.Millisecond)
// Should be recovered
assert.False(t, gd.isServiceDegraded("service2"))
})
}
// TestGracefulDegradationMarkServiceDegraded tests marking services as degraded
func TestGracefulDegradationMarkServiceDegraded(t *testing.T) {
logger := GetSingletonNoOpLogger()
config := DefaultGracefulDegradationConfig()
gd := NewGracefulDegradation(config, logger)
defer gd.Close()
t.Run("mark single service", func(t *testing.T) {
gd.markServiceDegraded("service1")
degraded := gd.GetDegradedServices()
assert.Contains(t, degraded, "service1")
})
t.Run("mark multiple services", func(t *testing.T) {
gd.markServiceDegraded("service2")
gd.markServiceDegraded("service3")
degraded := gd.GetDegradedServices()
assert.Contains(t, degraded, "service2")
assert.Contains(t, degraded, "service3")
})
t.Run("marking same service multiple times updates timestamp", func(t *testing.T) {
gd.markServiceDegraded("service4")
time.Sleep(50 * time.Millisecond)
gd.markServiceDegraded("service4")
// Service should still be marked as degraded
assert.True(t, gd.isServiceDegraded("service4"))
})
}
// TestGracefulDegradationExecuteFallback tests fallback execution
func TestGracefulDegradationExecuteFallback(t *testing.T) {
logger := GetSingletonNoOpLogger()
config := DefaultGracefulDegradationConfig()
gd := NewGracefulDegradation(config, logger)
defer gd.Close()
t.Run("execute registered fallback", func(t *testing.T) {
gd.RegisterFallback("service1", func() (interface{}, error) {
return "fallback value", nil
})
result, err := gd.executeFallback("service1")
assert.NoError(t, err)
assert.Equal(t, "fallback value", result)
})
t.Run("error when fallback not registered", func(t *testing.T) {
result, err := gd.executeFallback("non-existent-service")
assert.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "no fallback available")
})
t.Run("propagate fallback errors", func(t *testing.T) {
gd.RegisterFallback("service2", func() (interface{}, error) {
return nil, errors.New("fallback error")
})
result, err := gd.executeFallback("service2")
assert.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "fallback error")
})
}
// TestGracefulDegradationReset tests Reset method
func TestGracefulDegradationReset(t *testing.T) {
logger := GetSingletonNoOpLogger()
config := DefaultGracefulDegradationConfig()
gd := NewGracefulDegradation(config, logger)
defer gd.Close()
t.Run("reset clears degraded services", func(t *testing.T) {
// Mark several services as degraded
gd.markServiceDegraded("service1")
gd.markServiceDegraded("service2")
gd.markServiceDegraded("service3")
assert.Len(t, gd.GetDegradedServices(), 3)
// Reset
gd.Reset()
// All should be cleared
assert.Len(t, gd.GetDegradedServices(), 0)
})
t.Run("can mark services degraded after reset", func(t *testing.T) {
gd.Reset()
gd.markServiceDegraded("service4")
assert.Len(t, gd.GetDegradedServices(), 1)
assert.Contains(t, gd.GetDegradedServices(), "service4")
})
t.Run("multiple resets are safe", func(t *testing.T) {
assert.NotPanics(t, func() {
gd.Reset()
gd.Reset()
gd.Reset()
})
})
}
// TestGracefulDegradationIsAvailable tests IsAvailable method
func TestGracefulDegradationIsAvailable(t *testing.T) {
logger := GetSingletonNoOpLogger()
config := DefaultGracefulDegradationConfig()
gd := NewGracefulDegradation(config, logger)
defer gd.Close()
// Should always return true
assert.True(t, gd.IsAvailable())
// Even with degraded services
gd.markServiceDegraded("service1")
assert.True(t, gd.IsAvailable())
// Even after reset
gd.Reset()
assert.True(t, gd.IsAvailable())
}
// TestGracefulDegradationGetMetrics tests GetMetrics method
func TestGracefulDegradationGetMetrics(t *testing.T) {
logger := GetSingletonNoOpLogger()
config := DefaultGracefulDegradationConfig()
gd := NewGracefulDegradation(config, logger)
defer gd.Close()
t.Run("basic metrics", func(t *testing.T) {
metrics := gd.GetMetrics()
require.NotNil(t, metrics)
assert.Contains(t, metrics, "degraded_services_count")
assert.Contains(t, metrics, "degraded_services")
assert.Contains(t, metrics, "registered_fallbacks_count")
assert.Contains(t, metrics, "registered_health_checks_count")
assert.Contains(t, metrics, "health_check_interval_seconds")
assert.Contains(t, metrics, "recovery_timeout_seconds")
assert.Contains(t, metrics, "fallbacks_enabled")
})
t.Run("metrics reflect degraded services", func(t *testing.T) {
gd.Reset()
gd.markServiceDegraded("service1")
gd.markServiceDegraded("service2")
metrics := gd.GetMetrics()
assert.Equal(t, 2, metrics["degraded_services_count"])
degradedList := metrics["degraded_services"].([]string)
assert.Len(t, degradedList, 2)
})
t.Run("metrics reflect registered fallbacks", func(t *testing.T) {
gd.RegisterFallback("service1", func() (interface{}, error) { return nil, nil })
gd.RegisterFallback("service2", func() (interface{}, error) { return nil, nil })
metrics := gd.GetMetrics()
assert.GreaterOrEqual(t, metrics["registered_fallbacks_count"], 2)
})
t.Run("metrics include base metrics", func(t *testing.T) {
metrics := gd.GetMetrics()
// Should include base recovery mechanism metrics
assert.Contains(t, metrics, "name")
assert.Contains(t, metrics, "uptime_seconds")
assert.Contains(t, metrics, "total_requests")
})
}
// TestGracefulDegradationFullScenario tests a complete degradation scenario
func TestGracefulDegradationFullScenario(t *testing.T) {
if testing.Short() {
t.Skip("Skipping full scenario test in short mode")
}
logger := GetSingletonNoOpLogger()
config := DefaultGracefulDegradationConfig()
config.RecoveryTimeout = 200 * time.Millisecond
config.HealthCheckInterval = 50 * time.Millisecond
gd := NewGracefulDegradation(config, logger)
defer gd.Close()
// Register fallback
gd.RegisterFallback("critical-service", func() (interface{}, error) {
return "fallback data", nil
})
// Register health check
serviceHealthy := false
gd.RegisterHealthCheck("critical-service", func() bool {
return serviceHealthy
})
// First call - primary succeeds
result1, err1 := gd.ExecuteWithFallback("critical-service", func() (interface{}, error) {
return "primary data", nil
})
assert.NoError(t, err1)
assert.Equal(t, "primary data", result1)
// Second call - primary fails, fallback succeeds
result2, err2 := gd.ExecuteWithFallback("critical-service", func() (interface{}, error) {
return nil, errors.New("service down")
})
assert.NoError(t, err2)
assert.Equal(t, "fallback data", result2)
// Service is now degraded
assert.True(t, gd.isServiceDegraded("critical-service"))
// Third call - should use fallback immediately
result3, err3 := gd.ExecuteWithFallback("critical-service", func() (interface{}, error) {
return "should not be called", nil
})
assert.NoError(t, err3)
assert.Equal(t, "fallback data", result3)
// Mark service as healthy and wait for health check
serviceHealthy = true
time.Sleep(250 * time.Millisecond)
// Service should be recovered
// (timing-dependent, so we don't assert)
// Get metrics
metrics := gd.GetMetrics()
assert.NotNil(t, metrics)
}
+29
View File
@@ -0,0 +1,29 @@
package traefikoidc
import "testing"
func BenchmarkDefaultCircuitBreakerConfig(b *testing.B) {
for i := 0; i < b.N; i++ {
DefaultCircuitBreakerConfig()
}
}
func BenchmarkBaseRecoveryMechanism_GetBaseMetrics(b *testing.B) {
logger := GetSingletonNoOpLogger()
base := NewBaseRecoveryMechanism("test-mechanism", logger)
b.ResetTimer()
for i := 0; i < b.N; i++ {
base.GetBaseMetrics()
}
}
func BenchmarkBaseRecoveryMechanism_RecordRequest(b *testing.B) {
logger := GetSingletonNoOpLogger()
base := NewBaseRecoveryMechanism("test-mechanism", logger)
b.ResetTimer()
for i := 0; i < b.N; i++ {
base.RecordRequest()
}
}
-663
View File
@@ -1,663 +0,0 @@
package traefikoidc
import (
"context"
"errors"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
// TestCircuitBreakerAllowRequestEdgeCases tests edge cases in circuit breaker request allowing
func TestCircuitBreakerAllowRequestEdgeCases(t *testing.T) {
logger := GetSingletonNoOpLogger()
t.Run("invalid state returns false", func(t *testing.T) {
config := DefaultCircuitBreakerConfig()
cb := NewCircuitBreaker(config, logger)
// Force invalid state
cb.mutex.Lock()
cb.state = CircuitBreakerState(999) // Invalid state
cb.mutex.Unlock()
// Should return false for invalid state
allowed := cb.allowRequest()
assert.False(t, allowed, "invalid state should not allow requests")
})
t.Run("open to half-open transition on timeout", func(t *testing.T) {
baseTimeout := GetTestDuration(50 * time.Millisecond)
config := CircuitBreakerConfig{
MaxFailures: 1,
Timeout: baseTimeout,
ResetTimeout: 30 * time.Second,
}
cb := NewCircuitBreaker(config, logger)
// Trip the circuit
cb.Execute(func() error { return errors.New("fail") })
// Verify circuit is open
assert.Equal(t, CircuitBreakerOpen, cb.GetState())
assert.False(t, cb.allowRequest())
// Wait for timeout (longer than timeout to ensure transition)
time.Sleep(baseTimeout + GetTestDuration(20*time.Millisecond))
// Should transition to half-open
allowed := cb.allowRequest()
assert.True(t, allowed, "should allow request after timeout")
assert.Equal(t, CircuitBreakerHalfOpen, cb.GetState())
})
t.Run("half-open allows requests", func(t *testing.T) {
config := DefaultCircuitBreakerConfig()
cb := NewCircuitBreaker(config, logger)
// Manually set to half-open
cb.mutex.Lock()
cb.state = CircuitBreakerHalfOpen
cb.mutex.Unlock()
allowed := cb.allowRequest()
assert.True(t, allowed, "half-open should allow requests")
})
t.Run("open blocks requests before timeout", func(t *testing.T) {
config := CircuitBreakerConfig{
MaxFailures: 1,
Timeout: 1 * time.Hour, // Long timeout
ResetTimeout: 30 * time.Second,
}
cb := NewCircuitBreaker(config, logger)
// Trip the circuit
cb.Execute(func() error { return errors.New("fail") })
// Should be blocked
allowed := cb.allowRequest()
assert.False(t, allowed, "open circuit should block requests")
})
}
// TestRetryExecutorIsRetryableErrorEdgeCases tests edge cases for retry decision
func TestRetryExecutorIsRetryableErrorEdgeCases(t *testing.T) {
logger := GetSingletonNoOpLogger()
config := DefaultRetryConfig()
re := NewRetryExecutor(config, logger)
t.Run("nil error is not retryable", func(t *testing.T) {
retryable := re.isRetryableError(nil)
assert.False(t, retryable)
})
t.Run("HTTPError with 429 is retryable", func(t *testing.T) {
httpErr := &HTTPError{
StatusCode: 429,
Message: "Too Many Requests",
}
retryable := re.isRetryableError(httpErr)
assert.True(t, retryable, "429 Too Many Requests should be retryable")
})
t.Run("HTTPError with 500 is retryable", func(t *testing.T) {
httpErr := &HTTPError{
StatusCode: 500,
Message: "Internal Server Error",
}
retryable := re.isRetryableError(httpErr)
assert.True(t, retryable, "500 errors should be retryable")
})
t.Run("HTTPError with 503 is retryable", func(t *testing.T) {
httpErr := &HTTPError{
StatusCode: 503,
Message: "Service Unavailable",
}
retryable := re.isRetryableError(httpErr)
assert.True(t, retryable, "503 errors should be retryable")
})
t.Run("HTTPError with 400 is not retryable", func(t *testing.T) {
httpErr := &HTTPError{
StatusCode: 400,
Message: "Bad Request",
}
retryable := re.isRetryableError(httpErr)
assert.False(t, retryable, "400 errors should not be retryable")
})
t.Run("net.Error with timeout is retryable", func(t *testing.T) {
netErr := &mockNetError{
timeout: true,
temporary: false,
msg: "timeout error",
}
retryable := re.isRetryableError(netErr)
assert.True(t, retryable, "timeout errors should be retryable")
})
t.Run("net.Error with connection refused is retryable", func(t *testing.T) {
netErr := &mockNetError{
timeout: false,
temporary: false,
msg: "connection refused",
}
retryable := re.isRetryableError(netErr)
assert.True(t, retryable, "connection refused should be retryable")
})
t.Run("net.Error with connection reset is retryable", func(t *testing.T) {
netErr := &mockNetError{
timeout: false,
temporary: false,
msg: "connection reset by peer",
}
retryable := re.isRetryableError(netErr)
assert.True(t, retryable, "connection reset should be retryable")
})
t.Run("net.Error with network unreachable is retryable", func(t *testing.T) {
netErr := &mockNetError{
timeout: false,
temporary: false,
msg: "network is unreachable",
}
retryable := re.isRetryableError(netErr)
assert.True(t, retryable, "network unreachable should be retryable")
})
t.Run("net.Error with no route to host is retryable", func(t *testing.T) {
netErr := &mockNetError{
timeout: false,
temporary: false,
msg: "no route to host",
}
retryable := re.isRetryableError(netErr)
assert.True(t, retryable, "no route to host should be retryable")
})
t.Run("net.Error with temporary failure is retryable", func(t *testing.T) {
netErr := &mockNetError{
timeout: false,
temporary: false,
msg: "temporary failure in name resolution",
}
retryable := re.isRetryableError(netErr)
assert.True(t, retryable, "temporary failure should be retryable")
})
t.Run("net.Error with try again is retryable", func(t *testing.T) {
netErr := &mockNetError{
timeout: false,
temporary: false,
msg: "try again later",
}
retryable := re.isRetryableError(netErr)
assert.True(t, retryable, "try again should be retryable")
})
t.Run("net.Error with resource temporarily unavailable is retryable", func(t *testing.T) {
netErr := &mockNetError{
timeout: false,
temporary: false,
msg: "resource temporarily unavailable",
}
retryable := re.isRetryableError(netErr)
assert.True(t, retryable, "resource temporarily unavailable should be retryable")
})
t.Run("configured retryable error patterns", func(t *testing.T) {
err := errors.New("connection refused by server")
retryable := re.isRetryableError(err)
assert.True(t, retryable, "configured pattern should be retryable")
})
t.Run("non-retryable error", func(t *testing.T) {
err := errors.New("invalid input data")
retryable := re.isRetryableError(err)
assert.False(t, retryable, "non-configured error should not be retryable")
})
}
// TestRetryExecutorCalculateDelayEdgeCases tests delay calculation edge cases
func TestRetryExecutorCalculateDelayEdgeCases(t *testing.T) {
logger := GetSingletonNoOpLogger()
t.Run("delay calculation without jitter", func(t *testing.T) {
config := RetryConfig{
MaxAttempts: 3,
InitialDelay: 100 * time.Millisecond,
MaxDelay: 5 * time.Second,
BackoffFactor: 2.0,
EnableJitter: false, // Jitter disabled
}
re := NewRetryExecutor(config, logger)
// Attempt 1: 100ms * 2^0 = 100ms
delay1 := re.calculateDelay(1)
assert.Equal(t, 100*time.Millisecond, delay1)
// Attempt 2: 100ms * 2^1 = 200ms
delay2 := re.calculateDelay(2)
assert.Equal(t, 200*time.Millisecond, delay2)
// Attempt 3: 100ms * 2^2 = 400ms
delay3 := re.calculateDelay(3)
assert.Equal(t, 400*time.Millisecond, delay3)
})
t.Run("delay calculation with jitter", func(t *testing.T) {
config := RetryConfig{
MaxAttempts: 3,
InitialDelay: 100 * time.Millisecond,
MaxDelay: 5 * time.Second,
BackoffFactor: 2.0,
EnableJitter: true, // Jitter enabled
}
re := NewRetryExecutor(config, logger)
// With jitter, delay should be within 10% of expected
delay := re.calculateDelay(2)
expectedBase := 200 * time.Millisecond
minDelay := time.Duration(float64(expectedBase) * 0.9)
maxDelay := time.Duration(float64(expectedBase) * 1.1)
assert.GreaterOrEqual(t, delay, minDelay, "delay should be >= 90% of base")
assert.LessOrEqual(t, delay, maxDelay, "delay should be <= 110% of base")
})
t.Run("delay capped at max delay", func(t *testing.T) {
config := RetryConfig{
MaxAttempts: 10,
InitialDelay: 100 * time.Millisecond,
MaxDelay: 500 * time.Millisecond, // Low max delay
BackoffFactor: 2.0,
EnableJitter: false,
}
re := NewRetryExecutor(config, logger)
// Attempt 10: would be 100ms * 2^9 = 51200ms, but capped at 500ms
delay := re.calculateDelay(10)
assert.Equal(t, 500*time.Millisecond, delay, "delay should be capped at max")
})
t.Run("delay with large backoff factor", func(t *testing.T) {
config := RetryConfig{
MaxAttempts: 5,
InitialDelay: 50 * time.Millisecond,
MaxDelay: 10 * time.Second,
BackoffFactor: 3.0, // Larger backoff
EnableJitter: false,
}
re := NewRetryExecutor(config, logger)
// Attempt 3: 50ms * 3^2 = 450ms
delay := re.calculateDelay(3)
assert.Equal(t, 450*time.Millisecond, delay)
})
}
// TestErrorTypesErrorMethodsWithoutCause tests error type Error() methods without cause
func TestErrorTypesErrorMethodsWithoutCause(t *testing.T) {
t.Run("HTTPError.Error without cause", func(t *testing.T) {
httpErr := &HTTPError{
StatusCode: 404,
Message: "Not Found",
}
errStr := httpErr.Error()
assert.Equal(t, "HTTP 404: Not Found", errStr)
})
t.Run("HTTPError.Error with different status codes", func(t *testing.T) {
testCases := []struct {
code int
message string
expected string
}{
{200, "OK", "HTTP 200: OK"},
{301, "Moved", "HTTP 301: Moved"},
{401, "Unauthorized", "HTTP 401: Unauthorized"},
{500, "Server Error", "HTTP 500: Server Error"},
}
for _, tc := range testCases {
httpErr := &HTTPError{
StatusCode: tc.code,
Message: tc.message,
}
assert.Equal(t, tc.expected, httpErr.Error())
}
})
t.Run("OIDCError.Error without cause", func(t *testing.T) {
oidcErr := &OIDCError{
Code: "invalid_token",
Message: "Token validation failed",
Context: make(map[string]interface{}),
}
errStr := oidcErr.Error()
assert.Equal(t, "OIDC error [invalid_token]: Token validation failed", errStr)
})
t.Run("OIDCError.Error with cause", func(t *testing.T) {
rootErr := errors.New("signature mismatch")
oidcErr := &OIDCError{
Code: "invalid_signature",
Message: "JWT signature invalid",
Context: make(map[string]interface{}),
Cause: rootErr,
}
errStr := oidcErr.Error()
assert.Contains(t, errStr, "OIDC error [invalid_signature]: JWT signature invalid")
assert.Contains(t, errStr, "caused by: signature mismatch")
})
t.Run("SessionError.Error without cause", func(t *testing.T) {
sessErr := &SessionError{
Operation: "load",
Message: "Session not found",
SessionID: "sess123",
}
errStr := sessErr.Error()
assert.Equal(t, "Session error in load: Session not found", errStr)
})
t.Run("SessionError.Error with cause", func(t *testing.T) {
rootErr := errors.New("database connection failed")
sessErr := &SessionError{
Operation: "save",
Message: "Failed to persist session",
SessionID: "sess456",
Cause: rootErr,
}
errStr := sessErr.Error()
assert.Contains(t, errStr, "Session error in save: Failed to persist session")
assert.Contains(t, errStr, "caused by: database connection failed")
})
t.Run("TokenError.Error without cause", func(t *testing.T) {
tokenErr := &TokenError{
TokenType: "access_token",
Reason: "expired",
Message: "Token has expired",
}
errStr := tokenErr.Error()
assert.Equal(t, "Token error (access_token) - expired: Token has expired", errStr)
})
t.Run("TokenError.Error with cause", func(t *testing.T) {
rootErr := errors.New("time check failed")
tokenErr := &TokenError{
TokenType: "id_token",
Reason: "expired",
Message: "Token validity period exceeded",
Cause: rootErr,
}
errStr := tokenErr.Error()
assert.Contains(t, errStr, "Token error (id_token) - expired: Token validity period exceeded")
assert.Contains(t, errStr, "caused by: time check failed")
})
}
// TestGracefulDegradationHealthChecks tests health check functionality
func TestGracefulDegradationHealthChecks(t *testing.T) {
logger := GetSingletonNoOpLogger()
t.Run("performHealthChecks recovers degraded service", func(t *testing.T) {
config := DefaultGracefulDegradationConfig()
gd := NewGracefulDegradation(config, logger)
defer gd.Close()
// Register health check that returns true
healthCheckCalled := false
gd.RegisterHealthCheck("test-service", func() bool {
healthCheckCalled = true
return true // Service is healthy
})
// Mark service as degraded
gd.markServiceDegraded("test-service")
// Verify service is degraded
assert.True(t, gd.isServiceDegraded("test-service"))
// Manually trigger health check
gd.performHealthChecks()
// Health check should have been called
assert.True(t, healthCheckCalled, "health check should be called")
// Service should be recovered
assert.False(t, gd.isServiceDegraded("test-service"), "service should be recovered")
})
t.Run("performHealthChecks marks service degraded on failure", func(t *testing.T) {
config := DefaultGracefulDegradationConfig()
gd := NewGracefulDegradation(config, logger)
defer gd.Close()
// Register health check that returns false
gd.RegisterHealthCheck("failing-service", func() bool {
return false // Service is unhealthy
})
// Initially not degraded
assert.False(t, gd.isServiceDegraded("failing-service"))
// Manually trigger health check
gd.performHealthChecks()
// Service should be marked degraded
assert.True(t, gd.isServiceDegraded("failing-service"), "service should be degraded")
})
t.Run("performHealthChecks runs multiple health checks independently", func(t *testing.T) {
config := DefaultGracefulDegradationConfig()
gd := NewGracefulDegradation(config, logger)
defer gd.Close()
service1Checked := false
service2Checked := false
gd.RegisterHealthCheck("service1", func() bool {
service1Checked = true
return true
})
gd.RegisterHealthCheck("service2", func() bool {
service2Checked = true
return true
})
// Manually trigger health checks
gd.performHealthChecks()
assert.True(t, service1Checked, "service1 health check should run")
assert.True(t, service2Checked, "service2 health check should run")
})
t.Run("performHealthChecks handles empty health checks", func(t *testing.T) {
config := DefaultGracefulDegradationConfig()
gd := NewGracefulDegradation(config, logger)
defer gd.Close()
// Call performHealthChecks with no registered health checks
// Should not panic
assert.NotPanics(t, func() {
gd.performHealthChecks()
})
})
}
// TestGracefulDegradationServiceRecoveryTimeout tests recovery timeout behavior
func TestGracefulDegradationServiceRecoveryTimeout(t *testing.T) {
logger := GetSingletonNoOpLogger()
t.Run("service auto-recovers after timeout", func(t *testing.T) {
baseTimeout := GetTestDuration(50 * time.Millisecond)
config := GracefulDegradationConfig{
HealthCheckInterval: 1 * time.Hour, // Long interval, won't run during test
RecoveryTimeout: baseTimeout,
EnableFallbacks: true,
}
gd := NewGracefulDegradation(config, logger)
defer gd.Close()
// Mark service degraded
gd.markServiceDegraded("auto-recover-service")
// Verify degraded
assert.True(t, gd.isServiceDegraded("auto-recover-service"))
// Wait for recovery timeout (longer than timeout to ensure recovery)
time.Sleep(baseTimeout + GetTestDuration(20*time.Millisecond))
// Should auto-recover
assert.False(t, gd.isServiceDegraded("auto-recover-service"), "service should auto-recover after timeout")
})
t.Run("service remains degraded before timeout", func(t *testing.T) {
config := GracefulDegradationConfig{
HealthCheckInterval: 1 * time.Hour,
RecoveryTimeout: 1 * time.Hour, // Very long timeout
EnableFallbacks: true,
}
gd := NewGracefulDegradation(config, logger)
defer gd.Close()
// Mark service degraded
gd.markServiceDegraded("long-timeout-service")
// Verify degraded
assert.True(t, gd.isServiceDegraded("long-timeout-service"))
// Wait a bit
time.Sleep(GetTestDuration(10 * time.Millisecond))
// Should still be degraded
assert.True(t, gd.isServiceDegraded("long-timeout-service"), "service should remain degraded before timeout")
})
}
// TestErrorRecoveryManagerIntegration tests full integration of error recovery mechanisms
func TestErrorRecoveryManagerIntegration(t *testing.T) {
logger := GetSingletonNoOpLogger()
erm := NewErrorRecoveryManager(logger)
t.Run("circuit breaker and retry integration", func(t *testing.T) {
// Create a circuit breaker with higher max failures to allow retries
cb := NewCircuitBreaker(CircuitBreakerConfig{
MaxFailures: 10, // High threshold
Timeout: 60 * time.Second,
ResetTimeout: 30 * time.Second,
}, logger)
erm.mutex.Lock()
erm.circuitBreakers["test-service-integration"] = cb
erm.mutex.Unlock()
attempts := 0
fn := func() error {
attempts++
if attempts < 3 {
return errors.New("temporary failure")
}
return nil
}
err := erm.ExecuteWithRecovery(context.Background(), "test-service-integration", fn)
assert.NoError(t, err)
assert.GreaterOrEqual(t, attempts, 3, "should retry until success")
})
t.Run("circuit breaker opens on repeated failures", func(t *testing.T) {
fn := func() error {
return errors.New("persistent failure")
}
// First call - should fail after retries
err1 := erm.ExecuteWithRecovery(context.Background(), "failing-service", fn)
assert.Error(t, err1)
// Second call - should fail after retries
err2 := erm.ExecuteWithRecovery(context.Background(), "failing-service", fn)
assert.Error(t, err2)
// Check circuit breaker state
cb := erm.GetCircuitBreaker("failing-service")
state := cb.GetState()
assert.Equal(t, CircuitBreakerOpen, state, "circuit should be open after repeated failures")
})
t.Run("recovery metrics include all mechanisms", func(t *testing.T) {
metrics := erm.GetRecoveryMetrics()
assert.NotNil(t, metrics)
assert.Contains(t, metrics, "circuit_breakers")
assert.Contains(t, metrics, "degraded_services")
})
}
// TestContainsHelperFunction tests the contains helper function edge cases
func TestContainsHelperFunction(t *testing.T) {
t.Run("exact match", func(t *testing.T) {
assert.True(t, contains("timeout", "timeout"))
})
t.Run("prefix match", func(t *testing.T) {
assert.True(t, contains("timeout error occurred", "timeout"))
})
t.Run("suffix match", func(t *testing.T) {
assert.True(t, contains("connection timeout", "timeout"))
})
t.Run("middle match", func(t *testing.T) {
assert.True(t, contains("a connection timeout error", "timeout"))
})
t.Run("no match", func(t *testing.T) {
assert.False(t, contains("connection refused", "timeout"))
})
t.Run("substring longer than string", func(t *testing.T) {
assert.False(t, contains("abc", "abcdef"))
})
t.Run("empty substring", func(t *testing.T) {
assert.True(t, contains("test", ""))
})
t.Run("empty string", func(t *testing.T) {
assert.False(t, contains("", "test"))
})
t.Run("both empty", func(t *testing.T) {
assert.True(t, contains("", ""))
})
}
+1178 -37
View File
File diff suppressed because it is too large Load Diff
-797
View File
@@ -1,797 +0,0 @@
package features
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"text/template"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// Mock types for testing
type TemplatedHeader struct {
Name string `json:"name"`
Value string `json:"value"`
}
type MockConfig struct {
ProviderURL string `json:"providerURL"`
ClientID string `json:"clientID"`
ClientSecret string `json:"clientSecret"`
CallbackURL string `json:"callbackURL"`
SessionEncryptionKey string `json:"sessionEncryptionKey"`
Headers []TemplatedHeader `json:"headers"`
}
// TestTemplateHeaderFeatures consolidates all template header-related tests
func TestTemplateHeaderFeatures(t *testing.T) {
t.Run("Issue55_TemplateExecutionWithWrongTypes", testIssue55TemplateExecutionWithWrongTypes)
t.Run("Template_Parsing_Validation", testTemplateParsingValidation)
t.Run("Middleware_Header_Templating", testMiddlewareHeaderTemplating)
t.Run("JSON_Config_Parsing", testJSONConfigParsing)
t.Run("Template_Double_Processing", testTemplateDoubleProcessing)
t.Run("Template_Execution_Context", testTemplateExecutionContext)
t.Run("Template_Integration_With_Plugin", testTemplateIntegrationWithPlugin)
t.Run("Template_Syntax_Validation", testTemplateSyntaxValidation)
t.Run("Missing_Field_Handling", testMissingFieldHandling)
t.Run("Complex_Template_Expressions", testComplexTemplateExpressions)
t.Run("Traefik_Configuration_Parsing", testTraefikConfigurationParsing)
}
// testIssue55TemplateExecutionWithWrongTypes tests what happens when templates
// receive wrong data types during execution - reproduces GitHub issue #55
func testIssue55TemplateExecutionWithWrongTypes(t *testing.T) {
testCases := []struct {
name string
templateText string
templateData interface{}
errorContains string
expectError bool
}{
{
name: "correct map data",
templateText: "Bearer {{.AccessToken}}",
templateData: map[string]interface{}{
"AccessToken": "valid-token",
},
expectError: false,
},
{
name: "boolean as root context - reproduces issue #55",
templateText: "Bearer {{.AccessToken}}",
templateData: true,
expectError: true,
errorContains: "can't evaluate field AccessToken in type bool",
},
{
name: "string as root context",
templateText: "Bearer {{.AccessToken}}",
templateData: "just a string",
expectError: true,
errorContains: "can't evaluate field AccessToken in type string",
},
{
name: "nested claims access with correct data",
templateText: "User: {{.Claims.email}}",
templateData: map[string]interface{}{
"Claims": map[string]interface{}{
"email": "user@example.com",
},
},
expectError: false,
},
{
name: "nested claims with wrong structure",
templateText: "User: {{.Claims.email}}",
templateData: map[string]interface{}{
"Claims": "not a map",
},
expectError: true,
errorContains: "can't evaluate field email in type",
},
{
name: "complex nested structure",
templateText: "{{.Claims.sub}} - {{.Claims.groups}} - {{.AccessToken}}",
templateData: map[string]interface{}{
"AccessToken": "token123",
"Claims": map[string]interface{}{
"sub": "user-id",
"groups": "admin,users",
},
},
expectError: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tmpl, err := template.New("test").Parse(tc.templateText)
require.NoError(t, err)
var buf bytes.Buffer
err = tmpl.Execute(&buf, tc.templateData)
if tc.expectError {
require.Error(t, err)
if tc.errorContains != "" {
assert.Contains(t, err.Error(), tc.errorContains)
}
} else {
require.NoError(t, err)
}
})
}
}
// testTemplateParsingValidation ensures templates are parsed correctly
func testTemplateParsingValidation(t *testing.T) {
testCases := []struct {
name string
headerTemplates []TemplatedHeader
shouldError bool
}{
{
name: "valid bearer token template",
headerTemplates: []TemplatedHeader{
{Name: "Authorization", Value: "Bearer {{.AccessToken}}"},
},
shouldError: false,
},
{
name: "multiple valid templates",
headerTemplates: []TemplatedHeader{
{Name: "Authorization", Value: "Bearer {{.AccessToken}}"},
{Name: "X-User-Email", Value: "{{.Claims.email}}"},
{Name: "X-User-ID", Value: "{{.Claims.sub}}"},
},
shouldError: false,
},
{
name: "template with conditional logic",
headerTemplates: []TemplatedHeader{
{Name: "X-Auth-Info", Value: "{{if .AccessToken}}Bearer {{.AccessToken}}{{else}}No Token{{end}}"},
},
shouldError: false,
},
{
name: "invalid template syntax",
headerTemplates: []TemplatedHeader{
{Name: "Bad-Template", Value: "{{.AccessToken"},
},
shouldError: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
for _, header := range tc.headerTemplates {
_, err := template.New(header.Name).Parse(header.Value)
if tc.shouldError {
require.Error(t, err)
} else {
require.NoError(t, err)
}
}
})
}
}
// testMiddlewareHeaderTemplating simulates the actual middleware flow
func testMiddlewareHeaderTemplating(t *testing.T) {
testCases := []struct {
name string
headers []TemplatedHeader
accessToken string
idToken string
claims map[string]interface{}
expectedValues map[string]string
}{
{
name: "authorization header with access token",
headers: []TemplatedHeader{
{Name: "Authorization", Value: "Bearer {{.AccessToken}}"},
},
accessToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9",
expectedValues: map[string]string{
"Authorization": "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9",
},
},
{
name: "multiple headers with claims",
headers: []TemplatedHeader{
{Name: "X-User-Email", Value: "{{.Claims.email}}"},
{Name: "X-User-Groups", Value: "{{.Claims.groups}}"},
{Name: "X-Auth-Token", Value: "{{.AccessToken}}"},
},
accessToken: "token123",
claims: map[string]interface{}{
"email": "user@example.com",
"groups": "admin,developers",
},
expectedValues: map[string]string{
"X-User-Email": "user@example.com",
"X-User-Groups": "admin,developers",
"X-Auth-Token": "token123",
},
},
{
name: "complex template expressions",
headers: []TemplatedHeader{
{Name: "X-User-Info", Value: "{{.Claims.sub}} ({{.Claims.email}})"},
{Name: "X-Auth-Header", Value: "Bearer {{.AccessToken}} | ID: {{.IDToken}}"},
},
accessToken: "access-token",
idToken: "id-token",
claims: map[string]interface{}{
"sub": "user-12345",
"email": "john@example.com",
},
expectedValues: map[string]string{
"X-User-Info": "user-12345 (john@example.com)",
"X-Auth-Header": "Bearer access-token | ID: id-token",
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Parse all templates
headerTemplates := make(map[string]*template.Template)
for _, header := range tc.headers {
tmpl, err := template.New(header.Name).Parse(header.Value)
require.NoError(t, err)
headerTemplates[header.Name] = tmpl
}
// Create template data
templateData := map[string]interface{}{
"AccessToken": tc.accessToken,
"IDToken": tc.idToken,
"Claims": tc.claims,
}
// Create a test request
req := httptest.NewRequest("GET", "/test", nil)
// Execute templates and set headers
for headerName, tmpl := range headerTemplates {
var buf bytes.Buffer
err := tmpl.Execute(&buf, templateData)
require.NoError(t, err)
req.Header.Set(headerName, buf.String())
}
// Verify all expected headers are set correctly
for headerName, expectedValue := range tc.expectedValues {
actualValue := req.Header.Get(headerName)
assert.Equal(t, expectedValue, actualValue)
}
})
}
}
// testJSONConfigParsing tests that JSON configuration is properly parsed
func testJSONConfigParsing(t *testing.T) {
testCases := []struct {
name string
jsonConfig string
expectedError bool
description string
}{
{
name: "valid JSON configuration",
jsonConfig: `{
"headers": [
{
"name": "Authorization",
"value": "Bearer {{.AccessToken}}"
}
]
}`,
expectedError: false,
description: "Properly formatted JSON with string values",
},
{
name: "JSON with boolean value",
jsonConfig: `{
"headers": [
{
"name": "Authorization",
"value": true
}
]
}`,
expectedError: true,
description: "Boolean value instead of string template",
},
{
name: "JSON with number value",
jsonConfig: `{
"headers": [
{
"name": "Authorization",
"value": 123
}
]
}`,
expectedError: true,
description: "Number value instead of string template",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var config struct {
Headers []TemplatedHeader `json:"headers"`
}
err := json.Unmarshal([]byte(tc.jsonConfig), &config)
if tc.expectedError {
require.Error(t, err, tc.description)
} else {
require.NoError(t, err, tc.description)
}
})
}
}
// testTemplateDoubleProcessing tests if template strings are being double-processed
func testTemplateDoubleProcessing(t *testing.T) {
// Simulate how Traefik passes config to the plugin
config := &MockConfig{
Headers: []TemplatedHeader{
{Name: "X-User-Email", Value: "{{.Claims.email}}"},
{Name: "X-User-Role", Value: "{{.Claims.internal_role}}"},
},
}
// Verify that template strings are still raw (not processed)
assert.Equal(t, "{{.Claims.email}}", config.Headers[0].Value)
assert.Equal(t, "{{.Claims.internal_role}}", config.Headers[1].Value)
// Simulate template parsing during initialization
headerTemplates := make(map[string]*template.Template)
funcMap := template.FuncMap{
"default": func(defaultVal interface{}, val interface{}) interface{} {
if val == nil || val == "" || val == "<no value>" {
return defaultVal
}
return val
},
"get": func(m interface{}, key string) interface{} {
if mapVal, ok := m.(map[string]interface{}); ok {
if val, exists := mapVal[key]; exists {
return val
}
}
return ""
},
}
for _, header := range config.Headers {
tmpl := template.New(header.Name).Funcs(funcMap).Option("missingkey=zero")
parsedTmpl, err := tmpl.Parse(header.Value)
require.NoError(t, err)
headerTemplates[header.Name] = parsedTmpl
}
// Test execution with actual claims
claims := map[string]interface{}{
"email": "user@example.com",
// Note: internal_role is missing
}
templateData := map[string]interface{}{
"Claims": claims,
}
// Execute templates
for headerName, tmpl := range headerTemplates {
var buf bytes.Buffer
err := tmpl.Execute(&buf, templateData)
require.NoError(t, err)
result := buf.String()
if headerName == "X-User-Email" {
assert.Equal(t, "user@example.com", result)
} else if headerName == "X-User-Role" {
// With missingkey=zero, missing fields return "<no value>"
assert.Equal(t, "<no value>", result)
}
}
}
// testTemplateExecutionContext tests the specific template data context
func testTemplateExecutionContext(t *testing.T) {
testCases := []struct {
name string
templateText string
data map[string]interface{}
expectedValue string
}{
{
name: "Access and ID token distinction",
templateText: "Access: {{.AccessToken}} ID: {{.IDToken}}",
data: map[string]interface{}{
"AccessToken": "access-token-value",
"IDToken": "id-token-value",
"Claims": map[string]interface{}{},
},
expectedValue: "Access: access-token-value ID: id-token-value",
},
{
name: "Combining tokens and claims",
templateText: "User: {{.Claims.sub}} Token: {{.AccessToken}}",
data: map[string]interface{}{
"AccessToken": "access-token",
"IDToken": "id-token",
"Claims": map[string]interface{}{
"sub": "user123",
},
},
expectedValue: "User: user123 Token: access-token",
},
{
name: "Custom non-standard claims",
templateText: "X-User-Role: {{.Claims.role}}, X-User-Permissions: {{.Claims.permissions}}",
data: map[string]interface{}{
"AccessToken": "access-token-value",
"Claims": map[string]interface{}{
"role": "admin",
"permissions": "read:all,write:own",
},
},
expectedValue: "X-User-Role: admin, X-User-Permissions: read:all,write:own",
},
{
name: "Deeply nested custom claims",
templateText: "X-Organization: {{.Claims.app_metadata.organization.name}}, X-Team: {{.Claims.app_metadata.team}}",
data: map[string]interface{}{
"Claims": map[string]interface{}{
"app_metadata": map[string]interface{}{
"organization": map[string]interface{}{
"name": "acme-corp",
},
"team": "platform",
},
},
},
expectedValue: "X-Organization: acme-corp, X-Team: platform",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tmpl, err := template.New("test").Parse(tc.templateText)
require.NoError(t, err)
var buf bytes.Buffer
err = tmpl.Execute(&buf, tc.data)
require.NoError(t, err)
assert.Equal(t, tc.expectedValue, buf.String())
})
}
}
// testTemplateIntegrationWithPlugin tests template processing in the actual plugin
func testTemplateIntegrationWithPlugin(t *testing.T) {
// Test template integration using mock plugin components
// Set up test OIDC server
var testServerURL string
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/.well-known/openid-configuration":
json.NewEncoder(w).Encode(map[string]interface{}{
"issuer": testServerURL,
"authorization_endpoint": testServerURL + "/auth",
"token_endpoint": testServerURL + "/token",
"jwks_uri": testServerURL + "/jwks",
"userinfo_endpoint": testServerURL + "/userinfo",
})
case "/jwks":
json.NewEncoder(w).Encode(map[string]interface{}{
"keys": []interface{}{},
})
default:
http.NotFound(w, r)
}
}))
defer testServer.Close()
testServerURL = testServer.URL
// Create config with templates that reference potentially missing fields
config := &MockConfig{
ProviderURL: testServer.URL,
ClientID: "test-client",
ClientSecret: "test-secret",
CallbackURL: "/callback",
SessionEncryptionKey: "test-encryption-key-32-characters",
Headers: []TemplatedHeader{
{Name: "X-User-Email", Value: "{{.Claims.email}}"},
{Name: "X-User-Role", Value: "{{.Claims.internal_role}}"},
},
}
// Initialize plugin would be done here
ctx := context.Background()
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
// Test would create plugin handler here
_ = ctx
_ = next
_ = config
}
// testTemplateSyntaxValidation tests that template syntax is properly validated
func testTemplateSyntaxValidation(t *testing.T) {
validTemplates := []string{
"{{.Claims.email}}",
"{{.Claims.internal_role}}",
"{{.AccessToken}}",
"{{.IdToken}}",
"{{.RefreshToken}}",
}
for _, tmplStr := range validTemplates {
err := validateTemplateSecure(tmplStr)
assert.NoError(t, err, "Template should be valid: %s", tmplStr)
}
// Test invalid templates
invalidTemplates := []struct {
template string
reason string
}{
{"{{call .SomeFunc}}", "function calls not allowed"},
{"{{range .Items}}{{.}}{{end}}", "range not allowed"},
{"{{with .Data}}{{.Field}}{{end}}", "with statements blocked"},
{"{{index .Array 0}}", "index access blocked"},
{"{{printf \"%s\" .Data}}", "printf blocked"},
}
for _, tc := range invalidTemplates {
err := validateTemplateSecure(tc.template)
assert.Error(t, err, "Template should be invalid: %s (%s)", tc.template, tc.reason)
assert.Contains(t, strings.ToLower(err.Error()), "dangerous")
}
// Test safe custom functions
safeTemplates := []string{
"{{get .Claims \"internal_role\"}}",
"{{default \"guest\" .Claims.role}}",
}
for _, tmplStr := range safeTemplates {
err := validateTemplateSecure(tmplStr)
assert.NoError(t, err, "Safe custom functions should be allowed: %s", tmplStr)
}
}
// Mock validation function for template security
func validateTemplateSecure(templateStr string) error {
// List of potentially dangerous template actions
dangerousFunctions := []string{
"call", "range", "with", "index", "printf", "println", "print",
"js", "html", "urlquery", "base64", "exec",
}
for _, dangerous := range dangerousFunctions {
if strings.Contains(templateStr, dangerous) {
return fmt.Errorf("dangerous template function detected: %s", dangerous)
}
}
// Define safe custom functions
funcMap := template.FuncMap{
"get": func(data map[string]interface{}, key string) interface{} {
return data[key]
},
"default": func(defaultVal interface{}, val interface{}) interface{} {
if val == nil || val == "" {
return defaultVal
}
return val
},
}
// Try to parse the template with custom functions to check for syntax errors
_, err := template.New("test").Funcs(funcMap).Parse(templateStr)
return err
}
// testMissingFieldHandling tests handling of missing fields in templates
func testMissingFieldHandling(t *testing.T) {
testCases := []struct {
name string
templateText string
data map[string]interface{}
expected string
}{
{
name: "missing claim field",
templateText: "{{.Claims.missing}}",
data: map[string]interface{}{
"Claims": map[string]interface{}{},
},
expected: "<no value>",
},
{
name: "missing nested field",
templateText: "{{.Claims.user.missing}}",
data: map[string]interface{}{
"Claims": map[string]interface{}{
"user": map[string]interface{}{},
},
},
expected: "<no value>",
},
{
name: "missing entire path",
templateText: "{{.Missing.Path.Field}}",
data: map[string]interface{}{},
expected: "<no value>",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tmpl, err := template.New("test").Parse(tc.templateText)
require.NoError(t, err)
var buf bytes.Buffer
err = tmpl.Execute(&buf, tc.data)
require.NoError(t, err)
assert.Equal(t, tc.expected, buf.String())
})
}
}
// testComplexTemplateExpressions tests complex template expressions
func testComplexTemplateExpressions(t *testing.T) {
testCases := []struct {
name string
templateText string
data map[string]interface{}
expected string
}{
{
name: "conditional template",
templateText: "{{if .Claims.admin}}Admin User{{else}}Regular User{{end}}",
data: map[string]interface{}{
"Claims": map[string]interface{}{
"admin": true,
},
},
expected: "Admin User",
},
{
name: "multiple claims concatenation",
templateText: "{{.Claims.firstName}} {{.Claims.lastName}} <{{.Claims.email}}>",
data: map[string]interface{}{
"Claims": map[string]interface{}{
"firstName": "John",
"lastName": "Doe",
"email": "john.doe@example.com",
},
},
expected: "John Doe <john.doe@example.com>",
},
{
name: "array access",
templateText: "{{index .Claims.roles 0}}",
data: map[string]interface{}{
"Claims": map[string]interface{}{
"roles": []string{"admin", "user"},
},
},
expected: "admin",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tmpl, err := template.New("test").Parse(tc.templateText)
require.NoError(t, err)
var buf bytes.Buffer
err = tmpl.Execute(&buf, tc.data)
require.NoError(t, err)
assert.Equal(t, tc.expected, buf.String())
})
}
}
// testTraefikConfigurationParsing tests various ways Traefik might pass configuration
func testTraefikConfigurationParsing(t *testing.T) {
testCases := []struct {
name string
config *MockConfig
expectError bool
description string
}{
{
name: "valid configuration with templated headers",
config: &MockConfig{
ProviderURL: "https://accounts.google.com",
ClientID: "test-client",
ClientSecret: "test-secret",
SessionEncryptionKey: "test-encryption-key-32-bytes-long",
CallbackURL: "/oauth2/callback",
Headers: []TemplatedHeader{
{Name: "Authorization", Value: "Bearer {{.AccessToken}}"},
},
},
expectError: false,
description: "Standard configuration should work",
},
{
name: "configuration with multiple headers",
config: &MockConfig{
ProviderURL: "https://accounts.google.com",
ClientID: "test-client",
ClientSecret: "test-secret",
SessionEncryptionKey: "test-encryption-key-32-bytes-long",
CallbackURL: "/oauth2/callback",
Headers: []TemplatedHeader{
{Name: "Authorization", Value: "Bearer {{.AccessToken}}"},
{Name: "X-User-Email", Value: "{{.Claims.email}}"},
{Name: "X-User-ID", Value: "{{.Claims.sub}}"},
},
},
expectError: false,
description: "Multiple headers should work",
},
{
name: "empty headers configuration",
config: &MockConfig{
ProviderURL: "https://accounts.google.com",
ClientID: "test-client",
ClientSecret: "test-secret",
SessionEncryptionKey: "test-encryption-key-32-bytes-long",
CallbackURL: "/oauth2/callback",
Headers: []TemplatedHeader{},
},
expectError: false,
description: "Empty headers should not cause issues",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Create a simple next handler
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
// Try to create the middleware would be done here
ctx := context.Background()
// Test would create middleware handler here
_ = ctx
_ = next
_ = tc.config
// For now, we just validate the configuration is well-formed
if !tc.expectError {
require.NotNil(t, tc.config, tc.description)
require.NotEmpty(t, tc.config.ClientID, tc.description)
}
})
}
}
+2 -1
View File
@@ -6,7 +6,7 @@ require (
github.com/alicebob/miniredis/v2 v2.35.0
github.com/google/uuid v1.6.0
github.com/gorilla/sessions v1.3.0
github.com/redis/go-redis/v9 v9.14.0
github.com/redis/go-redis/v9 v9.17.2
github.com/stretchr/testify v1.10.0
golang.org/x/time v0.14.0
gopkg.in/yaml.v3 v3.0.1
@@ -18,5 +18,6 @@ require (
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/gorilla/securecookie v1.1.2 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/stretchr/objx v0.5.2 // indirect
github.com/yuin/gopher-lua v1.1.1 // indirect
)
+4 -2
View File
@@ -20,8 +20,10 @@ github.com/gorilla/sessions v1.3.0 h1:XYlkq7KcpOB2ZhHBPv5WpjMIxrQosiZanfoy1HLZFz
github.com/gorilla/sessions v1.3.0/go.mod h1:ePLdVu+jbEgHH+KWw8I1z2wqd0BAdAQh/8LRvBeoNcQ=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/redis/go-redis/v9 v9.14.0 h1:u4tNCjXOyzfgeLN+vAZaW1xUooqWDqVEsZN0U01jfAE=
github.com/redis/go-redis/v9 v9.14.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw=
github.com/redis/go-redis/v9 v9.17.2 h1:P2EGsA4qVIM3Pp+aPocCJ7DguDHhqrXNhVcEp4ViluI=
github.com/redis/go-redis/v9 v9.17.2/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M=
-764
View File
@@ -1,764 +0,0 @@
package handlers
import (
"errors"
"net/http"
"sync"
"testing"
"time"
)
// ============================================================================
// OAuth Handler Tests
// ============================================================================
func TestOAuthHandler(t *testing.T) {
t.Run("HandleAuthorizationRequest", func(t *testing.T) {
// Test authorization request handling logic
logger := &MockLogger{}
tests := []struct {
name string
requestURL string
expectedStatus int
checkLocation bool
}{
{
name: "Valid authorization request",
requestURL: "/auth/login",
expectedStatus: http.StatusFound,
checkLocation: true,
},
{
name: "With return URL",
requestURL: "/auth/login?return=/dashboard",
expectedStatus: http.StatusFound,
checkLocation: true,
},
}
// Test the test case structure
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
// Verify test case parameters
if test.requestURL == "" {
t.Error("Request URL should not be empty")
}
if test.expectedStatus == 0 {
t.Error("Expected status should be set")
}
// In a real implementation, this would test the actual handler
t.Logf("Testing %s with URL %s expecting status %d", test.name, test.requestURL, test.expectedStatus)
})
}
// Verify logger doesn't cause issues
logger.Debugf("Authorization request test completed")
})
t.Run("HandleCallbackRequest", func(t *testing.T) {
// Test callback request handling with existing mocks
sessionManager := NewMockSessionManager()
logger := &MockLogger{}
tests := []struct {
name string
queryParams string
expectedStatus int
expectError bool
}{
{
name: "Valid callback with code",
queryParams: "code=test-code&state=test-state",
expectedStatus: http.StatusFound,
expectError: false,
},
{
name: "Callback with error",
queryParams: "error=access_denied&error_description=User denied access",
expectedStatus: http.StatusBadRequest,
expectError: true,
},
{
name: "Missing code",
queryParams: "state=test-state",
expectedStatus: http.StatusBadRequest,
expectError: true,
},
{
name: "Missing state",
queryParams: "code=test-code",
expectedStatus: http.StatusBadRequest,
expectError: true,
},
}
// Test the callback scenarios
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
// Verify test case parameters
if test.queryParams == "" && !test.expectError {
t.Error("Query params should not be empty for successful cases")
}
if test.expectedStatus == 0 {
t.Error("Expected status should be set")
}
// Test session manager functionality
if sessionManager != nil {
t.Logf("Session manager available for test %s", test.name)
}
t.Logf("Testing %s with params %s expecting status %d", test.name, test.queryParams, test.expectedStatus)
})
}
// Verify logger doesn't cause issues
logger.Debugf("Callback request test completed")
})
t.Run("HandleLogout", func(t *testing.T) {
// Test logout functionality with mock implementations
sessionManager := NewMockSessionManager()
logger := &MockLogger{}
// Test session clearing
mockReq := &http.Request{}
session, err := sessionManager.GetSession(mockReq)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
// Set up authenticated session
err = session.SetAuthenticated(true)
if err != nil {
t.Fatalf("Failed to set authentication: %v", err)
}
session.SetIDToken("test-token")
// Verify session is authenticated
if !session.GetAuthenticated() {
t.Error("Session should be authenticated before logout")
}
// Test logout by clearing session
// session.Clear() // Method not implemented in SessionData
// Additional logout verification would go here
// Verify logger doesn't cause issues
logger.Debugf("Logout test completed")
t.Log("Logout test completed successfully")
})
}
// ============================================================================
// Auth Handler Tests
// ============================================================================
func TestAuthHandler(t *testing.T) {
t.Run("HandleAuthentication", func(t *testing.T) {
// Test authentication handling with mock types
// validator := &MockTokenValidator{valid: true} // Currently unused
/*
handler := &MockAuthHandler{
logger: &MockLogger{},
sessionManager: NewMockSessionManager(),
}
*/
tests := []struct {
name string
setupSession func(*MockSession)
expectedStatus int
expectNext bool
}{
{
name: "Authenticated user",
setupSession: func(s *MockSession) {
s.SetAuthenticated(true)
s.SetIDToken("valid-token")
},
expectedStatus: http.StatusOK,
expectNext: true,
},
{
name: "Unauthenticated user",
setupSession: func(s *MockSession) {
s.SetAuthenticated(false)
},
expectedStatus: http.StatusUnauthorized,
expectNext: false,
},
{
name: "Expired token",
setupSession: func(s *MockSession) {
s.SetAuthenticated(true)
s.SetIDToken("expired-token")
},
expectedStatus: http.StatusUnauthorized,
expectNext: false,
},
}
// Test the authentication test cases
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
// Test with mock session
mockSession := &MockSession{values: make(map[string]interface{})}
// Use mock session to avoid unused variable error
_ = mockSession
t.Logf("Testing %s", test.name)
})
}
})
t.Run("HandleRefreshToken", func(t *testing.T) {
// Test authentication handling with mock types
// validator := &MockTokenValidator{valid: true} // Currently unused
tests := []struct {
name string
refreshToken string
mockResponse *MockTokenResponse
mockError error
expectSuccess bool
}{
{
name: "Successful refresh",
refreshToken: "valid-refresh-token",
mockResponse: &MockTokenResponse{
AccessToken: "new-access-token",
IDToken: "new-id-token",
RefreshToken: "new-refresh-token",
},
expectSuccess: true,
},
{
name: "Failed refresh",
refreshToken: "invalid-refresh-token",
mockError: errors.New("invalid_grant"),
expectSuccess: false,
},
{
name: "Empty refresh token",
refreshToken: "",
expectSuccess: false,
},
}
// Test the authentication test cases
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
// Test with mock session
mockSession := &MockSession{values: make(map[string]interface{})}
// Use mock session to avoid unused variable error
_ = mockSession
t.Logf("Testing %s", test.name)
})
}
})
}
// ============================================================================
// Error Handler Tests
// ============================================================================
func TestErrorHandler(t *testing.T) {
t.Run("HandleHTTPErrors", func(t *testing.T) {
// Test with mock implementations
/*
handler := &MockErrorHandler{
logger: &MockLogger{},
}
*/
tests := []struct {
name string
errorCode int
errorMessage string
isAjax bool
expectedStatus int
expectedBody string
}{
{
name: "401 Unauthorized",
errorCode: http.StatusUnauthorized,
errorMessage: "Authentication required",
isAjax: false,
expectedStatus: http.StatusUnauthorized,
expectedBody: "Authentication required",
},
{
name: "403 Forbidden",
errorCode: http.StatusForbidden,
errorMessage: "Access denied",
isAjax: false,
expectedStatus: http.StatusForbidden,
expectedBody: "Access denied",
},
{
name: "500 Internal Server Error",
errorCode: http.StatusInternalServerError,
errorMessage: "Internal server error",
isAjax: false,
expectedStatus: http.StatusInternalServerError,
expectedBody: "Internal server error",
},
{
name: "Ajax 401",
errorCode: http.StatusUnauthorized,
errorMessage: "Token expired",
isAjax: true,
expectedStatus: http.StatusUnauthorized,
expectedBody: `{"error":"unauthorized","message":"Token expired"}`,
},
}
// Test the authentication test cases
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
// Test with mock session
mockSession := &MockSession{values: make(map[string]interface{})}
// Use mock session to avoid unused variable error
_ = mockSession
t.Logf("Testing %s", test.name)
})
}
})
t.Run("RecoverFromPanic", func(t *testing.T) {
// Test with mock implementations
/*
handler := &MockErrorHandler{
logger: &MockLogger{},
}
*/
tests := []struct {
name string
panicValue interface{}
expectError bool
}{
{
name: "String panic",
panicValue: "something went wrong",
expectError: true,
},
{
name: "Error panic",
panicValue: errors.New("critical error"),
expectError: true,
},
{
name: "Nil panic",
panicValue: nil,
expectError: false,
},
}
// Test the authentication test cases
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
// Test with mock session
mockSession := &MockSession{values: make(map[string]interface{})}
// Use mock session to avoid unused variable error
_ = mockSession
t.Logf("Testing %s", test.name)
})
}
})
}
// ============================================================================
// Azure OAuth Callback Tests
// ============================================================================
func TestAzureOAuthCallback(t *testing.T) {
t.Run("AzureSpecificClaims", func(t *testing.T) {
// Test with mock configuration
/*
handler := &OAuthHandler{
logger: &MockLogger{},
sessionManager: NewMockSessionManager(),
}
*/
azureClaims := map[string]interface{}{
"oid": "object-id",
"tid": "tenant-id",
"preferred_username": "user@example.com",
"name": "Test User",
"email": "user@example.com",
"groups": []string{"group1", "group2"},
}
// Test would go here when properly implemented
_ = azureClaims
})
t.Run("AzureTokenValidation", func(t *testing.T) {
// Test with mock validator types
/*
validator := &MockAzureTokenValidator{
tenantID: "test-tenant",
clientID: "test-client",
}
*/
tests := []struct {
name string
token string
claims map[string]interface{}
expectValid bool
}{
{
name: "Valid Azure token",
token: "valid-azure-token",
claims: map[string]interface{}{
"aud": "test-client",
"tid": "test-tenant",
"exp": float64(time.Now().Add(time.Hour).Unix()),
},
expectValid: true,
},
{
name: "Wrong tenant",
token: "wrong-tenant-token",
claims: map[string]interface{}{
"aud": "test-client",
"tid": "wrong-tenant",
"exp": float64(time.Now().Add(time.Hour).Unix()),
},
expectValid: false,
},
{
name: "Wrong audience",
token: "wrong-audience-token",
claims: map[string]interface{}{
"aud": "wrong-client",
"tid": "test-tenant",
"exp": float64(time.Now().Add(time.Hour).Unix()),
},
expectValid: false,
},
}
// Test the authentication test cases
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
// Test with mock session
mockSession := &MockSession{values: make(map[string]interface{})}
// Use mock session to avoid unused variable error
_ = mockSession
t.Logf("Testing %s", test.name)
})
}
})
}
// ============================================================================
// Concurrent Handler Tests
// ============================================================================
func TestConcurrentHandlers(t *testing.T) {
t.Run("ConcurrentCallbacks", func(t *testing.T) {
// Test with mock configuration
/*
handler := &OAuthHandler{
logger: &MockLogger{},
sessionManager: NewMockSessionManager(),
}
*/
var wg sync.WaitGroup
successCount := int32(0)
errorCount := int32(0)
// Test would go here when properly implemented
wg.Wait() // Proper usage instead of assignment
_ = successCount
_ = errorCount
})
t.Run("ConcurrentLogouts", func(t *testing.T) {
// Test with mock configuration
/*
handler := &OAuthHandler{
logger: &MockLogger{},
sessionManager: NewMockSessionManager(),
}
*/
var wg sync.WaitGroup
logoutCount := int32(0)
// Test would go here when properly implemented
wg.Wait() // Proper usage instead of assignment
_ = logoutCount
})
}
// ============================================================================
// Mock Implementations
// ============================================================================
type MockSessionManager struct {
sessions map[string]*MockSession
mu sync.RWMutex
}
func NewMockSessionManager() *MockSessionManager {
return &MockSessionManager{
sessions: make(map[string]*MockSession),
}
}
func (m *MockSessionManager) GetSession(r *http.Request) (SessionData, error) {
m.mu.Lock()
defer m.mu.Unlock()
sessionID := "test-session"
if session, exists := m.sessions[sessionID]; exists {
return session, nil
}
session := &MockSession{
values: make(map[string]interface{}),
}
m.sessions[sessionID] = session
return session, nil
}
type MockSession struct {
values map[string]interface{}
mu sync.RWMutex
}
func (s *MockSession) SetAuthenticated(auth bool) error {
s.mu.Lock()
defer s.mu.Unlock()
s.values["authenticated"] = auth
return nil
}
func (s *MockSession) GetAuthenticated() bool {
s.mu.RLock()
defer s.mu.RUnlock()
auth, ok := s.values["authenticated"].(bool)
return ok && auth
}
func (s *MockSession) SetIDToken(token string) {
s.mu.Lock()
defer s.mu.Unlock()
s.values["id_token"] = token
}
func (s *MockSession) GetIDToken() string {
s.mu.RLock()
defer s.mu.RUnlock()
token, _ := s.values["id_token"].(string)
return token
}
func (s *MockSession) SetAccessToken(token string) {
s.mu.Lock()
defer s.mu.Unlock()
s.values["access_token"] = token
}
func (s *MockSession) GetAccessToken() string {
s.mu.RLock()
defer s.mu.RUnlock()
token, _ := s.values["access_token"].(string)
return token
}
func (s *MockSession) SetRefreshToken(token string) {
s.mu.Lock()
defer s.mu.Unlock()
s.values["refresh_token"] = token
}
func (s *MockSession) GetRefreshToken() string {
s.mu.RLock()
defer s.mu.RUnlock()
token, _ := s.values["refresh_token"].(string)
return token
}
func (s *MockSession) SetState(state string) {
s.mu.Lock()
defer s.mu.Unlock()
s.values["state"] = state
}
func (s *MockSession) GetState() string {
s.mu.RLock()
defer s.mu.RUnlock()
state, _ := s.values["state"].(string)
return state
}
func (s *MockSession) SetClaims(claims map[string]interface{}) {
s.mu.Lock()
defer s.mu.Unlock()
s.values["claims"] = claims
}
func (s *MockSession) GetClaims() map[string]interface{} {
s.mu.RLock()
defer s.mu.RUnlock()
claims, _ := s.values["claims"].(map[string]interface{})
return claims
}
// Additional SessionData interface methods to match real interface
func (s *MockSession) GetCSRF() string {
s.mu.RLock()
defer s.mu.RUnlock()
csrf, _ := s.values["csrf"].(string)
return csrf
}
func (s *MockSession) GetNonce() string {
s.mu.RLock()
defer s.mu.RUnlock()
nonce, _ := s.values["nonce"].(string)
return nonce
}
func (s *MockSession) GetCodeVerifier() string {
s.mu.RLock()
defer s.mu.RUnlock()
verifier, _ := s.values["code_verifier"].(string)
return verifier
}
func (s *MockSession) GetIncomingPath() string {
s.mu.RLock()
defer s.mu.RUnlock()
path, _ := s.values["incoming_path"].(string)
return path
}
func (s *MockSession) SetEmail(email string) {
s.mu.Lock()
defer s.mu.Unlock()
s.values["email"] = email
}
func (s *MockSession) GetEmail() string {
s.mu.RLock()
defer s.mu.RUnlock()
email, _ := s.values["email"].(string)
return email
}
func (s *MockSession) SetCSRF(csrf string) {
s.mu.Lock()
defer s.mu.Unlock()
s.values["csrf"] = csrf
}
func (s *MockSession) SetNonce(nonce string) {
s.mu.Lock()
defer s.mu.Unlock()
s.values["nonce"] = nonce
}
func (s *MockSession) SetCodeVerifier(verifier string) {
s.mu.Lock()
defer s.mu.Unlock()
s.values["code_verifier"] = verifier
}
func (s *MockSession) SetIncomingPath(path string) {
s.mu.Lock()
defer s.mu.Unlock()
s.values["incoming_path"] = path
}
func (s *MockSession) ResetRedirectCount() {
s.mu.Lock()
defer s.mu.Unlock()
s.values["redirect_count"] = 0
}
func (s *MockSession) Save(r *http.Request, w http.ResponseWriter) error {
return nil
}
func (s *MockSession) Clear() {
s.mu.Lock()
defer s.mu.Unlock()
s.values = make(map[string]interface{})
}
func (s *MockSession) returnToPoolSafely() {
// No-op for mock
}
type MockTokenValidator struct {
valid bool
}
func (v *MockTokenValidator) Validate(token string) bool {
if token == "expired-token" {
return false
}
return v.valid
}
// ============================================================================
// Mock Handler Type Definitions (for testing)
// ============================================================================
// These mock handlers are simplified versions for testing purposes
// They don't match the actual handler implementations
type MockAuthHandler struct{}
type MockErrorHandler struct{}
type MockAzureTokenValidator struct {
tenantID string
clientID string
}
func (v *MockAzureTokenValidator) ValidateAzureToken(token string, claims map[string]interface{}) bool {
// Validate tenant ID
if tid, ok := claims["tid"].(string); !ok || tid != v.tenantID {
return false
}
// Validate audience
if aud, ok := claims["aud"].(string); !ok || aud != v.clientID {
return false
}
// Validate expiration
if exp, ok := claims["exp"].(float64); ok {
if time.Now().Unix() > int64(exp) {
return false
}
}
return true
}
// ============================================================================
// Helper Types and Mock Logger
// ============================================================================
type MockLogger struct{}
func (l *MockLogger) Debugf(format string, args ...interface{}) {}
func (l *MockLogger) Errorf(format string, args ...interface{}) {}
func (l *MockLogger) Error(msg string) {}
type MockTokenResponse struct {
AccessToken string `json:"access_token"`
IDToken string `json:"id_token"`
RefreshToken string `json:"refresh_token"`
}
-313
View File
@@ -1,313 +0,0 @@
// Package handlers provides HTTP request handlers for the OIDC middleware.
package handlers
import (
"context"
"fmt"
"net/http"
"strings"
)
// OAuthHandler handles OAuth callback requests
type OAuthHandler struct {
logger Logger
sessionManager SessionManager
tokenExchanger TokenExchanger
tokenVerifier TokenVerifier
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
isAllowedDomainFunc func(email string) bool
redirURLPath string
sendErrorResponseFunc func(rw http.ResponseWriter, req *http.Request, message string, code int)
}
// Logger interface for dependency injection
type Logger interface {
Debugf(format string, args ...interface{})
Errorf(format string, args ...interface{})
Error(msg string)
}
// SessionManager interface for session operations
type SessionManager interface {
GetSession(req *http.Request) (SessionData, error)
}
// SessionData interface for session data operations
type SessionData interface {
GetCSRF() string
GetNonce() string
GetCodeVerifier() string
GetIncomingPath() string
GetAuthenticated() bool
GetAccessToken() string
GetRefreshToken() string
GetIDToken() string
GetEmail() string
SetAuthenticated(bool) error
SetEmail(string)
SetIDToken(string)
SetAccessToken(string)
SetRefreshToken(string)
SetCSRF(string)
SetNonce(string)
SetCodeVerifier(string)
SetIncomingPath(string)
ResetRedirectCount()
Save(req *http.Request, rw http.ResponseWriter) error
returnToPoolSafely()
}
// TokenExchanger interface for token operations
type TokenExchanger interface {
ExchangeCodeForToken(ctx context.Context, grantType string, codeOrToken string, redirectURL string, codeVerifier string) (*TokenResponse, error)
}
// TokenVerifier interface for token verification
type TokenVerifier interface {
VerifyToken(token string) error
}
// TokenResponse represents the response from token exchange
type TokenResponse struct {
IDToken string
AccessToken string
RefreshToken string
}
// NewOAuthHandler creates a new OAuth handler
func NewOAuthHandler(logger Logger, sessionManager SessionManager, tokenExchanger TokenExchanger,
tokenVerifier TokenVerifier, extractClaimsFunc func(string) (map[string]interface{}, error),
isAllowedDomainFunc func(string) bool, redirURLPath string,
sendErrorResponseFunc func(http.ResponseWriter, *http.Request, string, int)) *OAuthHandler {
return &OAuthHandler{
logger: logger,
sessionManager: sessionManager,
tokenExchanger: tokenExchanger,
tokenVerifier: tokenVerifier,
extractClaimsFunc: extractClaimsFunc,
isAllowedDomainFunc: isAllowedDomainFunc,
redirURLPath: redirURLPath,
sendErrorResponseFunc: sendErrorResponseFunc,
}
}
// HandleCallback handles OAuth callback requests
func (h *OAuthHandler) HandleCallback(rw http.ResponseWriter, req *http.Request, redirectURL string) {
session, err := h.sessionManager.GetSession(req)
if err != nil {
h.logger.Errorf("Session error during callback: %v", err)
h.sendErrorResponseFunc(rw, req, "Session error during callback", http.StatusInternalServerError)
return
}
defer session.returnToPoolSafely()
h.logger.Debugf("Handling callback, URL: %s", req.URL.String())
// Debug logging for cookie configuration
h.logger.Debugf("Callback request headers - Host: %s, X-Forwarded-Host: %s, X-Forwarded-Proto: %s",
req.Host, req.Header.Get("X-Forwarded-Host"), req.Header.Get("X-Forwarded-Proto"))
// Log all cookies in the request for debugging
cookies := req.Cookies()
h.logger.Debugf("Total cookies in callback request: %d", len(cookies))
for _, cookie := range cookies {
if strings.HasPrefix(cookie.Name, "_oidc_") {
h.logger.Debugf("Cookie found - Name: %s, Domain: %s, Path: %s, SameSite: %v, Secure: %v, HttpOnly: %v, Value length: %d",
cookie.Name, cookie.Domain, cookie.Path, cookie.SameSite, cookie.Secure, cookie.HttpOnly, len(cookie.Value))
}
}
if req.URL.Query().Get("error") != "" {
errorDescription := req.URL.Query().Get("error_description")
if errorDescription == "" {
errorDescription = req.URL.Query().Get("error")
}
h.logger.Errorf("Authentication error from provider during callback: %s - %s", req.URL.Query().Get("error"), errorDescription)
h.sendErrorResponseFunc(rw, req, fmt.Sprintf("Authentication error from provider: %s", errorDescription), http.StatusBadRequest)
return
}
state := req.URL.Query().Get("state")
if state == "" {
h.logger.Error("No state in callback")
h.sendErrorResponseFunc(rw, req, "State parameter missing in callback", http.StatusBadRequest)
return
}
// Debug log the state parameter received
h.logger.Debugf("State parameter received in callback: %s (length: %d)", state, len(state))
csrfToken := session.GetCSRF()
if csrfToken == "" {
h.logger.Errorf("CSRF token missing in session during callback. Authenticated: %v, Request URL: %s",
session.GetAuthenticated(), req.URL.String())
// Enhanced debugging for missing CSRF token
cookie, err := req.Cookie("_oidc_raczylo_m")
if err != nil {
h.logger.Errorf("Main session cookie not found in request: %v", err)
// Log cookie names only, not values (avoid logging sensitive session data)
cookieNames := make([]string, 0, len(req.Cookies()))
for _, c := range req.Cookies() {
cookieNames = append(cookieNames, c.Name)
}
h.logger.Debugf("Available cookies (names only): %v", cookieNames)
} else {
h.logger.Errorf("Main session cookie exists but CSRF token is empty. Cookie value length: %d", len(cookie.Value))
h.logger.Debugf("Cookie details - Domain: %s, Path: %s, Secure: %v, HttpOnly: %v, SameSite: %v",
cookie.Domain, cookie.Path, cookie.Secure, cookie.HttpOnly, cookie.SameSite)
}
// Log session state for debugging
h.logger.Debugf("Session state during CSRF check - Authenticated: %v, Has AccessToken: %v",
session.GetAuthenticated(), session.GetAccessToken() != "")
h.sendErrorResponseFunc(rw, req, "CSRF token missing in session", http.StatusBadRequest)
return
}
// Debug log successful CSRF token retrieval
h.logger.Debugf("CSRF token retrieved from session: %s (length: %d)", csrfToken, len(csrfToken))
if state != csrfToken {
h.logger.Error("State parameter does not match CSRF token in session during callback")
h.sendErrorResponseFunc(rw, req, "Invalid state parameter (CSRF mismatch)", http.StatusBadRequest)
return
}
code := req.URL.Query().Get("code")
if code == "" {
h.logger.Error("No code in callback")
h.sendErrorResponseFunc(rw, req, "No authorization code received in callback", http.StatusBadRequest)
return
}
codeVerifier := session.GetCodeVerifier()
tokenResponse, err := h.tokenExchanger.ExchangeCodeForToken(req.Context(), "authorization_code", code, redirectURL, codeVerifier)
if err != nil {
h.logger.Errorf("Failed to exchange code for token during callback: %v", err)
h.sendErrorResponseFunc(rw, req, "Authentication failed: Could not exchange code for token", http.StatusInternalServerError)
return
}
if err = h.tokenVerifier.VerifyToken(tokenResponse.IDToken); err != nil {
h.logger.Errorf("Failed to verify id_token during callback: %v", err)
h.sendErrorResponseFunc(rw, req, "Authentication failed: Could not verify ID token", http.StatusInternalServerError)
return
}
claims, err := h.extractClaimsFunc(tokenResponse.IDToken)
if err != nil {
h.logger.Errorf("Failed to extract claims during callback: %v", err)
h.sendErrorResponseFunc(rw, req, "Authentication failed: Could not extract claims from token", http.StatusInternalServerError)
return
}
nonceClaim, ok := claims["nonce"].(string)
if !ok || nonceClaim == "" {
h.logger.Error("Nonce claim missing in id_token during callback")
h.sendErrorResponseFunc(rw, req, "Authentication failed: Nonce missing in token", http.StatusInternalServerError)
return
}
sessionNonce := session.GetNonce()
if sessionNonce == "" {
h.logger.Error("Nonce not found in session during callback")
h.sendErrorResponseFunc(rw, req, "Authentication failed: Nonce missing in session", http.StatusInternalServerError)
return
}
if nonceClaim != sessionNonce {
h.logger.Error("Nonce claim does not match session nonce during callback")
h.sendErrorResponseFunc(rw, req, "Authentication failed: Nonce mismatch", http.StatusInternalServerError)
return
}
email, _ := claims["email"].(string)
if email == "" {
h.logger.Errorf("Email claim missing or empty in token during callback")
h.sendErrorResponseFunc(rw, req, "Authentication failed: Email missing in token", http.StatusInternalServerError)
return
}
if !h.isAllowedDomainFunc(email) {
h.logger.Errorf("Disallowed email domain during callback: %s", email)
h.sendErrorResponseFunc(rw, req, "Authentication failed: Email domain not allowed", http.StatusForbidden)
return
}
if err := session.SetAuthenticated(true); err != nil {
h.logger.Errorf("Failed to set authenticated state and regenerate session ID: %v", err)
h.sendErrorResponseFunc(rw, req, "Failed to update session", http.StatusInternalServerError)
return
}
session.SetEmail(email)
session.SetIDToken(tokenResponse.IDToken)
session.SetAccessToken(tokenResponse.AccessToken)
session.SetRefreshToken(tokenResponse.RefreshToken)
session.SetCSRF("")
session.SetNonce("")
session.SetCodeVerifier("")
session.ResetRedirectCount()
redirectPath := "/"
if incomingPath := session.GetIncomingPath(); incomingPath != "" && incomingPath != h.redirURLPath {
redirectPath = incomingPath
}
session.SetIncomingPath("")
if err := session.Save(req, rw); err != nil {
h.logger.Errorf("Failed to save session after callback: %v", err)
h.sendErrorResponseFunc(rw, req, "Failed to save session after callback", http.StatusInternalServerError)
return
}
h.logger.Debugf("Callback successful, redirecting to %s", redirectPath)
http.Redirect(rw, req, redirectPath, http.StatusFound)
}
// URLHelper provides utility methods for URL operations
type URLHelper struct {
logger Logger
}
// NewURLHelper creates a new URL helper
func NewURLHelper(logger Logger) *URLHelper {
return &URLHelper{logger: logger}
}
// DetermineExcludedURL checks if a URL path should bypass OIDC authentication.
// It compares the request path against configured excluded URL prefixes.
func (h *URLHelper) DetermineExcludedURL(currentRequest string, excludedURLs map[string]struct{}) bool {
for excludedURL := range excludedURLs {
if strings.HasPrefix(currentRequest, excludedURL) {
h.logger.Debugf("URL is excluded - got %s / excluded hit: %s", currentRequest, excludedURL)
return true
}
}
return false
}
// DetermineScheme determines the URL scheme for building redirect URLs.
// It checks X-Forwarded-Proto header first, then TLS presence.
func (h *URLHelper) DetermineScheme(req *http.Request) string {
if scheme := req.Header.Get("X-Forwarded-Proto"); scheme != "" {
return scheme
}
if req.TLS != nil {
return "https"
}
return "http"
}
// DetermineHost determines the host for building redirect URLs.
// It checks X-Forwarded-Host header first, then falls back to req.Host.
func (h *URLHelper) DetermineHost(req *http.Request) string {
if host := req.Header.Get("X-Forwarded-Host"); host != "" {
return host
}
return req.Host
}
-899
View File
@@ -1,899 +0,0 @@
package handlers
import (
"context"
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
// Test mocks - implementing interfaces defined in oauth_handler.go
type mockLogger struct {
debugMessages []string
errorMessages []string
}
func (l *mockLogger) Debugf(format string, args ...interface{}) {
l.debugMessages = append(l.debugMessages, format)
}
func (l *mockLogger) Errorf(format string, args ...interface{}) {
l.errorMessages = append(l.errorMessages, format)
}
func (l *mockLogger) Error(msg string) {
l.errorMessages = append(l.errorMessages, msg)
}
type mockSessionManager struct {
sessionToReturn SessionData
errorToReturn error
}
func (m *mockSessionManager) GetSession(req *http.Request) (SessionData, error) {
return m.sessionToReturn, m.errorToReturn
}
type mockSessionData struct {
authenticated bool
email string
csrf string
nonce string
codeVerifier string
incomingPath string
accessToken string
refreshToken string
idToken string
saveError error
setAuthError error
}
func (s *mockSessionData) GetCSRF() string { return s.csrf }
func (s *mockSessionData) GetNonce() string { return s.nonce }
func (s *mockSessionData) GetCodeVerifier() string { return s.codeVerifier }
func (s *mockSessionData) GetIncomingPath() string { return s.incomingPath }
func (s *mockSessionData) GetAuthenticated() bool { return s.authenticated }
func (s *mockSessionData) GetAccessToken() string { return s.accessToken }
func (s *mockSessionData) GetRefreshToken() string { return s.refreshToken }
func (s *mockSessionData) GetIDToken() string { return s.idToken }
func (s *mockSessionData) GetEmail() string { return s.email }
func (s *mockSessionData) SetAuthenticated(auth bool) error {
s.authenticated = auth
return s.setAuthError
}
func (s *mockSessionData) SetEmail(email string) { s.email = email }
func (s *mockSessionData) SetIDToken(token string) { s.idToken = token }
func (s *mockSessionData) SetAccessToken(token string) { s.accessToken = token }
func (s *mockSessionData) SetRefreshToken(token string) { s.refreshToken = token }
func (s *mockSessionData) SetCSRF(csrf string) { s.csrf = csrf }
func (s *mockSessionData) SetNonce(nonce string) { s.nonce = nonce }
func (s *mockSessionData) SetCodeVerifier(verif string) { s.codeVerifier = verif }
func (s *mockSessionData) SetIncomingPath(path string) { s.incomingPath = path }
func (s *mockSessionData) ResetRedirectCount() {}
func (s *mockSessionData) returnToPoolSafely() {}
func (s *mockSessionData) Save(req *http.Request, rw http.ResponseWriter) error {
return s.saveError
}
type mockTokenExchanger struct {
response *TokenResponse
err error
}
func (e *mockTokenExchanger) ExchangeCodeForToken(ctx context.Context, grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error) {
return e.response, e.err
}
type mockTokenVerifier struct {
err error
}
func (v *mockTokenVerifier) VerifyToken(token string) error {
return v.err
}
// TestOAuthHandler_NewOAuthHandler tests the constructor
func TestOAuthHandler_NewOAuthHandler(t *testing.T) {
logger := &mockLogger{}
sessionManager := &mockSessionManager{}
tokenExchanger := &mockTokenExchanger{}
tokenVerifier := &mockTokenVerifier{}
extractClaims := func(token string) (map[string]interface{}, error) {
return map[string]interface{}{"email": "test@example.com", "nonce": "test-nonce"}, nil
}
isAllowed := func(email string) bool { return true }
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {}
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
extractClaims, isAllowed, "/callback", sendError)
if handler == nil {
t.Fatal("Expected handler to be created, got nil")
}
if handler.logger != logger {
t.Error("Logger not set correctly")
}
if handler.redirURLPath != "/callback" {
t.Errorf("Expected redirURLPath '/callback', got '%s'", handler.redirURLPath)
}
}
// TestOAuthHandler_HandleCallback_SessionError tests session retrieval errors
func TestOAuthHandler_HandleCallback_SessionError(t *testing.T) {
logger := &mockLogger{}
sessionManager := &mockSessionManager{errorToReturn: errors.New("session error")}
tokenExchanger := &mockTokenExchanger{}
tokenVerifier := &mockTokenVerifier{}
extractClaims := func(token string) (map[string]interface{}, error) {
return nil, nil
}
isAllowed := func(email string) bool { return true }
errorSent := false
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
errorSent = true
if code != http.StatusInternalServerError {
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
}
if !strings.Contains(msg, "Session error") {
t.Errorf("Expected error message to contain 'Session error', got '%s'", msg)
}
}
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
extractClaims, isAllowed, "/callback", sendError)
req := httptest.NewRequest("GET", "/callback?code=test&state=test", nil)
rw := httptest.NewRecorder()
handler.HandleCallback(rw, req, "http://example.com/callback")
if !errorSent {
t.Error("Expected error response to be sent")
}
if len(logger.errorMessages) == 0 {
t.Error("Expected error to be logged")
}
}
// TestOAuthHandler_HandleCallback_ProviderError tests OAuth provider errors
func TestOAuthHandler_HandleCallback_ProviderError(t *testing.T) {
logger := &mockLogger{}
session := &mockSessionData{csrf: "test-state", nonce: "test-nonce"}
sessionManager := &mockSessionManager{sessionToReturn: session}
tokenExchanger := &mockTokenExchanger{}
tokenVerifier := &mockTokenVerifier{}
extractClaims := func(token string) (map[string]interface{}, error) { return nil, nil }
isAllowed := func(email string) bool { return true }
errorSent := false
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
errorSent = true
if code != http.StatusBadRequest {
t.Errorf("Expected status %d, got %d", http.StatusBadRequest, code)
}
if !strings.Contains(msg, "Authentication error from provider") {
t.Errorf("Expected error message to contain 'Authentication error from provider', got '%s'", msg)
}
}
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
extractClaims, isAllowed, "/callback", sendError)
// Test with error parameter
req := httptest.NewRequest("GET", "/callback?error=access_denied&error_description=User%20denied%20access", nil)
rw := httptest.NewRecorder()
handler.HandleCallback(rw, req, "http://example.com/callback")
if !errorSent {
t.Error("Expected error response to be sent")
}
if len(logger.errorMessages) == 0 {
t.Error("Expected error to be logged")
}
}
// TestOAuthHandler_HandleCallback_MissingState tests missing state parameter
func TestOAuthHandler_HandleCallback_MissingState(t *testing.T) {
logger := &mockLogger{}
session := &mockSessionData{}
sessionManager := &mockSessionManager{sessionToReturn: session}
tokenExchanger := &mockTokenExchanger{}
tokenVerifier := &mockTokenVerifier{}
extractClaims := func(token string) (map[string]interface{}, error) { return nil, nil }
isAllowed := func(email string) bool { return true }
errorSent := false
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
errorSent = true
if code != http.StatusBadRequest {
t.Errorf("Expected status %d, got %d", http.StatusBadRequest, code)
}
if !strings.Contains(msg, "State parameter missing") {
t.Errorf("Expected error message to contain 'State parameter missing', got '%s'", msg)
}
}
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
extractClaims, isAllowed, "/callback", sendError)
req := httptest.NewRequest("GET", "/callback?code=test", nil)
rw := httptest.NewRecorder()
handler.HandleCallback(rw, req, "http://example.com/callback")
if !errorSent {
t.Error("Expected error response to be sent")
}
}
// TestOAuthHandler_HandleCallback_MissingCSRF tests missing CSRF token in session
func TestOAuthHandler_HandleCallback_MissingCSRF(t *testing.T) {
logger := &mockLogger{}
session := &mockSessionData{csrf: ""} // Empty CSRF
sessionManager := &mockSessionManager{sessionToReturn: session}
tokenExchanger := &mockTokenExchanger{}
tokenVerifier := &mockTokenVerifier{}
extractClaims := func(token string) (map[string]interface{}, error) { return nil, nil }
isAllowed := func(email string) bool { return true }
errorSent := false
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
errorSent = true
if code != http.StatusBadRequest {
t.Errorf("Expected status %d, got %d", http.StatusBadRequest, code)
}
if !strings.Contains(msg, "CSRF token missing") {
t.Errorf("Expected error message to contain 'CSRF token missing', got '%s'", msg)
}
}
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
extractClaims, isAllowed, "/callback", sendError)
req := httptest.NewRequest("GET", "/callback?code=test&state=test-state", nil)
rw := httptest.NewRecorder()
handler.HandleCallback(rw, req, "http://example.com/callback")
if !errorSent {
t.Error("Expected error response to be sent")
}
}
// TestOAuthHandler_HandleCallback_CSRFMismatch tests CSRF token mismatch
func TestOAuthHandler_HandleCallback_CSRFMismatch(t *testing.T) {
logger := &mockLogger{}
session := &mockSessionData{csrf: "different-token"}
sessionManager := &mockSessionManager{sessionToReturn: session}
tokenExchanger := &mockTokenExchanger{}
tokenVerifier := &mockTokenVerifier{}
extractClaims := func(token string) (map[string]interface{}, error) { return nil, nil }
isAllowed := func(email string) bool { return true }
errorSent := false
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
errorSent = true
if code != http.StatusBadRequest {
t.Errorf("Expected status %d, got %d", http.StatusBadRequest, code)
}
if !strings.Contains(msg, "CSRF mismatch") {
t.Errorf("Expected error message to contain 'CSRF mismatch', got '%s'", msg)
}
}
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
extractClaims, isAllowed, "/callback", sendError)
req := httptest.NewRequest("GET", "/callback?code=test&state=test-state", nil)
rw := httptest.NewRecorder()
handler.HandleCallback(rw, req, "http://example.com/callback")
if !errorSent {
t.Error("Expected error response to be sent")
}
}
// TestOAuthHandler_HandleCallback_MissingCode tests missing authorization code
func TestOAuthHandler_HandleCallback_MissingCode(t *testing.T) {
logger := &mockLogger{}
session := &mockSessionData{csrf: "test-state"}
sessionManager := &mockSessionManager{sessionToReturn: session}
tokenExchanger := &mockTokenExchanger{}
tokenVerifier := &mockTokenVerifier{}
extractClaims := func(token string) (map[string]interface{}, error) { return nil, nil }
isAllowed := func(email string) bool { return true }
errorSent := false
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
errorSent = true
if code != http.StatusBadRequest {
t.Errorf("Expected status %d, got %d", http.StatusBadRequest, code)
}
if !strings.Contains(msg, "No authorization code received") {
t.Errorf("Expected error message to contain 'No authorization code received', got '%s'", msg)
}
}
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
extractClaims, isAllowed, "/callback", sendError)
req := httptest.NewRequest("GET", "/callback?state=test-state", nil)
rw := httptest.NewRecorder()
handler.HandleCallback(rw, req, "http://example.com/callback")
if !errorSent {
t.Error("Expected error response to be sent")
}
}
// TestOAuthHandler_HandleCallback_TokenExchangeError tests token exchange failure
func TestOAuthHandler_HandleCallback_TokenExchangeError(t *testing.T) {
logger := &mockLogger{}
session := &mockSessionData{csrf: "test-state", nonce: "test-nonce", codeVerifier: "test-verifier"}
sessionManager := &mockSessionManager{sessionToReturn: session}
tokenExchanger := &mockTokenExchanger{err: errors.New("token exchange failed")}
tokenVerifier := &mockTokenVerifier{}
extractClaims := func(token string) (map[string]interface{}, error) { return nil, nil }
isAllowed := func(email string) bool { return true }
errorSent := false
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
errorSent = true
if code != http.StatusInternalServerError {
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
}
if !strings.Contains(msg, "Could not exchange code for token") {
t.Errorf("Expected error message to contain 'Could not exchange code for token', got '%s'", msg)
}
}
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
extractClaims, isAllowed, "/callback", sendError)
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
rw := httptest.NewRecorder()
handler.HandleCallback(rw, req, "http://example.com/callback")
if !errorSent {
t.Error("Expected error response to be sent")
}
}
// TestOAuthHandler_HandleCallback_TokenVerificationError tests token verification failure
func TestOAuthHandler_HandleCallback_TokenVerificationError(t *testing.T) {
logger := &mockLogger{}
session := &mockSessionData{csrf: "test-state", nonce: "test-nonce"}
sessionManager := &mockSessionManager{sessionToReturn: session}
tokenResponse := &TokenResponse{IDToken: "invalid-token", AccessToken: "access-token"}
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
tokenVerifier := &mockTokenVerifier{err: errors.New("token verification failed")}
extractClaims := func(token string) (map[string]interface{}, error) { return nil, nil }
isAllowed := func(email string) bool { return true }
errorSent := false
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
errorSent = true
if code != http.StatusInternalServerError {
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
}
if !strings.Contains(msg, "Could not verify ID token") {
t.Errorf("Expected error message to contain 'Could not verify ID token', got '%s'", msg)
}
}
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
extractClaims, isAllowed, "/callback", sendError)
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
rw := httptest.NewRecorder()
handler.HandleCallback(rw, req, "http://example.com/callback")
if !errorSent {
t.Error("Expected error response to be sent")
}
}
// TestOAuthHandler_HandleCallback_ClaimsExtractionError tests claims extraction failure
func TestOAuthHandler_HandleCallback_ClaimsExtractionError(t *testing.T) {
logger := &mockLogger{}
session := &mockSessionData{csrf: "test-state", nonce: "test-nonce"}
sessionManager := &mockSessionManager{sessionToReturn: session}
tokenResponse := &TokenResponse{IDToken: "valid-token", AccessToken: "access-token"}
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
tokenVerifier := &mockTokenVerifier{}
extractClaims := func(token string) (map[string]interface{}, error) {
return nil, errors.New("claims extraction failed")
}
isAllowed := func(email string) bool { return true }
errorSent := false
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
errorSent = true
if code != http.StatusInternalServerError {
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
}
if !strings.Contains(msg, "Could not extract claims") {
t.Errorf("Expected error message to contain 'Could not extract claims', got '%s'", msg)
}
}
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
extractClaims, isAllowed, "/callback", sendError)
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
rw := httptest.NewRecorder()
handler.HandleCallback(rw, req, "http://example.com/callback")
if !errorSent {
t.Error("Expected error response to be sent")
}
}
// TestOAuthHandler_HandleCallback_MissingNonceInToken tests missing nonce in token
func TestOAuthHandler_HandleCallback_MissingNonceInToken(t *testing.T) {
logger := &mockLogger{}
session := &mockSessionData{csrf: "test-state", nonce: "test-nonce"}
sessionManager := &mockSessionManager{sessionToReturn: session}
tokenResponse := &TokenResponse{IDToken: "valid-token", AccessToken: "access-token"}
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
tokenVerifier := &mockTokenVerifier{}
// Claims without nonce
extractClaims := func(token string) (map[string]interface{}, error) {
return map[string]interface{}{"email": "test@example.com"}, nil
}
isAllowed := func(email string) bool { return true }
errorSent := false
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
errorSent = true
if code != http.StatusInternalServerError {
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
}
if !strings.Contains(msg, "Nonce missing in token") {
t.Errorf("Expected error message to contain 'Nonce missing in token', got '%s'", msg)
}
}
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
extractClaims, isAllowed, "/callback", sendError)
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
rw := httptest.NewRecorder()
handler.HandleCallback(rw, req, "http://example.com/callback")
if !errorSent {
t.Error("Expected error response to be sent")
}
}
// TestOAuthHandler_HandleCallback_MissingNonceInSession tests missing nonce in session
func TestOAuthHandler_HandleCallback_MissingNonceInSession(t *testing.T) {
logger := &mockLogger{}
session := &mockSessionData{csrf: "test-state", nonce: ""} // Empty nonce
sessionManager := &mockSessionManager{sessionToReturn: session}
tokenResponse := &TokenResponse{IDToken: "valid-token", AccessToken: "access-token"}
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
tokenVerifier := &mockTokenVerifier{}
extractClaims := func(token string) (map[string]interface{}, error) {
return map[string]interface{}{"email": "test@example.com", "nonce": "test-nonce"}, nil
}
isAllowed := func(email string) bool { return true }
errorSent := false
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
errorSent = true
if code != http.StatusInternalServerError {
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
}
if !strings.Contains(msg, "Nonce missing in session") {
t.Errorf("Expected error message to contain 'Nonce missing in session', got '%s'", msg)
}
}
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
extractClaims, isAllowed, "/callback", sendError)
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
rw := httptest.NewRecorder()
handler.HandleCallback(rw, req, "http://example.com/callback")
if !errorSent {
t.Error("Expected error response to be sent")
}
}
// TestOAuthHandler_HandleCallback_NonceMismatch tests nonce mismatch
func TestOAuthHandler_HandleCallback_NonceMismatch(t *testing.T) {
logger := &mockLogger{}
session := &mockSessionData{csrf: "test-state", nonce: "session-nonce"}
sessionManager := &mockSessionManager{sessionToReturn: session}
tokenResponse := &TokenResponse{IDToken: "valid-token", AccessToken: "access-token"}
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
tokenVerifier := &mockTokenVerifier{}
extractClaims := func(token string) (map[string]interface{}, error) {
return map[string]interface{}{"email": "test@example.com", "nonce": "token-nonce"}, nil
}
isAllowed := func(email string) bool { return true }
errorSent := false
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
errorSent = true
if code != http.StatusInternalServerError {
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
}
if !strings.Contains(msg, "Nonce mismatch") {
t.Errorf("Expected error message to contain 'Nonce mismatch', got '%s'", msg)
}
}
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
extractClaims, isAllowed, "/callback", sendError)
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
rw := httptest.NewRecorder()
handler.HandleCallback(rw, req, "http://example.com/callback")
if !errorSent {
t.Error("Expected error response to be sent")
}
}
// TestOAuthHandler_HandleCallback_MissingEmail tests missing email in claims
func TestOAuthHandler_HandleCallback_MissingEmail(t *testing.T) {
logger := &mockLogger{}
session := &mockSessionData{csrf: "test-state", nonce: "test-nonce"}
sessionManager := &mockSessionManager{sessionToReturn: session}
tokenResponse := &TokenResponse{IDToken: "valid-token", AccessToken: "access-token"}
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
tokenVerifier := &mockTokenVerifier{}
extractClaims := func(token string) (map[string]interface{}, error) {
return map[string]interface{}{"nonce": "test-nonce"}, nil // No email
}
isAllowed := func(email string) bool { return true }
errorSent := false
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
errorSent = true
if code != http.StatusInternalServerError {
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
}
if !strings.Contains(msg, "Email missing in token") {
t.Errorf("Expected error message to contain 'Email missing in token', got '%s'", msg)
}
}
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
extractClaims, isAllowed, "/callback", sendError)
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
rw := httptest.NewRecorder()
handler.HandleCallback(rw, req, "http://example.com/callback")
if !errorSent {
t.Error("Expected error response to be sent")
}
}
// TestOAuthHandler_HandleCallback_DisallowedDomain tests disallowed email domain
func TestOAuthHandler_HandleCallback_DisallowedDomain(t *testing.T) {
logger := &mockLogger{}
session := &mockSessionData{csrf: "test-state", nonce: "test-nonce"}
sessionManager := &mockSessionManager{sessionToReturn: session}
tokenResponse := &TokenResponse{IDToken: "valid-token", AccessToken: "access-token"}
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
tokenVerifier := &mockTokenVerifier{}
extractClaims := func(token string) (map[string]interface{}, error) {
return map[string]interface{}{"email": "test@disallowed.com", "nonce": "test-nonce"}, nil
}
isAllowed := func(email string) bool { return false } // Disallow all domains
errorSent := false
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
errorSent = true
if code != http.StatusForbidden {
t.Errorf("Expected status %d, got %d", http.StatusForbidden, code)
}
if !strings.Contains(msg, "Email domain not allowed") {
t.Errorf("Expected error message to contain 'Email domain not allowed', got '%s'", msg)
}
}
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
extractClaims, isAllowed, "/callback", sendError)
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
rw := httptest.NewRecorder()
handler.HandleCallback(rw, req, "http://example.com/callback")
if !errorSent {
t.Error("Expected error response to be sent")
}
}
// TestOAuthHandler_HandleCallback_SessionSaveError tests session save failure
func TestOAuthHandler_HandleCallback_SessionSaveError(t *testing.T) {
logger := &mockLogger{}
session := &mockSessionData{
csrf: "test-state",
nonce: "test-nonce",
saveError: errors.New("save failed"),
}
sessionManager := &mockSessionManager{sessionToReturn: session}
tokenResponse := &TokenResponse{IDToken: "valid-token", AccessToken: "access-token", RefreshToken: "refresh-token"}
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
tokenVerifier := &mockTokenVerifier{}
extractClaims := func(token string) (map[string]interface{}, error) {
return map[string]interface{}{"email": "test@example.com", "nonce": "test-nonce"}, nil
}
isAllowed := func(email string) bool { return true }
errorSent := false
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
errorSent = true
if code != http.StatusInternalServerError {
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
}
if !strings.Contains(msg, "Failed to save session") {
t.Errorf("Expected error message to contain 'Failed to save session', got '%s'", msg)
}
}
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
extractClaims, isAllowed, "/callback", sendError)
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
rw := httptest.NewRecorder()
handler.HandleCallback(rw, req, "http://example.com/callback")
if !errorSent {
t.Error("Expected error response to be sent")
}
}
// TestOAuthHandler_HandleCallback_SetAuthenticatedError tests SetAuthenticated failure
func TestOAuthHandler_HandleCallback_SetAuthenticatedError(t *testing.T) {
logger := &mockLogger{}
session := &mockSessionData{
csrf: "test-state",
nonce: "test-nonce",
setAuthError: errors.New("set auth failed"),
}
sessionManager := &mockSessionManager{sessionToReturn: session}
tokenResponse := &TokenResponse{IDToken: "valid-token", AccessToken: "access-token"}
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
tokenVerifier := &mockTokenVerifier{}
extractClaims := func(token string) (map[string]interface{}, error) {
return map[string]interface{}{"email": "test@example.com", "nonce": "test-nonce"}, nil
}
isAllowed := func(email string) bool { return true }
errorSent := false
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
errorSent = true
if code != http.StatusInternalServerError {
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
}
if !strings.Contains(msg, "Failed to update session") {
t.Errorf("Expected error message to contain 'Failed to update session', got '%s'", msg)
}
}
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
extractClaims, isAllowed, "/callback", sendError)
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
rw := httptest.NewRecorder()
handler.HandleCallback(rw, req, "http://example.com/callback")
if !errorSent {
t.Error("Expected error response to be sent")
}
}
// TestOAuthHandler_HandleCallback_Success tests successful callback handling
func TestOAuthHandler_HandleCallback_Success(t *testing.T) {
logger := &mockLogger{}
session := &mockSessionData{
csrf: "test-state",
nonce: "test-nonce",
incomingPath: "/dashboard",
}
sessionManager := &mockSessionManager{sessionToReturn: session}
tokenResponse := &TokenResponse{
IDToken: "valid-id-token",
AccessToken: "valid-access-token",
RefreshToken: "valid-refresh-token",
}
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
tokenVerifier := &mockTokenVerifier{}
extractClaims := func(token string) (map[string]interface{}, error) {
return map[string]interface{}{"email": "test@example.com", "nonce": "test-nonce"}, nil
}
isAllowed := func(email string) bool { return true }
errorSent := false
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
errorSent = true
t.Errorf("Unexpected error sent: %s (code: %d)", msg, code)
}
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
extractClaims, isAllowed, "/callback", sendError)
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
rw := httptest.NewRecorder()
handler.HandleCallback(rw, req, "http://example.com/callback")
if errorSent {
t.Error("Unexpected error response sent")
}
// Check redirect
if rw.Code != http.StatusFound {
t.Errorf("Expected status %d, got %d", http.StatusFound, rw.Code)
}
location := rw.Header().Get("Location")
if location != "/dashboard" {
t.Errorf("Expected redirect to '/dashboard', got '%s'", location)
}
// Verify session data was set correctly
if session.email != "test@example.com" {
t.Errorf("Expected email 'test@example.com', got '%s'", session.email)
}
if session.idToken != "valid-id-token" {
t.Errorf("Expected ID token 'valid-id-token', got '%s'", session.idToken)
}
if session.accessToken != "valid-access-token" {
t.Errorf("Expected access token 'valid-access-token', got '%s'", session.accessToken)
}
if session.refreshToken != "valid-refresh-token" {
t.Errorf("Expected refresh token 'valid-refresh-token', got '%s'", session.refreshToken)
}
if !session.authenticated {
t.Error("Expected session to be authenticated")
}
// Check that temporary fields are cleared
if session.csrf != "" {
t.Errorf("Expected CSRF to be cleared, got '%s'", session.csrf)
}
if session.nonce != "" {
t.Errorf("Expected nonce to be cleared, got '%s'", session.nonce)
}
if session.codeVerifier != "" {
t.Errorf("Expected code verifier to be cleared, got '%s'", session.codeVerifier)
}
if session.incomingPath != "" {
t.Errorf("Expected incoming path to be cleared, got '%s'", session.incomingPath)
}
}
// TestOAuthHandler_HandleCallback_SuccessDefaultRedirect tests successful callback with default redirect
func TestOAuthHandler_HandleCallback_SuccessDefaultRedirect(t *testing.T) {
logger := &mockLogger{}
session := &mockSessionData{
csrf: "test-state",
nonce: "test-nonce",
incomingPath: "", // No incoming path, should default to "/"
}
sessionManager := &mockSessionManager{sessionToReturn: session}
tokenResponse := &TokenResponse{IDToken: "valid-token", AccessToken: "access-token"}
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
tokenVerifier := &mockTokenVerifier{}
extractClaims := func(token string) (map[string]interface{}, error) {
return map[string]interface{}{"email": "test@example.com", "nonce": "test-nonce"}, nil
}
isAllowed := func(email string) bool { return true }
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
t.Errorf("Unexpected error sent: %s (code: %d)", msg, code)
}
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
extractClaims, isAllowed, "/callback", sendError)
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
rw := httptest.NewRecorder()
handler.HandleCallback(rw, req, "http://example.com/callback")
// Check redirect to default path
if rw.Code != http.StatusFound {
t.Errorf("Expected status %d, got %d", http.StatusFound, rw.Code)
}
location := rw.Header().Get("Location")
if location != "/" {
t.Errorf("Expected redirect to '/', got '%s'", location)
}
}
// TestOAuthHandler_HandleCallback_RedirectURLPathExcluded tests incoming path same as redirect URL
func TestOAuthHandler_HandleCallback_RedirectURLPathExcluded(t *testing.T) {
logger := &mockLogger{}
session := &mockSessionData{
csrf: "test-state",
nonce: "test-nonce",
incomingPath: "/callback", // Same as redirect URL path
}
sessionManager := &mockSessionManager{sessionToReturn: session}
tokenResponse := &TokenResponse{IDToken: "valid-token", AccessToken: "access-token"}
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
tokenVerifier := &mockTokenVerifier{}
extractClaims := func(token string) (map[string]interface{}, error) {
return map[string]interface{}{"email": "test@example.com", "nonce": "test-nonce"}, nil
}
isAllowed := func(email string) bool { return true }
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
t.Errorf("Unexpected error sent: %s (code: %d)", msg, code)
}
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
extractClaims, isAllowed, "/callback", sendError)
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
rw := httptest.NewRecorder()
handler.HandleCallback(rw, req, "http://example.com/callback")
// Should redirect to default path when incoming path is same as callback path
location := rw.Header().Get("Location")
if location != "/" {
t.Errorf("Expected redirect to '/', got '%s'", location)
}
}
-454
View File
@@ -1,454 +0,0 @@
package handlers
import (
"crypto/tls"
"net/http"
"testing"
)
// TestURLHelper_NewURLHelper tests the URLHelper constructor
func TestURLHelper_NewURLHelper(t *testing.T) {
logger := &mockLogger{}
helper := NewURLHelper(logger)
if helper == nil {
t.Fatal("Expected URLHelper to be created, got nil")
}
if helper.logger != logger {
t.Error("Logger not set correctly")
}
}
// TestURLHelper_DetermineExcludedURL tests URL exclusion checking
func TestURLHelper_DetermineExcludedURL(t *testing.T) {
logger := &mockLogger{}
helper := NewURLHelper(logger)
tests := []struct {
name string
currentURL string
excludedURLs map[string]struct{}
expected bool
}{
{
name: "Exact match",
currentURL: "/health",
excludedURLs: map[string]struct{}{
"/health": {},
},
expected: true,
},
{
name: "Prefix match",
currentURL: "/health/status",
excludedURLs: map[string]struct{}{
"/health": {},
},
expected: true,
},
{
name: "No match",
currentURL: "/api/users",
excludedURLs: map[string]struct{}{
"/health": {},
},
expected: false,
},
{
name: "Multiple exclusions - first match",
currentURL: "/api/health",
excludedURLs: map[string]struct{}{
"/api": {},
"/health": {},
},
expected: true,
},
{
name: "Multiple exclusions - second match",
currentURL: "/health/check",
excludedURLs: map[string]struct{}{
"/api": {},
"/health": {},
},
expected: true,
},
{
name: "Empty excluded URLs",
currentURL: "/api/users",
excludedURLs: map[string]struct{}{},
expected: false,
},
{
name: "Root path exclusion",
currentURL: "/anything",
excludedURLs: map[string]struct{}{
"/": {},
},
expected: true,
},
{
name: "Case sensitive matching",
currentURL: "/API/users",
excludedURLs: map[string]struct{}{
"/api": {},
},
expected: false,
},
{
name: "Partial substring but not prefix",
currentURL: "/user/api/test",
excludedURLs: map[string]struct{}{
"/api": {},
},
expected: false,
},
{
name: "Empty current URL",
currentURL: "",
excludedURLs: map[string]struct{}{
"/health": {},
},
expected: false,
},
{
name: "URL with query parameters",
currentURL: "/health?status=ok",
excludedURLs: map[string]struct{}{
"/health": {},
},
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := helper.DetermineExcludedURL(tt.currentURL, tt.excludedURLs)
if result != tt.expected {
t.Errorf("DetermineExcludedURL() = %v, expected %v", result, tt.expected)
}
// Verify debug logging for excluded URLs
if result && len(logger.debugMessages) > 0 {
// Should have logged a debug message for excluded URL
found := false
for _, msg := range logger.debugMessages {
if msg == "URL is excluded - got %s / excluded hit: %s" {
found = true
break
}
}
if !found {
t.Error("Expected debug message for excluded URL")
}
}
// Reset logger messages for next test
logger.debugMessages = nil
})
}
}
// TestURLHelper_DetermineScheme tests scheme determination
func TestURLHelper_DetermineScheme(t *testing.T) {
logger := &mockLogger{}
helper := NewURLHelper(logger)
tests := []struct {
name string
setupRequest func() *http.Request
expectedScheme string
}{
{
name: "X-Forwarded-Proto header present - https",
setupRequest: func() *http.Request {
req, _ := http.NewRequest("GET", "http://example.com", nil)
req.Header.Set("X-Forwarded-Proto", "https")
return req
},
expectedScheme: "https",
},
{
name: "X-Forwarded-Proto header present - http",
setupRequest: func() *http.Request {
req, _ := http.NewRequest("GET", "http://example.com", nil)
req.Header.Set("X-Forwarded-Proto", "http")
return req
},
expectedScheme: "http",
},
{
name: "TLS connection without X-Forwarded-Proto",
setupRequest: func() *http.Request {
req, _ := http.NewRequest("GET", "https://example.com", nil)
req.TLS = &tls.ConnectionState{} // Simulate TLS connection
return req
},
expectedScheme: "https",
},
{
name: "No TLS and no X-Forwarded-Proto",
setupRequest: func() *http.Request {
req, _ := http.NewRequest("GET", "http://example.com", nil)
return req
},
expectedScheme: "http",
},
{
name: "X-Forwarded-Proto takes precedence over TLS",
setupRequest: func() *http.Request {
req, _ := http.NewRequest("GET", "https://example.com", nil)
req.TLS = &tls.ConnectionState{} // Simulate TLS connection
req.Header.Set("X-Forwarded-Proto", "http")
return req
},
expectedScheme: "http",
},
{
name: "Empty X-Forwarded-Proto falls back to TLS",
setupRequest: func() *http.Request {
req, _ := http.NewRequest("GET", "https://example.com", nil)
req.TLS = &tls.ConnectionState{} // Simulate TLS connection
req.Header.Set("X-Forwarded-Proto", "")
return req
},
expectedScheme: "https",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := tt.setupRequest()
result := helper.DetermineScheme(req)
if result != tt.expectedScheme {
t.Errorf("DetermineScheme() = %v, expected %v", result, tt.expectedScheme)
}
})
}
}
// TestURLHelper_DetermineHost tests host determination
func TestURLHelper_DetermineHost(t *testing.T) {
logger := &mockLogger{}
helper := NewURLHelper(logger)
tests := []struct {
name string
setupRequest func() *http.Request
expectedHost string
}{
{
name: "X-Forwarded-Host header present",
setupRequest: func() *http.Request {
req, _ := http.NewRequest("GET", "http://example.com", nil)
req.Host = "internal.example.com"
req.Header.Set("X-Forwarded-Host", "public.example.com")
return req
},
expectedHost: "public.example.com",
},
{
name: "No X-Forwarded-Host, use req.Host",
setupRequest: func() *http.Request {
req, _ := http.NewRequest("GET", "http://example.com", nil)
req.Host = "direct.example.com"
return req
},
expectedHost: "direct.example.com",
},
{
name: "Empty X-Forwarded-Host falls back to req.Host",
setupRequest: func() *http.Request {
req, _ := http.NewRequest("GET", "http://example.com", nil)
req.Host = "fallback.example.com"
req.Header.Set("X-Forwarded-Host", "")
return req
},
expectedHost: "fallback.example.com",
},
{
name: "X-Forwarded-Host with port",
setupRequest: func() *http.Request {
req, _ := http.NewRequest("GET", "http://example.com", nil)
req.Host = "internal.example.com:8080"
req.Header.Set("X-Forwarded-Host", "public.example.com:443")
return req
},
expectedHost: "public.example.com:443",
},
{
name: "req.Host with port",
setupRequest: func() *http.Request {
req, _ := http.NewRequest("GET", "http://example.com:8080", nil)
req.Host = "example.com:8080"
return req
},
expectedHost: "example.com:8080",
},
{
name: "Multiple X-Forwarded-Host values (first one used)",
setupRequest: func() *http.Request {
req, _ := http.NewRequest("GET", "http://example.com", nil)
req.Host = "internal.example.com"
req.Header.Set("X-Forwarded-Host", "first.example.com, second.example.com")
return req
},
expectedHost: "first.example.com, second.example.com", // Header value as-is
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := tt.setupRequest()
result := helper.DetermineHost(req)
if result != tt.expectedHost {
t.Errorf("DetermineHost() = %v, expected %v", result, tt.expectedHost)
}
})
}
}
// TestURLHelper_DetermineSchemeAndHost_Integration tests scheme and host working together
func TestURLHelper_DetermineSchemeAndHost_Integration(t *testing.T) {
logger := &mockLogger{}
helper := NewURLHelper(logger)
tests := []struct {
name string
setupRequest func() *http.Request
expectedScheme string
expectedHost string
}{
{
name: "Both headers present",
setupRequest: func() *http.Request {
req, _ := http.NewRequest("GET", "http://internal.example.com", nil)
req.Host = "internal.example.com"
req.Header.Set("X-Forwarded-Proto", "https")
req.Header.Set("X-Forwarded-Host", "public.example.com")
return req
},
expectedScheme: "https",
expectedHost: "public.example.com",
},
{
name: "Neither header present, TLS connection",
setupRequest: func() *http.Request {
req, _ := http.NewRequest("GET", "https://secure.example.com", nil)
req.Host = "secure.example.com"
req.TLS = &tls.ConnectionState{} // Simulate TLS connection
return req
},
expectedScheme: "https",
expectedHost: "secure.example.com",
},
{
name: "Neither header present, no TLS",
setupRequest: func() *http.Request {
req, _ := http.NewRequest("GET", "http://plain.example.com", nil)
req.Host = "plain.example.com"
return req
},
expectedScheme: "http",
expectedHost: "plain.example.com",
},
{
name: "Mixed - only scheme header",
setupRequest: func() *http.Request {
req, _ := http.NewRequest("GET", "http://mixed.example.com", nil)
req.Host = "mixed.example.com"
req.Header.Set("X-Forwarded-Proto", "https")
return req
},
expectedScheme: "https",
expectedHost: "mixed.example.com",
},
{
name: "Mixed - only host header",
setupRequest: func() *http.Request {
req, _ := http.NewRequest("GET", "http://mixed.example.com", nil)
req.Host = "internal.example.com"
req.Header.Set("X-Forwarded-Host", "external.example.com")
return req
},
expectedScheme: "http",
expectedHost: "external.example.com",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := tt.setupRequest()
scheme := helper.DetermineScheme(req)
host := helper.DetermineHost(req)
if scheme != tt.expectedScheme {
t.Errorf("DetermineScheme() = %v, expected %v", scheme, tt.expectedScheme)
}
if host != tt.expectedHost {
t.Errorf("DetermineHost() = %v, expected %v", host, tt.expectedHost)
}
// Test that we can build a complete URL
fullURL := scheme + "://" + host + "/callback"
expectedURL := tt.expectedScheme + "://" + tt.expectedHost + "/callback"
if fullURL != expectedURL {
t.Errorf("Combined URL = %v, expected %v", fullURL, expectedURL)
}
})
}
}
// Benchmark tests to ensure the helper methods are performant
func BenchmarkURLHelper_DetermineExcludedURL(b *testing.B) {
logger := &mockLogger{}
helper := NewURLHelper(logger)
excludedURLs := map[string]struct{}{
"/health": {},
"/metrics": {},
"/status": {},
"/api/v1": {},
"/api/v2": {},
"/static": {},
"/assets": {},
"/favicon": {},
"/robots": {},
"/sitemap": {},
}
testURL := "/api/users"
b.ResetTimer()
for i := 0; i < b.N; i++ {
helper.DetermineExcludedURL(testURL, excludedURLs)
}
}
func BenchmarkURLHelper_DetermineScheme(b *testing.B) {
logger := &mockLogger{}
helper := NewURLHelper(logger)
req, _ := http.NewRequest("GET", "http://example.com", nil)
req.Header.Set("X-Forwarded-Proto", "https")
b.ResetTimer()
for i := 0; i < b.N; i++ {
helper.DetermineScheme(req)
}
}
func BenchmarkURLHelper_DetermineHost(b *testing.B) {
logger := &mockLogger{}
helper := NewURLHelper(logger)
req, _ := http.NewRequest("GET", "http://example.com", nil)
req.Host = "internal.example.com"
req.Header.Set("X-Forwarded-Host", "external.example.com")
b.ResetTimer()
for i := 0; i < b.N; i++ {
helper.DetermineHost(req)
}
}
+4 -2
View File
@@ -13,6 +13,8 @@ import (
"net/url"
"strings"
"time"
"github.com/lukaszraczylo/traefikoidc/internal/utils"
)
// generateNonce creates a cryptographically secure random nonce for OIDC flows.
@@ -349,8 +351,8 @@ func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
return
}
host := t.determineHost(req)
scheme := t.determineScheme(req)
host := utils.DetermineHost(req)
scheme := utils.DetermineScheme(req, t.forceHTTPS)
baseURL := fmt.Sprintf("%s://%s", scheme, host)
postLogoutRedirectURI := t.postLogoutRedirectURI
+52 -47
View File
@@ -15,20 +15,21 @@ import (
// XSS, path traversal, and other injection attacks. It validates and sanitizes
// various input types used in OIDC authentication flows.
type InputValidator struct {
usernameRegex *regexp.Regexp
tokenRegex *regexp.Regexp
logger *Logger
urlRegex *regexp.Regexp
emailRegex *regexp.Regexp
sqlInjectionPatterns []string
pathTraversalPatterns []string
xssPatterns []string
maxUsernameLength int
maxURLLength int
maxTokenLength int
maxEmailLength int
maxClaimLength int
maxHeaderLength int
usernameRegex *regexp.Regexp
tokenRegex *regexp.Regexp
logger *Logger
urlRegex *regexp.Regexp
emailRegex *regexp.Regexp
sqlInjectionPatterns []string
pathTraversalPatterns []string
xssPatterns []string
maxUsernameLength int
maxURLLength int
maxTokenLength int
maxEmailLength int
maxClaimLength int
maxHeaderLength int
allowPrivateIPAddresses bool // Allow private IP addresses in URL validation
}
// ValidationResult encapsulates the outcome of input validation.
@@ -46,13 +47,14 @@ type ValidationResult struct {
// It specifies maximum lengths for various input types and controls whether
// strict validation mode is enabled.
type InputValidationConfig struct {
MaxTokenLength int `json:"max_token_length"`
MaxURLLength int `json:"max_url_length"`
MaxHeaderLength int `json:"max_header_length"`
MaxClaimLength int `json:"max_claim_length"`
MaxEmailLength int `json:"max_email_length"`
MaxUsernameLength int `json:"max_username_length"`
StrictMode bool `json:"strict_mode"`
MaxTokenLength int `json:"max_token_length"`
MaxURLLength int `json:"max_url_length"`
MaxHeaderLength int `json:"max_header_length"`
MaxClaimLength int `json:"max_claim_length"`
MaxEmailLength int `json:"max_email_length"`
MaxUsernameLength int `json:"max_username_length"`
StrictMode bool `json:"strict_mode"`
AllowPrivateIPAddresses bool `json:"allow_private_ip_addresses"` // Allow private IP addresses in URL validation
}
// DefaultInputValidationConfig returns a secure default configuration
@@ -103,16 +105,17 @@ func NewInputValidator(config InputValidationConfig, logger *Logger) (*InputVali
}
return &InputValidator{
maxTokenLength: config.MaxTokenLength,
maxURLLength: config.MaxURLLength,
maxHeaderLength: config.MaxHeaderLength,
maxClaimLength: config.MaxClaimLength,
maxEmailLength: config.MaxEmailLength,
maxUsernameLength: config.MaxUsernameLength,
emailRegex: emailRegex,
urlRegex: urlRegex,
tokenRegex: tokenRegex,
usernameRegex: usernameRegex,
maxTokenLength: config.MaxTokenLength,
maxURLLength: config.MaxURLLength,
maxHeaderLength: config.MaxHeaderLength,
maxClaimLength: config.MaxClaimLength,
maxEmailLength: config.MaxEmailLength,
maxUsernameLength: config.MaxUsernameLength,
allowPrivateIPAddresses: config.AllowPrivateIPAddresses,
emailRegex: emailRegex,
urlRegex: urlRegex,
tokenRegex: tokenRegex,
usernameRegex: usernameRegex,
sqlInjectionPatterns: []string{
"'", "\"", ";", "--", "/*", "*/", "xp_", "sp_",
"union", "select", "insert", "update", "delete", "drop",
@@ -335,24 +338,26 @@ func (iv *InputValidator) ValidateURL(urlStr string) ValidationResult {
}
}
// Check for private IP ranges (RFC 1918)
if strings.HasPrefix(hostname, "10.") ||
strings.HasPrefix(hostname, "192.168.") ||
strings.HasPrefix(hostname, "172.") {
// For 172.x check if it's in the 172.16.0.0/12 range
if strings.HasPrefix(hostname, "172.") {
parts := strings.Split(hostname, ".")
if len(parts) >= 2 {
if second, err := strconv.Atoi(parts[1]); err == nil && second >= 16 && second <= 31 {
result.IsValid = false
result.Errors = append(result.Errors, "private IP URLs are not allowed for security")
return result
// Check for private IP ranges (RFC 1918) - skip if allowPrivateIPAddresses is enabled
if !iv.allowPrivateIPAddresses {
if strings.HasPrefix(hostname, "10.") ||
strings.HasPrefix(hostname, "192.168.") ||
strings.HasPrefix(hostname, "172.") {
// For 172.x check if it's in the 172.16.0.0/12 range
if strings.HasPrefix(hostname, "172.") {
parts := strings.Split(hostname, ".")
if len(parts) >= 2 {
if second, err := strconv.Atoi(parts[1]); err == nil && second >= 16 && second <= 31 {
result.IsValid = false
result.Errors = append(result.Errors, "private IP URLs are not allowed for security")
return result
}
}
} else {
result.IsValid = false
result.Errors = append(result.Errors, "private IP URLs are not allowed for security")
return result
}
} else {
result.IsValid = false
result.Errors = append(result.Errors, "private IP URLs are not allowed for security")
return result
}
}
+2 -2
View File
@@ -76,7 +76,7 @@ func NewRedisBackend(config *Config) (*RedisBackend, error) {
// Test connectivity
if err := backend.Ping(context.Background()); err != nil {
pool.Close()
_ = pool.Close()
return nil, fmt.Errorf("failed to ping Redis: %w", err)
}
@@ -263,7 +263,7 @@ func (r *RedisBackend) Clear(ctx context.Context) error {
if err != nil {
continue
}
conn.Do("DEL", key) // Best effort, ignore errors
_, _ = conn.Do("DEL", key) // Best effort, ignore errors
}
return nil
+23 -22
View File
@@ -82,7 +82,7 @@ func (p *ConnectionPool) Get(ctx context.Context) (*RedisConn, error) {
// Reuse existing connection - validate if health check enabled
if p.config.EnableHealthCheck && !p.isConnectionHealthy(conn) {
// Connection is stale, close it and try again
conn.Close()
_ = conn.Close()
p.totalConns.Add(-1)
continue
}
@@ -94,6 +94,7 @@ func (p *ConnectionPool) Get(ctx context.Context) (*RedisConn, error) {
default:
// No available connection, create new one if under limit
// #nosec G115 -- MaxConnections is a small config value that fits in int32
if p.totalConns.Load() < int32(p.config.MaxConnections) {
conn, err = p.createConnection()
if err != nil {
@@ -115,7 +116,7 @@ func (p *ConnectionPool) Get(ctx context.Context) (*RedisConn, error) {
case conn = <-p.connections:
// Validate connection if health check enabled
if p.config.EnableHealthCheck && !p.isConnectionHealthy(conn) {
conn.Close()
_ = conn.Close()
p.totalConns.Add(-1)
continue
}
@@ -144,7 +145,7 @@ func (p *ConnectionPool) Put(conn *RedisConn) {
p.activeConns.Add(-1)
if p.closed.Load() || conn.closed.Load() {
conn.Close()
_ = conn.Close()
p.totalConns.Add(-1)
return
}
@@ -155,7 +156,7 @@ func (p *ConnectionPool) Put(conn *RedisConn) {
// Successfully returned to pool
default:
// Pool full, close connection
conn.Close()
_ = conn.Close()
p.totalConns.Add(-1)
}
}
@@ -173,7 +174,7 @@ func (p *ConnectionPool) Close() error {
// Close all pooled connections
for conn := range p.connections {
conn.Close()
_ = conn.Close()
}
return nil
@@ -212,7 +213,7 @@ func (p *ConnectionPool) createConnection() (*RedisConn, error) {
// Authenticate if password is provided
if p.config.Password != "" {
if _, err := redisConn.Do("AUTH", p.config.Password); err != nil {
redisConn.Close()
_ = redisConn.Close()
return nil, fmt.Errorf("authentication failed: %w", err)
}
}
@@ -220,7 +221,7 @@ func (p *ConnectionPool) createConnection() (*RedisConn, error) {
// Select database
if p.config.DB != 0 {
if _, err := redisConn.Do("SELECT", fmt.Sprintf("%d", p.config.DB)); err != nil {
redisConn.Close()
_ = redisConn.Close()
return nil, fmt.Errorf("failed to select database: %w", err)
}
}
@@ -246,15 +247,15 @@ func (c *RedisConn) Do(command string, args ...string) (interface{}, error) {
c.mu.Lock()
defer c.mu.Unlock()
// Build command arguments
// Check for overflow: ensure len(args)+1 doesn't cause allocation overflow
// Limit to a safe value that prevents integer overflow in allocation size calculation
// (capacity * sizeof(string) must fit in int/size_t)
argsLen := len(args)
const maxSafeArgs = (1 << 20) - 1 // 1M args is already absurdly large for Redis commands
if argsLen < 0 || argsLen > maxSafeArgs {
return nil, errors.New("too many arguments")
// Validate argument count to prevent integer overflow in slice operations
// maxSafeArgs is set to (1<<20)-1 = 1,048,575 which is more than any reasonable Redis command
const maxSafeArgs = (1 << 20) - 1
if len(args) > maxSafeArgs {
return nil, errors.New("too many arguments: exceeds maximum safe count")
}
// Build command arguments
// Validate total argument size to prevent memory exhaustion
const maxTotalArgBytes = 64 << 20 // 64 MiB max total size
totalBytes := len(command)
for _, s := range args {
@@ -267,13 +268,13 @@ func (c *RedisConn) Do(command string, args ...string) (interface{}, error) {
return nil, errors.New("total argument size exceeds maximum allowed")
}
}
cmdArgs := make([]string, 0, argsLen+1)
cmdArgs = append(cmdArgs, command)
cmdArgs = append(cmdArgs, args...)
// Build command slice: prepend command to args
// Using append avoids arithmetic on potentially large len(args)
cmdArgs := append([]string{command}, args...)
// Set write timeout
if c.writeTimeout > 0 {
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
_ = c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
}
// Write command (using pooled writer for memory efficiency)
@@ -287,7 +288,7 @@ func (c *RedisConn) Do(command string, args ...string) (interface{}, error) {
// Set read timeout
if c.readTimeout > 0 {
c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
_ = c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
}
// Read response (using pooled reader for memory efficiency)
@@ -328,8 +329,8 @@ func (p *ConnectionPool) isConnectionHealthy(conn *RedisConn) bool {
// Set a read deadline for the ping
if conn.conn != nil {
conn.conn.SetReadDeadline(time.Now().Add(1 * time.Second))
defer conn.conn.SetReadDeadline(time.Time{}) // Clear deadline
_ = conn.conn.SetReadDeadline(time.Now().Add(1 * time.Second))
defer func() { _ = conn.conn.SetReadDeadline(time.Time{}) }() // Clear deadline
}
_, err := conn.Do("PING")
+3
View File
@@ -158,6 +158,7 @@ func (cb *CircuitBreaker) AllowRequest() bool {
case StateHalfOpen:
// Allow limited requests in half-open state
current := cb.halfOpenRequests.Add(1)
// #nosec G115 -- HalfOpenMaxRequests is a small config value that fits in int32
return current <= int32(cb.config.HalfOpenMaxRequests)
default:
@@ -181,6 +182,7 @@ func (cb *CircuitBreaker) RecordSuccess() {
case StateHalfOpen:
// If we've had enough successful requests, close the circuit
successfulRequests := cb.halfOpenRequests.Load()
// #nosec G115 -- HalfOpenMaxRequests is a small config value that fits in int32
if successfulRequests >= int32(cb.config.HalfOpenMaxRequests) {
cb.setState(StateClosed)
cb.consecutiveFailures.Store(0)
@@ -203,6 +205,7 @@ func (cb *CircuitBreaker) RecordFailure() {
switch state {
case StateClosed:
// Check if we should open the circuit
// #nosec G115 -- MaxFailures is a small config value that fits in int32
if failures >= int32(cb.config.MaxFailures) {
cb.openCircuit()
} else if cb.config.FailureThreshold > 0 {
+2
View File
@@ -217,6 +217,7 @@ func (hc *HealthChecker) recordSuccess(latency time.Duration) {
newStatus := currentStatus
// Check if we should become healthy
// #nosec G115 -- HealthyThreshold is a small config value that fits in int32
if successes >= int32(hc.config.HealthyThreshold) {
if latency > hc.config.DegradedThreshold {
newStatus = HealthDegraded
@@ -241,6 +242,7 @@ func (hc *HealthChecker) recordFailure() {
hc.timeMu.Unlock()
// Check if we should become unhealthy
// #nosec G115 -- UnhealthyThreshold is a small config value that fits in int32
if failures >= int32(hc.config.UnhealthyThreshold) {
hc.setStatus(HealthUnhealthy)
}
+1
View File
@@ -150,6 +150,7 @@ func (h *HealthCheckBackend) IsHealthy() bool {
// recordResult records the result of an operation for health tracking
func (h *HealthCheckBackend) recordResult(success bool) {
// #nosec G115 -- threshold config values are small integers that fit in int32
if success {
fails := h.consecutiveFails.Swap(0)
oks := h.consecutiveOK.Add(1)
-219
View File
@@ -1,219 +0,0 @@
// Package errors provides unified error handling for OIDC operations
package errors
import (
"fmt"
"net/http"
)
// ErrorCode represents specific error types
type ErrorCode string
const (
// Authentication errors
ErrCodeAuthenticationFailed ErrorCode = "AUTH_FAILED"
ErrCodeTokenExpired ErrorCode = "TOKEN_EXPIRED"
ErrCodeTokenInvalid ErrorCode = "TOKEN_INVALID"
ErrCodeSessionExpired ErrorCode = "SESSION_EXPIRED"
ErrCodeCSRFMismatch ErrorCode = "CSRF_MISMATCH"
ErrCodeNonceMismatch ErrorCode = "NONCE_MISMATCH"
// Configuration errors
ErrCodeConfigInvalid ErrorCode = "CONFIG_INVALID"
ErrCodeProviderUnreachable ErrorCode = "PROVIDER_UNREACHABLE"
ErrCodeMetadataFailed ErrorCode = "METADATA_FAILED"
// Network errors
ErrCodeNetworkTimeout ErrorCode = "NETWORK_TIMEOUT"
ErrCodeRateLimited ErrorCode = "RATE_LIMITED"
ErrCodeServiceUnavailable ErrorCode = "SERVICE_UNAVAILABLE"
// Validation errors
ErrCodeValidationFailed ErrorCode = "VALIDATION_FAILED"
ErrCodeDomainNotAllowed ErrorCode = "DOMAIN_NOT_ALLOWED"
ErrCodeUserNotAllowed ErrorCode = "USER_NOT_ALLOWED"
ErrCodeRoleNotAllowed ErrorCode = "ROLE_NOT_ALLOWED"
)
// OIDCError represents a structured error with context
type OIDCError struct {
Code ErrorCode `json:"code"`
Message string `json:"message"`
Details string `json:"details,omitempty"`
HTTPStatus int `json:"http_status"`
Internal error `json:"-"` // Internal error, not exposed
}
// Error implements the error interface
func (e *OIDCError) Error() string {
if e.Details != "" {
return fmt.Sprintf("%s: %s (%s)", e.Code, e.Message, e.Details)
}
return fmt.Sprintf("%s: %s", e.Code, e.Message)
}
// Unwrap returns the internal error for error wrapping
func (e *OIDCError) Unwrap() error {
return e.Internal
}
// IsRetryable indicates if the error is temporary and can be retried
func (e *OIDCError) IsRetryable() bool {
return e.Code == ErrCodeNetworkTimeout ||
e.Code == ErrCodeServiceUnavailable ||
e.Code == ErrCodeProviderUnreachable
}
// IsAuthenticationError indicates if this is an authentication-related error
func (e *OIDCError) IsAuthenticationError() bool {
return e.Code == ErrCodeAuthenticationFailed ||
e.Code == ErrCodeTokenExpired ||
e.Code == ErrCodeTokenInvalid ||
e.Code == ErrCodeSessionExpired ||
e.Code == ErrCodeCSRFMismatch ||
e.Code == ErrCodeNonceMismatch
}
// IsAuthorizationError indicates if this is an authorization-related error
func (e *OIDCError) IsAuthorizationError() bool {
return e.Code == ErrCodeDomainNotAllowed ||
e.Code == ErrCodeUserNotAllowed ||
e.Code == ErrCodeRoleNotAllowed
}
// ToJSON converts the error to a JSON response
func (e *OIDCError) ToJSON() map[string]any {
result := map[string]any{
"error": map[string]any{
"code": string(e.Code),
"message": e.Message,
},
}
if e.Details != "" {
errorMap, _ := result["error"].(map[string]any) // Safe to ignore: type assertion from known type
errorMap["details"] = e.Details
}
return result
}
// Error constructors for common scenarios
// NewAuthenticationError creates an authentication-related error
func NewAuthenticationError(code ErrorCode, message string, internal error) *OIDCError {
status := http.StatusUnauthorized
if code == ErrCodeSessionExpired {
status = http.StatusForbidden
}
return &OIDCError{
Code: code,
Message: message,
HTTPStatus: status,
Internal: internal,
}
}
// NewAuthorizationError creates an authorization-related error
func NewAuthorizationError(code ErrorCode, message string, details string) *OIDCError {
return &OIDCError{
Code: code,
Message: message,
Details: details,
HTTPStatus: http.StatusForbidden,
}
}
// NewConfigurationError creates a configuration-related error
func NewConfigurationError(code ErrorCode, message string, internal error) *OIDCError {
return &OIDCError{
Code: code,
Message: message,
HTTPStatus: http.StatusInternalServerError,
Internal: internal,
}
}
// NewNetworkError creates a network-related error
func NewNetworkError(code ErrorCode, message string, internal error) *OIDCError {
status := http.StatusServiceUnavailable
if code == ErrCodeRateLimited {
status = http.StatusTooManyRequests
}
return &OIDCError{
Code: code,
Message: message,
HTTPStatus: status,
Internal: internal,
}
}
// NewValidationError creates a validation-related error
func NewValidationError(code ErrorCode, message string, details string) *OIDCError {
return &OIDCError{
Code: code,
Message: message,
Details: details,
HTTPStatus: http.StatusBadRequest,
}
}
// Convenience functions for common error patterns
// WrapAuthenticationError wraps an existing error as an authentication error
func WrapAuthenticationError(err error, message string) *OIDCError {
return NewAuthenticationError(ErrCodeAuthenticationFailed, message, err)
}
// WrapTokenError wraps a token-related error
func WrapTokenError(err error, tokenType string) *OIDCError {
message := fmt.Sprintf("Token validation failed: %s", tokenType)
return NewAuthenticationError(ErrCodeTokenInvalid, message, err)
}
// WrapProviderError wraps a provider communication error
func WrapProviderError(err error, providerURL string) *OIDCError {
message := fmt.Sprintf("Provider communication failed: %s", providerURL)
return NewNetworkError(ErrCodeProviderUnreachable, message, err)
}
// IsOIDCError checks if an error is an OIDCError
func IsOIDCError(err error) (*OIDCError, bool) {
oidcErr, ok := err.(*OIDCError)
return oidcErr, ok
}
// GetHTTPStatus extracts HTTP status from error, defaulting to 500
func GetHTTPStatus(err error) int {
if oidcErr, ok := IsOIDCError(err); ok {
return oidcErr.HTTPStatus
}
return http.StatusInternalServerError
}
// FormatUserMessage creates a user-friendly error message
func FormatUserMessage(err error) string {
if oidcErr, ok := IsOIDCError(err); ok {
switch oidcErr.Code {
case ErrCodeDomainNotAllowed:
return "Your email domain is not authorized for this application"
case ErrCodeUserNotAllowed:
return "Your account is not authorized for this application"
case ErrCodeRoleNotAllowed:
return "You do not have the required permissions for this application"
case ErrCodeSessionExpired:
return "Your session has expired. Please log in again"
case ErrCodeTokenExpired:
return "Your authentication has expired. Please log in again"
case ErrCodeProviderUnreachable:
return "Authentication service is temporarily unavailable. Please try again later"
case ErrCodeRateLimited:
return "Too many requests. Please wait a moment and try again"
default:
return "Authentication failed. Please try again"
}
}
return "An unexpected error occurred. Please try again"
}
-529
View File
@@ -1,529 +0,0 @@
package errors
import (
"errors"
"net/http"
"reflect"
"testing"
)
func TestOIDCError_Error(t *testing.T) {
tests := []struct {
name string
oidcErr *OIDCError
expected string
}{
{
name: "Error with details",
oidcErr: &OIDCError{
Code: ErrCodeTokenInvalid,
Message: "Token validation failed",
Details: "JWT signature invalid",
},
expected: "TOKEN_INVALID: Token validation failed (JWT signature invalid)",
},
{
name: "Error without details",
oidcErr: &OIDCError{
Code: ErrCodeAuthenticationFailed,
Message: "Authentication failed",
},
expected: "AUTH_FAILED: Authentication failed",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.oidcErr.Error()
if result != tt.expected {
t.Errorf("Expected '%s', got '%s'", tt.expected, result)
}
})
}
}
func TestOIDCError_Unwrap(t *testing.T) {
internalErr := errors.New("internal error")
oidcErr := &OIDCError{
Code: ErrCodeTokenInvalid,
Message: "Token validation failed",
Internal: internalErr,
}
unwrapped := oidcErr.Unwrap()
if unwrapped != internalErr {
t.Errorf("Expected internal error, got %v", unwrapped)
}
// Test with nil internal error
oidcErrNoInternal := &OIDCError{
Code: ErrCodeTokenInvalid,
Message: "Token validation failed",
}
unwrappedNil := oidcErrNoInternal.Unwrap()
if unwrappedNil != nil {
t.Errorf("Expected nil, got %v", unwrappedNil)
}
}
func TestOIDCError_IsRetryable(t *testing.T) {
tests := []struct {
name string
code ErrorCode
expected bool
}{
{"Network timeout", ErrCodeNetworkTimeout, true},
{"Service unavailable", ErrCodeServiceUnavailable, true},
{"Provider unreachable", ErrCodeProviderUnreachable, true},
{"Authentication failed", ErrCodeAuthenticationFailed, false},
{"Token invalid", ErrCodeTokenInvalid, false},
{"Rate limited", ErrCodeRateLimited, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
oidcErr := &OIDCError{Code: tt.code}
result := oidcErr.IsRetryable()
if result != tt.expected {
t.Errorf("Expected %v, got %v for code %s", tt.expected, result, tt.code)
}
})
}
}
func TestOIDCError_IsAuthenticationError(t *testing.T) {
tests := []struct {
name string
code ErrorCode
expected bool
}{
{"Authentication failed", ErrCodeAuthenticationFailed, true},
{"Token expired", ErrCodeTokenExpired, true},
{"Token invalid", ErrCodeTokenInvalid, true},
{"Session expired", ErrCodeSessionExpired, true},
{"CSRF mismatch", ErrCodeCSRFMismatch, true},
{"Nonce mismatch", ErrCodeNonceMismatch, true},
{"Config invalid", ErrCodeConfigInvalid, false},
{"Domain not allowed", ErrCodeDomainNotAllowed, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
oidcErr := &OIDCError{Code: tt.code}
result := oidcErr.IsAuthenticationError()
if result != tt.expected {
t.Errorf("Expected %v, got %v for code %s", tt.expected, result, tt.code)
}
})
}
}
func TestOIDCError_IsAuthorizationError(t *testing.T) {
tests := []struct {
name string
code ErrorCode
expected bool
}{
{"Domain not allowed", ErrCodeDomainNotAllowed, true},
{"User not allowed", ErrCodeUserNotAllowed, true},
{"Role not allowed", ErrCodeRoleNotAllowed, true},
{"Authentication failed", ErrCodeAuthenticationFailed, false},
{"Token expired", ErrCodeTokenExpired, false},
{"Config invalid", ErrCodeConfigInvalid, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
oidcErr := &OIDCError{Code: tt.code}
result := oidcErr.IsAuthorizationError()
if result != tt.expected {
t.Errorf("Expected %v, got %v for code %s", tt.expected, result, tt.code)
}
})
}
}
func TestOIDCError_ToJSON(t *testing.T) {
tests := []struct {
name string
oidcErr *OIDCError
expected map[string]any
}{
{
name: "Error with details",
oidcErr: &OIDCError{
Code: ErrCodeTokenInvalid,
Message: "Token validation failed",
Details: "JWT signature invalid",
},
expected: map[string]any{
"error": map[string]any{
"code": "TOKEN_INVALID",
"message": "Token validation failed",
"details": "JWT signature invalid",
},
},
},
{
name: "Error without details",
oidcErr: &OIDCError{
Code: ErrCodeAuthenticationFailed,
Message: "Authentication failed",
},
expected: map[string]any{
"error": map[string]any{
"code": "AUTH_FAILED",
"message": "Authentication failed",
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.oidcErr.ToJSON()
if !reflect.DeepEqual(result, tt.expected) {
t.Errorf("Expected %+v, got %+v", tt.expected, result)
}
})
}
}
func TestNewAuthenticationError(t *testing.T) {
internalErr := errors.New("internal error")
tests := []struct {
name string
code ErrorCode
message string
internal error
expectedHTTP int
}{
{
name: "Regular auth error",
code: ErrCodeAuthenticationFailed,
message: "Auth failed",
internal: internalErr,
expectedHTTP: http.StatusUnauthorized,
},
{
name: "Session expired error",
code: ErrCodeSessionExpired,
message: "Session expired",
internal: internalErr,
expectedHTTP: http.StatusForbidden,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := NewAuthenticationError(tt.code, tt.message, tt.internal)
if err.Code != tt.code {
t.Errorf("Expected code %s, got %s", tt.code, err.Code)
}
if err.Message != tt.message {
t.Errorf("Expected message '%s', got '%s'", tt.message, err.Message)
}
if err.Internal != tt.internal {
t.Errorf("Expected internal error %v, got %v", tt.internal, err.Internal)
}
if err.HTTPStatus != tt.expectedHTTP {
t.Errorf("Expected HTTP status %d, got %d", tt.expectedHTTP, err.HTTPStatus)
}
})
}
}
func TestNewAuthorizationError(t *testing.T) {
err := NewAuthorizationError(ErrCodeDomainNotAllowed, "Domain not allowed", "example.com not in whitelist")
if err.Code != ErrCodeDomainNotAllowed {
t.Errorf("Expected code %s, got %s", ErrCodeDomainNotAllowed, err.Code)
}
if err.Message != "Domain not allowed" {
t.Errorf("Expected message 'Domain not allowed', got '%s'", err.Message)
}
if err.Details != "example.com not in whitelist" {
t.Errorf("Expected details 'example.com not in whitelist', got '%s'", err.Details)
}
if err.HTTPStatus != http.StatusForbidden {
t.Errorf("Expected HTTP status %d, got %d", http.StatusForbidden, err.HTTPStatus)
}
}
func TestNewConfigurationError(t *testing.T) {
internalErr := errors.New("config parse error")
err := NewConfigurationError(ErrCodeConfigInvalid, "Invalid config", internalErr)
if err.Code != ErrCodeConfigInvalid {
t.Errorf("Expected code %s, got %s", ErrCodeConfigInvalid, err.Code)
}
if err.HTTPStatus != http.StatusInternalServerError {
t.Errorf("Expected HTTP status %d, got %d", http.StatusInternalServerError, err.HTTPStatus)
}
if err.Internal != internalErr {
t.Errorf("Expected internal error %v, got %v", internalErr, err.Internal)
}
}
func TestNewNetworkError(t *testing.T) {
internalErr := errors.New("network error")
tests := []struct {
name string
code ErrorCode
expectedHTTP int
}{
{
name: "Rate limited",
code: ErrCodeRateLimited,
expectedHTTP: http.StatusTooManyRequests,
},
{
name: "Service unavailable",
code: ErrCodeServiceUnavailable,
expectedHTTP: http.StatusServiceUnavailable,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := NewNetworkError(tt.code, "Network error", internalErr)
if err.Code != tt.code {
t.Errorf("Expected code %s, got %s", tt.code, err.Code)
}
if err.HTTPStatus != tt.expectedHTTP {
t.Errorf("Expected HTTP status %d, got %d", tt.expectedHTTP, err.HTTPStatus)
}
})
}
}
func TestNewValidationError(t *testing.T) {
err := NewValidationError(ErrCodeValidationFailed, "Validation failed", "field 'email' is required")
if err.Code != ErrCodeValidationFailed {
t.Errorf("Expected code %s, got %s", ErrCodeValidationFailed, err.Code)
}
if err.HTTPStatus != http.StatusBadRequest {
t.Errorf("Expected HTTP status %d, got %d", http.StatusBadRequest, err.HTTPStatus)
}
if err.Details != "field 'email' is required" {
t.Errorf("Expected details 'field 'email' is required', got '%s'", err.Details)
}
}
func TestWrapAuthenticationError(t *testing.T) {
internalErr := errors.New("original error")
err := WrapAuthenticationError(internalErr, "Custom auth message")
if err.Code != ErrCodeAuthenticationFailed {
t.Errorf("Expected code %s, got %s", ErrCodeAuthenticationFailed, err.Code)
}
if err.Message != "Custom auth message" {
t.Errorf("Expected message 'Custom auth message', got '%s'", err.Message)
}
if err.Internal != internalErr {
t.Errorf("Expected internal error %v, got %v", internalErr, err.Internal)
}
}
func TestWrapTokenError(t *testing.T) {
internalErr := errors.New("token error")
err := WrapTokenError(internalErr, "ID token")
if err.Code != ErrCodeTokenInvalid {
t.Errorf("Expected code %s, got %s", ErrCodeTokenInvalid, err.Code)
}
if err.Message != "Token validation failed: ID token" {
t.Errorf("Expected message 'Token validation failed: ID token', got '%s'", err.Message)
}
if err.Internal != internalErr {
t.Errorf("Expected internal error %v, got %v", internalErr, err.Internal)
}
}
func TestWrapProviderError(t *testing.T) {
internalErr := errors.New("provider error")
err := WrapProviderError(internalErr, "https://provider.example.com")
if err.Code != ErrCodeProviderUnreachable {
t.Errorf("Expected code %s, got %s", ErrCodeProviderUnreachable, err.Code)
}
if err.Message != "Provider communication failed: https://provider.example.com" {
t.Errorf("Expected specific message, got '%s'", err.Message)
}
if err.Internal != internalErr {
t.Errorf("Expected internal error %v, got %v", internalErr, err.Internal)
}
}
func TestIsOIDCError(t *testing.T) {
// Test with OIDCError
oidcErr := &OIDCError{Code: ErrCodeTokenInvalid, Message: "test"}
result, ok := IsOIDCError(oidcErr)
if !ok {
t.Error("Expected IsOIDCError to return true for OIDCError")
}
if result != oidcErr {
t.Error("Expected to get the same OIDCError back")
}
// Test with regular error
regularErr := errors.New("regular error")
result, ok = IsOIDCError(regularErr)
if ok {
t.Error("Expected IsOIDCError to return false for regular error")
}
if result != nil {
t.Error("Expected nil result for regular error")
}
}
func TestGetHTTPStatus(t *testing.T) {
// Test with OIDCError
oidcErr := &OIDCError{
Code: ErrCodeTokenInvalid,
HTTPStatus: http.StatusUnauthorized,
}
status := GetHTTPStatus(oidcErr)
if status != http.StatusUnauthorized {
t.Errorf("Expected %d, got %d", http.StatusUnauthorized, status)
}
// Test with regular error
regularErr := errors.New("regular error")
status = GetHTTPStatus(regularErr)
if status != http.StatusInternalServerError {
t.Errorf("Expected %d, got %d", http.StatusInternalServerError, status)
}
}
func TestFormatUserMessage(t *testing.T) {
tests := []struct {
name string
err error
expected string
}{
{
name: "Domain not allowed",
err: &OIDCError{Code: ErrCodeDomainNotAllowed},
expected: "Your email domain is not authorized for this application",
},
{
name: "User not allowed",
err: &OIDCError{Code: ErrCodeUserNotAllowed},
expected: "Your account is not authorized for this application",
},
{
name: "Role not allowed",
err: &OIDCError{Code: ErrCodeRoleNotAllowed},
expected: "You do not have the required permissions for this application",
},
{
name: "Session expired",
err: &OIDCError{Code: ErrCodeSessionExpired},
expected: "Your session has expired. Please log in again",
},
{
name: "Token expired",
err: &OIDCError{Code: ErrCodeTokenExpired},
expected: "Your authentication has expired. Please log in again",
},
{
name: "Provider unreachable",
err: &OIDCError{Code: ErrCodeProviderUnreachable},
expected: "Authentication service is temporarily unavailable. Please try again later",
},
{
name: "Rate limited",
err: &OIDCError{Code: ErrCodeRateLimited},
expected: "Too many requests. Please wait a moment and try again",
},
{
name: "Unknown OIDC error",
err: &OIDCError{Code: ErrCodeConfigInvalid},
expected: "Authentication failed. Please try again",
},
{
name: "Regular error",
err: errors.New("regular error"),
expected: "An unexpected error occurred. Please try again",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := FormatUserMessage(tt.err)
if result != tt.expected {
t.Errorf("Expected '%s', got '%s'", tt.expected, result)
}
})
}
}
func TestErrorCodes(t *testing.T) {
// Test that all error codes are defined correctly
codes := []ErrorCode{
ErrCodeAuthenticationFailed,
ErrCodeTokenExpired,
ErrCodeTokenInvalid,
ErrCodeSessionExpired,
ErrCodeCSRFMismatch,
ErrCodeNonceMismatch,
ErrCodeConfigInvalid,
ErrCodeProviderUnreachable,
ErrCodeMetadataFailed,
ErrCodeNetworkTimeout,
ErrCodeRateLimited,
ErrCodeServiceUnavailable,
ErrCodeValidationFailed,
ErrCodeDomainNotAllowed,
ErrCodeUserNotAllowed,
ErrCodeRoleNotAllowed,
}
for _, code := range codes {
if string(code) == "" {
t.Errorf("Error code %v is empty", code)
}
}
}
func TestErrorConstructorCompleteness(t *testing.T) {
// Test each constructor function to ensure they set all required fields
internalErr := errors.New("test error")
// Test NewAuthenticationError
authErr := NewAuthenticationError(ErrCodeAuthenticationFailed, "auth message", internalErr)
if authErr.Code == "" || authErr.Message == "" || authErr.HTTPStatus == 0 {
t.Error("NewAuthenticationError did not set all required fields")
}
// Test NewAuthorizationError
authzErr := NewAuthorizationError(ErrCodeDomainNotAllowed, "authz message", "details")
if authzErr.Code == "" || authzErr.Message == "" || authzErr.HTTPStatus == 0 {
t.Error("NewAuthorizationError did not set all required fields")
}
// Test NewConfigurationError
configErr := NewConfigurationError(ErrCodeConfigInvalid, "config message", internalErr)
if configErr.Code == "" || configErr.Message == "" || configErr.HTTPStatus == 0 {
t.Error("NewConfigurationError did not set all required fields")
}
// Test NewNetworkError
netErr := NewNetworkError(ErrCodeNetworkTimeout, "network message", internalErr)
if netErr.Code == "" || netErr.Message == "" || netErr.HTTPStatus == 0 {
t.Error("NewNetworkError did not set all required fields")
}
// Test NewValidationError
validErr := NewValidationError(ErrCodeValidationFailed, "validation message", "details")
if validErr.Code == "" || validErr.Message == "" || validErr.HTTPStatus == 0 {
t.Error("NewValidationError did not set all required fields")
}
}
-224
View File
@@ -1,224 +0,0 @@
// Package handlers provides authentication flow management
package handlers
import (
"net/http"
"time"
)
// AuthFlowHandler manages the complete OIDC authentication flow
type AuthFlowHandler struct {
sessionHandler *SessionHandler
tokenHandler TokenHandler
logger Logger
excludedURLs map[string]struct{}
initComplete chan struct{}
issuerURL string
}
// TokenHandler interface for token operations
type TokenHandler interface {
VerifyToken(token string) error
RefreshToken(refreshToken string) (*TokenResponse, error)
}
// TokenResponse represents token exchange response
type TokenResponse struct {
IDToken string `json:"id_token"`
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int `json:"expires_in"`
}
// AuthFlowResult represents the result of authentication flow processing
type AuthFlowResult struct {
Authenticated bool
RequiresAuth bool
RequiresRefresh bool
Error error
RedirectURL string
StatusCode int
}
// NewAuthFlowHandler creates a new authentication flow handler
func NewAuthFlowHandler(sessionHandler *SessionHandler, tokenHandler TokenHandler, logger Logger, excludedURLs map[string]struct{}, initComplete chan struct{}, issuerURL string) *AuthFlowHandler {
return &AuthFlowHandler{
sessionHandler: sessionHandler,
tokenHandler: tokenHandler,
logger: logger,
excludedURLs: excludedURLs,
initComplete: initComplete,
issuerURL: issuerURL,
}
}
// ProcessRequest handles the main authentication flow
func (h *AuthFlowHandler) ProcessRequest(rw http.ResponseWriter, req *http.Request) AuthFlowResult {
// Check if URL should be excluded
if h.shouldExcludeURL(req.URL.Path) {
h.logger.Debugf("Request path %s excluded by configuration, bypassing OIDC", req.URL.Path)
return AuthFlowResult{Authenticated: true}
}
// Check for streaming requests
if h.isStreamingRequest(req) {
h.logger.Debugf("Streaming request detected, bypassing OIDC")
return AuthFlowResult{Authenticated: true}
}
// Wait for initialization
if !h.waitForInitialization(req) {
return AuthFlowResult{
Error: ErrInitializationTimeout,
StatusCode: http.StatusServiceUnavailable,
}
}
// Get and validate session
session, err := h.sessionHandler.sessionManager.GetSession(req)
if err != nil {
h.logger.Errorf("Error getting session: %v", err)
return AuthFlowResult{
RequiresAuth: true,
Error: err,
}
}
defer session.ReturnToPoolSafely()
// Clean up old cookies
h.sessionHandler.sessionManager.CleanupOldCookies(rw, req)
// Validate session
validationResult := h.sessionHandler.ValidateSession(session)
if !validationResult.Valid {
if validationResult.NeedsAuth {
return AuthFlowResult{RequiresAuth: true}
}
return AuthFlowResult{
Error: ErrSessionInvalid,
StatusCode: http.StatusUnauthorized,
}
}
// Check token validity and refresh if needed
return h.validateAndRefreshTokens(session, req, rw)
}
// shouldExcludeURL checks if a URL should bypass authentication
func (h *AuthFlowHandler) shouldExcludeURL(path string) bool {
for excludedURL := range h.excludedURLs {
if len(path) >= len(excludedURL) && path[:len(excludedURL)] == excludedURL {
return true
}
}
return false
}
// isStreamingRequest checks if request is a streaming request that should bypass auth
func (h *AuthFlowHandler) isStreamingRequest(req *http.Request) bool {
acceptHeader := req.Header.Get("Accept")
return acceptHeader == "text/event-stream"
}
// waitForInitialization waits for OIDC provider initialization
func (h *AuthFlowHandler) waitForInitialization(req *http.Request) bool {
select {
case <-h.initComplete:
if h.issuerURL == "" {
h.logger.Error("OIDC provider metadata initialization failed")
return false
}
return true
case <-req.Context().Done():
h.logger.Debug("Request canceled while waiting for OIDC initialization")
return false
case <-time.After(30 * time.Second):
h.logger.Error("Timeout waiting for OIDC initialization")
return false
}
}
// validateAndRefreshTokens handles token validation and refresh logic
func (h *AuthFlowHandler) validateAndRefreshTokens(session Session, req *http.Request, rw http.ResponseWriter) AuthFlowResult {
// Check access token if present
if accessToken := session.GetAccessToken(); accessToken != "" {
if err := h.tokenHandler.VerifyToken(accessToken); err != nil {
h.logger.Errorf("Access token validation failed: %v", err)
// Try refresh if refresh token is available
if refreshToken := session.GetRefreshToken(); refreshToken != "" {
return h.attemptTokenRefresh(session, req, rw)
}
return AuthFlowResult{RequiresAuth: true}
}
}
// Check ID token
if idToken := session.GetIDToken(); idToken != "" {
if err := h.tokenHandler.VerifyToken(idToken); err != nil {
h.logger.Errorf("ID token validation failed: %v", err)
// Try refresh if refresh token is available
if refreshToken := session.GetRefreshToken(); refreshToken != "" {
return h.attemptTokenRefresh(session, req, rw)
}
return AuthFlowResult{RequiresAuth: true}
}
}
return AuthFlowResult{Authenticated: true}
}
// attemptTokenRefresh tries to refresh tokens
func (h *AuthFlowHandler) attemptTokenRefresh(session Session, req *http.Request, rw http.ResponseWriter) AuthFlowResult {
refreshToken := session.GetRefreshToken()
if refreshToken == "" {
return AuthFlowResult{RequiresAuth: true}
}
// Check if this is an AJAX request
if h.sessionHandler.IsAjaxRequest(req) {
return AuthFlowResult{
Error: ErrSessionExpiredAjax,
StatusCode: http.StatusUnauthorized,
}
}
_, err := h.tokenHandler.RefreshToken(refreshToken)
if err != nil {
h.logger.Errorf("Token refresh failed: %v", err)
return AuthFlowResult{RequiresAuth: true}
}
// Update session with new tokens would be handled here
// Implementation depends on the actual session interface
if err := session.Save(req, rw); err != nil {
h.logger.Errorf("Failed to save refreshed session: %v", err)
return AuthFlowResult{
Error: err,
StatusCode: http.StatusInternalServerError,
}
}
return AuthFlowResult{Authenticated: true}
}
// Common errors
var (
ErrInitializationTimeout = &AuthFlowError{Code: "INIT_TIMEOUT", Message: "OIDC initialization timeout"}
ErrSessionInvalid = &AuthFlowError{Code: "SESSION_INVALID", Message: "Invalid session"}
ErrSessionExpiredAjax = &AuthFlowError{Code: "SESSION_EXPIRED_AJAX", Message: "Session expired for AJAX request"}
)
// AuthFlowError represents authentication flow errors
type AuthFlowError struct {
Code string
Message string
}
func (e *AuthFlowError) Error() string {
return e.Message
}
-588
View File
@@ -1,588 +0,0 @@
package handlers
import (
"context"
"errors"
"net/http"
"net/http/httptest"
"testing"
"time"
)
// Mock implementations that embed SessionHandler
type MockSessionHandlerWrapper struct {
*SessionHandler
}
func NewMockSessionHandlerWrapper() *MockSessionHandlerWrapper {
sessionManager := &MockSessionManager{}
logger := &MockLogger{}
sessionHandler := NewSessionHandler(
sessionManager,
logger,
"/logout",
"https://example.com/post-logout",
"https://provider.example.com/logout",
"test-client-id",
)
return &MockSessionHandlerWrapper{
SessionHandler: sessionHandler,
}
}
type MockSessionManager struct {
session Session
err error
}
func (m *MockSessionManager) GetSession(req *http.Request) (Session, error) {
return m.session, m.err
}
func (m *MockSessionManager) CleanupOldCookies(rw http.ResponseWriter, req *http.Request) {
// Mock implementation
}
type MockSession struct {
authenticated bool
email string
idToken string
accessToken string
refreshToken string
saveError error
clearError error
}
func (m *MockSession) GetAuthenticated() bool { return m.authenticated }
func (m *MockSession) SetAuthenticated(auth bool) error { m.authenticated = auth; return nil }
func (m *MockSession) GetEmail() string { return m.email }
func (m *MockSession) SetEmail(email string) { m.email = email }
func (m *MockSession) GetIDToken() string { return m.idToken }
func (m *MockSession) GetAccessToken() string { return m.accessToken }
func (m *MockSession) GetRefreshToken() string { return m.refreshToken }
func (m *MockSession) SetRefreshToken(token string) { m.refreshToken = token }
func (m *MockSession) Clear(req *http.Request, rw http.ResponseWriter) error { return m.clearError }
func (m *MockSession) Save(req *http.Request, rw http.ResponseWriter) error { return m.saveError }
func (m *MockSession) ReturnToPoolSafely() {}
type MockTokenHandler struct {
verifyError error
refreshError error
tokenResponse *TokenResponse
}
func (m *MockTokenHandler) VerifyToken(token string) error {
return m.verifyError
}
func (m *MockTokenHandler) RefreshToken(refreshToken string) (*TokenResponse, error) {
return m.tokenResponse, m.refreshError
}
type MockLogger struct {
debugMessages []string
errorMessages []string
}
func (m *MockLogger) Debug(msg string) {
m.debugMessages = append(m.debugMessages, msg)
}
func (m *MockLogger) Debugf(format string, args ...interface{}) {
m.debugMessages = append(m.debugMessages, format)
}
func (m *MockLogger) Info(msg string) {}
func (m *MockLogger) Infof(format string, args ...interface{}) {}
func (m *MockLogger) Error(msg string) {
m.errorMessages = append(m.errorMessages, msg)
}
func (m *MockLogger) Errorf(format string, args ...interface{}) {
m.errorMessages = append(m.errorMessages, format)
}
func TestNewAuthFlowHandler(t *testing.T) {
sessionHandler := NewMockSessionHandlerWrapper()
tokenHandler := &MockTokenHandler{}
logger := &MockLogger{}
excludedURLs := map[string]struct{}{"/health": {}}
initComplete := make(chan struct{})
issuerURL := "https://issuer.example.com"
handler := NewAuthFlowHandler(sessionHandler.SessionHandler, tokenHandler, logger, excludedURLs, initComplete, issuerURL)
if handler == nil {
t.Fatal("NewAuthFlowHandler returned nil")
}
if handler.sessionHandler == nil {
t.Error("SessionHandler not set correctly")
}
if handler.tokenHandler != tokenHandler {
t.Error("TokenHandler not set correctly")
}
if handler.logger != logger {
t.Error("Logger not set correctly")
}
if handler.issuerURL != issuerURL {
t.Error("IssuerURL not set correctly")
}
}
func TestAuthFlowHandler_shouldExcludeURL(t *testing.T) {
excludedURLs := map[string]struct{}{
"/health": {},
"/metrics": {},
"/api/public": {},
}
handler := &AuthFlowHandler{excludedURLs: excludedURLs}
tests := []struct {
path string
expected bool
}{
{"/health", true},
{"/health/check", true},
{"/metrics", true},
{"/metrics/prometheus", true},
{"/api/public", true},
{"/api/public/endpoint", true},
{"/api/private", false},
{"/login", false},
{"/dashboard", false},
}
for _, test := range tests {
result := handler.shouldExcludeURL(test.path)
if result != test.expected {
t.Errorf("For path '%s': expected %v, got %v", test.path, test.expected, result)
}
}
}
func TestAuthFlowHandler_isStreamingRequest(t *testing.T) {
handler := &AuthFlowHandler{}
tests := []struct {
name string
accept string
expected bool
}{
{
name: "SSE request",
accept: "text/event-stream",
expected: true,
},
{
name: "Regular HTML request",
accept: "text/html,application/xhtml+xml",
expected: false,
},
{
name: "JSON request",
accept: "application/json",
expected: false,
},
{
name: "Empty accept header",
accept: "",
expected: false,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
req.Header.Set("Accept", test.accept)
result := handler.isStreamingRequest(req)
if result != test.expected {
t.Errorf("Expected %v, got %v", test.expected, result)
}
})
}
}
func TestAuthFlowHandler_waitForInitialization(t *testing.T) {
tests := []struct {
name string
setupHandler func() (*AuthFlowHandler, context.CancelFunc)
expectedResult bool
}{
{
name: "Initialization complete successfully",
setupHandler: func() (*AuthFlowHandler, context.CancelFunc) {
initComplete := make(chan struct{})
close(initComplete) // Already complete
handler := &AuthFlowHandler{
initComplete: initComplete,
issuerURL: "https://issuer.example.com",
}
return handler, nil
},
expectedResult: true,
},
{
name: "Initialization complete but no issuer URL",
setupHandler: func() (*AuthFlowHandler, context.CancelFunc) {
initComplete := make(chan struct{})
close(initComplete)
handler := &AuthFlowHandler{
initComplete: initComplete,
issuerURL: "",
logger: &MockLogger{},
}
return handler, nil
},
expectedResult: false,
},
{
name: "Request canceled",
setupHandler: func() (*AuthFlowHandler, context.CancelFunc) {
initComplete := make(chan struct{})
handler := &AuthFlowHandler{
initComplete: initComplete,
logger: &MockLogger{},
}
_, cancel := context.WithCancel(context.Background())
return handler, cancel
},
expectedResult: false,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
handler, cancelFunc := test.setupHandler()
req := httptest.NewRequest("GET", "/", nil)
if cancelFunc != nil {
ctx, cancel := context.WithCancel(context.Background())
req = req.WithContext(ctx)
cancel() // Cancel immediately
}
result := handler.waitForInitialization(req)
if result != test.expectedResult {
t.Errorf("Expected %v, got %v", test.expectedResult, result)
}
})
}
}
func TestAuthFlowHandler_ProcessRequest(t *testing.T) {
tests := []struct {
name string
setupRequest func() *http.Request
setupHandler func() *AuthFlowHandler
expectedResult AuthFlowResult
}{
{
name: "Excluded URL bypasses authentication",
setupRequest: func() *http.Request {
return httptest.NewRequest("GET", "/health", nil)
},
setupHandler: func() *AuthFlowHandler {
return &AuthFlowHandler{
excludedURLs: map[string]struct{}{"/health": {}},
logger: &MockLogger{},
}
},
expectedResult: AuthFlowResult{Authenticated: true},
},
{
name: "Streaming request bypasses authentication",
setupRequest: func() *http.Request {
req := httptest.NewRequest("GET", "/events", nil)
req.Header.Set("Accept", "text/event-stream")
return req
},
setupHandler: func() *AuthFlowHandler {
return &AuthFlowHandler{
excludedURLs: map[string]struct{}{},
logger: &MockLogger{},
}
},
expectedResult: AuthFlowResult{Authenticated: true},
},
{
name: "Initialization timeout",
setupRequest: func() *http.Request {
return httptest.NewRequest("GET", "/dashboard", nil)
},
setupHandler: func() *AuthFlowHandler {
return &AuthFlowHandler{
excludedURLs: map[string]struct{}{},
initComplete: make(chan struct{}), // Never closes
logger: &MockLogger{},
}
},
expectedResult: AuthFlowResult{
Error: ErrInitializationTimeout,
StatusCode: http.StatusServiceUnavailable,
},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
req := test.setupRequest()
handler := test.setupHandler()
rw := httptest.NewRecorder()
// For timeout test, use context with timeout
if test.name == "Initialization timeout" {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
defer cancel()
req = req.WithContext(ctx)
}
result := handler.ProcessRequest(rw, req)
if result.Authenticated != test.expectedResult.Authenticated {
t.Errorf("Expected Authenticated %v, got %v", test.expectedResult.Authenticated, result.Authenticated)
}
if result.StatusCode != test.expectedResult.StatusCode {
t.Errorf("Expected StatusCode %d, got %d", test.expectedResult.StatusCode, result.StatusCode)
}
if test.expectedResult.Error != nil && result.Error == nil {
t.Error("Expected error but got nil")
}
})
}
}
func TestAuthFlowHandler_validateAndRefreshTokens(t *testing.T) {
tests := []struct {
name string
session *MockSession
tokenHandler *MockTokenHandler
expectedResult AuthFlowResult
}{
{
name: "Valid access token",
session: &MockSession{
authenticated: true,
accessToken: "valid-access-token",
},
tokenHandler: &MockTokenHandler{
verifyError: nil,
},
expectedResult: AuthFlowResult{Authenticated: true},
},
{
name: "Invalid access token, successful refresh",
session: &MockSession{
authenticated: true,
accessToken: "invalid-access-token",
refreshToken: "valid-refresh-token",
},
tokenHandler: &MockTokenHandler{
verifyError: errors.New("token expired"),
refreshError: nil,
tokenResponse: &TokenResponse{
IDToken: "new-id-token",
AccessToken: "new-access-token",
},
},
expectedResult: AuthFlowResult{Authenticated: true},
},
{
name: "Invalid access token, no refresh token",
session: &MockSession{
authenticated: true,
accessToken: "invalid-access-token",
refreshToken: "",
},
tokenHandler: &MockTokenHandler{
verifyError: errors.New("token expired"),
},
expectedResult: AuthFlowResult{RequiresAuth: true},
},
{
name: "Valid ID token only",
session: &MockSession{
authenticated: true,
idToken: "valid-id-token",
},
tokenHandler: &MockTokenHandler{
verifyError: nil,
},
expectedResult: AuthFlowResult{Authenticated: true},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
handler := &AuthFlowHandler{
tokenHandler: test.tokenHandler,
logger: &MockLogger{},
}
req := httptest.NewRequest("GET", "/", nil)
rw := httptest.NewRecorder()
result := handler.validateAndRefreshTokens(test.session, req, rw)
if result.Authenticated != test.expectedResult.Authenticated {
t.Errorf("Expected Authenticated %v, got %v", test.expectedResult.Authenticated, result.Authenticated)
}
if result.RequiresAuth != test.expectedResult.RequiresAuth {
t.Errorf("Expected RequiresAuth %v, got %v", test.expectedResult.RequiresAuth, result.RequiresAuth)
}
})
}
}
func TestAuthFlowHandler_attemptTokenRefresh(t *testing.T) {
tests := []struct {
name string
session *MockSession
tokenHandler *MockTokenHandler
isAjax bool
expectedResult AuthFlowResult
}{
{
name: "No refresh token",
session: &MockSession{
refreshToken: "",
},
tokenHandler: &MockTokenHandler{},
expectedResult: AuthFlowResult{RequiresAuth: true},
},
{
name: "AJAX request with expired session",
session: &MockSession{
refreshToken: "refresh-token",
},
tokenHandler: &MockTokenHandler{},
isAjax: true,
expectedResult: AuthFlowResult{
Error: ErrSessionExpiredAjax,
StatusCode: http.StatusUnauthorized,
},
},
{
name: "Successful token refresh",
session: &MockSession{
refreshToken: "valid-refresh-token",
},
tokenHandler: &MockTokenHandler{
refreshError: nil,
tokenResponse: &TokenResponse{
IDToken: "new-id-token",
AccessToken: "new-access-token",
},
},
expectedResult: AuthFlowResult{Authenticated: true},
},
{
name: "Failed token refresh",
session: &MockSession{
refreshToken: "invalid-refresh-token",
},
tokenHandler: &MockTokenHandler{
refreshError: errors.New("refresh failed"),
},
expectedResult: AuthFlowResult{RequiresAuth: true},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
sessionHandlerWrapper := NewMockSessionHandlerWrapper()
handler := &AuthFlowHandler{
sessionHandler: sessionHandlerWrapper.SessionHandler,
tokenHandler: test.tokenHandler,
logger: &MockLogger{},
}
req := httptest.NewRequest("GET", "/", nil)
if test.isAjax {
req.Header.Set("X-Requested-With", "XMLHttpRequest")
}
rw := httptest.NewRecorder()
result := handler.attemptTokenRefresh(test.session, req, rw)
if result.Authenticated != test.expectedResult.Authenticated {
t.Errorf("Expected Authenticated %v, got %v", test.expectedResult.Authenticated, result.Authenticated)
}
if result.RequiresAuth != test.expectedResult.RequiresAuth {
t.Errorf("Expected RequiresAuth %v, got %v", test.expectedResult.RequiresAuth, result.RequiresAuth)
}
if result.StatusCode != test.expectedResult.StatusCode {
t.Errorf("Expected StatusCode %d, got %d", test.expectedResult.StatusCode, result.StatusCode)
}
})
}
}
func TestAuthFlowError_Error(t *testing.T) {
err := &AuthFlowError{
Code: "TEST_ERROR",
Message: "This is a test error",
}
expected := "This is a test error"
result := err.Error()
if result != expected {
t.Errorf("Expected '%s', got '%s'", expected, result)
}
}
func TestAuthFlowResult(t *testing.T) {
// Test AuthFlowResult struct
result := AuthFlowResult{
Authenticated: true,
RequiresAuth: false,
RequiresRefresh: false,
Error: nil,
RedirectURL: "https://example.com",
StatusCode: 200,
}
if !result.Authenticated {
t.Error("Expected Authenticated to be true")
}
if result.RequiresAuth {
t.Error("Expected RequiresAuth to be false")
}
if result.StatusCode != 200 {
t.Errorf("Expected StatusCode 200, got %d", result.StatusCode)
}
}
func TestTokenResponse(t *testing.T) {
response := &TokenResponse{
IDToken: "id-token-value",
AccessToken: "access-token-value",
RefreshToken: "refresh-token-value",
ExpiresIn: 3600,
}
if response.IDToken != "id-token-value" {
t.Errorf("Expected IDToken 'id-token-value', got '%s'", response.IDToken)
}
if response.ExpiresIn != 3600 {
t.Errorf("Expected ExpiresIn 3600, got %d", response.ExpiresIn)
}
}
-247
View File
@@ -1,247 +0,0 @@
// Package handlers provides HTTP request handlers for OIDC operations
package handlers
import (
"fmt"
"net/http"
"strings"
)
// SessionHandler manages session-related HTTP operations
type SessionHandler struct {
sessionManager SessionManager
logger Logger
logoutURLPath string
postLogoutRedirectURI string
endSessionURL string
clientID string
}
// SessionManager interface for session operations
type SessionManager interface {
GetSession(req *http.Request) (Session, error)
CleanupOldCookies(rw http.ResponseWriter, req *http.Request)
}
// Session interface for session data
type Session interface {
GetAuthenticated() bool
SetAuthenticated(bool) error
GetEmail() string
SetEmail(string)
GetIDToken() string
GetAccessToken() string
GetRefreshToken() string
SetRefreshToken(string)
Clear(req *http.Request, rw http.ResponseWriter) error
Save(req *http.Request, rw http.ResponseWriter) error
ReturnToPoolSafely()
}
// Logger interface for logging operations
type Logger interface {
Debug(msg string)
Debugf(format string, args ...interface{})
Info(msg string)
Infof(format string, args ...interface{})
Error(msg string)
Errorf(format string, args ...interface{})
}
// NewSessionHandler creates a new session handler
func NewSessionHandler(sessionManager SessionManager, logger Logger, logoutURLPath, postLogoutRedirectURI, endSessionURL, clientID string) *SessionHandler {
return &SessionHandler{
sessionManager: sessionManager,
logger: logger,
logoutURLPath: logoutURLPath,
postLogoutRedirectURI: postLogoutRedirectURI,
endSessionURL: endSessionURL,
clientID: clientID,
}
}
// HandleLogout processes logout requests
func (h *SessionHandler) HandleLogout(rw http.ResponseWriter, req *http.Request) {
h.logger.Debug("Processing logout request")
session, err := h.sessionManager.GetSession(req)
if err != nil {
h.logger.Errorf("Error getting session during logout: %v", err)
// Continue with logout even if session retrieval fails
}
var idToken string
if session != nil {
defer session.ReturnToPoolSafely()
idToken = session.GetIDToken()
// Clear the session
if err := session.Clear(req, rw); err != nil {
h.logger.Errorf("Error clearing session during logout: %v", err)
}
}
// Build logout URL
logoutURL := h.buildLogoutURL(idToken)
h.logger.Debugf("Redirecting to logout URL: %s", logoutURL)
http.Redirect(rw, req, logoutURL, http.StatusFound)
}
// buildLogoutURL constructs the provider logout URL
func (h *SessionHandler) buildLogoutURL(idToken string) string {
if h.endSessionURL == "" {
// If no end session URL, redirect to post-logout redirect URI
return h.postLogoutRedirectURI
}
logoutURL := h.endSessionURL
// Add query parameters
params := make([]string, 0, 3)
if idToken != "" {
params = append(params, fmt.Sprintf("id_token_hint=%s", idToken))
}
if h.postLogoutRedirectURI != "" {
params = append(params, fmt.Sprintf("post_logout_redirect_uri=%s", h.postLogoutRedirectURI))
}
if h.clientID != "" {
params = append(params, fmt.Sprintf("client_id=%s", h.clientID))
}
if len(params) > 0 {
separator := "?"
if strings.Contains(logoutURL, "?") {
separator = "&"
}
logoutURL += separator + strings.Join(params, "&")
}
return logoutURL
}
// ValidateSession checks if a session is valid and authenticated
func (h *SessionHandler) ValidateSession(session Session) SessionValidationResult {
if session == nil {
return SessionValidationResult{
Valid: false,
NeedsAuth: true,
ErrorMessage: "session is nil",
}
}
if !session.GetAuthenticated() {
return SessionValidationResult{
Valid: false,
NeedsAuth: true,
ErrorMessage: "session not authenticated",
}
}
email := session.GetEmail()
if email == "" {
return SessionValidationResult{
Valid: false,
NeedsAuth: true,
ErrorMessage: "no email in session",
}
}
return SessionValidationResult{
Valid: true,
NeedsAuth: false,
}
}
// SessionValidationResult represents the result of session validation
type SessionValidationResult struct {
Valid bool
NeedsAuth bool
ErrorMessage string
}
// CleanupExpiredSession clears an expired session
func (h *SessionHandler) CleanupExpiredSession(rw http.ResponseWriter, req *http.Request, session Session) error {
h.logger.Debug("Cleaning up expired session")
if session == nil {
return nil
}
// Clear all session data
if err := session.SetAuthenticated(false); err != nil {
h.logger.Errorf("Failed to set authenticated to false: %v", err)
}
session.SetEmail("")
session.SetRefreshToken("")
// Save the cleared session
if err := session.Save(req, rw); err != nil {
h.logger.Errorf("Failed to save cleared session: %v", err)
return err
}
return nil
}
// IsAjaxRequest determines if the request is an AJAX/XHR request
func (h *SessionHandler) IsAjaxRequest(req *http.Request) bool {
// Check X-Requested-With header (commonly used by jQuery and other libraries)
if req.Header.Get("X-Requested-With") == "XMLHttpRequest" {
return true
}
// Check Accept header for JSON preference
accept := req.Header.Get("Accept")
if strings.Contains(accept, "application/json") && !strings.Contains(accept, "text/html") {
return true
}
// Check for fetch API indication
if req.Header.Get("Sec-Fetch-Mode") == "cors" {
return true
}
return false
}
// SendErrorResponse sends an appropriate error response based on request type
func (h *SessionHandler) SendErrorResponse(rw http.ResponseWriter, req *http.Request, message string, statusCode int) {
if h.IsAjaxRequest(req) {
// For AJAX requests, send JSON response
rw.Header().Set("Content-Type", "application/json")
rw.WriteHeader(statusCode)
_, _ = fmt.Fprintf(rw, `{"error": "%s"}`, message) // Safe to ignore: writing error response
} else {
// For browser requests, send HTML response
rw.Header().Set("Content-Type", "text/html")
rw.WriteHeader(statusCode)
_, _ = fmt.Fprintf(rw, `<html><body><h1>Error %d</h1><p>%s</p></body></html>`, statusCode, message) // Safe to ignore: writing error response
}
}
// SetSecurityHeaders sets standard security headers
func (h *SessionHandler) SetSecurityHeaders(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set("X-Frame-Options", "DENY")
rw.Header().Set("X-Content-Type-Options", "nosniff")
rw.Header().Set("X-XSS-Protection", "1; mode=block")
rw.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
// Handle CORS for AJAX requests
origin := req.Header.Get("Origin")
if origin != "" {
rw.Header().Set("Access-Control-Allow-Origin", origin)
rw.Header().Set("Access-Control-Allow-Credentials", "true")
rw.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
rw.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
if req.Method == "OPTIONS" {
rw.WriteHeader(http.StatusOK)
return
}
}
}
-587
View File
@@ -1,587 +0,0 @@
package handlers
import (
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestNewSessionHandler(t *testing.T) {
sessionManager := &MockSessionManager{}
logger := &MockLogger{}
logoutURLPath := "/logout"
postLogoutRedirectURI := "https://example.com/post-logout"
endSessionURL := "https://provider.example.com/logout"
clientID := "test-client-id"
handler := NewSessionHandler(
sessionManager,
logger,
logoutURLPath,
postLogoutRedirectURI,
endSessionURL,
clientID,
)
if handler == nil {
t.Fatal("NewSessionHandler returned nil")
}
if handler.sessionManager != sessionManager {
t.Error("SessionManager not set correctly")
}
if handler.logger != logger {
t.Error("Logger not set correctly")
}
if handler.logoutURLPath != logoutURLPath {
t.Error("LogoutURLPath not set correctly")
}
if handler.postLogoutRedirectURI != postLogoutRedirectURI {
t.Error("PostLogoutRedirectURI not set correctly")
}
if handler.endSessionURL != endSessionURL {
t.Error("EndSessionURL not set correctly")
}
if handler.clientID != clientID {
t.Error("ClientID not set correctly")
}
}
func TestSessionHandler_HandleLogout(t *testing.T) {
tests := []struct {
name string
setupSession func() *MockSession
setupManager func() *MockSessionManager
expectedCode int
expectedURL string
}{
{
name: "Successful logout with ID token",
setupSession: func() *MockSession {
return &MockSession{
authenticated: true,
idToken: "test-id-token",
}
},
setupManager: func() *MockSessionManager {
return &MockSessionManager{
session: &MockSession{
authenticated: true,
idToken: "test-id-token",
},
}
},
expectedCode: http.StatusFound,
expectedURL: "https://provider.example.com/logout?id_token_hint=test-id-token&post_logout_redirect_uri=https://example.com/post-logout&client_id=test-client-id",
},
{
name: "Logout without ID token",
setupSession: func() *MockSession {
return &MockSession{
authenticated: true,
idToken: "",
}
},
setupManager: func() *MockSessionManager {
return &MockSessionManager{
session: &MockSession{
authenticated: true,
idToken: "",
},
}
},
expectedCode: http.StatusFound,
expectedURL: "https://provider.example.com/logout?post_logout_redirect_uri=https://example.com/post-logout&client_id=test-client-id",
},
{
name: "Session retrieval error",
setupSession: func() *MockSession { return nil },
setupManager: func() *MockSessionManager {
return &MockSessionManager{
err: fmt.Errorf("session error"),
}
},
expectedCode: http.StatusFound,
expectedURL: "https://provider.example.com/logout?post_logout_redirect_uri=https://example.com/post-logout&client_id=test-client-id",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
handler := &SessionHandler{
sessionManager: test.setupManager(),
logger: &MockLogger{},
logoutURLPath: "/logout",
postLogoutRedirectURI: "https://example.com/post-logout",
endSessionURL: "https://provider.example.com/logout",
clientID: "test-client-id",
}
req := httptest.NewRequest("POST", "/logout", nil)
rw := httptest.NewRecorder()
handler.HandleLogout(rw, req)
if rw.Code != test.expectedCode {
t.Errorf("Expected status code %d, got %d", test.expectedCode, rw.Code)
}
location := rw.Header().Get("Location")
if location != test.expectedURL {
t.Errorf("Expected location '%s', got '%s'", test.expectedURL, location)
}
})
}
}
func TestSessionHandler_buildLogoutURL(t *testing.T) {
tests := []struct {
name string
endSessionURL string
postLogoutRedirectURI string
clientID string
idToken string
expected string
}{
{
name: "Complete logout URL with all parameters",
endSessionURL: "https://provider.example.com/logout",
postLogoutRedirectURI: "https://example.com/post-logout",
clientID: "test-client-id",
idToken: "test-id-token",
expected: "https://provider.example.com/logout?id_token_hint=test-id-token&post_logout_redirect_uri=https://example.com/post-logout&client_id=test-client-id",
},
{
name: "Logout URL without ID token",
endSessionURL: "https://provider.example.com/logout",
postLogoutRedirectURI: "https://example.com/post-logout",
clientID: "test-client-id",
idToken: "",
expected: "https://provider.example.com/logout?post_logout_redirect_uri=https://example.com/post-logout&client_id=test-client-id",
},
{
name: "No end session URL",
endSessionURL: "",
postLogoutRedirectURI: "https://example.com/post-logout",
clientID: "test-client-id",
idToken: "test-id-token",
expected: "https://example.com/post-logout",
},
{
name: "End session URL with existing query parameters",
endSessionURL: "https://provider.example.com/logout?foo=bar",
postLogoutRedirectURI: "https://example.com/post-logout",
clientID: "test-client-id",
idToken: "",
expected: "https://provider.example.com/logout?foo=bar&post_logout_redirect_uri=https://example.com/post-logout&client_id=test-client-id",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
handler := &SessionHandler{
endSessionURL: test.endSessionURL,
postLogoutRedirectURI: test.postLogoutRedirectURI,
clientID: test.clientID,
}
result := handler.buildLogoutURL(test.idToken)
if result != test.expected {
t.Errorf("Expected '%s', got '%s'", test.expected, result)
}
})
}
}
func TestSessionHandler_ValidateSession(t *testing.T) {
handler := &SessionHandler{}
tests := []struct {
name string
session Session
expectedValid bool
expectedAuth bool
expectedMessage string
}{
{
name: "Nil session",
session: nil,
expectedValid: false,
expectedAuth: true,
expectedMessage: "session is nil",
},
{
name: "Not authenticated session",
session: &MockSession{
authenticated: false,
},
expectedValid: false,
expectedAuth: true,
expectedMessage: "session not authenticated",
},
{
name: "Authenticated session without email",
session: &MockSession{
authenticated: true,
email: "",
},
expectedValid: false,
expectedAuth: true,
expectedMessage: "no email in session",
},
{
name: "Valid authenticated session with email",
session: &MockSession{
authenticated: true,
email: "user@example.com",
},
expectedValid: true,
expectedAuth: false,
expectedMessage: "",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
result := handler.ValidateSession(test.session)
if result.Valid != test.expectedValid {
t.Errorf("Expected Valid %v, got %v", test.expectedValid, result.Valid)
}
if result.NeedsAuth != test.expectedAuth {
t.Errorf("Expected NeedsAuth %v, got %v", test.expectedAuth, result.NeedsAuth)
}
if result.ErrorMessage != test.expectedMessage {
t.Errorf("Expected ErrorMessage '%s', got '%s'", test.expectedMessage, result.ErrorMessage)
}
})
}
}
func TestSessionHandler_CleanupExpiredSession(t *testing.T) {
tests := []struct {
name string
session *MockSession
expectError bool
}{
{
name: "Successful cleanup",
session: &MockSession{
authenticated: true,
email: "user@example.com",
refreshToken: "refresh-token",
},
expectError: false,
},
{
name: "Save error during cleanup",
session: &MockSession{
authenticated: true,
email: "user@example.com",
refreshToken: "refresh-token",
saveError: fmt.Errorf("save failed"),
},
expectError: true,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
handler := &SessionHandler{
logger: &MockLogger{},
}
req := httptest.NewRequest("GET", "/", nil)
rw := httptest.NewRecorder()
err := handler.CleanupExpiredSession(rw, req, test.session)
if test.expectError && err == nil {
t.Error("Expected error but got nil")
}
if !test.expectError && err != nil {
t.Errorf("Expected no error but got: %v", err)
}
if test.session != nil && !test.expectError {
if test.session.authenticated {
t.Error("Expected session authenticated to be false after cleanup")
}
if test.session.email != "" {
t.Error("Expected session email to be empty after cleanup")
}
if test.session.refreshToken != "" {
t.Error("Expected session refresh token to be empty after cleanup")
}
}
})
}
// Test nil session separately
t.Run("Nil session", func(t *testing.T) {
handler := &SessionHandler{
logger: &MockLogger{},
}
req := httptest.NewRequest("GET", "/", nil)
rw := httptest.NewRecorder()
var nilSession Session = nil
err := handler.CleanupExpiredSession(rw, req, nilSession)
if err != nil {
t.Errorf("Expected no error for nil session, got: %v", err)
}
})
}
func TestSessionHandler_IsAjaxRequest(t *testing.T) {
handler := &SessionHandler{}
tests := []struct {
name string
headers map[string]string
expected bool
}{
{
name: "XMLHttpRequest header",
headers: map[string]string{
"X-Requested-With": "XMLHttpRequest",
},
expected: true,
},
{
name: "JSON Accept header without HTML",
headers: map[string]string{
"Accept": "application/json",
},
expected: true,
},
{
name: "JSON Accept header with HTML",
headers: map[string]string{
"Accept": "application/json, text/html",
},
expected: false,
},
{
name: "Fetch API CORS mode",
headers: map[string]string{
"Sec-Fetch-Mode": "cors",
},
expected: true,
},
{
name: "Regular browser request",
headers: map[string]string{
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
},
expected: false,
},
{
name: "No special headers",
headers: map[string]string{},
expected: false,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
for key, value := range test.headers {
req.Header.Set(key, value)
}
result := handler.IsAjaxRequest(req)
if result != test.expected {
t.Errorf("Expected %v, got %v", test.expected, result)
}
})
}
}
func TestSessionHandler_SendErrorResponse(t *testing.T) {
tests := []struct {
name string
isAjax bool
message string
statusCode int
expectedContentType string
expectedBodyContains string
}{
{
name: "AJAX error response",
isAjax: true,
message: "Authentication failed",
statusCode: http.StatusUnauthorized,
expectedContentType: "application/json",
expectedBodyContains: `{"error": "Authentication failed"}`,
},
{
name: "Browser error response",
isAjax: false,
message: "Session expired",
statusCode: http.StatusForbidden,
expectedContentType: "text/html",
expectedBodyContains: "<h1>Error 403</h1>",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
handler := &SessionHandler{}
req := httptest.NewRequest("GET", "/", nil)
if test.isAjax {
req.Header.Set("X-Requested-With", "XMLHttpRequest")
}
rw := httptest.NewRecorder()
handler.SendErrorResponse(rw, req, test.message, test.statusCode)
if rw.Code != test.statusCode {
t.Errorf("Expected status code %d, got %d", test.statusCode, rw.Code)
}
contentType := rw.Header().Get("Content-Type")
if contentType != test.expectedContentType {
t.Errorf("Expected Content-Type '%s', got '%s'", test.expectedContentType, contentType)
}
body := rw.Body.String()
if !strings.Contains(body, test.expectedBodyContains) {
t.Errorf("Expected body to contain '%s', got '%s'", test.expectedBodyContains, body)
}
})
}
}
func TestSessionHandler_SetSecurityHeaders(t *testing.T) {
tests := []struct {
name string
method string
origin string
expectedCORS bool
expectedStatus int
}{
{
name: "Regular request without CORS",
method: "GET",
origin: "",
expectedCORS: false,
expectedStatus: 0, // No status written
},
{
name: "CORS request with origin",
method: "GET",
origin: "https://example.com",
expectedCORS: true,
expectedStatus: 0,
},
{
name: "OPTIONS preflight request",
method: "OPTIONS",
origin: "https://example.com",
expectedCORS: true,
expectedStatus: http.StatusOK,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
handler := &SessionHandler{}
req := httptest.NewRequest(test.method, "/", nil)
if test.origin != "" {
req.Header.Set("Origin", test.origin)
}
rw := httptest.NewRecorder()
handler.SetSecurityHeaders(rw, req)
// Check standard security headers
expectedSecurityHeaders := map[string]string{
"X-Frame-Options": "DENY",
"X-Content-Type-Options": "nosniff",
"X-XSS-Protection": "1; mode=block",
"Referrer-Policy": "strict-origin-when-cross-origin",
}
for header, expectedValue := range expectedSecurityHeaders {
actualValue := rw.Header().Get(header)
if actualValue != expectedValue {
t.Errorf("Expected %s header '%s', got '%s'", header, expectedValue, actualValue)
}
}
// Check CORS headers
if test.expectedCORS {
corsOrigin := rw.Header().Get("Access-Control-Allow-Origin")
if corsOrigin != test.origin {
t.Errorf("Expected CORS origin '%s', got '%s'", test.origin, corsOrigin)
}
corsCredentials := rw.Header().Get("Access-Control-Allow-Credentials")
if corsCredentials != "true" {
t.Errorf("Expected CORS credentials 'true', got '%s'", corsCredentials)
}
corsMethods := rw.Header().Get("Access-Control-Allow-Methods")
if corsMethods != "GET, POST, OPTIONS" {
t.Errorf("Expected CORS methods 'GET, POST, OPTIONS', got '%s'", corsMethods)
}
corsHeaders := rw.Header().Get("Access-Control-Allow-Headers")
if corsHeaders != "Authorization, Content-Type" {
t.Errorf("Expected CORS headers 'Authorization, Content-Type', got '%s'", corsHeaders)
}
} else {
corsOrigin := rw.Header().Get("Access-Control-Allow-Origin")
if corsOrigin != "" {
t.Errorf("Expected no CORS origin header, got '%s'", corsOrigin)
}
}
// Check status code for OPTIONS requests
if test.expectedStatus > 0 {
if rw.Code != test.expectedStatus {
t.Errorf("Expected status code %d, got %d", test.expectedStatus, rw.Code)
}
}
})
}
}
func TestSessionValidationResult(t *testing.T) {
result := SessionValidationResult{
Valid: true,
NeedsAuth: false,
ErrorMessage: "test message",
}
if !result.Valid {
t.Error("Expected Valid to be true")
}
if result.NeedsAuth {
t.Error("Expected NeedsAuth to be false")
}
if result.ErrorMessage != "test message" {
t.Errorf("Expected ErrorMessage 'test message', got '%s'", result.ErrorMessage)
}
}
-545
View File
@@ -1,545 +0,0 @@
package httpclient
import (
"context"
"crypto/tls"
"fmt"
"net"
"net/http"
"net/http/cookiejar"
"sync"
"sync/atomic"
"time"
)
// Config provides configuration for creating HTTP clients
type Config struct {
// Timeout for the entire request
Timeout time.Duration
// MaxRedirects allowed (0 means follow Go's default of 10)
MaxRedirects int
// UseCookieJar enables cookie jar for the client
UseCookieJar bool
// Connection settings
DialTimeout time.Duration
KeepAlive time.Duration
TLSHandshakeTimeout time.Duration
ResponseHeaderTimeout time.Duration
ExpectContinueTimeout time.Duration
IdleConnTimeout time.Duration
// Connection pool settings
MaxIdleConns int
MaxIdleConnsPerHost int
MaxConnsPerHost int
// Buffer settings
WriteBufferSize int
ReadBufferSize int
// Feature flags
ForceHTTP2 bool
DisableKeepAlives bool
DisableCompression bool
// TLS configuration
TLSConfig *tls.Config
}
// ClientType defines the type of HTTP client for optimized behavior
type ClientType string
const (
ClientTypeDefault ClientType = "default"
ClientTypeToken ClientType = "token"
ClientTypeAPI ClientType = "api"
ClientTypeProxy ClientType = "proxy"
)
// PresetConfigs provides pre-configured settings for different client types
var PresetConfigs = map[ClientType]Config{
ClientTypeDefault: {
Timeout: 10 * time.Second, // Reduced from 30s to prevent slowloris attacks
MaxRedirects: 5, // Reduced from 10 to prevent redirect loops
UseCookieJar: false,
DialTimeout: 3 * time.Second,
KeepAlive: 15 * time.Second,
TLSHandshakeTimeout: 2 * time.Second,
ResponseHeaderTimeout: 3 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
IdleConnTimeout: 5 * time.Second,
MaxIdleConns: 20, // Reduced from 100 to limit resource usage
MaxIdleConnsPerHost: 2, // Reduced from 10 to prevent connection exhaustion
MaxConnsPerHost: 5, // Reduced from 10 to limit concurrent connections
WriteBufferSize: 4096,
ReadBufferSize: 4096,
ForceHTTP2: true,
DisableKeepAlives: false,
DisableCompression: false,
},
ClientTypeToken: {
Timeout: 10 * time.Second,
MaxRedirects: 50, // Token endpoints may redirect more
UseCookieJar: true,
DialTimeout: 3 * time.Second,
KeepAlive: 15 * time.Second,
TLSHandshakeTimeout: 2 * time.Second,
ResponseHeaderTimeout: 3 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
IdleConnTimeout: 5 * time.Second,
MaxIdleConns: 10,
MaxIdleConnsPerHost: 2,
MaxConnsPerHost: 5,
WriteBufferSize: 4096,
ReadBufferSize: 4096,
ForceHTTP2: true,
DisableKeepAlives: false,
DisableCompression: false,
},
ClientTypeAPI: {
Timeout: 30 * time.Second, // Longer for API operations
MaxRedirects: 10,
UseCookieJar: false,
DialTimeout: 5 * time.Second,
KeepAlive: 30 * time.Second,
TLSHandshakeTimeout: 5 * time.Second,
ResponseHeaderTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
IdleConnTimeout: 90 * time.Second,
MaxIdleConns: 50,
MaxIdleConnsPerHost: 5,
MaxConnsPerHost: 10,
WriteBufferSize: 8192,
ReadBufferSize: 8192,
ForceHTTP2: true,
DisableKeepAlives: false,
DisableCompression: false,
},
ClientTypeProxy: {
Timeout: 60 * time.Second, // Proxy needs longer timeouts
MaxRedirects: 0, // Proxy should not follow redirects
UseCookieJar: false,
DialTimeout: 10 * time.Second,
KeepAlive: 30 * time.Second,
TLSHandshakeTimeout: 5 * time.Second,
ResponseHeaderTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
IdleConnTimeout: 90 * time.Second,
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
MaxConnsPerHost: 20,
WriteBufferSize: 16384,
ReadBufferSize: 16384,
ForceHTTP2: true,
DisableKeepAlives: false,
DisableCompression: true, // Proxy should not modify content
},
}
// Factory provides methods for creating configured HTTP clients
type Factory struct {
pool *TransportPool
logger Logger
}
// Logger interface for HTTP client operations
type Logger interface {
Debug(msg string)
Debugf(format string, args ...interface{})
Info(msg string)
Infof(format string, args ...interface{})
Error(msg string)
Errorf(format string, args ...interface{})
}
var (
globalFactory *Factory
globalFactoryOnce sync.Once
)
// GetGlobalFactory returns the singleton HTTP client factory
func GetGlobalFactory(logger Logger) *Factory {
globalFactoryOnce.Do(func() {
globalFactory = NewFactory(logger)
})
return globalFactory
}
// NewFactory creates a new HTTP client factory
func NewFactory(logger Logger) *Factory {
if logger == nil {
logger = &noOpLogger{}
}
return &Factory{
pool: GetGlobalTransportPool(),
logger: logger,
}
}
// CreateClient creates an HTTP client with the specified configuration
func (f *Factory) CreateClient(config Config) (*http.Client, error) {
// Validate configuration
if err := f.ValidateConfig(&config); err != nil {
return nil, fmt.Errorf("invalid configuration: %w", err)
}
// Apply TLS configuration if not provided
if config.TLSConfig == nil {
config.TLSConfig = f.createSecureTLSConfig()
}
// Get or create transport from pool
transport := f.pool.GetOrCreateTransport(config)
if transport == nil {
return nil, fmt.Errorf("failed to create transport: client limit exceeded")
}
// Create HTTP client
client := &http.Client{
Transport: transport,
Timeout: config.Timeout,
}
// Configure redirect policy
if config.MaxRedirects > 0 {
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
if len(via) >= config.MaxRedirects {
return fmt.Errorf("stopped after %d redirects", config.MaxRedirects)
}
return nil
}
}
// Add cookie jar if requested
if config.UseCookieJar {
jar, err := cookiejar.New(nil)
if err != nil {
return nil, fmt.Errorf("failed to create cookie jar: %w", err)
}
client.Jar = jar
}
f.logger.Debugf("Created HTTP client with config: timeout=%v, maxRedirects=%d", config.Timeout, config.MaxRedirects)
return client, nil
}
// CreateClientWithPreset creates an HTTP client using a preset configuration
func (f *Factory) CreateClientWithPreset(clientType ClientType) (*http.Client, error) {
config, ok := PresetConfigs[clientType]
if !ok {
return nil, fmt.Errorf("unknown client type: %s", clientType)
}
return f.CreateClient(config)
}
// CreateDefault creates a default HTTP client
func (f *Factory) CreateDefault() (*http.Client, error) {
return f.CreateClientWithPreset(ClientTypeDefault)
}
// CreateToken creates an HTTP client optimized for token operations
func (f *Factory) CreateToken() (*http.Client, error) {
return f.CreateClientWithPreset(ClientTypeToken)
}
// CreateAPI creates an HTTP client optimized for API operations
func (f *Factory) CreateAPI() (*http.Client, error) {
return f.CreateClientWithPreset(ClientTypeAPI)
}
// CreateProxy creates an HTTP client optimized for proxy operations
func (f *Factory) CreateProxy() (*http.Client, error) {
return f.CreateClientWithPreset(ClientTypeProxy)
}
// ValidateConfig validates HTTP client configuration parameters
func (f *Factory) ValidateConfig(config *Config) error {
// Validate connection pool limits
if config.MaxIdleConns < 0 {
return fmt.Errorf("MaxIdleConns cannot be negative: %d", config.MaxIdleConns)
}
if config.MaxIdleConns > 1000 {
return fmt.Errorf("MaxIdleConns too high (max 1000): %d", config.MaxIdleConns)
}
if config.MaxIdleConnsPerHost < 0 {
return fmt.Errorf("MaxIdleConnsPerHost cannot be negative: %d", config.MaxIdleConnsPerHost)
}
if config.MaxIdleConnsPerHost > 100 {
return fmt.Errorf("MaxIdleConnsPerHost too high (max 100): %d", config.MaxIdleConnsPerHost)
}
if config.MaxConnsPerHost < 0 {
return fmt.Errorf("MaxConnsPerHost cannot be negative: %d", config.MaxConnsPerHost)
}
if config.MaxConnsPerHost > 200 {
return fmt.Errorf("MaxConnsPerHost too high (max 200): %d", config.MaxConnsPerHost)
}
// Validate timeouts
if config.Timeout < 0 {
return fmt.Errorf("timeout cannot be negative")
}
if config.Timeout > 5*time.Minute {
return fmt.Errorf("timeout too long (max 5 minutes): %v", config.Timeout)
}
// Validate buffer sizes
if config.WriteBufferSize < 0 || config.ReadBufferSize < 0 {
return fmt.Errorf("buffer sizes cannot be negative")
}
if config.WriteBufferSize > 1024*1024 || config.ReadBufferSize > 1024*1024 {
return fmt.Errorf("buffer sizes too large (max 1MB)")
}
return nil
}
// createSecureTLSConfig creates a secure TLS configuration
func (f *Factory) createSecureTLSConfig() *tls.Config {
return &tls.Config{
MinVersion: tls.VersionTLS12, // SECURITY: Enforce TLS 1.2 minimum
MaxVersion: tls.VersionTLS13, // Support up to TLS 1.3
CipherSuites: []uint16{
// TLS 1.3 cipher suites (automatically selected when TLS 1.3 is negotiated)
// TLS 1.2 secure cipher suites
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
},
InsecureSkipVerify: false, // SECURITY: Always verify certificates
PreferServerCipherSuites: false, // Let client choose best cipher
}
}
// TransportPool manages a pool of shared HTTP transports
type TransportPool struct {
mu sync.RWMutex
transports map[string]*sharedTransport
maxConns int
ctx context.Context
cancel context.CancelFunc
// Resource limits
clientCount int32 // Track total HTTP clients
maxClients int32 // Limit total clients
}
type sharedTransport struct {
transport *http.Transport
refCount int32
lastUsed time.Time
config Config
}
var (
globalTransportPool *TransportPool
globalTransportPoolOnce sync.Once
)
// GetGlobalTransportPool returns the singleton transport pool instance
func GetGlobalTransportPool() *TransportPool {
globalTransportPoolOnce.Do(func() {
ctx, cancel := context.WithCancel(context.Background())
globalTransportPool = &TransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20, // Reduced from 100 to prevent resource exhaustion
ctx: ctx,
cancel: cancel,
clientCount: 0,
maxClients: 5, // Maximum 5 HTTP clients
}
// Start cleanup goroutine with context cancellation
go globalTransportPool.cleanupIdleTransports(ctx)
})
return globalTransportPool
}
// GetOrCreateTransport gets or creates a shared transport with the given config
func (p *TransportPool) GetOrCreateTransport(config Config) *http.Transport {
// Check client limit before creating new transport
if atomic.LoadInt32(&p.clientCount) >= p.maxClients {
// Try to return existing transport if limit reached
p.mu.RLock()
defer p.mu.RUnlock()
for _, shared := range p.transports {
if shared != nil && shared.transport != nil {
atomic.AddInt32(&shared.refCount, 1)
shared.lastUsed = time.Now()
return shared.transport
}
}
// If no transport available, return nil
return nil
}
p.mu.Lock()
defer p.mu.Unlock()
key := p.configKey(config)
if shared, exists := p.transports[key]; exists {
atomic.AddInt32(&shared.refCount, 1)
shared.lastUsed = time.Now()
return shared.transport
}
// Create new transport
transport := p.createTransport(config)
p.transports[key] = &sharedTransport{
transport: transport,
refCount: 1,
lastUsed: time.Now(),
config: config,
}
atomic.AddInt32(&p.clientCount, 1)
return transport
}
// createTransport creates a new HTTP transport with the given configuration
func (p *TransportPool) createTransport(config Config) *http.Transport {
// Create secure TLS config if not provided
tlsConfig := config.TLSConfig
if tlsConfig == nil {
tlsConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
MaxVersion: tls.VersionTLS13,
}
}
return &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: config.DialTimeout,
KeepAlive: config.KeepAlive,
}).DialContext,
TLSClientConfig: tlsConfig,
TLSHandshakeTimeout: config.TLSHandshakeTimeout,
ResponseHeaderTimeout: config.ResponseHeaderTimeout,
ExpectContinueTimeout: config.ExpectContinueTimeout,
IdleConnTimeout: config.IdleConnTimeout,
MaxIdleConns: config.MaxIdleConns,
MaxIdleConnsPerHost: config.MaxIdleConnsPerHost,
MaxConnsPerHost: config.MaxConnsPerHost,
WriteBufferSize: config.WriteBufferSize,
ReadBufferSize: config.ReadBufferSize,
ForceAttemptHTTP2: config.ForceHTTP2,
DisableKeepAlives: config.DisableKeepAlives,
DisableCompression: config.DisableCompression,
}
}
// configKey generates a unique key for the configuration
func (p *TransportPool) configKey(config Config) string {
return fmt.Sprintf("%v-%d-%d-%d-%d-%v-%v-%v",
config.Timeout,
config.MaxIdleConns,
config.MaxIdleConnsPerHost,
config.MaxConnsPerHost,
config.MaxRedirects,
config.ForceHTTP2,
config.DisableKeepAlives,
config.DisableCompression,
)
}
// cleanupIdleTransports periodically cleans up idle transports
func (p *TransportPool) cleanupIdleTransports(ctx context.Context) {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
p.cleanupIdle()
}
}
}
// cleanupIdle removes idle transports with zero references
func (p *TransportPool) cleanupIdle() {
p.mu.Lock()
defer p.mu.Unlock()
now := time.Now()
var toRemove []string
for key, shared := range p.transports {
if atomic.LoadInt32(&shared.refCount) == 0 && now.Sub(shared.lastUsed) > 10*time.Minute {
if shared.transport != nil {
shared.transport.CloseIdleConnections()
}
toRemove = append(toRemove, key)
}
}
for _, key := range toRemove {
delete(p.transports, key)
atomic.AddInt32(&p.clientCount, -1)
}
}
// Release decrements the reference count for a transport
func (p *TransportPool) Release(transport *http.Transport) {
p.mu.RLock()
defer p.mu.RUnlock()
for _, shared := range p.transports {
if shared.transport == transport {
atomic.AddInt32(&shared.refCount, -1)
return
}
}
}
// Close shuts down the transport pool
func (p *TransportPool) Close() error {
p.cancel()
p.mu.Lock()
defer p.mu.Unlock()
for key, shared := range p.transports {
if shared.transport != nil {
shared.transport.CloseIdleConnections()
}
delete(p.transports, key)
}
atomic.StoreInt32(&p.clientCount, 0)
return nil
}
// noOpLogger provides a no-op logger implementation
type noOpLogger struct{}
func (l *noOpLogger) Debug(msg string) {}
func (l *noOpLogger) Debugf(format string, args ...interface{}) {}
func (l *noOpLogger) Info(msg string) {}
func (l *noOpLogger) Infof(format string, args ...interface{}) {}
func (l *noOpLogger) Error(msg string) {}
func (l *noOpLogger) Errorf(format string, args ...interface{}) {}
// Compatibility functions for backward compatibility
// CreateDefaultHTTPClient creates a default HTTP client
func CreateDefaultHTTPClient() *http.Client {
factory := GetGlobalFactory(nil)
client, _ := factory.CreateDefault()
return client
}
// CreateTokenHTTPClient creates an HTTP client optimized for token operations
func CreateTokenHTTPClient() *http.Client {
factory := GetGlobalFactory(nil)
client, _ := factory.CreateToken()
return client
}
// CreateHTTPClientWithConfig creates an HTTP client with custom configuration
func CreateHTTPClientWithConfig(config Config) *http.Client {
factory := GetGlobalFactory(nil)
client, _ := factory.CreateClient(config)
return client
}
@@ -1,408 +0,0 @@
package httpclient
import (
"context"
"crypto/tls"
"net/http"
"net/http/httptest"
"testing"
"time"
)
// TestCreateProxy tests the CreateProxy method
func TestCreateProxy(t *testing.T) {
factory := NewFactory(nil)
client, err := factory.CreateProxy()
if err != nil {
t.Fatalf("Failed to create proxy client: %v", err)
}
if client == nil {
t.Fatal("Expected non-nil proxy client")
}
// Verify proxy configuration specifics
if client.Timeout != 60*time.Second {
t.Errorf("Expected proxy timeout to be 60s, got %v", client.Timeout)
}
}
// TestValidateConfigEdgeCases tests additional validation scenarios
func TestValidateConfigEdgeCases(t *testing.T) {
factory := NewFactory(nil)
testCases := []struct {
name string
config Config
shouldFail bool
errorMsg string
}{
{
name: "Negative MaxIdleConnsPerHost",
config: Config{
MaxIdleConnsPerHost: -1,
},
shouldFail: true,
errorMsg: "MaxIdleConnsPerHost cannot be negative",
},
{
name: "Excessive MaxIdleConnsPerHost",
config: Config{
MaxIdleConnsPerHost: 200,
},
shouldFail: true,
errorMsg: "MaxIdleConnsPerHost too high",
},
{
name: "Negative MaxConnsPerHost",
config: Config{
MaxConnsPerHost: -1,
},
shouldFail: true,
errorMsg: "MaxConnsPerHost cannot be negative",
},
{
name: "Excessive MaxConnsPerHost",
config: Config{
MaxConnsPerHost: 300,
},
shouldFail: true,
errorMsg: "MaxConnsPerHost too high",
},
{
name: "Negative WriteBufferSize",
config: Config{
WriteBufferSize: -1,
},
shouldFail: true,
errorMsg: "buffer sizes cannot be negative",
},
{
name: "Negative ReadBufferSize",
config: Config{
ReadBufferSize: -1,
},
shouldFail: true,
errorMsg: "buffer sizes cannot be negative",
},
{
name: "Excessive WriteBufferSize",
config: Config{
WriteBufferSize: 2 * 1024 * 1024,
},
shouldFail: true,
errorMsg: "buffer sizes too large",
},
{
name: "Excessive ReadBufferSize",
config: Config{
ReadBufferSize: 2 * 1024 * 1024,
},
shouldFail: true,
errorMsg: "buffer sizes too large",
},
{
name: "Valid edge values",
config: Config{
MaxIdleConns: 1000,
MaxIdleConnsPerHost: 100,
MaxConnsPerHost: 200,
Timeout: 5 * time.Minute,
WriteBufferSize: 1024 * 1024,
ReadBufferSize: 1024 * 1024,
},
shouldFail: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := factory.ValidateConfig(&tc.config)
if tc.shouldFail {
if err == nil {
t.Fatalf("Expected validation to fail with message containing: %s", tc.errorMsg)
}
} else {
if err != nil {
t.Fatalf("Unexpected validation error: %v", err)
}
}
})
}
}
// TestTransportPoolClose tests the Close method of TransportPool
func TestTransportPoolClose(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
pool := &TransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
ctx: ctx,
cancel: cancel,
clientCount: 0,
maxClients: 5,
}
// Create some transports
config := PresetConfigs[ClientTypeDefault]
transport1 := pool.GetOrCreateTransport(config)
if transport1 == nil {
t.Fatal("Failed to create transport")
}
// Modify config slightly to create a different transport
config.Timeout = 20 * time.Second
transport2 := pool.GetOrCreateTransport(config)
if transport2 == nil {
t.Fatal("Failed to create second transport")
}
// Verify transports were created
pool.mu.RLock()
initialCount := len(pool.transports)
pool.mu.RUnlock()
if initialCount == 0 {
t.Fatal("Expected transports to be created")
}
// Close the pool
err := pool.Close()
if err != nil {
t.Fatalf("Failed to close pool: %v", err)
}
// Verify all transports were removed
pool.mu.RLock()
finalCount := len(pool.transports)
pool.mu.RUnlock()
if finalCount != 0 {
t.Fatalf("Expected 0 transports after close, got %d", finalCount)
}
// Verify client count was reset
if pool.clientCount != 0 {
t.Fatalf("Expected client count to be 0 after close, got %d", pool.clientCount)
}
}
// TestNoOpLogger tests the no-op logger implementation
func TestNoOpLogger(t *testing.T) {
logger := &noOpLogger{}
// These should not panic or cause any issues
logger.Debug("test debug")
logger.Debugf("test debug %s", "formatted")
logger.Info("test info")
logger.Infof("test info %s", "formatted")
logger.Error("test error")
logger.Errorf("test error %s", "formatted")
// Test using logger with factory
factory := NewFactory(logger)
client, err := factory.CreateDefault()
if err != nil {
t.Fatalf("Failed to create client with no-op logger: %v", err)
}
if client == nil {
t.Fatal("Expected non-nil client")
}
}
// TestCreateClientWithCustomTLS tests creating client with custom TLS config
func TestCreateClientWithCustomTLS(t *testing.T) {
factory := NewFactory(nil)
customTLS := &tls.Config{
MinVersion: tls.VersionTLS13,
MaxVersion: tls.VersionTLS13,
}
config := Config{
Timeout: 10 * time.Second,
MaxIdleConns: 10,
MaxIdleConnsPerHost: 2,
MaxConnsPerHost: 5,
TLSConfig: customTLS,
}
client, err := factory.CreateClient(config)
if err != nil {
t.Fatalf("Failed to create client with custom TLS: %v", err)
}
if client == nil {
t.Fatal("Expected non-nil client")
}
}
// TestCreateClientWithMaxRedirects tests redirect limiting
func TestCreateClientWithMaxRedirects(t *testing.T) {
redirectCount := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
redirectCount++
if redirectCount <= 3 {
http.Redirect(w, r, "/redirect", http.StatusFound)
} else {
w.WriteHeader(http.StatusOK)
w.Write([]byte("final"))
}
}))
defer server.Close()
factory := NewFactory(nil)
// Test with max redirects = 2 (should fail)
config := Config{
Timeout: 10 * time.Second,
MaxRedirects: 2,
MaxIdleConns: 10,
MaxIdleConnsPerHost: 2,
MaxConnsPerHost: 5,
}
client, err := factory.CreateClient(config)
if err != nil {
t.Fatalf("Failed to create client: %v", err)
}
redirectCount = 0
_, err = client.Get(server.URL)
if err == nil {
t.Fatal("Expected redirect limit error")
}
// Test with max redirects = 5 (should succeed)
config.MaxRedirects = 5
client, err = factory.CreateClient(config)
if err != nil {
t.Fatalf("Failed to create client: %v", err)
}
redirectCount = 0
resp, err := client.Get(server.URL)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("Expected status 200, got %d", resp.StatusCode)
}
}
// TestTransportPoolMaxClientsLimit tests the max clients limitation
func TestTransportPoolMaxClientsLimit(t *testing.T) {
pool := &TransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
clientCount: 0,
maxClients: 2, // Set low limit for testing
}
// Create transports up to the limit
configs := []Config{
{Timeout: 10 * time.Second},
{Timeout: 20 * time.Second},
{Timeout: 30 * time.Second}, // This should not create a new transport
}
for i, config := range configs {
transport := pool.GetOrCreateTransport(config)
if i < 2 {
if transport == nil {
t.Fatalf("Expected transport %d to be created", i)
}
// Transport created successfully within limit
} else {
// When limit is reached, should return existing transport or nil
if transport == nil {
// This is acceptable - nil when limit reached
t.Log("Transport creation blocked due to client limit")
}
}
}
// Verify client count doesn't exceed limit
if pool.clientCount > pool.maxClients {
t.Fatalf("Client count %d exceeds max %d", pool.clientCount, pool.maxClients)
}
}
// TestCleanupIdleTransportsContext tests cleanup goroutine with context
func TestCleanupIdleTransportsContext(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
pool := &TransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
ctx: ctx,
cancel: cancel,
clientCount: 0,
maxClients: 5,
}
// Start cleanup goroutine
done := make(chan bool)
go func() {
pool.cleanupIdleTransports(ctx)
done <- true
}()
// Give it a moment to start
time.Sleep(10 * time.Millisecond)
// Cancel context to stop cleanup
cancel()
// Wait for goroutine to exit
select {
case <-done:
// Success - goroutine exited
case <-time.After(1 * time.Second):
t.Fatal("Cleanup goroutine did not exit after context cancellation")
}
}
// TestFactoryWithLogger tests factory creation with custom logger
func TestFactoryWithLogger(t *testing.T) {
// Create a mock logger that implements the Logger interface
logger := &MockLogger{}
factory := NewFactory(logger)
if factory.logger == nil {
t.Fatal("Expected logger to be set")
}
}
// MockLogger for testing
type MockLogger struct {
debugCalled bool
debugfCalled bool
infoCalled bool
infofCalled bool
errorCalled bool
errorfCalled bool
}
func (m *MockLogger) Debug(msg string) { m.debugCalled = true }
func (m *MockLogger) Debugf(format string, args ...interface{}) { m.debugfCalled = true }
func (m *MockLogger) Info(msg string) { m.infoCalled = true }
func (m *MockLogger) Infof(format string, args ...interface{}) { m.infofCalled = true }
func (m *MockLogger) Error(msg string) { m.errorCalled = true }
func (m *MockLogger) Errorf(format string, args ...interface{}) { m.errorfCalled = true }
// TestCreateClientLogging tests that logger is called during client creation
func TestCreateClientLogging(t *testing.T) {
logger := &MockLogger{}
factory := NewFactory(logger)
client, err := factory.CreateDefault()
if err != nil {
t.Fatalf("Failed to create client: %v", err)
}
if client == nil {
t.Fatal("Expected non-nil client")
}
// Verify logger was called
if !logger.debugfCalled {
t.Error("Expected Debugf to be called during client creation")
}
}
-299
View File
@@ -1,299 +0,0 @@
package httpclient
import (
"net/http"
"net/http/httptest"
"sync"
"sync/atomic"
"testing"
"time"
)
func TestFactoryCreateClient(t *testing.T) {
factory := NewFactory(nil)
// Test creating default client
client, err := factory.CreateDefault()
if err != nil {
t.Fatalf("Failed to create default client: %v", err)
}
if client == nil {
t.Fatal("Expected non-nil client")
}
// Test creating token client
tokenClient, err := factory.CreateToken()
if err != nil {
t.Fatalf("Failed to create token client: %v", err)
}
if tokenClient == nil {
t.Fatal("Expected non-nil token client")
}
}
func TestFactoryCreateClientWithPreset(t *testing.T) {
factory := NewFactory(nil)
testCases := []struct {
name string
clientType ClientType
shouldFail bool
}{
{"Default", ClientTypeDefault, false},
{"Token", ClientTypeToken, false},
{"API", ClientTypeAPI, false},
{"Proxy", ClientTypeProxy, false},
{"Invalid", ClientType("invalid"), true},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
client, err := factory.CreateClientWithPreset(tc.clientType)
if tc.shouldFail {
if err == nil {
t.Fatal("Expected error for invalid client type")
}
} else {
if err != nil {
t.Fatalf("Failed to create %s client: %v", tc.clientType, err)
}
if client == nil {
t.Fatal("Expected non-nil client")
}
}
})
}
}
func TestFactoryValidateConfig(t *testing.T) {
factory := NewFactory(nil)
testCases := []struct {
name string
config Config
shouldFail bool
}{
{
name: "Valid config",
config: PresetConfigs[ClientTypeDefault],
shouldFail: false,
},
{
name: "Negative MaxIdleConns",
config: Config{
MaxIdleConns: -1,
},
shouldFail: true,
},
{
name: "Excessive MaxIdleConns",
config: Config{
MaxIdleConns: 2000,
},
shouldFail: true,
},
{
name: "Negative timeout",
config: Config{
Timeout: -1 * time.Second,
},
shouldFail: true,
},
{
name: "Excessive timeout",
config: Config{
Timeout: 10 * time.Minute,
},
shouldFail: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := factory.ValidateConfig(&tc.config)
if tc.shouldFail && err == nil {
t.Fatal("Expected validation to fail")
}
if !tc.shouldFail && err != nil {
t.Fatalf("Unexpected validation error: %v", err)
}
})
}
}
func TestTransportPoolConcurrency(t *testing.T) {
pool := &TransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
clientCount: 0,
maxClients: 5,
}
config := PresetConfigs[ClientTypeDefault]
var wg sync.WaitGroup
numGoroutines := 10
// Test concurrent transport creation
wg.Add(numGoroutines)
for i := 0; i < numGoroutines; i++ {
go func() {
defer wg.Done()
transport := pool.GetOrCreateTransport(config)
if transport != nil {
// Simulate usage
time.Sleep(10 * time.Millisecond)
pool.Release(transport)
}
}()
}
wg.Wait()
// Verify client count is within limits
clientCount := atomic.LoadInt32(&pool.clientCount)
if clientCount > pool.maxClients {
t.Fatalf("Client count %d exceeds max %d", clientCount, pool.maxClients)
}
}
func TestHTTPClientRequests(t *testing.T) {
// Create test server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("test response"))
}))
defer server.Close()
factory := NewFactory(nil)
client, err := factory.CreateDefault()
if err != nil {
t.Fatalf("Failed to create client: %v", err)
}
// Make request
resp, err := client.Get(server.URL)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("Expected status 200, got %d", resp.StatusCode)
}
}
func TestClientWithCookieJar(t *testing.T) {
config := PresetConfigs[ClientTypeToken]
if !config.UseCookieJar {
t.Skip("Token client should have cookie jar enabled")
}
factory := NewFactory(nil)
client, err := factory.CreateToken()
if err != nil {
t.Fatalf("Failed to create token client: %v", err)
}
if client.Jar == nil {
t.Fatal("Expected cookie jar to be set")
}
}
func TestTransportPoolCleanup(t *testing.T) {
pool := &TransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
clientCount: 0,
maxClients: 5,
}
config := PresetConfigs[ClientTypeDefault]
// Create transport
transport := pool.GetOrCreateTransport(config)
if transport == nil {
t.Fatal("Failed to create transport")
}
// Release transport
pool.Release(transport)
// Simulate idle time
pool.mu.Lock()
for _, shared := range pool.transports {
shared.lastUsed = time.Now().Add(-11 * time.Minute)
atomic.StoreInt32(&shared.refCount, 0)
}
pool.mu.Unlock()
// Run cleanup
pool.cleanupIdle()
// Verify transport was removed
pool.mu.RLock()
count := len(pool.transports)
pool.mu.RUnlock()
if count != 0 {
t.Fatalf("Expected 0 transports after cleanup, got %d", count)
}
}
func TestGlobalFactorySingleton(t *testing.T) {
factory1 := GetGlobalFactory(nil)
factory2 := GetGlobalFactory(nil)
if factory1 != factory2 {
t.Fatal("Expected singleton factory instances to be the same")
}
}
func TestCompatibilityFunctions(t *testing.T) {
// Test CreateDefaultHTTPClient
defaultClient := CreateDefaultHTTPClient()
if defaultClient == nil {
t.Fatal("Expected non-nil default client")
}
// Test CreateTokenHTTPClient
tokenClient := CreateTokenHTTPClient()
if tokenClient == nil {
t.Fatal("Expected non-nil token client")
}
// Test CreateHTTPClientWithConfig
config := PresetConfigs[ClientTypeAPI]
apiClient := CreateHTTPClientWithConfig(config)
if apiClient == nil {
t.Fatal("Expected non-nil API client")
}
}
func BenchmarkFactoryCreateClient(b *testing.B) {
factory := NewFactory(nil)
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
client, err := factory.CreateDefault()
if err != nil || client == nil {
b.Fatal("Failed to create client")
}
}
})
}
func BenchmarkTransportPoolGetOrCreate(b *testing.B) {
pool := GetGlobalTransportPool()
config := PresetConfigs[ClientTypeDefault]
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
transport := pool.GetOrCreateTransport(config)
if transport != nil {
pool.Release(transport)
}
}
})
}
-83
View File
@@ -1,83 +0,0 @@
package logger
import (
"fmt"
"log"
)
// LegacyLoggerAdapter wraps the old Logger struct from the main package
// to implement the new unified Logger interface. This allows for gradual
// migration of the codebase to the new logger interface.
type LegacyLoggerAdapter struct {
logError *log.Logger
logInfo *log.Logger
logDebug *log.Logger
}
// NewLegacyAdapter creates a new adapter from the old logger components
func NewLegacyAdapter(logError, logInfo, logDebug *log.Logger) Logger {
if logError == nil || logInfo == nil || logDebug == nil {
return GetNoOpLogger()
}
return &LegacyLoggerAdapter{
logError: logError,
logInfo: logInfo,
logDebug: logDebug,
}
}
// Debug logs a debug message
func (l *LegacyLoggerAdapter) Debug(msg string) {
l.logDebug.Print(msg)
}
// Debugf logs a formatted debug message
func (l *LegacyLoggerAdapter) Debugf(format string, args ...interface{}) {
l.logDebug.Printf(format, args...)
}
// Info logs an info message
func (l *LegacyLoggerAdapter) Info(msg string) {
l.logInfo.Print(msg)
}
// Infof logs a formatted info message
func (l *LegacyLoggerAdapter) Infof(format string, args ...interface{}) {
l.logInfo.Printf(format, args...)
}
// Error logs an error message
func (l *LegacyLoggerAdapter) Error(msg string) {
l.logError.Print(msg)
}
// Errorf logs a formatted error message
func (l *LegacyLoggerAdapter) Errorf(format string, args ...interface{}) {
l.logError.Printf(format, args...)
}
// Printf logs a formatted message at info level
func (l *LegacyLoggerAdapter) Printf(format string, args ...interface{}) {
l.logInfo.Printf(format, args...)
}
// Println logs a message at info level
func (l *LegacyLoggerAdapter) Println(args ...interface{}) {
l.logInfo.Print(args...)
}
// Fatalf logs a formatted error message and panics
func (l *LegacyLoggerAdapter) Fatalf(format string, args ...interface{}) {
l.logError.Printf(format, args...)
panic(fmt.Sprintf(format, args...))
}
// WithField returns the same logger (no structured logging support in legacy adapter)
func (l *LegacyLoggerAdapter) WithField(key string, value interface{}) Logger {
return l
}
// WithFields returns the same logger (no structured logging support in legacy adapter)
func (l *LegacyLoggerAdapter) WithFields(fields map[string]interface{}) Logger {
return l
}
-182
View File
@@ -1,182 +0,0 @@
package logger
import (
"io"
"os"
"sync"
)
// Factory creates and manages logger instances with singleton support
// for common logger types to reduce memory allocation.
type Factory struct {
mu sync.RWMutex
defaultLogger Logger
noOpLogger Logger
loggers map[string]Logger
defaultLogLevel string
}
var (
// globalFactory is the singleton factory instance
globalFactory *Factory
// factoryOnce ensures the factory is created only once
factoryOnce sync.Once
)
// GetFactory returns the global logger factory instance
func GetFactory() *Factory {
factoryOnce.Do(func() {
globalFactory = &Factory{
loggers: make(map[string]Logger),
defaultLogLevel: "info",
}
})
return globalFactory
}
// SetDefaultLogLevel sets the default log level for new loggers
func (f *Factory) SetDefaultLogLevel(level string) {
f.mu.Lock()
defer f.mu.Unlock()
f.defaultLogLevel = level
}
// GetLogger returns a logger for the given name, creating one if it doesn't exist
func (f *Factory) GetLogger(name string) Logger {
f.mu.RLock()
if logger, exists := f.loggers[name]; exists {
f.mu.RUnlock()
return logger
}
f.mu.RUnlock()
// Create new logger
f.mu.Lock()
defer f.mu.Unlock()
// Double check after acquiring write lock
if logger, exists := f.loggers[name]; exists {
return logger
}
logger := f.createLogger(name)
f.loggers[name] = logger
return logger
}
// createLogger creates a new logger instance
func (f *Factory) createLogger(name string) Logger {
if name == "noop" || name == "no-op" || name == "discard" {
return GetNoOpLogger()
}
// Create logger with appropriate outputs based on environment
var errorOut, infoOut, debugOut io.Writer
if os.Getenv("OIDC_LOG_TO_FILE") == "true" {
// Log to files if configured
errorOut = getOrCreateLogFile("error.log")
infoOut = getOrCreateLogFile("info.log")
debugOut = getOrCreateLogFile("debug.log")
} else {
// Default to stdout/stderr
errorOut = os.Stderr
infoOut = os.Stdout
debugOut = os.Stdout
}
return NewStandardLogger(f.defaultLogLevel, errorOut, infoOut, debugOut)
}
// GetDefaultLogger returns the default logger instance
func (f *Factory) GetDefaultLogger() Logger {
f.mu.RLock()
if f.defaultLogger != nil {
f.mu.RUnlock()
return f.defaultLogger
}
f.mu.RUnlock()
f.mu.Lock()
defer f.mu.Unlock()
if f.defaultLogger == nil {
f.defaultLogger = f.createLogger("default")
}
return f.defaultLogger
}
// GetNoOpLogger returns the singleton no-op logger
func (f *Factory) GetNoOpLogger() Logger {
f.mu.RLock()
if f.noOpLogger != nil {
f.mu.RUnlock()
return f.noOpLogger
}
f.mu.RUnlock()
f.mu.Lock()
defer f.mu.Unlock()
if f.noOpLogger == nil {
f.noOpLogger = GetNoOpLogger()
}
return f.noOpLogger
}
// Clear removes all cached loggers (useful for testing)
func (f *Factory) Clear() {
f.mu.Lock()
defer f.mu.Unlock()
f.loggers = make(map[string]Logger)
f.defaultLogger = nil
// Don't clear noOpLogger as it's a singleton
}
// getOrCreateLogFile returns a file writer for the given log file
func getOrCreateLogFile(filename string) io.Writer {
logDir := os.Getenv("OIDC_LOG_DIR")
if logDir == "" {
logDir = "/var/log/traefik-oidc"
}
// Ensure log directory exists
if err := os.MkdirAll(logDir, 0755); err != nil {
// Fall back to stderr if we can't create the directory
return os.Stderr
}
filepath := logDir + "/" + filename
file, err := os.OpenFile(filepath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
if err != nil {
// Fall back to stderr if we can't open the file
return os.Stderr
}
return file
}
// Global convenience functions
// New creates a new logger with the specified level
func New(level string) Logger {
return GetFactory().GetLogger(level)
}
// Default returns the default logger
func Default() Logger {
return GetFactory().GetDefaultLogger()
}
// NoOp returns a no-op logger
func NoOp() Logger {
return GetFactory().GetNoOpLogger()
}
// WithLevel creates a new logger with the specified level
func WithLevel(level string) Logger {
return NewStandardLogger(level, os.Stderr, os.Stdout, os.Stdout)
}
-312
View File
@@ -1,312 +0,0 @@
// Package logger provides a unified logging interface for the entire application.
// It consolidates all the duplicate logger interfaces into a single, comprehensive
// interface that supports different log levels and structured logging.
package logger
import (
"fmt"
"io"
"log"
"sync"
)
// Logger is the unified interface for all logging operations in the application.
// It combines all the methods from the various logger interfaces that were
// previously scattered across different packages.
type Logger interface {
// Basic logging methods
Debug(msg string)
Debugf(format string, args ...interface{})
Info(msg string)
Infof(format string, args ...interface{})
Error(msg string)
Errorf(format string, args ...interface{})
// Additional methods for compatibility with existing code
Printf(format string, args ...interface{})
Println(args ...interface{})
Fatalf(format string, args ...interface{})
// Structured logging support
WithField(key string, value interface{}) Logger
WithFields(fields map[string]interface{}) Logger
}
// StandardLogger implements the Logger interface using Go's standard log package.
// It provides thread-safe logging with different output streams for different log levels.
type StandardLogger struct {
mu sync.RWMutex
logError *log.Logger
logInfo *log.Logger
logDebug *log.Logger
fields map[string]interface{}
level LogLevel
}
// LogLevel represents the logging level
type LogLevel int
const (
// LogLevelDebug enables all log messages
LogLevelDebug LogLevel = iota
// LogLevelInfo enables info and error messages
LogLevelInfo
// LogLevelError enables only error messages
LogLevelError
// LogLevelNone disables all logging
LogLevelNone
)
// ParseLogLevel converts a string log level to LogLevel
func ParseLogLevel(level string) LogLevel {
switch level {
case "debug", "DEBUG":
return LogLevelDebug
case "info", "INFO":
return LogLevelInfo
case "error", "ERROR":
return LogLevelError
case "none", "NONE":
return LogLevelNone
default:
return LogLevelInfo
}
}
// NewStandardLogger creates a new StandardLogger with the specified log level
func NewStandardLogger(level string, errorOutput, infoOutput, debugOutput io.Writer) *StandardLogger {
logLevel := ParseLogLevel(level)
if errorOutput == nil {
errorOutput = io.Discard
}
if infoOutput == nil {
infoOutput = io.Discard
}
if debugOutput == nil {
debugOutput = io.Discard
}
return &StandardLogger{
logError: log.New(errorOutput, "ERROR: ", log.Ldate|log.Ltime|log.Lshortfile),
logInfo: log.New(infoOutput, "INFO: ", log.Ldate|log.Ltime),
logDebug: log.New(debugOutput, "DEBUG: ", log.Ldate|log.Ltime|log.Lshortfile),
fields: make(map[string]interface{}),
level: logLevel,
}
}
// Debug logs a debug message
func (l *StandardLogger) Debug(msg string) {
if l.level <= LogLevelDebug {
l.mu.RLock()
defer l.mu.RUnlock()
if len(l.fields) > 0 {
msg = l.formatWithFields(msg)
}
l.logDebug.Print(msg)
}
}
// Debugf logs a formatted debug message
func (l *StandardLogger) Debugf(format string, args ...interface{}) {
if l.level <= LogLevelDebug {
l.mu.RLock()
defer l.mu.RUnlock()
msg := fmt.Sprintf(format, args...)
if len(l.fields) > 0 {
msg = l.formatWithFields(msg)
}
l.logDebug.Print(msg)
}
}
// Info logs an info message
func (l *StandardLogger) Info(msg string) {
if l.level <= LogLevelInfo {
l.mu.RLock()
defer l.mu.RUnlock()
if len(l.fields) > 0 {
msg = l.formatWithFields(msg)
}
l.logInfo.Print(msg)
}
}
// Infof logs a formatted info message
func (l *StandardLogger) Infof(format string, args ...interface{}) {
if l.level <= LogLevelInfo {
l.mu.RLock()
defer l.mu.RUnlock()
msg := fmt.Sprintf(format, args...)
if len(l.fields) > 0 {
msg = l.formatWithFields(msg)
}
l.logInfo.Print(msg)
}
}
// Error logs an error message
func (l *StandardLogger) Error(msg string) {
if l.level <= LogLevelError {
l.mu.RLock()
defer l.mu.RUnlock()
if len(l.fields) > 0 {
msg = l.formatWithFields(msg)
}
l.logError.Print(msg)
}
}
// Errorf logs a formatted error message
func (l *StandardLogger) Errorf(format string, args ...interface{}) {
if l.level <= LogLevelError {
l.mu.RLock()
defer l.mu.RUnlock()
msg := fmt.Sprintf(format, args...)
if len(l.fields) > 0 {
msg = l.formatWithFields(msg)
}
l.logError.Print(msg)
}
}
// Printf logs a formatted message at info level
func (l *StandardLogger) Printf(format string, args ...interface{}) {
l.Infof(format, args...)
}
// Println logs a message at info level
func (l *StandardLogger) Println(args ...interface{}) {
l.Info(fmt.Sprint(args...))
}
// Fatalf logs a formatted error message and exits the program
func (l *StandardLogger) Fatalf(format string, args ...interface{}) {
l.Errorf(format, args...)
panic(fmt.Sprintf(format, args...))
}
// WithField returns a new logger with an additional field
func (l *StandardLogger) WithField(key string, value interface{}) Logger {
l.mu.Lock()
defer l.mu.Unlock()
newLogger := &StandardLogger{
logError: l.logError,
logInfo: l.logInfo,
logDebug: l.logDebug,
fields: make(map[string]interface{}, len(l.fields)+1),
level: l.level,
}
for k, v := range l.fields {
newLogger.fields[k] = v
}
newLogger.fields[key] = value
return newLogger
}
// WithFields returns a new logger with additional fields
func (l *StandardLogger) WithFields(fields map[string]interface{}) Logger {
l.mu.Lock()
defer l.mu.Unlock()
newLogger := &StandardLogger{
logError: l.logError,
logInfo: l.logInfo,
logDebug: l.logDebug,
fields: make(map[string]interface{}, len(l.fields)+len(fields)),
level: l.level,
}
for k, v := range l.fields {
newLogger.fields[k] = v
}
for k, v := range fields {
newLogger.fields[k] = v
}
return newLogger
}
// formatWithFields formats a message with structured fields
func (l *StandardLogger) formatWithFields(msg string) string {
if len(l.fields) == 0 {
return msg
}
fieldsStr := ""
for k, v := range l.fields {
if fieldsStr != "" {
fieldsStr += " "
}
fieldsStr += fmt.Sprintf("%s=%v", k, v)
}
return fmt.Sprintf("%s [%s]", msg, fieldsStr)
}
// NoOpLogger is a logger that discards all output.
// It's useful for testing and for cases where logging should be disabled.
type NoOpLogger struct{}
// Debug discards the message
func (n *NoOpLogger) Debug(msg string) {}
// Debugf discards the formatted message
func (n *NoOpLogger) Debugf(format string, args ...interface{}) {}
// Info discards the message
func (n *NoOpLogger) Info(msg string) {}
// Infof discards the formatted message
func (n *NoOpLogger) Infof(format string, args ...interface{}) {}
// Error discards the message
func (n *NoOpLogger) Error(msg string) {}
// Errorf discards the formatted message
func (n *NoOpLogger) Errorf(format string, args ...interface{}) {}
// Printf discards the formatted message
func (n *NoOpLogger) Printf(format string, args ...interface{}) {}
// Println discards the message
func (n *NoOpLogger) Println(args ...interface{}) {}
// Fatalf discards the message and does not exit
func (n *NoOpLogger) Fatalf(format string, args ...interface{}) {}
// WithField returns the same NoOpLogger
func (n *NoOpLogger) WithField(key string, value interface{}) Logger {
return n
}
// WithFields returns the same NoOpLogger
func (n *NoOpLogger) WithFields(fields map[string]interface{}) Logger {
return n
}
var (
// singletonNoOpLogger is the global instance of the no-op logger
singletonNoOpLogger *NoOpLogger
// noOpLoggerOnce ensures the singleton is created only once
noOpLoggerOnce sync.Once
)
// GetNoOpLogger returns the singleton no-op logger instance.
// This reduces memory allocation by reusing the same no-op logger
// instance across the entire application.
func GetNoOpLogger() Logger {
noOpLoggerOnce.Do(func() {
singletonNoOpLogger = &NoOpLogger{}
})
return singletonNoOpLogger
}
// DefaultLogger creates a default logger based on the provided configuration
func DefaultLogger(level string) Logger {
return NewStandardLogger(level, log.Writer(), log.Writer(), log.Writer())
}
File diff suppressed because it is too large Load Diff
-122
View File
@@ -1,122 +0,0 @@
package middleware
import (
"fmt"
"net/http"
"strings"
"time"
)
// RequestContext holds request processing context
type RequestContext struct {
Writer http.ResponseWriter
Request *http.Request
RedirectURL string
Scheme string
Host string
}
// RequestProcessor handles common request processing operations
type RequestProcessor struct {
logger Logger
}
// Logger interface for logging operations
type Logger interface {
Debug(msg string)
Debugf(format string, args ...interface{})
Error(msg string)
Errorf(format string, args ...interface{})
Info(msg string)
Infof(format string, args ...interface{})
}
// NewRequestProcessor creates a new request processor
func NewRequestProcessor(logger Logger) *RequestProcessor {
return &RequestProcessor{
logger: logger,
}
}
// BuildRequestContext creates a request context with scheme and host detection
func (rp *RequestProcessor) BuildRequestContext(rw http.ResponseWriter, req *http.Request, redirectPath string) *RequestContext {
scheme := rp.determineScheme(req)
host := rp.determineHost(req)
redirectURL := buildFullURL(scheme, host, redirectPath)
return &RequestContext{
Writer: rw,
Request: req,
RedirectURL: redirectURL,
Scheme: scheme,
Host: host,
}
}
// IsHealthCheckRequest checks if request is a health check
func (rp *RequestProcessor) IsHealthCheckRequest(req *http.Request) bool {
return strings.HasPrefix(req.URL.Path, "/health")
}
// IsEventStreamRequest checks if request expects event stream
func (rp *RequestProcessor) IsEventStreamRequest(req *http.Request) bool {
acceptHeader := req.Header.Get("Accept")
return strings.Contains(acceptHeader, "text/event-stream")
}
// IsAjaxRequest determines if this is an AJAX request
func (rp *RequestProcessor) IsAjaxRequest(req *http.Request) bool {
xhr := req.Header.Get("X-Requested-With")
contentType := req.Header.Get("Content-Type")
accept := req.Header.Get("Accept")
return xhr == "XMLHttpRequest" ||
strings.Contains(contentType, "application/json") ||
strings.Contains(accept, "application/json")
}
// WaitForInitialization waits for OIDC provider initialization with timeout
func (rp *RequestProcessor) WaitForInitialization(req *http.Request, initComplete <-chan struct{}) error {
select {
case <-initComplete:
return nil
case <-req.Context().Done():
rp.logger.Debug("Request canceled while waiting for OIDC initialization")
return fmt.Errorf("request canceled")
case <-time.After(30 * time.Second):
rp.logger.Error("Timeout waiting for OIDC initialization")
return fmt.Errorf("timeout waiting for OIDC provider initialization")
}
}
// determineScheme determines the URL scheme for building redirect URLs
func (rp *RequestProcessor) determineScheme(req *http.Request) string {
if scheme := req.Header.Get("X-Forwarded-Proto"); scheme != "" {
return scheme
}
if req.TLS != nil {
return "https"
}
return "http"
}
// determineHost determines the host for building redirect URLs
func (rp *RequestProcessor) determineHost(req *http.Request) string {
if host := req.Header.Get("X-Forwarded-Host"); host != "" {
return host
}
return req.Host
}
// buildFullURL constructs a complete URL from scheme, host, and path components
func buildFullURL(scheme, host, path string) string {
if strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://") {
return path
}
if !strings.HasPrefix(path, "/") {
path = "/" + path
}
return fmt.Sprintf("%s://%s%s", scheme, host, path)
}
-655
View File
@@ -1,655 +0,0 @@
package middleware
import (
"context"
"crypto/tls"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
)
// MockLogger implements the Logger interface for testing
type MockLogger struct {
DebugCalls []string
DebugfCalls []string
ErrorCalls []string
ErrorfCalls []string
InfoCalls []string
InfofCalls []string
}
func (m *MockLogger) Debug(msg string) {
m.DebugCalls = append(m.DebugCalls, msg)
}
func (m *MockLogger) Debugf(format string, args ...interface{}) {
m.DebugfCalls = append(m.DebugfCalls, format)
}
func (m *MockLogger) Error(msg string) {
m.ErrorCalls = append(m.ErrorCalls, msg)
}
func (m *MockLogger) Errorf(format string, args ...interface{}) {
m.ErrorfCalls = append(m.ErrorfCalls, format)
}
func (m *MockLogger) Info(msg string) {
m.InfoCalls = append(m.InfoCalls, msg)
}
func (m *MockLogger) Infof(format string, args ...interface{}) {
m.InfofCalls = append(m.InfofCalls, format)
}
// TestNewRequestProcessor tests the constructor
func TestNewRequestProcessor(t *testing.T) {
logger := &MockLogger{}
processor := NewRequestProcessor(logger)
if processor == nil {
t.Error("Expected NewRequestProcessor to return non-nil processor")
return
}
if processor.logger != logger {
t.Error("Expected processor to use provided logger")
}
}
// TestBuildRequestContext tests request context building
func TestBuildRequestContext(t *testing.T) {
logger := &MockLogger{}
processor := NewRequestProcessor(logger)
tests := []struct {
name string
setupRequest func() (*http.Request, http.ResponseWriter)
redirectPath string
expectedURL string
expectedHost string
}{
{
name: "Basic HTTP request",
setupRequest: func() (*http.Request, http.ResponseWriter) {
req := httptest.NewRequest("GET", "http://example.com/test", nil)
rw := httptest.NewRecorder()
return req, rw
},
redirectPath: "/callback",
expectedURL: "http://example.com/callback",
expectedHost: "example.com",
},
{
name: "HTTPS request with TLS",
setupRequest: func() (*http.Request, http.ResponseWriter) {
req := httptest.NewRequest("GET", "https://secure.com/test", nil)
req.TLS = &tls.ConnectionState{} // Simulate TLS
rw := httptest.NewRecorder()
return req, rw
},
redirectPath: "/auth",
expectedURL: "https://secure.com/auth",
expectedHost: "secure.com",
},
{
name: "Request with X-Forwarded-Proto header",
setupRequest: func() (*http.Request, http.ResponseWriter) {
req := httptest.NewRequest("GET", "http://internal.com/test", nil)
req.Header.Set("X-Forwarded-Proto", "https")
rw := httptest.NewRecorder()
return req, rw
},
redirectPath: "/callback",
expectedURL: "https://internal.com/callback",
expectedHost: "internal.com",
},
{
name: "Request with X-Forwarded-Host header",
setupRequest: func() (*http.Request, http.ResponseWriter) {
req := httptest.NewRequest("GET", "http://internal.com/test", nil)
req.Header.Set("X-Forwarded-Host", "public.com")
rw := httptest.NewRecorder()
return req, rw
},
redirectPath: "/callback",
expectedURL: "http://public.com/callback",
expectedHost: "public.com",
},
{
name: "Request with both forwarded headers",
setupRequest: func() (*http.Request, http.ResponseWriter) {
req := httptest.NewRequest("GET", "http://internal.com/test", nil)
req.Header.Set("X-Forwarded-Proto", "https")
req.Header.Set("X-Forwarded-Host", "public.com")
rw := httptest.NewRecorder()
return req, rw
},
redirectPath: "/auth",
expectedURL: "https://public.com/auth",
expectedHost: "public.com",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req, rw := tt.setupRequest()
ctx := processor.BuildRequestContext(rw, req, tt.redirectPath)
if ctx == nil {
t.Error("Expected BuildRequestContext to return non-nil context")
return
}
if ctx.Writer != rw {
t.Error("Expected context writer to match provided writer")
}
if ctx.Request != req {
t.Error("Expected context request to match provided request")
}
if ctx.RedirectURL != tt.expectedURL {
t.Errorf("Expected redirect URL '%s', got '%s'", tt.expectedURL, ctx.RedirectURL)
}
if ctx.Host != tt.expectedHost {
t.Errorf("Expected host '%s', got '%s'", tt.expectedHost, ctx.Host)
}
})
}
}
// TestIsHealthCheckRequest tests health check detection
func TestIsHealthCheckRequest(t *testing.T) {
logger := &MockLogger{}
processor := NewRequestProcessor(logger)
tests := []struct {
name string
path string
expected bool
}{
{
name: "Health check path",
path: "/health",
expected: true,
},
{
name: "Health check subpath",
path: "/health/status",
expected: true,
},
{
name: "Health check with query params",
path: "/health?check=db",
expected: true,
},
{
name: "Not a health check",
path: "/api/users",
expected: false,
},
{
name: "Health-related path (matches prefix)",
path: "/healthiness",
expected: true, // HasPrefix behavior - this actually matches
},
{
name: "Root path",
path: "/",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com"+tt.path, nil)
result := processor.IsHealthCheckRequest(req)
if result != tt.expected {
t.Errorf("Expected IsHealthCheckRequest to return %v for path '%s', got %v", tt.expected, tt.path, result)
}
})
}
}
// TestIsEventStreamRequest tests event stream detection
func TestIsEventStreamRequest(t *testing.T) {
logger := &MockLogger{}
processor := NewRequestProcessor(logger)
tests := []struct {
name string
acceptHeader string
expected bool
}{
{
name: "Event stream accept header",
acceptHeader: "text/event-stream",
expected: true,
},
{
name: "Event stream with other types",
acceptHeader: "text/html, text/event-stream, application/json",
expected: true,
},
{
name: "JSON accept header",
acceptHeader: "application/json",
expected: false,
},
{
name: "HTML accept header",
acceptHeader: "text/html,application/xhtml+xml",
expected: false,
},
{
name: "Empty accept header",
acceptHeader: "",
expected: false,
},
{
name: "Similar but not event stream",
acceptHeader: "text/event-source",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com/test", nil)
if tt.acceptHeader != "" {
req.Header.Set("Accept", tt.acceptHeader)
}
result := processor.IsEventStreamRequest(req)
if result != tt.expected {
t.Errorf("Expected IsEventStreamRequest to return %v for accept header '%s', got %v", tt.expected, tt.acceptHeader, result)
}
})
}
}
// TestIsAjaxRequest tests AJAX request detection
func TestIsAjaxRequest(t *testing.T) {
logger := &MockLogger{}
processor := NewRequestProcessor(logger)
tests := []struct {
name string
setupHeader func(*http.Request)
expected bool
}{
{
name: "XMLHttpRequest header",
setupHeader: func(req *http.Request) {
req.Header.Set("X-Requested-With", "XMLHttpRequest")
},
expected: true,
},
{
name: "JSON content type",
setupHeader: func(req *http.Request) {
req.Header.Set("Content-Type", "application/json")
},
expected: true,
},
{
name: "JSON content type with charset",
setupHeader: func(req *http.Request) {
req.Header.Set("Content-Type", "application/json; charset=utf-8")
},
expected: true,
},
{
name: "JSON accept header",
setupHeader: func(req *http.Request) {
req.Header.Set("Accept", "application/json")
},
expected: true,
},
{
name: "JSON accept with other types",
setupHeader: func(req *http.Request) {
req.Header.Set("Accept", "text/html, application/json, application/xml")
},
expected: true,
},
{
name: "Multiple AJAX indicators",
setupHeader: func(req *http.Request) {
req.Header.Set("X-Requested-With", "XMLHttpRequest")
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
},
expected: true,
},
{
name: "Regular HTML request",
setupHeader: func(req *http.Request) {
req.Header.Set("Accept", "text/html,application/xhtml+xml")
},
expected: false,
},
{
name: "Form submission",
setupHeader: func(req *http.Request) {
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
},
expected: false,
},
{
name: "No special headers",
setupHeader: func(req *http.Request) {},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("POST", "http://example.com/api", nil)
tt.setupHeader(req)
result := processor.IsAjaxRequest(req)
if result != tt.expected {
t.Errorf("Expected IsAjaxRequest to return %v, got %v", tt.expected, result)
}
})
}
}
// TestWaitForInitialization tests initialization waiting
func TestWaitForInitialization(t *testing.T) {
logger := &MockLogger{}
processor := NewRequestProcessor(logger)
t.Run("Initialization completes successfully", func(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com/test", nil)
initComplete := make(chan struct{})
go func() {
time.Sleep(10 * time.Millisecond)
close(initComplete)
}()
err := processor.WaitForInitialization(req, initComplete)
if err != nil {
t.Errorf("Expected no error when initialization completes, got: %v", err)
}
})
t.Run("Request context canceled", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
req := httptest.NewRequest("GET", "http://example.com/test", nil)
req = req.WithContext(ctx)
initComplete := make(chan struct{})
go func() {
time.Sleep(10 * time.Millisecond)
cancel()
}()
err := processor.WaitForInitialization(req, initComplete)
if err == nil {
t.Error("Expected error when request context is canceled")
}
if !strings.Contains(err.Error(), "request canceled") {
t.Errorf("Expected 'request canceled' error, got: %v", err)
}
if len(logger.DebugCalls) == 0 {
t.Error("Expected debug log when request is canceled")
}
})
t.Run("Initialization timeout", func(t *testing.T) {
if testing.Short() {
t.Skip("Skipping timeout test in short mode")
}
req := httptest.NewRequest("GET", "http://example.com/test", nil)
initComplete := make(chan struct{}) // Never closes
// Note: This test takes 30 seconds due to hardcoded timeout in implementation
start := time.Now()
err := processor.WaitForInitialization(req, initComplete)
duration := time.Since(start)
if err == nil {
t.Error("Expected timeout error")
}
if !strings.Contains(err.Error(), "timeout") {
t.Errorf("Expected timeout error, got: %v", err)
}
// The timeout should be around 30 seconds, allow some variance
if duration < 29*time.Second || duration > 31*time.Second {
t.Errorf("Expected timeout after ~30 seconds, but got %v", duration)
}
if len(logger.ErrorCalls) == 0 {
t.Error("Expected error log when timeout occurs")
}
})
}
// TestDetermineScheme tests scheme determination
func TestDetermineScheme(t *testing.T) {
logger := &MockLogger{}
processor := NewRequestProcessor(logger)
tests := []struct {
name string
setup func(*http.Request)
expected string
}{
{
name: "X-Forwarded-Proto HTTPS",
setup: func(req *http.Request) {
req.Header.Set("X-Forwarded-Proto", "https")
},
expected: "https",
},
{
name: "X-Forwarded-Proto HTTP",
setup: func(req *http.Request) {
req.Header.Set("X-Forwarded-Proto", "http")
},
expected: "http",
},
{
name: "TLS connection without header",
setup: func(req *http.Request) {
req.TLS = &tls.ConnectionState{}
},
expected: "https",
},
{
name: "No TLS, no header",
setup: func(req *http.Request) {
// No special setup
},
expected: "http",
},
{
name: "X-Forwarded-Proto takes precedence over TLS",
setup: func(req *http.Request) {
req.Header.Set("X-Forwarded-Proto", "http")
req.TLS = &tls.ConnectionState{}
},
expected: "http",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com/test", nil)
tt.setup(req)
result := processor.determineScheme(req)
if result != tt.expected {
t.Errorf("Expected scheme '%s', got '%s'", tt.expected, result)
}
})
}
}
// TestDetermineHost tests host determination
func TestDetermineHost(t *testing.T) {
logger := &MockLogger{}
processor := NewRequestProcessor(logger)
tests := []struct {
name string
setup func(*http.Request)
expected string
}{
{
name: "X-Forwarded-Host header present",
setup: func(req *http.Request) {
req.Header.Set("X-Forwarded-Host", "public.example.com")
},
expected: "public.example.com",
},
{
name: "No X-Forwarded-Host, use req.Host",
setup: func(req *http.Request) {
// No special setup, will use req.Host
},
expected: "example.com",
},
{
name: "Empty X-Forwarded-Host, fallback to req.Host",
setup: func(req *http.Request) {
req.Header.Set("X-Forwarded-Host", "")
},
expected: "example.com",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com/test", nil)
tt.setup(req)
result := processor.determineHost(req)
if result != tt.expected {
t.Errorf("Expected host '%s', got '%s'", tt.expected, result)
}
})
}
}
// TestBuildFullURL tests URL building
func TestBuildFullURL(t *testing.T) {
tests := []struct {
name string
scheme string
host string
path string
expected string
}{
{
name: "Basic URL construction",
scheme: "https",
host: "example.com",
path: "/callback",
expected: "https://example.com/callback",
},
{
name: "Path without leading slash",
scheme: "http",
host: "test.com",
path: "auth",
expected: "http://test.com/auth",
},
{
name: "Absolute HTTP URL in path",
scheme: "https",
host: "example.com",
path: "http://other.com/callback",
expected: "http://other.com/callback",
},
{
name: "Absolute HTTPS URL in path",
scheme: "http",
host: "example.com",
path: "https://secure.com/auth",
expected: "https://secure.com/auth",
},
{
name: "Root path",
scheme: "https",
host: "example.com:8080",
path: "/",
expected: "https://example.com:8080/",
},
{
name: "Empty path",
scheme: "https",
host: "example.com",
path: "",
expected: "https://example.com/",
},
{
name: "Path with query parameters",
scheme: "https",
host: "example.com",
path: "/callback?state=abc123",
expected: "https://example.com/callback?state=abc123",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := buildFullURL(tt.scheme, tt.host, tt.path)
if result != tt.expected {
t.Errorf("Expected URL '%s', got '%s'", tt.expected, result)
}
})
}
}
// TestRequestContext tests the RequestContext struct
func TestRequestContext(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com/test", nil)
rw := httptest.NewRecorder()
ctx := &RequestContext{
Writer: rw,
Request: req,
RedirectURL: "https://example.com/callback",
Scheme: "https",
Host: "example.com",
}
if ctx.Writer != rw {
t.Error("Expected Writer to be set correctly")
}
if ctx.Request != req {
t.Error("Expected Request to be set correctly")
}
if ctx.RedirectURL != "https://example.com/callback" {
t.Error("Expected RedirectURL to be set correctly")
}
if ctx.Scheme != "https" {
t.Error("Expected Scheme to be set correctly")
}
if ctx.Host != "example.com" {
t.Error("Expected Host to be set correctly")
}
}
-309
View File
@@ -1,309 +0,0 @@
// Package patterns provides cached compiled regex patterns for performance optimization
package patterns
import (
"regexp"
"sync"
)
// RegexCache manages compiled regex patterns with thread-safe access
type RegexCache struct {
patterns map[string]*regexp.Regexp
mu sync.RWMutex
}
// NewRegexCache creates a new regex cache instance
func NewRegexCache() *RegexCache {
return &RegexCache{
patterns: make(map[string]*regexp.Regexp),
}
}
// Get retrieves a compiled regex pattern, compiling and caching it if not present
func (c *RegexCache) Get(pattern string) (*regexp.Regexp, error) {
// First try read lock for existing pattern
c.mu.RLock()
if regex, exists := c.patterns[pattern]; exists {
c.mu.RUnlock()
return regex, nil
}
c.mu.RUnlock()
// Pattern not found, acquire write lock to compile and cache
c.mu.Lock()
defer c.mu.Unlock()
// Double-check in case another goroutine compiled it while we waited
if regex, exists := c.patterns[pattern]; exists {
return regex, nil
}
// Compile the pattern
regex, err := regexp.Compile(pattern)
if err != nil {
return nil, err
}
// Cache the compiled pattern
c.patterns[pattern] = regex
return regex, nil
}
// MustGet is like Get but panics if the pattern cannot be compiled
func (c *RegexCache) MustGet(pattern string) *regexp.Regexp {
regex, err := c.Get(pattern)
if err != nil {
panic("regex compilation failed for pattern '" + pattern + "': " + err.Error())
}
return regex
}
// Precompile compiles and caches multiple patterns at once
func (c *RegexCache) Precompile(patterns []string) error {
c.mu.Lock()
defer c.mu.Unlock()
for _, pattern := range patterns {
if _, exists := c.patterns[pattern]; !exists {
regex, err := regexp.Compile(pattern)
if err != nil {
return err
}
c.patterns[pattern] = regex
}
}
return nil
}
// Size returns the number of cached patterns
func (c *RegexCache) Size() int {
c.mu.RLock()
defer c.mu.RUnlock()
return len(c.patterns)
}
// Clear removes all cached patterns
func (c *RegexCache) Clear() {
c.mu.Lock()
defer c.mu.Unlock()
c.patterns = make(map[string]*regexp.Regexp)
}
// Global regex cache instance
var globalCache = NewRegexCache()
// Common regex patterns used throughout the OIDC implementation
const (
// Email validation pattern (RFC 5322 compliant)
EmailPattern = `^[a-zA-Z0-9.!#$%&'*+/=?^_` + "`" + `{|}~-]+@[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?(?:\.[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)*$`
// Domain validation pattern
DomainPattern = `^[a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?(\.[a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)*$`
// URL validation pattern (http/https)
URLPattern = `^https?://[a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?(\.[a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)*(/.*)?$`
// JWT token pattern (three base64url parts separated by dots)
JWTPattern = `^[A-Za-z0-9_-]+\.[A-Za-z0-9_-]+\.[A-Za-z0-9_-]+$`
// Bearer token pattern (Authorization header)
BearerTokenPattern = `^Bearer\s+([A-Za-z0-9._~+/-]+=*)$`
// Client ID pattern (alphanumeric with common separators)
ClientIDPattern = `^[a-zA-Z0-9._-]+$`
// Scope pattern (space-separated alphanumeric with underscores)
ScopePattern = `^[a-zA-Z0-9_]+(\s+[a-zA-Z0-9_]+)*$`
// Session ID pattern (hexadecimal)
SessionIDPattern = `^[a-fA-F0-9]{32,128}$`
// CSRF token pattern (base64url)
CSRFTokenPattern = `^[A-Za-z0-9_-]+$`
// Nonce pattern (base64url)
NoncePattern = `^[A-Za-z0-9_-]+$`
// Code verifier pattern for PKCE (base64url, 43-128 chars)
CodeVerifierPattern = `^[A-Za-z0-9_-]{43,128}$`
// Authorization code pattern (base64url)
AuthCodePattern = `^[A-Za-z0-9._~+/-]+=*$`
// Redirect URI validation (must be absolute HTTP/HTTPS URL)
RedirectURIPattern = `^https?://[^\s/$.?#].[^\s]*$`
// User-Agent pattern for bot detection
BotUserAgentPattern = `(?i)(bot|crawler|spider|scraper|curl|wget|python|java|go-http)`
// IP address pattern (IPv4)
IPv4Pattern = `^(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$`
// Tenant ID pattern (UUID format for Azure, etc.)
TenantIDPattern = `^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$`
)
// Precompiled common patterns for immediate use
var (
EmailRegex *regexp.Regexp
DomainRegex *regexp.Regexp
URLRegex *regexp.Regexp
JWTRegex *regexp.Regexp
BearerTokenRegex *regexp.Regexp
ClientIDRegex *regexp.Regexp
ScopeRegex *regexp.Regexp
SessionIDRegex *regexp.Regexp
CSRFTokenRegex *regexp.Regexp
NonceRegex *regexp.Regexp
CodeVerifierRegex *regexp.Regexp
AuthCodeRegex *regexp.Regexp
RedirectURIRegex *regexp.Regexp
BotUserAgentRegex *regexp.Regexp
IPv4Regex *regexp.Regexp
TenantIDRegex *regexp.Regexp
)
// Initialize precompiled patterns
func init() {
commonPatterns := []string{
EmailPattern,
DomainPattern,
URLPattern,
JWTPattern,
BearerTokenPattern,
ClientIDPattern,
ScopePattern,
SessionIDPattern,
CSRFTokenPattern,
NoncePattern,
CodeVerifierPattern,
AuthCodePattern,
RedirectURIPattern,
BotUserAgentPattern,
IPv4Pattern,
TenantIDPattern,
}
if err := globalCache.Precompile(commonPatterns); err != nil {
panic("Failed to precompile common regex patterns: " + err.Error())
}
// Assign precompiled patterns to global variables for easy access
EmailRegex = globalCache.MustGet(EmailPattern)
DomainRegex = globalCache.MustGet(DomainPattern)
URLRegex = globalCache.MustGet(URLPattern)
JWTRegex = globalCache.MustGet(JWTPattern)
BearerTokenRegex = globalCache.MustGet(BearerTokenPattern)
ClientIDRegex = globalCache.MustGet(ClientIDPattern)
ScopeRegex = globalCache.MustGet(ScopePattern)
SessionIDRegex = globalCache.MustGet(SessionIDPattern)
CSRFTokenRegex = globalCache.MustGet(CSRFTokenPattern)
NonceRegex = globalCache.MustGet(NoncePattern)
CodeVerifierRegex = globalCache.MustGet(CodeVerifierPattern)
AuthCodeRegex = globalCache.MustGet(AuthCodePattern)
RedirectURIRegex = globalCache.MustGet(RedirectURIPattern)
BotUserAgentRegex = globalCache.MustGet(BotUserAgentPattern)
IPv4Regex = globalCache.MustGet(IPv4Pattern)
TenantIDRegex = globalCache.MustGet(TenantIDPattern)
}
// Global helper functions for common validations
// ValidateEmail checks if an email address is valid
func ValidateEmail(email string) bool {
return EmailRegex.MatchString(email)
}
// ValidateDomain checks if a domain name is valid
func ValidateDomain(domain string) bool {
return DomainRegex.MatchString(domain)
}
// ValidateURL checks if a URL is valid (http/https)
func ValidateURL(url string) bool {
return URLRegex.MatchString(url)
}
// ValidateJWT checks if a token has valid JWT format
func ValidateJWT(token string) bool {
return JWTRegex.MatchString(token)
}
// ExtractBearerToken extracts the token from a Bearer authorization header
func ExtractBearerToken(authHeader string) (string, bool) {
matches := BearerTokenRegex.FindStringSubmatch(authHeader)
if len(matches) == 2 {
return matches[1], true
}
return "", false
}
// ValidateClientID checks if a client ID has valid format
func ValidateClientID(clientID string) bool {
return ClientIDRegex.MatchString(clientID)
}
// ValidateScopes checks if scopes string has valid format
func ValidateScopes(scopes string) bool {
return ScopeRegex.MatchString(scopes)
}
// ValidateSessionID checks if a session ID has valid format
func ValidateSessionID(sessionID string) bool {
return SessionIDRegex.MatchString(sessionID)
}
// ValidateCSRFToken checks if a CSRF token has valid format
func ValidateCSRFToken(token string) bool {
return CSRFTokenRegex.MatchString(token)
}
// ValidateNonce checks if a nonce has valid format
func ValidateNonce(nonce string) bool {
return NonceRegex.MatchString(nonce)
}
// ValidateCodeVerifier checks if a PKCE code verifier has valid format
func ValidateCodeVerifier(verifier string) bool {
return CodeVerifierRegex.MatchString(verifier)
}
// ValidateAuthCode checks if an authorization code has valid format
func ValidateAuthCode(code string) bool {
return AuthCodeRegex.MatchString(code)
}
// ValidateRedirectURI checks if a redirect URI is valid
func ValidateRedirectURI(uri string) bool {
return RedirectURIRegex.MatchString(uri)
}
// IsBotUserAgent checks if a User-Agent suggests an automated client
func IsBotUserAgent(userAgent string) bool {
return BotUserAgentRegex.MatchString(userAgent)
}
// ValidateIPv4 checks if an IP address is valid IPv4
func ValidateIPv4(ip string) bool {
return IPv4Regex.MatchString(ip)
}
// ValidateTenantID checks if a tenant ID has valid UUID format
func ValidateTenantID(tenantID string) bool {
return TenantIDRegex.MatchString(tenantID)
}
// GetGlobalCache returns the global regex cache instance
func GetGlobalCache() *RegexCache {
return globalCache
}
// CompilePattern compiles a pattern using the global cache
func CompilePattern(pattern string) (*regexp.Regexp, error) {
return globalCache.Get(pattern)
}
// MustCompilePattern compiles a pattern using the global cache, panicking on error
func MustCompilePattern(pattern string) *regexp.Regexp {
return globalCache.MustGet(pattern)
}
-484
View File
@@ -1,484 +0,0 @@
package patterns
import (
"regexp"
"sync"
"testing"
)
func TestRegexCache_Get(t *testing.T) {
cache := NewRegexCache()
pattern := `^test\d+$`
// First call should compile and cache
regex1, err := cache.Get(pattern)
if err != nil {
t.Fatalf("Failed to get regex: %v", err)
}
// Second call should return cached version
regex2, err := cache.Get(pattern)
if err != nil {
t.Fatalf("Failed to get cached regex: %v", err)
}
// Should be the same instance
if regex1 != regex2 {
t.Error("Expected same regex instance from cache")
}
// Test the regex works
if !regex1.MatchString("test123") {
t.Error("Regex should match 'test123'")
}
if regex1.MatchString("test") {
t.Error("Regex should not match 'test'")
}
}
func TestRegexCache_ConcurrentAccess(t *testing.T) {
cache := NewRegexCache()
pattern := `^concurrent\d+$`
var wg sync.WaitGroup
results := make([]*regexp.Regexp, 10)
errors := make([]error, 10)
// Launch multiple goroutines to access the same pattern
for i := 0; i < 10; i++ {
wg.Add(1)
go func(index int) {
defer wg.Done()
regex, err := cache.Get(pattern)
results[index] = regex
errors[index] = err
}(i)
}
wg.Wait()
// Check all succeeded
for i, err := range errors {
if err != nil {
t.Fatalf("Goroutine %d failed: %v", i, err)
}
}
// All should return the same instance
first := results[0]
for i, regex := range results[1:] {
if regex != first {
t.Errorf("Goroutine %d got different regex instance", i+1)
}
}
}
func TestRegexCache_InvalidPattern(t *testing.T) {
cache := NewRegexCache()
_, err := cache.Get(`[invalid`)
if err == nil {
t.Error("Expected error for invalid regex pattern")
}
}
func TestRegexCache_Precompile(t *testing.T) {
cache := NewRegexCache()
patterns := []string{
`^test1$`,
`^test2$`,
`^test3$`,
}
err := cache.Precompile(patterns)
if err != nil {
t.Fatalf("Failed to precompile patterns: %v", err)
}
if cache.Size() != 3 {
t.Errorf("Expected cache size 3, got %d", cache.Size())
}
// Should be able to get precompiled patterns without error
for _, pattern := range patterns {
_, err := cache.Get(pattern)
if err != nil {
t.Errorf("Failed to get precompiled pattern %s: %v", pattern, err)
}
}
}
func TestValidationFunctions(t *testing.T) {
tests := []struct {
name string
function func(string) bool
valid []string
invalid []string
}{
{
name: "ValidateEmail",
function: ValidateEmail,
valid: []string{"test@example.com", "user.name@domain.org", "admin+tag@company.co.uk"},
invalid: []string{"invalid-email", "@domain.com", "user@", ""},
},
{
name: "ValidateDomain",
function: ValidateDomain,
valid: []string{"example.com", "sub.domain.org", "test.co.uk"},
invalid: []string{"", "invalid..domain", ".example.com", "domain."},
},
{
name: "ValidateJWT",
function: ValidateJWT,
valid: []string{"eyJ0.eyJ1.sig", "a.b.c"},
invalid: []string{"invalid", "a.b", "a.b.c.d", ""},
},
{
name: "ValidateClientID",
function: ValidateClientID,
valid: []string{"client123", "my-client_id", "123.456"},
invalid: []string{"", "client with spaces", "client@invalid"},
},
{
name: "ValidateURL",
function: ValidateURL,
valid: []string{"https://example.com", "https://sub.domain.org/path", "http://localhost", "https://example.com/path?query=value", "http://192.168.1.1"},
invalid: []string{"", "ftp://example.com", "not-a-url", "https://", "example.com", "http://localhost:8080"},
},
{
name: "ValidateScopes",
function: ValidateScopes,
valid: []string{"openid", "openid profile", "read write admin", "user_info"},
invalid: []string{"", "scope-with-dash", "scope@invalid", "scope with.dot", " "},
},
{
name: "ValidateSessionID",
function: ValidateSessionID,
valid: []string{"a1b2c3d4e5f6789012345678901234567890abcdef", "ABCDEF1234567890abcdef1234567890", "0123456789abcdef0123456789abcdef"},
invalid: []string{"", "too-short", "contains-invalid-chars!", "g123456789abcdef0123456789abcdef", "1234567890abcdef1234567890abcde"},
},
{
name: "ValidateCSRFToken",
function: ValidateCSRFToken,
valid: []string{"abc123", "ABC_123-xyz", "token-value_123", "_valid-token_"},
invalid: []string{"", "token with spaces", "token@invalid", "token.with.dots!", "token/with/slash"},
},
{
name: "ValidateNonce",
function: ValidateNonce,
valid: []string{"abc123", "ABC_123-xyz", "nonce-value_123", "_valid-nonce_"},
invalid: []string{"", "nonce with spaces", "nonce@invalid", "nonce.with.dots!", "nonce/with/slash"},
},
{
name: "ValidateCodeVerifier",
function: ValidateCodeVerifier,
valid: []string{"dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk", "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_-"},
invalid: []string{"", "too-short", "short", "verifier with spaces", "verifier@invalid", "a"},
},
{
name: "ValidateAuthCode",
function: ValidateAuthCode,
valid: []string{"auth_code_123", "ABC.123-xyz/code+value=", "simple-code"},
invalid: []string{"", "code with spaces", "code@invalid"},
},
{
name: "ValidateRedirectURI",
function: ValidateRedirectURI,
valid: []string{"https://example.com/callback", "http://localhost:8080/auth", "https://app.example.org/oauth/callback", "http://127.0.0.1:3000"},
invalid: []string{"", "ftp://example.com", "not-a-url", "example.com/callback", "https://"},
},
{
name: "ValidateIPv4",
function: ValidateIPv4,
valid: []string{"192.168.1.1", "10.0.0.1", "127.0.0.1", "255.255.255.255", "0.0.0.0"},
invalid: []string{"", "256.1.1.1", "192.168.1", "192.168.1.1.1", "not-an-ip"},
},
{
name: "ValidateTenantID",
function: ValidateTenantID,
valid: []string{"12345678-1234-1234-1234-123456789abc", "ABCDEF12-3456-7890-ABCD-EF1234567890"},
invalid: []string{"", "not-a-uuid", "12345678-1234-1234-1234", "12345678-1234-1234-1234-123456789abcd", "123456781234123412341234567890ab"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
for _, valid := range tt.valid {
if !tt.function(valid) {
t.Errorf("%s should be valid: %s", tt.name, valid)
}
}
for _, invalid := range tt.invalid {
if tt.function(invalid) {
t.Errorf("%s should be invalid: %s", tt.name, invalid)
}
}
})
}
}
func TestExtractBearerToken(t *testing.T) {
tests := []struct {
header string
expected string
valid bool
}{
{"Bearer abc123", "abc123", true},
{"Bearer eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9", "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9", true},
{"bearer token123", "", false}, // case sensitive
{"Basic abc123", "", false},
{"Bearer", "", false},
{"", "", false},
}
for _, tt := range tests {
token, valid := ExtractBearerToken(tt.header)
if valid != tt.valid {
t.Errorf("ExtractBearerToken(%q) valid = %v, want %v", tt.header, valid, tt.valid)
}
if token != tt.expected {
t.Errorf("ExtractBearerToken(%q) token = %q, want %q", tt.header, token, tt.expected)
}
}
}
func BenchmarkRegexCache_Get(b *testing.B) {
cache := NewRegexCache()
pattern := `^benchmark\d+$`
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_, err := cache.Get(pattern)
if err != nil {
b.Fatal(err)
}
}
})
}
func BenchmarkRegexCache_Validation(b *testing.B) {
email := "test@example.com"
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
ValidateEmail(email)
}
})
}
func BenchmarkRegex_DirectCompile(b *testing.B) {
pattern := `^benchmark\d+$`
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := regexp.Compile(pattern)
if err != nil {
b.Fatal(err)
}
}
}
func TestRegexCache_Clear(t *testing.T) {
cache := NewRegexCache()
// Add some patterns to the cache
patterns := []string{`^test1$`, `^test2$`, `^test3$`}
for _, pattern := range patterns {
_, err := cache.Get(pattern)
if err != nil {
t.Fatalf("Failed to add pattern %s: %v", pattern, err)
}
}
// Verify cache has patterns
if cache.Size() != 3 {
t.Errorf("Expected cache size 3, got %d", cache.Size())
}
// Clear the cache
cache.Clear()
// Verify cache is empty
if cache.Size() != 0 {
t.Errorf("Expected cache size 0 after clear, got %d", cache.Size())
}
}
func TestIsBotUserAgent(t *testing.T) {
tests := []struct {
userAgent string
isBot bool
}{
{"Mozilla/5.0 (compatible; Googlebot/2.1; +http://www.google.com/bot.html)", true},
{"Mozilla/5.0 (compatible; bingbot/2.0; +http://www.bing.com/bingbot.htm)", true},
{"facebookexternalhit/1.1 (+http://www.facebook.com/externalhit_uatext.php)", false},
{"crawler-bot/1.0", true},
{"spider-agent/2.0", true},
{"curl/7.68.0", true},
{"wget/1.20.3", true},
{"python-requests/2.25.1", true},
{"Go-http-client/1.1", true},
{"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36", false},
{"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36", false},
{"", false},
}
for _, tt := range tests {
t.Run(tt.userAgent, func(t *testing.T) {
result := IsBotUserAgent(tt.userAgent)
if result != tt.isBot {
t.Errorf("IsBotUserAgent(%q) = %v, want %v", tt.userAgent, result, tt.isBot)
}
})
}
}
func TestGetGlobalCache(t *testing.T) {
cache := GetGlobalCache()
if cache == nil {
t.Error("GetGlobalCache() should not return nil")
}
// Should return the same instance
cache2 := GetGlobalCache()
if cache != cache2 {
t.Error("GetGlobalCache() should return the same instance")
}
// Should have precompiled patterns
if cache.Size() == 0 {
t.Error("Global cache should have precompiled patterns")
}
}
func TestCompilePattern(t *testing.T) {
pattern := `^test_compile\d+$`
regex, err := CompilePattern(pattern)
if err != nil {
t.Fatalf("CompilePattern failed: %v", err)
}
if !regex.MatchString("test_compile123") {
t.Error("Compiled pattern should match 'test_compile123'")
}
if regex.MatchString("test_compile") {
t.Error("Compiled pattern should not match 'test_compile'")
}
// Test invalid pattern
_, err = CompilePattern(`[invalid`)
if err == nil {
t.Error("Expected error for invalid pattern")
}
}
func TestMustCompilePattern(t *testing.T) {
pattern := `^test_must_compile\d+$`
regex := MustCompilePattern(pattern)
if regex == nil {
t.Fatal("MustCompilePattern should not return nil")
}
if !regex.MatchString("test_must_compile456") {
t.Error("Compiled pattern should match 'test_must_compile456'")
}
// Test that it panics with invalid pattern
defer func() {
if r := recover(); r == nil {
t.Error("MustCompilePattern should panic with invalid pattern")
}
}()
MustCompilePattern(`[invalid`)
}
func TestAdditionalValidationEdgeCases(t *testing.T) {
// Test edge cases for ValidateURL
t.Run("ValidateURL_EdgeCases", func(t *testing.T) {
edgeCases := []struct {
url string
valid bool
}{
{"https://a.b", true},
{"http://localhost", true},
{"https://example.com/path?query=value#fragment", true},
{"http://192.168.0.1:8080/api", false},
{"https://", false},
{"http://", false},
{"https://example", true},
}
for _, tc := range edgeCases {
result := ValidateURL(tc.url)
if result != tc.valid {
t.Errorf("ValidateURL(%q) = %v, want %v", tc.url, result, tc.valid)
}
}
})
// Test edge cases for ValidateScopes
t.Run("ValidateScopes_EdgeCases", func(t *testing.T) {
edgeCases := []struct {
scopes string
valid bool
}{
{"a", true},
{"a b", true},
{"openid profile email", true},
{"user_profile", true},
{"read_all write_all", true},
{"scope-with-dash", false},
{"scope.with.dot", false},
{"scope@email", false},
{" scope", false},
{"scope ", false},
{"a b", true}, // pattern allows multiple spaces
}
for _, tc := range edgeCases {
result := ValidateScopes(tc.scopes)
if result != tc.valid {
t.Errorf("ValidateScopes(%q) = %v, want %v", tc.scopes, result, tc.valid)
}
}
})
// Test edge cases for ValidateSessionID
t.Run("ValidateSessionID_EdgeCases", func(t *testing.T) {
edgeCases := []struct {
sessionID string
valid bool
}{
{"12345678901234567890123456789012", true}, // 32 chars (min)
{"1234567890123456789012345678901", false}, // 31 chars (too short)
{string(make([]byte, 128)), false}, // 128 non-hex chars
{"abcdef1234567890ABCDEF1234567890" + string(make([]byte, 96)), false}, // 128+ chars with non-hex
}
// Generate valid 128-char hex string (max length)
validLongHex := ""
for i := 0; i < 128; i++ {
validLongHex += "a"
}
edgeCases = append(edgeCases, struct {
sessionID string
valid bool
}{validLongHex, true})
for _, tc := range edgeCases {
result := ValidateSessionID(tc.sessionID)
if result != tc.valid {
t.Errorf("ValidateSessionID(%q) = %v, want %v", tc.sessionID, result, tc.valid)
}
}
})
}
+3 -1
View File
@@ -202,8 +202,10 @@ func (p *TransportPool) createTransport(config TransportConfig) *http.Transport
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
},
// #nosec G402 -- PreferServerCipherSuites is deprecated in Go 1.17+ but setting it is harmless
PreferServerCipherSuites: true,
InsecureSkipVerify: config.InsecureSkipVerify,
// #nosec G402 -- InsecureSkipVerify is configurable for testing/dev environments
InsecureSkipVerify: config.InsecureSkipVerify,
}
return &http.Transport{
+3
View File
@@ -148,6 +148,7 @@ func (cb *CircuitBreaker) allowRequest() bool {
// allowHalfOpenRequest checks if a request is allowed in half-open state
func (cb *CircuitBreaker) allowHalfOpenRequest() bool {
current := atomic.AddInt32(&cb.halfOpenRequests, 1)
// #nosec G115 -- MaxRequests is a small config value that fits in int32
if current <= int32(cb.config.MaxRequests) {
return true
}
@@ -164,6 +165,7 @@ func (cb *CircuitBreaker) recordFailure() {
state := CircuitBreakerState(atomic.LoadInt32(&cb.state))
// #nosec G115 -- FailureThreshold is a small config value that fits in int32
if state == CircuitBreakerClosed && failures >= int32(cb.config.FailureThreshold) {
cb.transitionToOpen()
} else if state == CircuitBreakerHalfOpen {
@@ -180,6 +182,7 @@ func (cb *CircuitBreaker) recordSuccess() {
state := CircuitBreakerState(atomic.LoadInt32(&cb.state))
// #nosec G115 -- SuccessThreshold is a small config value that fits in int32
if state == CircuitBreakerHalfOpen && successes >= int32(cb.config.SuccessThreshold) {
cb.transitionToClosed()
}
+1
View File
@@ -191,6 +191,7 @@ func (re *RetryExecutor) calculateDelay(attempt int) time.Duration {
}
// Add jitter
// #nosec G404 -- math/rand is acceptable for jitter timing, not security-sensitive
if re.config.RandomizationFactor > 0 {
jitter := delay * re.config.RandomizationFactor
minDelay := delay - jitter
+796
View File
@@ -0,0 +1,796 @@
//go:build !yaegi
package recovery
import (
"context"
"errors"
"net/http"
"net/http/httptest"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
)
// =============================================================================
// RETRY CONFIG TESTS
// =============================================================================
func TestDefaultRetryConfig(t *testing.T) {
config := DefaultRetryConfig()
if config.MaxAttempts != 3 {
t.Errorf("Expected MaxAttempts to be 3, got %d", config.MaxAttempts)
}
if config.InitialDelay != 100*time.Millisecond {
t.Errorf("Expected InitialDelay to be 100ms, got %v", config.InitialDelay)
}
if config.MaxDelay != 30*time.Second {
t.Errorf("Expected MaxDelay to be 30s, got %v", config.MaxDelay)
}
if config.Multiplier != 2.0 {
t.Errorf("Expected Multiplier to be 2.0, got %f", config.Multiplier)
}
if config.RandomizationFactor != 0.1 {
t.Errorf("Expected RandomizationFactor to be 0.1, got %f", config.RandomizationFactor)
}
if len(config.RetryableErrors) != 3 {
t.Errorf("Expected 3 retryable errors, got %d", len(config.RetryableErrors))
}
if len(config.RetryableStatusCodes) != 6 {
t.Errorf("Expected 6 retryable status codes, got %d", len(config.RetryableStatusCodes))
}
}
// =============================================================================
// RETRY EXECUTOR TESTS
// =============================================================================
func TestNewRetryExecutor(t *testing.T) {
logger := &mockLogger{}
config := DefaultRetryConfig()
executor := NewRetryExecutor(config, logger)
if executor == nil {
t.Fatal("Expected NewRetryExecutor to return non-nil")
}
if executor.config.MaxAttempts != 3 {
t.Errorf("Expected MaxAttempts to be 3, got %d", executor.config.MaxAttempts)
}
}
func TestNewRetryExecutor_InvalidConfig(t *testing.T) {
logger := &mockLogger{}
// Test with invalid MaxAttempts
config := RetryConfig{
MaxAttempts: 0, // Invalid
Multiplier: 0, // Invalid
}
executor := NewRetryExecutor(config, logger)
if executor.config.MaxAttempts != 1 {
t.Errorf("Expected MaxAttempts to be corrected to 1, got %d", executor.config.MaxAttempts)
}
if executor.config.Multiplier != 1.0 {
t.Errorf("Expected Multiplier to be corrected to 1.0, got %f", executor.config.Multiplier)
}
}
func TestRetryExecutor_ExecuteWithContext_Success(t *testing.T) {
logger := &mockLogger{}
config := DefaultRetryConfig()
executor := NewRetryExecutor(config, logger)
callCount := 0
err := executor.ExecuteWithContext(context.Background(), func() error {
callCount++
return nil
})
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if callCount != 1 {
t.Errorf("Expected function to be called once, got %d", callCount)
}
}
func TestRetryExecutor_ExecuteWithContext_Retry(t *testing.T) {
logger := &mockLogger{}
config := RetryConfig{
MaxAttempts: 3,
InitialDelay: 1 * time.Millisecond,
MaxDelay: 10 * time.Millisecond,
Multiplier: 2.0,
RetryableErrors: []string{"connection refused"},
}
executor := NewRetryExecutor(config, logger)
callCount := 0
err := executor.ExecuteWithContext(context.Background(), func() error {
callCount++
if callCount < 3 {
return errors.New("connection refused")
}
return nil
})
if err != nil {
t.Errorf("Expected success after retries, got %v", err)
}
if callCount != 3 {
t.Errorf("Expected function to be called 3 times, got %d", callCount)
}
}
func TestRetryExecutor_ExecuteWithContext_MaxRetriesExhausted(t *testing.T) {
logger := &mockLogger{}
config := RetryConfig{
MaxAttempts: 3,
InitialDelay: 1 * time.Millisecond,
MaxDelay: 10 * time.Millisecond,
Multiplier: 2.0,
RetryableErrors: []string{"timeout"},
}
executor := NewRetryExecutor(config, logger)
callCount := 0
err := executor.ExecuteWithContext(context.Background(), func() error {
callCount++
return errors.New("timeout")
})
if err == nil {
t.Error("Expected error after max retries exhausted")
}
if !strings.Contains(err.Error(), "all retry attempts failed") {
t.Errorf("Expected 'all retry attempts failed' error, got %v", err)
}
if callCount != 3 {
t.Errorf("Expected function to be called 3 times, got %d", callCount)
}
}
func TestRetryExecutor_ExecuteWithContext_NonRetryableError(t *testing.T) {
logger := &mockLogger{}
config := RetryConfig{
MaxAttempts: 3,
InitialDelay: 1 * time.Millisecond,
RetryableErrors: []string{"timeout"},
}
executor := NewRetryExecutor(config, logger)
callCount := 0
err := executor.ExecuteWithContext(context.Background(), func() error {
callCount++
return errors.New("permanent error")
})
if err == nil {
t.Error("Expected error for non-retryable error")
}
if callCount != 1 {
t.Errorf("Expected function to be called once (non-retryable), got %d", callCount)
}
}
func TestRetryExecutor_ExecuteWithContext_ContextCancelled(t *testing.T) {
logger := &mockLogger{}
config := RetryConfig{
MaxAttempts: 3,
InitialDelay: 1 * time.Second, // Long delay
RetryableErrors: []string{"timeout"},
}
executor := NewRetryExecutor(config, logger)
ctx, cancel := context.WithCancel(context.Background())
callCount := 0
var wg sync.WaitGroup
wg.Add(1)
var execErr error
go func() {
defer wg.Done()
execErr = executor.ExecuteWithContext(ctx, func() error {
callCount++
return errors.New("timeout")
})
}()
// Cancel after a short delay
time.Sleep(50 * time.Millisecond)
cancel()
wg.Wait()
if execErr == nil {
t.Error("Expected error when context is cancelled")
}
}
func TestRetryExecutor_ExecuteWithContext_ContextCancelledBeforeStart(t *testing.T) {
logger := &mockLogger{}
config := DefaultRetryConfig()
executor := NewRetryExecutor(config, logger)
ctx, cancel := context.WithCancel(context.Background())
cancel() // Cancel immediately
err := executor.ExecuteWithContext(ctx, func() error {
return nil
})
if err == nil {
t.Error("Expected error when context is already cancelled")
}
}
func TestRetryExecutor_Execute(t *testing.T) {
logger := &mockLogger{}
config := DefaultRetryConfig()
executor := NewRetryExecutor(config, logger)
called := false
err := executor.Execute(context.Background(), func() error {
called = true
return nil
})
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if !called {
t.Error("Expected function to be called")
}
}
func TestRetryExecutor_isRetryableError(t *testing.T) {
logger := &mockLogger{}
config := RetryConfig{
RetryableErrors: []string{"connection refused", "timeout"},
RetryableStatusCodes: []int{500, 503},
}
executor := NewRetryExecutor(config, logger)
tests := []struct {
name string
err error
expected bool
}{
{"nil error", nil, false},
{"connection refused", errors.New("connection refused"), true},
{"timeout", errors.New("TIMEOUT"), true}, // case insensitive
{"EOF", errors.New("EOF"), false},
{"random error", errors.New("something else"), false},
{"context cancelled", context.Canceled, false},
{"context deadline exceeded", context.DeadlineExceeded, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := executor.isRetryableError(tt.err)
if result != tt.expected {
t.Errorf("Expected %v, got %v", tt.expected, result)
}
})
}
}
func TestRetryExecutor_isRetryableError_HTTPError(t *testing.T) {
logger := &mockLogger{}
config := RetryConfig{
RetryableStatusCodes: []int{500, 503},
}
executor := NewRetryExecutor(config, logger)
tests := []struct {
name string
statusCode int
expected bool
}{
{"500 error", 500, true},
{"503 error", 503, true},
{"502 error (5xx)", 502, true},
{"400 error", 400, false},
{"401 error", 401, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
httpErr := &HTTPError{StatusCode: tt.statusCode}
result := executor.isRetryableError(httpErr)
if result != tt.expected {
t.Errorf("Expected %v, got %v", tt.expected, result)
}
})
}
}
func TestRetryExecutor_isRetryableError_OIDCError(t *testing.T) {
logger := &mockLogger{}
config := DefaultRetryConfig()
executor := NewRetryExecutor(config, logger)
// Test retryable OIDC error
retryableErr := &OIDCError{Code: "temporarily_unavailable", Description: "Server busy"}
if !executor.isRetryableError(retryableErr) {
t.Error("Expected temporarily_unavailable to be retryable")
}
// Test non-retryable OIDC error
nonRetryableErr := &OIDCError{Code: "invalid_token", Description: "Token expired"}
if executor.isRetryableError(nonRetryableErr) {
t.Error("Expected invalid_token to not be retryable")
}
}
func TestRetryExecutor_calculateDelay(t *testing.T) {
logger := &mockLogger{}
config := RetryConfig{
InitialDelay: 100 * time.Millisecond,
MaxDelay: 1 * time.Second,
Multiplier: 2.0,
RandomizationFactor: 0.0, // No jitter for predictable tests
}
executor := NewRetryExecutor(config, logger)
// Test exponential backoff without jitter
delay1 := executor.calculateDelay(1)
if delay1 != 100*time.Millisecond {
t.Errorf("Expected 100ms for attempt 1, got %v", delay1)
}
delay2 := executor.calculateDelay(2)
if delay2 != 200*time.Millisecond {
t.Errorf("Expected 200ms for attempt 2, got %v", delay2)
}
delay3 := executor.calculateDelay(3)
if delay3 != 400*time.Millisecond {
t.Errorf("Expected 400ms for attempt 3, got %v", delay3)
}
// Test max delay cap
delay10 := executor.calculateDelay(10)
if delay10 > 1*time.Second {
t.Errorf("Expected delay capped at 1s, got %v", delay10)
}
}
func TestRetryExecutor_calculateDelay_WithJitter(t *testing.T) {
logger := &mockLogger{}
config := RetryConfig{
InitialDelay: 100 * time.Millisecond,
MaxDelay: 1 * time.Second,
Multiplier: 2.0,
RandomizationFactor: 0.5, // 50% jitter
}
executor := NewRetryExecutor(config, logger)
// With jitter, delay should be within range
baseDelay := 100 * time.Millisecond
minExpected := time.Duration(float64(baseDelay) * 0.5)
maxExpected := time.Duration(float64(baseDelay) * 1.5)
for i := 0; i < 10; i++ {
delay := executor.calculateDelay(1)
if delay < minExpected || delay > maxExpected {
t.Errorf("Delay %v outside expected range [%v, %v]", delay, minExpected, maxExpected)
}
}
}
func TestRetryExecutor_Reset(t *testing.T) {
logger := &mockLogger{}
config := DefaultRetryConfig()
executor := NewRetryExecutor(config, logger)
// Generate some metrics
executor.ExecuteWithContext(context.Background(), func() error { return nil })
executor.ExecuteWithContext(context.Background(), func() error { return nil })
// Reset
executor.Reset()
if atomic.LoadInt64(&executor.totalRetries) != 0 {
t.Error("Expected totalRetries to be 0 after reset")
}
if atomic.LoadInt64(&executor.maxRetriesHit) != 0 {
t.Error("Expected maxRetriesHit to be 0 after reset")
}
if atomic.LoadInt64(&executor.totalRequests) != 0 {
t.Error("Expected totalRequests to be 0 after reset")
}
}
func TestRetryExecutor_IsAvailable(t *testing.T) {
logger := &mockLogger{}
config := DefaultRetryConfig()
executor := NewRetryExecutor(config, logger)
if !executor.IsAvailable() {
t.Error("Expected IsAvailable to return true")
}
}
func TestRetryExecutor_GetMetrics(t *testing.T) {
logger := &mockLogger{}
config := DefaultRetryConfig()
executor := NewRetryExecutor(config, logger)
// Generate some metrics
executor.ExecuteWithContext(context.Background(), func() error { return nil })
metrics := executor.GetMetrics()
// Check required fields
if _, ok := metrics["totalRetries"]; !ok {
t.Error("Expected 'totalRetries' in metrics")
}
if _, ok := metrics["maxRetriesHit"]; !ok {
t.Error("Expected 'maxRetriesHit' in metrics")
}
if _, ok := metrics["config"]; !ok {
t.Error("Expected 'config' in metrics")
}
if _, ok := metrics["lastRetryTime"]; !ok {
t.Error("Expected 'lastRetryTime' in metrics")
}
}
func TestRetryExecutor_GetMetrics_WithRetries(t *testing.T) {
logger := &mockLogger{}
config := RetryConfig{
MaxAttempts: 3,
InitialDelay: 1 * time.Millisecond,
MaxDelay: 10 * time.Millisecond,
Multiplier: 2.0,
RetryableErrors: []string{"retry"},
}
executor := NewRetryExecutor(config, logger)
// Generate retries
callCount := 0
executor.ExecuteWithContext(context.Background(), func() error {
callCount++
if callCount < 2 {
return errors.New("retry me")
}
return nil
})
metrics := executor.GetMetrics()
totalRetries := metrics["totalRetries"].(int64)
if totalRetries < 1 {
t.Errorf("Expected at least 1 retry, got %d", totalRetries)
}
// Check for average retries calculation
if _, ok := metrics["averageRetriesPerRequest"]; !ok {
t.Error("Expected 'averageRetriesPerRequest' in metrics")
}
}
// =============================================================================
// RECOVERY METRICS TESTS
// =============================================================================
func TestNewRecoveryMetrics(t *testing.T) {
rm := NewRecoveryMetrics()
if rm == nil {
t.Fatal("Expected NewRecoveryMetrics to return non-nil")
}
if rm.mechanisms == nil {
t.Error("Expected mechanisms map to be initialized")
}
}
func TestRecoveryMetrics_RegisterMechanism(t *testing.T) {
rm := NewRecoveryMetrics()
logger := &mockLogger{}
cb := NewCircuitBreaker(DefaultCircuitBreakerConfig(), logger)
rm.RegisterMechanism("circuit_breaker", cb)
rm.mu.RLock()
defer rm.mu.RUnlock()
if _, exists := rm.mechanisms["circuit_breaker"]; !exists {
t.Error("Expected mechanism to be registered")
}
}
func TestRecoveryMetrics_UnregisterMechanism(t *testing.T) {
rm := NewRecoveryMetrics()
logger := &mockLogger{}
cb := NewCircuitBreaker(DefaultCircuitBreakerConfig(), logger)
rm.RegisterMechanism("circuit_breaker", cb)
rm.UnregisterMechanism("circuit_breaker")
rm.mu.RLock()
defer rm.mu.RUnlock()
if _, exists := rm.mechanisms["circuit_breaker"]; exists {
t.Error("Expected mechanism to be unregistered")
}
}
func TestRecoveryMetrics_GetAllMetrics(t *testing.T) {
rm := NewRecoveryMetrics()
logger := &mockLogger{}
cb := NewCircuitBreaker(DefaultCircuitBreakerConfig(), logger)
rm.RegisterMechanism("circuit_breaker", cb)
re := NewRetryExecutor(DefaultRetryConfig(), logger)
rm.RegisterMechanism("retry_executor", re)
metrics := rm.GetAllMetrics()
if _, ok := metrics["circuit_breaker"]; !ok {
t.Error("Expected 'circuit_breaker' in metrics")
}
if _, ok := metrics["retry_executor"]; !ok {
t.Error("Expected 'retry_executor' in metrics")
}
if _, ok := metrics["summary"]; !ok {
t.Error("Expected 'summary' in metrics")
}
summary := metrics["summary"].(map[string]interface{})
if summary["totalMechanisms"] != 2 {
t.Errorf("Expected 2 mechanisms, got %v", summary["totalMechanisms"])
}
}
func TestRecoveryMetrics_GetAllMetrics_WithActivity(t *testing.T) {
rm := NewRecoveryMetrics()
logger := &mockLogger{}
cb := NewCircuitBreaker(DefaultCircuitBreakerConfig(), logger)
rm.RegisterMechanism("circuit_breaker", cb)
// Generate some activity
cb.Execute(func() error { return nil })
cb.Execute(func() error { return nil })
metrics := rm.GetAllMetrics()
summary := metrics["summary"].(map[string]interface{})
// Should have success rate calculated
if _, ok := summary["overallSuccessRate"]; !ok {
t.Error("Expected 'overallSuccessRate' in summary")
}
}
func TestRecoveryMetrics_GetMechanismMetrics(t *testing.T) {
rm := NewRecoveryMetrics()
logger := &mockLogger{}
cb := NewCircuitBreaker(DefaultCircuitBreakerConfig(), logger)
rm.RegisterMechanism("circuit_breaker", cb)
// Test existing mechanism
metrics, ok := rm.GetMechanismMetrics("circuit_breaker")
if !ok {
t.Error("Expected to find circuit_breaker mechanism")
}
if metrics == nil {
t.Error("Expected metrics to be non-nil")
}
// Test non-existing mechanism
_, ok = rm.GetMechanismMetrics("non_existent")
if ok {
t.Error("Expected to not find non_existent mechanism")
}
}
func TestRecoveryMetrics_HealthCheck(t *testing.T) {
rm := NewRecoveryMetrics()
logger := &mockLogger{}
// Test with healthy mechanism
cb := NewCircuitBreaker(DefaultCircuitBreakerConfig(), logger)
rm.RegisterMechanism("circuit_breaker", cb)
health := rm.HealthCheck()
if health["status"] != "healthy" {
t.Errorf("Expected status 'healthy', got %v", health["status"])
}
mechanisms := health["mechanisms"].(map[string]interface{})
if mechanisms["circuit_breaker"] != "healthy" {
t.Errorf("Expected circuit_breaker to be 'healthy', got %v", mechanisms["circuit_breaker"])
}
if health["healthy"] != 1 {
t.Errorf("Expected 1 healthy, got %v", health["healthy"])
}
if health["unhealthy"] != 0 {
t.Errorf("Expected 0 unhealthy, got %v", health["unhealthy"])
}
}
func TestRecoveryMetrics_HealthCheck_Degraded(t *testing.T) {
rm := NewRecoveryMetrics()
logger := &mockLogger{}
// Add a healthy mechanism
cb1 := NewCircuitBreaker(DefaultCircuitBreakerConfig(), logger)
rm.RegisterMechanism("healthy_cb", cb1)
// Add an unhealthy mechanism (trip the circuit breaker)
config := CircuitBreakerConfig{
FailureThreshold: 1,
SuccessThreshold: 10,
Timeout: 1 * time.Hour,
MaxRequests: 1,
}
cb2 := NewCircuitBreaker(config, logger)
cb2.Execute(func() error { return errors.New("fail") })
rm.RegisterMechanism("unhealthy_cb", cb2)
health := rm.HealthCheck()
if health["status"] != "degraded" {
t.Errorf("Expected status 'degraded', got %v", health["status"])
}
}
func TestRecoveryMetrics_HealthCheck_Unhealthy(t *testing.T) {
rm := NewRecoveryMetrics()
logger := &mockLogger{}
// Add only an unhealthy mechanism
config := CircuitBreakerConfig{
FailureThreshold: 1,
SuccessThreshold: 10,
Timeout: 1 * time.Hour,
MaxRequests: 1,
}
cb := NewCircuitBreaker(config, logger)
cb.Execute(func() error { return errors.New("fail") })
rm.RegisterMechanism("unhealthy_cb", cb)
health := rm.HealthCheck()
if health["status"] != "unhealthy" {
t.Errorf("Expected status 'unhealthy', got %v", health["status"])
}
}
func TestRecoveryMetrics_HTTPMetricsHandler(t *testing.T) {
rm := NewRecoveryMetrics()
logger := &mockLogger{}
cb := NewCircuitBreaker(DefaultCircuitBreakerConfig(), logger)
rm.RegisterMechanism("circuit_breaker", cb)
handler := rm.HTTPMetricsHandler()
req := httptest.NewRequest("GET", "/metrics", nil)
w := httptest.NewRecorder()
handler(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
contentType := w.Header().Get("Content-Type")
if contentType != "application/json" {
t.Errorf("Expected Content-Type 'application/json', got %s", contentType)
}
body := w.Body.String()
if body == "" {
t.Error("Expected non-empty response body")
}
}
// =============================================================================
// CONCURRENT ACCESS TESTS
// =============================================================================
func TestRecoveryMetrics_ConcurrentAccess(t *testing.T) {
rm := NewRecoveryMetrics()
logger := &mockLogger{}
var wg sync.WaitGroup
// Concurrent registrations
for i := 0; i < 10; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
cb := NewCircuitBreaker(DefaultCircuitBreakerConfig(), logger)
rm.RegisterMechanism(string(rune('a'+idx)), cb)
}(i)
}
// Concurrent reads
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
_ = rm.GetAllMetrics()
_ = rm.HealthCheck()
}()
}
wg.Wait()
// Verify no race conditions
health := rm.HealthCheck()
if health == nil {
t.Error("Expected HealthCheck to return non-nil after concurrent access")
}
}
func TestRetryExecutor_ConcurrentExecution(t *testing.T) {
logger := &mockLogger{}
config := RetryConfig{
MaxAttempts: 3,
InitialDelay: 1 * time.Millisecond,
MaxDelay: 10 * time.Millisecond,
Multiplier: 2.0,
RetryableErrors: []string{"retry"},
}
executor := NewRetryExecutor(config, logger)
var wg sync.WaitGroup
successCount := int64(0)
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
err := executor.ExecuteWithContext(context.Background(), func() error {
return nil
})
if err == nil {
atomic.AddInt64(&successCount, 1)
}
}()
}
wg.Wait()
if successCount != 100 {
t.Errorf("Expected 100 successes, got %d", successCount)
}
}
-403
View File
@@ -1,403 +0,0 @@
// Package security provides security-related middleware and utilities
package security
import (
"net/http"
"strings"
"time"
)
// SecurityHeadersConfig configures security headers
type SecurityHeadersConfig struct {
// Content Security Policy
ContentSecurityPolicy string
// HSTS settings
StrictTransportSecurity string
StrictTransportSecurityMaxAge int // seconds
StrictTransportSecuritySubdomains bool
StrictTransportSecurityPreload bool
// Frame options
FrameOptions string // DENY, SAMEORIGIN, or ALLOW-FROM uri
// Content type options
ContentTypeOptions string // nosniff
// XSS protection
XSSProtection string // 1; mode=block
// Referrer policy
ReferrerPolicy string
// Permissions policy
PermissionsPolicy string
// Cross-origin settings
CrossOriginEmbedderPolicy string
CrossOriginOpenerPolicy string
CrossOriginResourcePolicy string
// CORS settings
CORSEnabled bool
CORSAllowedOrigins []string
CORSAllowedMethods []string
CORSAllowedHeaders []string
CORSAllowCredentials bool
CORSMaxAge int // seconds
// Custom headers
CustomHeaders map[string]string
// Security features
DisableServerHeader bool
DisablePoweredByHeader bool
// Development mode (less strict for local development)
DevelopmentMode bool
}
// DefaultSecurityConfig returns a secure default configuration
func DefaultSecurityConfig() *SecurityHeadersConfig {
return &SecurityHeadersConfig{
ContentSecurityPolicy: "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'; img-src 'self' data: https:; font-src 'self' data:; connect-src 'self'; frame-ancestors 'none';",
StrictTransportSecurityMaxAge: 31536000, // 1 year
StrictTransportSecuritySubdomains: true,
StrictTransportSecurityPreload: true,
FrameOptions: "DENY",
ContentTypeOptions: "nosniff",
XSSProtection: "1; mode=block",
ReferrerPolicy: "strict-origin-when-cross-origin",
PermissionsPolicy: "geolocation=(), microphone=(), camera=(), payment=(), usb=(), magnetometer=(), gyroscope=(), speaker=()",
CrossOriginEmbedderPolicy: "require-corp",
CrossOriginOpenerPolicy: "same-origin",
CrossOriginResourcePolicy: "same-origin",
CORSEnabled: false,
CORSAllowedMethods: []string{"GET", "POST", "OPTIONS"},
CORSAllowedHeaders: []string{"Authorization", "Content-Type", "X-Requested-With"},
CORSMaxAge: 86400, // 24 hours
DisableServerHeader: true,
DisablePoweredByHeader: true,
DevelopmentMode: false,
}
}
// DevelopmentSecurityConfig returns a configuration suitable for development
func DevelopmentSecurityConfig() *SecurityHeadersConfig {
config := DefaultSecurityConfig()
// Relax CSP for development
config.ContentSecurityPolicy = "default-src 'self' 'unsafe-inline' 'unsafe-eval'; img-src 'self' data: https: http:; connect-src 'self' ws: wss:;"
// Allow framing for development tools
config.FrameOptions = "SAMEORIGIN"
// Enable CORS for local development
config.CORSEnabled = true
config.CORSAllowedOrigins = []string{"http://localhost:*", "http://127.0.0.1:*"}
config.CORSAllowCredentials = true
// Relax cross-origin policies
config.CrossOriginEmbedderPolicy = ""
config.CrossOriginOpenerPolicy = "unsafe-none"
config.CrossOriginResourcePolicy = "cross-origin"
config.DevelopmentMode = true
return config
}
// SecurityHeadersMiddleware applies security headers to HTTP responses
type SecurityHeadersMiddleware struct {
config *SecurityHeadersConfig
}
// NewSecurityHeadersMiddleware creates a new security headers middleware
func NewSecurityHeadersMiddleware(config *SecurityHeadersConfig) *SecurityHeadersMiddleware {
if config == nil {
config = DefaultSecurityConfig()
}
return &SecurityHeadersMiddleware{
config: config,
}
}
// Apply applies security headers to the response
func (m *SecurityHeadersMiddleware) Apply(rw http.ResponseWriter, req *http.Request) {
headers := rw.Header()
// Content Security Policy
if m.config.ContentSecurityPolicy != "" {
headers.Set("Content-Security-Policy", m.config.ContentSecurityPolicy)
}
// HSTS (only for HTTPS)
if req.TLS != nil || req.Header.Get("X-Forwarded-Proto") == "https" {
hstsValue := m.buildHSTSHeader()
if hstsValue != "" {
headers.Set("Strict-Transport-Security", hstsValue)
}
}
// Frame options
if m.config.FrameOptions != "" {
headers.Set("X-Frame-Options", m.config.FrameOptions)
}
// Content type options
if m.config.ContentTypeOptions != "" {
headers.Set("X-Content-Type-Options", m.config.ContentTypeOptions)
}
// XSS protection
if m.config.XSSProtection != "" {
headers.Set("X-XSS-Protection", m.config.XSSProtection)
}
// Referrer policy
if m.config.ReferrerPolicy != "" {
headers.Set("Referrer-Policy", m.config.ReferrerPolicy)
}
// Permissions policy
if m.config.PermissionsPolicy != "" {
headers.Set("Permissions-Policy", m.config.PermissionsPolicy)
}
// Cross-origin policies
if m.config.CrossOriginEmbedderPolicy != "" {
headers.Set("Cross-Origin-Embedder-Policy", m.config.CrossOriginEmbedderPolicy)
}
if m.config.CrossOriginOpenerPolicy != "" {
headers.Set("Cross-Origin-Opener-Policy", m.config.CrossOriginOpenerPolicy)
}
if m.config.CrossOriginResourcePolicy != "" {
headers.Set("Cross-Origin-Resource-Policy", m.config.CrossOriginResourcePolicy)
}
// CORS headers
if m.config.CORSEnabled {
m.applyCORSHeaders(rw, req)
}
// Custom headers
for name, value := range m.config.CustomHeaders {
headers.Set(name, value)
}
// Remove server identification headers
if m.config.DisableServerHeader {
headers.Del("Server")
}
if m.config.DisablePoweredByHeader {
headers.Del("X-Powered-By")
}
// Add security timestamp for debugging
if m.config.DevelopmentMode {
headers.Set("X-Security-Headers-Applied", time.Now().UTC().Format(time.RFC3339))
}
}
// buildHSTSHeader constructs the HSTS header value
func (m *SecurityHeadersMiddleware) buildHSTSHeader() string {
if m.config.StrictTransportSecurityMaxAge <= 0 {
return ""
}
parts := []string{
"max-age=" + string(rune(m.config.StrictTransportSecurityMaxAge)),
}
if m.config.StrictTransportSecuritySubdomains {
parts = append(parts, "includeSubDomains")
}
if m.config.StrictTransportSecurityPreload {
parts = append(parts, "preload")
}
return strings.Join(parts, "; ")
}
// applyCORSHeaders applies CORS headers based on the request
func (m *SecurityHeadersMiddleware) applyCORSHeaders(rw http.ResponseWriter, req *http.Request) {
headers := rw.Header()
origin := req.Header.Get("Origin")
// Check if origin is allowed
if origin != "" && m.isOriginAllowed(origin) {
headers.Set("Access-Control-Allow-Origin", origin)
} else if len(m.config.CORSAllowedOrigins) == 1 && m.config.CORSAllowedOrigins[0] == "*" {
headers.Set("Access-Control-Allow-Origin", "*")
}
// Set other CORS headers
if len(m.config.CORSAllowedMethods) > 0 {
headers.Set("Access-Control-Allow-Methods", strings.Join(m.config.CORSAllowedMethods, ", "))
}
if len(m.config.CORSAllowedHeaders) > 0 {
headers.Set("Access-Control-Allow-Headers", strings.Join(m.config.CORSAllowedHeaders, ", "))
}
if m.config.CORSAllowCredentials {
headers.Set("Access-Control-Allow-Credentials", "true")
}
if m.config.CORSMaxAge > 0 {
headers.Set("Access-Control-Max-Age", string(rune(m.config.CORSMaxAge)))
}
// Handle preflight requests
if req.Method == "OPTIONS" {
headers.Set("Access-Control-Allow-Methods", strings.Join(m.config.CORSAllowedMethods, ", "))
headers.Set("Access-Control-Allow-Headers", strings.Join(m.config.CORSAllowedHeaders, ", "))
rw.WriteHeader(http.StatusOK)
}
}
// isOriginAllowed checks if the origin is in the allowed list
func (m *SecurityHeadersMiddleware) isOriginAllowed(origin string) bool {
for _, allowed := range m.config.CORSAllowedOrigins {
if m.matchOrigin(origin, allowed) {
return true
}
}
return false
}
// matchOrigin checks if an origin matches an allowed pattern
func (m *SecurityHeadersMiddleware) matchOrigin(origin, pattern string) bool {
// Exact match
if origin == pattern {
return true
}
// Wildcard subdomain match (e.g., "https://*.example.com")
if strings.Contains(pattern, "*") {
// Simple wildcard matching for subdomains
if strings.HasPrefix(pattern, "https://*.") {
domain := strings.TrimPrefix(pattern, "https://*.")
if strings.HasSuffix(origin, "."+domain) || origin == "https://"+domain {
return true
}
}
if strings.HasPrefix(pattern, "http://*.") {
domain := strings.TrimPrefix(pattern, "http://*.")
if strings.HasSuffix(origin, "."+domain) || origin == "http://"+domain {
return true
}
}
}
// Port wildcard match (e.g., "http://localhost:*")
if strings.HasSuffix(pattern, ":*") {
prefix := strings.TrimSuffix(pattern, ":*")
if strings.HasPrefix(origin, prefix+":") {
return true
}
}
return false
}
// Wrap wraps an HTTP handler with security headers
func (m *SecurityHeadersMiddleware) Wrap(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
m.Apply(rw, req)
next.ServeHTTP(rw, req)
})
}
// SecurityHeadersHandler is a convenience function that creates and applies security headers
func SecurityHeadersHandler(config *SecurityHeadersConfig) func(http.ResponseWriter, *http.Request) {
middleware := NewSecurityHeadersMiddleware(config)
return middleware.Apply
}
// Common security header presets
// StrictSecurityConfig returns a very strict security configuration
func StrictSecurityConfig() *SecurityHeadersConfig {
config := DefaultSecurityConfig()
// Very strict CSP
config.ContentSecurityPolicy = "default-src 'none'; script-src 'self'; style-src 'self'; img-src 'self'; font-src 'self'; connect-src 'self'; frame-ancestors 'none'; base-uri 'self'; form-action 'self';"
// Stricter frame options
config.FrameOptions = "DENY"
// Disable CORS entirely
config.CORSEnabled = false
// Very strict cross-origin policies
config.CrossOriginEmbedderPolicy = "require-corp"
config.CrossOriginOpenerPolicy = "same-origin"
config.CrossOriginResourcePolicy = "same-site"
return config
}
// APISecurityConfig returns a configuration suitable for APIs
func APISecurityConfig() *SecurityHeadersConfig {
config := DefaultSecurityConfig()
// API-friendly CSP
config.ContentSecurityPolicy = "default-src 'none'; frame-ancestors 'none';"
// Enable CORS for APIs
config.CORSEnabled = true
config.CORSAllowedMethods = []string{"GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"}
config.CORSAllowedHeaders = []string{"Authorization", "Content-Type", "X-Requested-With", "X-API-Key"}
// API-appropriate policies
config.CrossOriginResourcePolicy = "cross-origin"
return config
}
// ValidateConfig validates the security configuration
func (c *SecurityHeadersConfig) Validate() error {
// Validate HSTS max age
if c.StrictTransportSecurityMaxAge < 0 {
c.StrictTransportSecurityMaxAge = 0
}
// Validate CORS max age
if c.CORSMaxAge < 0 {
c.CORSMaxAge = 0
}
// Validate frame options
validFrameOptions := []string{"DENY", "SAMEORIGIN", ""}
isValidFrameOption := false
for _, valid := range validFrameOptions {
if c.FrameOptions == valid || strings.HasPrefix(c.FrameOptions, "ALLOW-FROM ") {
isValidFrameOption = true
break
}
}
if !isValidFrameOption {
c.FrameOptions = "DENY"
}
return nil
}
// ApplyToResponseWriter is a helper function to quickly apply security headers
func ApplySecurityHeaders(rw http.ResponseWriter, req *http.Request, config *SecurityHeadersConfig) {
middleware := NewSecurityHeadersMiddleware(config)
middleware.Apply(rw, req)
}
-350
View File
@@ -1,350 +0,0 @@
package security
import (
"crypto/tls"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestDefaultSecurityConfig(t *testing.T) {
config := DefaultSecurityConfig()
if config.ContentSecurityPolicy == "" {
t.Error("Expected default CSP to be set")
}
if config.FrameOptions != "DENY" {
t.Errorf("Expected frame options to be DENY, got %s", config.FrameOptions)
}
if !config.DisableServerHeader {
t.Error("Expected server header to be disabled by default")
}
}
func TestSecurityHeadersMiddleware_Apply(t *testing.T) {
config := DefaultSecurityConfig()
middleware := NewSecurityHeadersMiddleware(config)
// Create a mock request (HTTPS)
req := httptest.NewRequest("GET", "https://example.com/test", nil)
req.TLS = &tls.ConnectionState{} // Mock TLS
// Create a response recorder
rr := httptest.NewRecorder()
// Apply security headers
middleware.Apply(rr, req)
headers := rr.Header()
// Check that security headers are set
if headers.Get("Content-Security-Policy") == "" {
t.Error("Expected CSP header to be set")
}
if headers.Get("X-Frame-Options") != "DENY" {
t.Errorf("Expected X-Frame-Options to be DENY, got %s", headers.Get("X-Frame-Options"))
}
if headers.Get("X-Content-Type-Options") != "nosniff" {
t.Errorf("Expected X-Content-Type-Options to be nosniff, got %s", headers.Get("X-Content-Type-Options"))
}
if headers.Get("X-XSS-Protection") != "1; mode=block" {
t.Errorf("Expected X-XSS-Protection to be '1; mode=block', got %s", headers.Get("X-XSS-Protection"))
}
// Check HSTS for HTTPS requests
hsts := headers.Get("Strict-Transport-Security")
if hsts == "" {
t.Error("Expected HSTS header for HTTPS request")
}
if !strings.Contains(hsts, "max-age=") {
t.Error("Expected HSTS header to contain max-age")
}
}
func TestSecurityHeadersMiddleware_HTTPSOnly(t *testing.T) {
config := DefaultSecurityConfig()
middleware := NewSecurityHeadersMiddleware(config)
// Test HTTP request (no HSTS)
req := httptest.NewRequest("GET", "http://example.com/test", nil)
rr := httptest.NewRecorder()
middleware.Apply(rr, req)
if rr.Header().Get("Strict-Transport-Security") != "" {
t.Error("Expected no HSTS header for HTTP request")
}
// Test HTTPS request (with HSTS)
req = httptest.NewRequest("GET", "https://example.com/test", nil)
req.TLS = &tls.ConnectionState{}
rr = httptest.NewRecorder()
middleware.Apply(rr, req)
if rr.Header().Get("Strict-Transport-Security") == "" {
t.Error("Expected HSTS header for HTTPS request")
}
}
func TestCORSHeaders(t *testing.T) {
config := DefaultSecurityConfig()
config.CORSEnabled = true
config.CORSAllowedOrigins = []string{"https://example.com", "https://*.test.com"}
config.CORSAllowCredentials = true
middleware := NewSecurityHeadersMiddleware(config)
tests := []struct {
name string
origin string
expectedOrigin string
}{
{
name: "exact match",
origin: "https://example.com",
expectedOrigin: "https://example.com",
},
{
name: "wildcard subdomain match",
origin: "https://api.test.com",
expectedOrigin: "https://api.test.com",
},
{
name: "no match",
origin: "https://malicious.com",
expectedOrigin: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "https://example.com/test", nil)
if tt.origin != "" {
req.Header.Set("Origin", tt.origin)
}
rr := httptest.NewRecorder()
middleware.Apply(rr, req)
actualOrigin := rr.Header().Get("Access-Control-Allow-Origin")
if actualOrigin != tt.expectedOrigin {
t.Errorf("Expected origin %s, got %s", tt.expectedOrigin, actualOrigin)
}
if tt.expectedOrigin != "" {
// Should have credentials header
if rr.Header().Get("Access-Control-Allow-Credentials") != "true" {
t.Error("Expected credentials header for allowed origin")
}
}
})
}
}
func TestCORSPreflight(t *testing.T) {
config := DefaultSecurityConfig()
config.CORSEnabled = true
config.CORSAllowedOrigins = []string{"*"}
config.CORSAllowedMethods = []string{"GET", "POST", "OPTIONS"}
middleware := NewSecurityHeadersMiddleware(config)
req := httptest.NewRequest("OPTIONS", "https://example.com/test", nil)
req.Header.Set("Origin", "https://other.com")
rr := httptest.NewRecorder()
middleware.Apply(rr, req)
if rr.Header().Get("Access-Control-Allow-Origin") != "*" {
t.Error("Expected wildcard origin for preflight request")
}
if rr.Header().Get("Access-Control-Allow-Methods") == "" {
t.Error("Expected methods header for preflight request")
}
if rr.Code != http.StatusOK {
t.Errorf("Expected status 200 for preflight, got %d", rr.Code)
}
}
func TestOriginMatching(t *testing.T) {
config := &SecurityHeadersConfig{
CORSEnabled: true,
CORSAllowedOrigins: []string{
"https://example.com",
"https://*.example.com",
"http://localhost:*",
},
}
middleware := NewSecurityHeadersMiddleware(config)
tests := []struct {
origin string
expected bool
}{
{"https://example.com", true},
{"https://api.example.com", true},
{"https://sub.api.example.com", true},
{"http://localhost:3000", true},
{"http://localhost:8080", true},
{"https://malicious.com", false},
{"http://example.com", false}, // Different scheme
{"https://example.com.evil.com", false}, // Domain suffix attack
}
for _, tt := range tests {
t.Run(tt.origin, func(t *testing.T) {
result := middleware.isOriginAllowed(tt.origin)
if result != tt.expected {
t.Errorf("Origin %s: expected %v, got %v", tt.origin, tt.expected, result)
}
})
}
}
func TestDevelopmentMode(t *testing.T) {
config := DevelopmentSecurityConfig()
if !config.DevelopmentMode {
t.Error("Expected development mode to be enabled")
}
if !config.CORSEnabled {
t.Error("Expected CORS to be enabled in development mode")
}
if config.FrameOptions != "SAMEORIGIN" {
t.Errorf("Expected frame options to be SAMEORIGIN in dev mode, got %s", config.FrameOptions)
}
// Should be less strict CSP
if strings.Contains(config.ContentSecurityPolicy, "'none'") {
t.Error("Expected less strict CSP in development mode")
}
}
func TestStrictSecurityConfig(t *testing.T) {
config := StrictSecurityConfig()
if !strings.Contains(config.ContentSecurityPolicy, "'none'") {
t.Error("Expected very strict CSP with 'none' defaults")
}
if config.CORSEnabled {
t.Error("Expected CORS to be disabled in strict mode")
}
if config.FrameOptions != "DENY" {
t.Error("Expected frame options to be DENY in strict mode")
}
}
func TestAPISecurityConfig(t *testing.T) {
config := APISecurityConfig()
if !config.CORSEnabled {
t.Error("Expected CORS to be enabled for API config")
}
methods := config.CORSAllowedMethods
expectedMethods := []string{"GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"}
for _, method := range expectedMethods {
found := false
for _, allowed := range methods {
if allowed == method {
found = true
break
}
}
if !found {
t.Errorf("Expected method %s to be allowed in API config", method)
}
}
}
func TestMiddlewareWrap(t *testing.T) {
config := DefaultSecurityConfig()
middleware := NewSecurityHeadersMiddleware(config)
// Create a simple handler
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
})
// Wrap with security middleware
wrappedHandler := middleware.Wrap(handler)
req := httptest.NewRequest("GET", "https://example.com/test", nil)
req.TLS = &tls.ConnectionState{}
rr := httptest.NewRecorder()
wrappedHandler.ServeHTTP(rr, req)
// Check response
if rr.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", rr.Code)
}
if rr.Body.String() != "OK" {
t.Errorf("Expected body 'OK', got %s", rr.Body.String())
}
// Check security headers were applied
if rr.Header().Get("X-Frame-Options") == "" {
t.Error("Expected security headers to be applied by wrapper")
}
}
func TestConfigValidation(t *testing.T) {
config := &SecurityHeadersConfig{
StrictTransportSecurityMaxAge: -1,
CORSMaxAge: -1,
FrameOptions: "INVALID",
}
err := config.Validate()
if err != nil {
t.Errorf("Unexpected validation error: %v", err)
}
// Should fix invalid values
if config.StrictTransportSecurityMaxAge != 0 {
t.Error("Expected negative HSTS max age to be reset to 0")
}
if config.CORSMaxAge != 0 {
t.Error("Expected negative CORS max age to be reset to 0")
}
if config.FrameOptions != "DENY" {
t.Error("Expected invalid frame options to be reset to DENY")
}
}
func BenchmarkSecurityHeadersApply(b *testing.B) {
config := DefaultSecurityConfig()
middleware := NewSecurityHeadersMiddleware(config)
req := httptest.NewRequest("GET", "https://example.com/test", nil)
req.TLS = &tls.ConnectionState{}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
rr := httptest.NewRecorder()
middleware.Apply(rr, req)
}
})
}
-394
View File
@@ -1,394 +0,0 @@
// Package singleton provides a centralized, thread-safe singleton management system
// that consolidates all singleton patterns used throughout the application.
// It ensures proper initialization, lifecycle management, and graceful shutdown.
package singleton
import (
"context"
"fmt"
"sync"
"sync/atomic"
)
// Registry is the centralized singleton registry that manages all singleton instances
// in the application. It provides thread-safe initialization, access, and cleanup.
type Registry struct {
mu sync.RWMutex
instances map[string]*Instance
groups map[string]*Group
shutdown int32
wg sync.WaitGroup
}
// Instance represents a singleton instance with lifecycle management
type Instance struct {
name string
value interface{}
initializer func() interface{}
finalizer func(interface{})
once sync.Once
refCount int32
}
// Group represents a group of related singletons
type Group struct {
name string
instances map[string]*Instance
mu sync.RWMutex
}
var (
// globalRegistry is the singleton registry instance
globalRegistry *Registry
// registryOnce ensures single initialization
registryOnce sync.Once
)
// Get returns the global singleton registry
func Get() *Registry {
registryOnce.Do(func() {
globalRegistry = &Registry{
instances: make(map[string]*Instance),
groups: make(map[string]*Group),
}
})
return globalRegistry
}
// Register registers a new singleton with its initializer and optional finalizer
func (r *Registry) Register(name string, initializer func() interface{}, finalizer func(interface{})) error {
if atomic.LoadInt32(&r.shutdown) == 1 {
return fmt.Errorf("registry is shutting down")
}
r.mu.Lock()
defer r.mu.Unlock()
if _, exists := r.instances[name]; exists {
return fmt.Errorf("singleton %s already registered", name)
}
r.instances[name] = &Instance{
name: name,
initializer: initializer,
finalizer: finalizer,
}
return nil
}
// GetInstance retrieves or initializes a singleton instance
func (r *Registry) GetInstance(name string) (interface{}, error) {
if atomic.LoadInt32(&r.shutdown) == 1 {
return nil, fmt.Errorf("registry is shutting down")
}
r.mu.RLock()
instance, exists := r.instances[name]
r.mu.RUnlock()
if !exists {
return nil, fmt.Errorf("singleton %s not registered", name)
}
// Initialize the singleton if needed
instance.once.Do(func() {
if instance.initializer != nil {
instance.value = instance.initializer()
atomic.AddInt32(&instance.refCount, 1)
}
})
return instance.value, nil
}
// MustGet retrieves a singleton instance, panicking if not found
func (r *Registry) MustGet(name string) interface{} {
val, err := r.GetInstance(name)
if err != nil {
panic(fmt.Sprintf("singleton %s: %v", name, err))
}
return val
}
// RegisterGroup creates a new singleton group
func (r *Registry) RegisterGroup(name string) error {
r.mu.Lock()
defer r.mu.Unlock()
if _, exists := r.groups[name]; exists {
return fmt.Errorf("group %s already exists", name)
}
r.groups[name] = &Group{
name: name,
instances: make(map[string]*Instance),
}
return nil
}
// AddToGroup adds a singleton to a group
func (r *Registry) AddToGroup(groupName, singletonName string) error {
r.mu.Lock()
defer r.mu.Unlock()
group, groupExists := r.groups[groupName]
if !groupExists {
return fmt.Errorf("group %s does not exist", groupName)
}
instance, instanceExists := r.instances[singletonName]
if !instanceExists {
return fmt.Errorf("singleton %s not registered", singletonName)
}
group.mu.Lock()
defer group.mu.Unlock()
group.instances[singletonName] = instance
return nil
}
// GetGroup retrieves all singletons in a group
func (r *Registry) GetGroup(name string) (map[string]interface{}, error) {
r.mu.RLock()
group, exists := r.groups[name]
r.mu.RUnlock()
if !exists {
return nil, fmt.Errorf("group %s does not exist", name)
}
group.mu.RLock()
defer group.mu.RUnlock()
result := make(map[string]interface{})
for name, instance := range group.instances {
if instance.value != nil {
result[name] = instance.value
}
}
return result, nil
}
// AddReference increments the reference count for a singleton
func (r *Registry) AddReference(name string) error {
r.mu.RLock()
instance, exists := r.instances[name]
r.mu.RUnlock()
if !exists {
return fmt.Errorf("singleton %s not registered", name)
}
atomic.AddInt32(&instance.refCount, 1)
return nil
}
// ReleaseReference decrements the reference count for a singleton
func (r *Registry) ReleaseReference(name string) error {
r.mu.RLock()
instance, exists := r.instances[name]
r.mu.RUnlock()
if !exists {
return fmt.Errorf("singleton %s not registered", name)
}
count := atomic.AddInt32(&instance.refCount, -1)
if count == 0 && instance.finalizer != nil && instance.value != nil {
// Run finalizer when last reference is released
go instance.finalizer(instance.value)
}
return nil
}
// GetReferenceCount returns the reference count for a singleton
func (r *Registry) GetReferenceCount(name string) (int32, error) {
r.mu.RLock()
instance, exists := r.instances[name]
r.mu.RUnlock()
if !exists {
return 0, fmt.Errorf("singleton %s not registered", name)
}
return atomic.LoadInt32(&instance.refCount), nil
}
// Shutdown gracefully shuts down all singletons
func (r *Registry) Shutdown(ctx context.Context) error {
if !atomic.CompareAndSwapInt32(&r.shutdown, 0, 1) {
return fmt.Errorf("registry already shutting down")
}
r.mu.Lock()
defer r.mu.Unlock()
// Create error channel for collecting shutdown errors
errChan := make(chan error, len(r.instances))
// Run finalizers for all initialized singletons
for name, instance := range r.instances {
if instance.value != nil && instance.finalizer != nil {
r.wg.Add(1)
go func(n string, i *Instance) {
defer r.wg.Done()
// Run finalizer with panic recovery
func() {
defer func() {
if r := recover(); r != nil {
errChan <- fmt.Errorf("finalizer for %s panicked: %v", n, r)
}
}()
i.finalizer(i.value)
}()
}(name, instance)
}
}
// Wait for all finalizers to complete or timeout
done := make(chan struct{})
go func() {
r.wg.Wait()
close(done)
}()
select {
case <-done:
// All finalizers completed
case <-ctx.Done():
return fmt.Errorf("shutdown timeout: %w", ctx.Err())
}
// Collect any errors
close(errChan)
var errs []error
for err := range errChan {
if err != nil {
errs = append(errs, err)
}
}
// Clear all instances
r.instances = make(map[string]*Instance)
r.groups = make(map[string]*Group)
if len(errs) > 0 {
return fmt.Errorf("shutdown errors: %v", errs)
}
return nil
}
// Reset resets the registry (mainly for testing)
func (r *Registry) Reset() {
r.mu.Lock()
defer r.mu.Unlock()
r.instances = make(map[string]*Instance)
r.groups = make(map[string]*Group)
atomic.StoreInt32(&r.shutdown, 0)
}
// Stats returns statistics about the registry
type Stats struct {
TotalRegistered int
TotalInitialized int
TotalGroups int
TotalReferences int32
}
// GetStats returns current registry statistics
func (r *Registry) GetStats() Stats {
r.mu.RLock()
defer r.mu.RUnlock()
stats := Stats{
TotalRegistered: len(r.instances),
TotalGroups: len(r.groups),
}
for _, instance := range r.instances {
if instance.value != nil {
stats.TotalInitialized++
}
stats.TotalReferences += atomic.LoadInt32(&instance.refCount)
}
return stats
}
// Builder provides a fluent interface for registering singletons
type Builder struct {
registry *Registry
name string
initializer func() interface{}
finalizer func(interface{})
group string
}
// NewBuilder creates a new singleton builder
func NewBuilder(name string) *Builder {
return &Builder{
registry: Get(),
name: name,
}
}
// WithInitializer sets the initializer function
func (b *Builder) WithInitializer(init func() interface{}) *Builder {
b.initializer = init
return b
}
// WithFinalizer sets the finalizer function
func (b *Builder) WithFinalizer(final func(interface{})) *Builder {
b.finalizer = final
return b
}
// InGroup adds the singleton to a group
func (b *Builder) InGroup(group string) *Builder {
b.group = group
return b
}
// Register registers the singleton with the configured options
func (b *Builder) Register() error {
if err := b.registry.Register(b.name, b.initializer, b.finalizer); err != nil {
return err
}
if b.group != "" {
// Ensure group exists
if err := b.registry.RegisterGroup(b.group); err != nil {
// Group might already exist, which is ok
if !contains(err.Error(), "already exists") {
return err
}
}
return b.registry.AddToGroup(b.group, b.name)
}
return nil
}
// Helper function to check if string contains substring
func contains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(s) > len(substr) && containsHelper(s, substr))
}
func containsHelper(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}
-970
View File
@@ -1,970 +0,0 @@
package singleton
import (
"context"
"fmt"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
)
// TestGet_Singleton tests that Get() returns the same instance
func TestGet_Singleton(t *testing.T) {
registry1 := Get()
registry2 := Get()
if registry1 != registry2 {
t.Error("Get() should return the same instance (singleton)")
}
if registry1 == nil {
t.Error("Get() should not return nil")
}
}
// TestRegistry_Register tests singleton registration
func TestRegistry_Register(t *testing.T) {
registry := &Registry{
instances: make(map[string]*Instance),
groups: make(map[string]*Group),
}
initializer := func() interface{} {
return "test-value"
}
finalizer := func(v interface{}) {
// Mock finalizer
}
// Test successful registration
err := registry.Register("test-singleton", initializer, finalizer)
if err != nil {
t.Errorf("Register should succeed, got error: %v", err)
}
// Verify instance was registered
if len(registry.instances) != 1 {
t.Error("Instance should be registered")
}
instance := registry.instances["test-singleton"]
if instance == nil {
t.Error("Instance should not be nil")
return
}
if instance.name != "test-singleton" {
t.Errorf("Instance name should be 'test-singleton', got '%s'", instance.name)
}
if instance.initializer == nil {
t.Error("Instance should have initializer")
}
if instance.finalizer == nil {
t.Error("Instance should have finalizer")
}
}
// TestRegistry_Register_Duplicate tests duplicate registration
func TestRegistry_Register_Duplicate(t *testing.T) {
registry := &Registry{
instances: make(map[string]*Instance),
groups: make(map[string]*Group),
}
initializer := func() interface{} {
return "test-value"
}
// Register first time
err := registry.Register("test-singleton", initializer, nil)
if err != nil {
t.Errorf("First registration should succeed, got error: %v", err)
}
// Register again - should fail
err = registry.Register("test-singleton", initializer, nil)
if err == nil {
t.Error("Duplicate registration should fail")
}
if !strings.Contains(err.Error(), "already registered") {
t.Errorf("Error should mention already registered, got: %v", err)
}
}
// TestRegistry_Register_DuringShutdown tests registration during shutdown
func TestRegistry_Register_DuringShutdown(t *testing.T) {
registry := &Registry{
instances: make(map[string]*Instance),
groups: make(map[string]*Group),
shutdown: 1, // Already shutting down
}
initializer := func() interface{} {
return "test-value"
}
err := registry.Register("test-singleton", initializer, nil)
if err == nil {
t.Error("Registration during shutdown should fail")
}
if !strings.Contains(err.Error(), "shutting down") {
t.Errorf("Error should mention shutting down, got: %v", err)
}
}
// TestRegistry_GetInstance tests singleton retrieval and initialization
func TestRegistry_GetInstance(t *testing.T) {
registry := &Registry{
instances: make(map[string]*Instance),
groups: make(map[string]*Group),
}
callCount := int32(0)
testValue := "test-value"
initializer := func() interface{} {
atomic.AddInt32(&callCount, 1)
return testValue
}
// Register singleton
err := registry.Register("test-singleton", initializer, nil)
if err != nil {
t.Errorf("Register should succeed, got error: %v", err)
}
// First get - should initialize
value1, err := registry.GetInstance("test-singleton")
if err != nil {
t.Errorf("GetInstance should succeed, got error: %v", err)
}
if value1 != testValue {
t.Errorf("Value should be '%s', got '%v'", testValue, value1)
}
if atomic.LoadInt32(&callCount) != 1 {
t.Errorf("Initializer should be called once, called %d times", callCount)
}
// Second get - should return same instance without calling initializer
value2, err := registry.GetInstance("test-singleton")
if err != nil {
t.Errorf("GetInstance should succeed, got error: %v", err)
}
if value2 != testValue {
t.Errorf("Value should be '%s', got '%v'", testValue, value2)
}
if atomic.LoadInt32(&callCount) != 1 {
t.Errorf("Initializer should still be called only once, called %d times", callCount)
}
}
// TestRegistry_GetInstance_NotRegistered tests getting unregistered singleton
func TestRegistry_GetInstance_NotRegistered(t *testing.T) {
registry := &Registry{
instances: make(map[string]*Instance),
groups: make(map[string]*Group),
}
value, err := registry.GetInstance("non-existent")
if err == nil {
t.Error("GetInstance of non-existent singleton should fail")
}
if value != nil {
t.Error("Value should be nil for non-existent singleton")
}
if !strings.Contains(err.Error(), "not registered") {
t.Errorf("Error should mention not registered, got: %v", err)
}
}
// TestRegistry_GetInstance_DuringShutdown tests getting instance during shutdown
func TestRegistry_GetInstance_DuringShutdown(t *testing.T) {
registry := &Registry{
instances: make(map[string]*Instance),
groups: make(map[string]*Group),
shutdown: 1, // Already shutting down
}
value, err := registry.GetInstance("test-singleton")
if err == nil {
t.Error("GetInstance during shutdown should fail")
}
if value != nil {
t.Error("Value should be nil during shutdown")
}
if !strings.Contains(err.Error(), "shutting down") {
t.Errorf("Error should mention shutting down, got: %v", err)
}
}
// TestRegistry_MustGet tests MustGet method
func TestRegistry_MustGet(t *testing.T) {
registry := &Registry{
instances: make(map[string]*Instance),
groups: make(map[string]*Group),
}
testValue := "test-value"
initializer := func() interface{} {
return testValue
}
// Register singleton
err := registry.Register("test-singleton", initializer, nil)
if err != nil {
t.Errorf("Register should succeed, got error: %v", err)
}
// MustGet should succeed
value := registry.MustGet("test-singleton")
if value != testValue {
t.Errorf("Value should be '%s', got '%v'", testValue, value)
}
// MustGet non-existent should panic
defer func() {
if r := recover(); r == nil {
t.Error("MustGet of non-existent singleton should panic")
}
}()
registry.MustGet("non-existent")
}
// TestRegistry_RegisterGroup tests group registration
func TestRegistry_RegisterGroup(t *testing.T) {
registry := &Registry{
instances: make(map[string]*Instance),
groups: make(map[string]*Group),
}
// Test successful group registration
err := registry.RegisterGroup("test-group")
if err != nil {
t.Errorf("RegisterGroup should succeed, got error: %v", err)
}
// Verify group was registered
if len(registry.groups) != 1 {
t.Error("Group should be registered")
}
group := registry.groups["test-group"]
if group == nil {
t.Error("Group should not be nil")
return
}
if group.name != "test-group" {
t.Errorf("Group name should be 'test-group', got '%s'", group.name)
}
// Test duplicate group registration
err = registry.RegisterGroup("test-group")
if err == nil {
t.Error("Duplicate group registration should fail")
}
if !strings.Contains(err.Error(), "already exists") {
t.Errorf("Error should mention already exists, got: %v", err)
}
}
// TestRegistry_AddToGroup tests adding singletons to groups
func TestRegistry_AddToGroup(t *testing.T) {
registry := &Registry{
instances: make(map[string]*Instance),
groups: make(map[string]*Group),
}
// Register a singleton
initializer := func() interface{} {
return "test-value"
}
err := registry.Register("test-singleton", initializer, nil)
if err != nil {
t.Errorf("Register should succeed, got error: %v", err)
}
// Register a group
err = registry.RegisterGroup("test-group")
if err != nil {
t.Errorf("RegisterGroup should succeed, got error: %v", err)
}
// Add singleton to group
err = registry.AddToGroup("test-group", "test-singleton")
if err != nil {
t.Errorf("AddToGroup should succeed, got error: %v", err)
}
// Verify singleton is in group
group := registry.groups["test-group"]
if len(group.instances) != 1 {
t.Error("Group should contain one instance")
}
if group.instances["test-singleton"] == nil {
t.Error("Singleton should be in group")
}
// Test adding to non-existent group
err = registry.AddToGroup("non-existent-group", "test-singleton")
if err == nil {
t.Error("Adding to non-existent group should fail")
}
// Test adding non-existent singleton to group
err = registry.AddToGroup("test-group", "non-existent-singleton")
if err == nil {
t.Error("Adding non-existent singleton should fail")
}
}
// TestRegistry_GetGroup tests retrieving group instances
func TestRegistry_GetGroup(t *testing.T) {
registry := &Registry{
instances: make(map[string]*Instance),
groups: make(map[string]*Group),
}
// Register singletons
err := registry.Register("test-singleton-1", func() interface{} {
return "value-1"
}, nil)
if err != nil {
t.Errorf("Register should succeed, got error: %v", err)
}
err = registry.Register("test-singleton-2", func() interface{} {
return "value-2"
}, nil)
if err != nil {
t.Errorf("Register should succeed, got error: %v", err)
}
// Register group and add singletons
err = registry.RegisterGroup("test-group")
if err != nil {
t.Errorf("RegisterGroup should succeed, got error: %v", err)
}
err = registry.AddToGroup("test-group", "test-singleton-1")
if err != nil {
t.Errorf("AddToGroup should succeed, got error: %v", err)
}
err = registry.AddToGroup("test-group", "test-singleton-2")
if err != nil {
t.Errorf("AddToGroup should succeed, got error: %v", err)
}
// Initialize singletons
_, _ = registry.GetInstance("test-singleton-1")
_, _ = registry.GetInstance("test-singleton-2")
// Get group
groupInstances, err := registry.GetGroup("test-group")
if err != nil {
t.Errorf("GetGroup should succeed, got error: %v", err)
}
if len(groupInstances) != 2 {
t.Errorf("Group should contain 2 instances, got %d", len(groupInstances))
}
if groupInstances["test-singleton-1"] != "value-1" {
t.Error("Group should contain correct instance values")
}
if groupInstances["test-singleton-2"] != "value-2" {
t.Error("Group should contain correct instance values")
}
// Test getting non-existent group
_, err = registry.GetGroup("non-existent-group")
if err == nil {
t.Error("Getting non-existent group should fail")
}
}
// TestRegistry_ReferenceCountingv tests reference counting
func TestRegistry_ReferenceCountingv(t *testing.T) {
registry := &Registry{
instances: make(map[string]*Instance),
groups: make(map[string]*Group),
}
finalizerCalled := int32(0)
finalizer := func(v interface{}) {
atomic.AddInt32(&finalizerCalled, 1)
}
// Register singleton
err := registry.Register("test-singleton", func() interface{} {
return "test-value"
}, finalizer)
if err != nil {
t.Errorf("Register should succeed, got error: %v", err)
}
// Initialize singleton (this adds 1 reference)
_, err = registry.GetInstance("test-singleton")
if err != nil {
t.Errorf("GetInstance should succeed, got error: %v", err)
}
// Check initial reference count
count, err := registry.GetReferenceCount("test-singleton")
if err != nil {
t.Errorf("GetReferenceCount should succeed, got error: %v", err)
}
if count != 1 {
t.Errorf("Reference count should be 1, got %d", count)
}
// Add reference
err = registry.AddReference("test-singleton")
if err != nil {
t.Errorf("AddReference should succeed, got error: %v", err)
}
count, _ = registry.GetReferenceCount("test-singleton")
if count != 2 {
t.Errorf("Reference count should be 2, got %d", count)
}
// Release reference
err = registry.ReleaseReference("test-singleton")
if err != nil {
t.Errorf("ReleaseReference should succeed, got error: %v", err)
}
count, _ = registry.GetReferenceCount("test-singleton")
if count != 1 {
t.Errorf("Reference count should be 1, got %d", count)
}
// Release last reference - should trigger finalizer
err = registry.ReleaseReference("test-singleton")
if err != nil {
t.Errorf("ReleaseReference should succeed, got error: %v", err)
}
count, _ = registry.GetReferenceCount("test-singleton")
if count != 0 {
t.Errorf("Reference count should be 0, got %d", count)
}
// Wait for finalizer to run (it runs in goroutine)
time.Sleep(10 * time.Millisecond)
if atomic.LoadInt32(&finalizerCalled) != 1 {
t.Errorf("Finalizer should be called once, called %d times", finalizerCalled)
}
// Test reference operations on non-existent singleton
err = registry.AddReference("non-existent")
if err == nil {
t.Error("AddReference on non-existent singleton should fail")
}
err = registry.ReleaseReference("non-existent")
if err == nil {
t.Error("ReleaseReference on non-existent singleton should fail")
}
_, err = registry.GetReferenceCount("non-existent")
if err == nil {
t.Error("GetReferenceCount on non-existent singleton should fail")
}
}
// TestRegistry_Shutdown tests graceful shutdown
func TestRegistry_Shutdown(t *testing.T) {
registry := &Registry{
instances: make(map[string]*Instance),
groups: make(map[string]*Group),
}
finalizerCalled := int32(0)
finalizer := func(v interface{}) {
atomic.AddInt32(&finalizerCalled, 1)
}
// Register and initialize singletons
err := registry.Register("test-singleton-1", func() interface{} {
return "value-1"
}, finalizer)
if err != nil {
t.Errorf("Register should succeed, got error: %v", err)
}
err = registry.Register("test-singleton-2", func() interface{} {
return "value-2"
}, finalizer)
if err != nil {
t.Errorf("Register should succeed, got error: %v", err)
}
// Initialize singletons
_, _ = registry.GetInstance("test-singleton-1")
_, _ = registry.GetInstance("test-singleton-2")
// Shutdown
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err = registry.Shutdown(ctx)
if err != nil {
t.Errorf("Shutdown should succeed, got error: %v", err)
}
// Verify finalizers were called
if atomic.LoadInt32(&finalizerCalled) != 2 {
t.Errorf("Finalizers should be called 2 times, called %d times", finalizerCalled)
}
// Verify registry is cleared
if len(registry.instances) != 0 {
t.Error("Instances should be cleared after shutdown")
}
if len(registry.groups) != 0 {
t.Error("Groups should be cleared after shutdown")
}
// Verify shutdown flag is set
if atomic.LoadInt32(&registry.shutdown) != 1 {
t.Error("Shutdown flag should be set")
}
// Test double shutdown
err = registry.Shutdown(ctx)
if err == nil {
t.Error("Double shutdown should fail")
}
}
// TestRegistry_Shutdown_Timeout tests shutdown timeout
func TestRegistry_Shutdown_Timeout(t *testing.T) {
registry := &Registry{
instances: make(map[string]*Instance),
groups: make(map[string]*Group),
}
// Register singleton with slow finalizer
slowFinalizer := func(v interface{}) {
time.Sleep(100 * time.Millisecond)
}
err := registry.Register("slow-singleton", func() interface{} {
return "value"
}, slowFinalizer)
if err != nil {
t.Errorf("Register should succeed, got error: %v", err)
}
// Initialize singleton
_, _ = registry.GetInstance("slow-singleton")
// Shutdown with short timeout
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
defer cancel()
err = registry.Shutdown(ctx)
if err == nil {
t.Error("Shutdown should timeout")
}
if !strings.Contains(err.Error(), "timeout") {
t.Errorf("Error should mention timeout, got: %v", err)
}
}
// TestRegistry_Shutdown_PanicRecovery tests panic recovery during shutdown
func TestRegistry_Shutdown_PanicRecovery(t *testing.T) {
registry := &Registry{
instances: make(map[string]*Instance),
groups: make(map[string]*Group),
}
// Register singleton with panicking finalizer
panicFinalizer := func(v interface{}) {
panic("finalizer panic")
}
err := registry.Register("panic-singleton", func() interface{} {
return "value"
}, panicFinalizer)
if err != nil {
t.Errorf("Register should succeed, got error: %v", err)
}
// Initialize singleton
_, _ = registry.GetInstance("panic-singleton")
// Shutdown should handle panic
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err = registry.Shutdown(ctx)
if err == nil {
t.Error("Shutdown should report finalizer panic")
}
if !strings.Contains(err.Error(), "panicked") {
t.Errorf("Error should mention panic, got: %v", err)
}
}
// TestRegistry_Reset tests registry reset
func TestRegistry_Reset(t *testing.T) {
registry := &Registry{
instances: make(map[string]*Instance),
groups: make(map[string]*Group),
shutdown: 1,
}
// Add some data
registry.instances["test"] = &Instance{}
registry.groups["test"] = &Group{}
// Reset
registry.Reset()
// Verify everything is cleared
if len(registry.instances) != 0 {
t.Error("Instances should be cleared after reset")
}
if len(registry.groups) != 0 {
t.Error("Groups should be cleared after reset")
}
if atomic.LoadInt32(&registry.shutdown) != 0 {
t.Error("Shutdown flag should be cleared after reset")
}
}
// TestRegistry_GetStats tests statistics
func TestRegistry_GetStats(t *testing.T) {
registry := &Registry{
instances: make(map[string]*Instance),
groups: make(map[string]*Group),
}
// Register singletons
err := registry.Register("test-singleton-1", func() interface{} {
return "value-1"
}, nil)
if err != nil {
t.Errorf("Register should succeed, got error: %v", err)
}
err = registry.Register("test-singleton-2", func() interface{} {
return "value-2"
}, nil)
if err != nil {
t.Errorf("Register should succeed, got error: %v", err)
}
// Register group
err = registry.RegisterGroup("test-group")
if err != nil {
t.Errorf("RegisterGroup should succeed, got error: %v", err)
}
// Initialize one singleton
_, _ = registry.GetInstance("test-singleton-1")
// Add reference
_ = registry.AddReference("test-singleton-1")
// Get stats
stats := registry.GetStats()
if stats.TotalRegistered != 2 {
t.Errorf("TotalRegistered should be 2, got %d", stats.TotalRegistered)
}
if stats.TotalInitialized != 1 {
t.Errorf("TotalInitialized should be 1, got %d", stats.TotalInitialized)
}
if stats.TotalGroups != 1 {
t.Errorf("TotalGroups should be 1, got %d", stats.TotalGroups)
}
if stats.TotalReferences != 2 { // 1 from initialization + 1 from AddReference
t.Errorf("TotalReferences should be 2, got %d", stats.TotalReferences)
}
}
// TestBuilder tests the fluent builder interface
func TestBuilder(t *testing.T) {
// Reset global registry for clean test
Get().Reset()
testValue := "builder-test-value"
initializer := func() interface{} {
return testValue
}
finalizer := func(v interface{}) {
// Mock finalizer for builder test
}
// Test builder
err := NewBuilder("builder-singleton").
WithInitializer(initializer).
WithFinalizer(finalizer).
InGroup("builder-group").
Register()
if err != nil {
t.Errorf("Builder registration should succeed, got error: %v", err)
}
// Verify singleton was registered
value, err := Get().GetInstance("builder-singleton")
if err != nil {
t.Errorf("GetInstance should succeed, got error: %v", err)
}
if value != testValue {
t.Errorf("Value should be '%s', got '%v'", testValue, value)
}
// Verify group was created and singleton added
groupInstances, err := Get().GetGroup("builder-group")
if err != nil {
t.Errorf("GetGroup should succeed, got error: %v", err)
}
if len(groupInstances) != 1 {
t.Errorf("Group should contain 1 instance, got %d", len(groupInstances))
}
if groupInstances["builder-singleton"] != testValue {
t.Error("Group should contain correct instance")
}
}
// TestBuilder_WithoutGroup tests builder without group
func TestBuilder_WithoutGroup(t *testing.T) {
registry := &Registry{
instances: make(map[string]*Instance),
groups: make(map[string]*Group),
}
builder := &Builder{
registry: registry,
name: "no-group-singleton",
}
err := builder.WithInitializer(func() interface{} {
return "value"
}).Register()
if err != nil {
t.Errorf("Registration without group should succeed, got error: %v", err)
}
// Verify singleton was registered
if len(registry.instances) != 1 {
t.Error("Singleton should be registered")
}
}
// TestContainsHelper tests the helper string contains function
func TestContainsHelper(t *testing.T) {
tests := []struct {
s string
substr string
expect bool
}{
{"hello world", "world", true},
{"hello world", "hello", true},
{"hello world", "lo wo", true},
{"hello world", "xyz", false},
{"hello", "hello world", false},
{"", "test", false},
{"test", "", true},
{"", "", true},
}
for _, test := range tests {
result := contains(test.s, test.substr)
if result != test.expect {
t.Errorf("contains(%q, %q) = %v, want %v", test.s, test.substr, result, test.expect)
}
}
}
// TestRegistry_ConcurrentAccess tests concurrent access to registry
func TestRegistry_ConcurrentAccess(t *testing.T) {
registry := &Registry{
instances: make(map[string]*Instance),
groups: make(map[string]*Group),
}
callCount := int32(0)
initializer := func() interface{} {
atomic.AddInt32(&callCount, 1)
return "concurrent-value"
}
// Register singleton
err := registry.Register("concurrent-singleton", initializer, nil)
if err != nil {
t.Errorf("Register should succeed, got error: %v", err)
}
var wg sync.WaitGroup
numGoroutines := 50
// Concurrent access
wg.Add(numGoroutines)
for i := 0; i < numGoroutines; i++ {
go func() {
defer wg.Done()
value, err := registry.GetInstance("concurrent-singleton")
if err != nil {
t.Errorf("GetInstance should succeed, got error: %v", err)
return
}
if value != "concurrent-value" {
t.Errorf("Value should be 'concurrent-value', got '%v'", value)
}
}()
}
wg.Wait()
// Initializer should be called only once despite concurrent access
if atomic.LoadInt32(&callCount) != 1 {
t.Errorf("Initializer should be called only once, called %d times", callCount)
}
}
// TestRegistry_ConcurrentReferenceOperations tests concurrent reference operations
func TestRegistry_ConcurrentReferenceOperations(t *testing.T) {
registry := &Registry{
instances: make(map[string]*Instance),
groups: make(map[string]*Group),
}
// Register singleton
err := registry.Register("ref-singleton", func() interface{} {
return "ref-value"
}, nil)
if err != nil {
t.Errorf("Register should succeed, got error: %v", err)
}
// Initialize singleton
_, _ = registry.GetInstance("ref-singleton")
var wg sync.WaitGroup
numGoroutines := 20
// Concurrent reference operations
wg.Add(numGoroutines * 2)
for i := 0; i < numGoroutines; i++ {
go func() {
defer wg.Done()
_ = registry.AddReference("ref-singleton")
}()
go func() {
defer wg.Done()
_ = registry.ReleaseReference("ref-singleton")
}()
}
wg.Wait()
// Reference count should be consistent (initial 1 + net operations)
count, err := registry.GetReferenceCount("ref-singleton")
if err != nil {
t.Errorf("GetReferenceCount should succeed, got error: %v", err)
}
// Count should be >= 0 due to balanced add/release operations
if count < 0 {
t.Errorf("Reference count should not be negative, got %d", count)
}
}
// Benchmark tests for performance verification
func BenchmarkRegistry_GetInstance(b *testing.B) {
registry := &Registry{
instances: make(map[string]*Instance),
groups: make(map[string]*Group),
}
registry.Register("benchmark-singleton", func() interface{} {
return "benchmark-value"
}, nil)
b.ResetTimer()
for i := 0; i < b.N; i++ {
registry.GetInstance("benchmark-singleton")
}
}
func BenchmarkRegistry_ConcurrentGetInstance(b *testing.B) {
registry := &Registry{
instances: make(map[string]*Instance),
groups: make(map[string]*Group),
}
registry.Register("concurrent-benchmark", func() interface{} {
return "concurrent-value"
}, nil)
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
registry.GetInstance("concurrent-benchmark")
}
})
}
func BenchmarkBuilder_Register(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
registry := &Registry{
instances: make(map[string]*Instance),
groups: make(map[string]*Group),
}
builder := &Builder{
registry: registry,
name: fmt.Sprintf("benchmark-%d", i),
}
builder.WithInitializer(func() interface{} {
return "value"
}).Register()
}
}
-393
View File
@@ -1,393 +0,0 @@
// Package testing provides unified mock implementations for tests
package testing
import (
"fmt"
"net/http"
"sync"
"time"
)
// UnifiedMockLogger provides a standard mock logger for all tests
type UnifiedMockLogger struct {
LoggedMessages []string
mu sync.RWMutex
}
func NewUnifiedMockLogger() *UnifiedMockLogger {
return &UnifiedMockLogger{
LoggedMessages: make([]string, 0),
}
}
func (l *UnifiedMockLogger) Debug(msg string) {
l.mu.Lock()
defer l.mu.Unlock()
l.LoggedMessages = append(l.LoggedMessages, fmt.Sprintf("DEBUG: %s", msg))
}
func (l *UnifiedMockLogger) Debugf(format string, args ...interface{}) {
l.Debug(fmt.Sprintf(format, args...))
}
func (l *UnifiedMockLogger) Info(msg string) {
l.mu.Lock()
defer l.mu.Unlock()
l.LoggedMessages = append(l.LoggedMessages, fmt.Sprintf("INFO: %s", msg))
}
func (l *UnifiedMockLogger) Infof(format string, args ...interface{}) {
l.Info(fmt.Sprintf(format, args...))
}
func (l *UnifiedMockLogger) Error(msg string) {
l.mu.Lock()
defer l.mu.Unlock()
l.LoggedMessages = append(l.LoggedMessages, fmt.Sprintf("ERROR: %s", msg))
}
func (l *UnifiedMockLogger) Errorf(format string, args ...interface{}) {
l.Error(fmt.Sprintf(format, args...))
}
func (l *UnifiedMockLogger) GetMessages() []string {
l.mu.RLock()
defer l.mu.RUnlock()
result := make([]string, len(l.LoggedMessages))
copy(result, l.LoggedMessages)
return result
}
func (l *UnifiedMockLogger) Clear() {
l.mu.Lock()
defer l.mu.Unlock()
l.LoggedMessages = l.LoggedMessages[:0]
}
// UnifiedMockSession provides a standard mock session for all tests
type UnifiedMockSession struct {
authenticated bool
idToken string
accessToken string
refreshToken string
email string
csrf string
nonce string
codeVerifier string
incomingPath string
redirectCount int
mu sync.RWMutex
}
func NewUnifiedMockSession() *UnifiedMockSession {
return &UnifiedMockSession{}
}
func (s *UnifiedMockSession) GetAuthenticated() bool {
s.mu.RLock()
defer s.mu.RUnlock()
return s.authenticated
}
func (s *UnifiedMockSession) SetAuthenticated(auth bool) error {
s.mu.Lock()
defer s.mu.Unlock()
s.authenticated = auth
return nil
}
func (s *UnifiedMockSession) GetIDToken() string {
s.mu.RLock()
defer s.mu.RUnlock()
return s.idToken
}
func (s *UnifiedMockSession) SetIDToken(token string) {
s.mu.Lock()
defer s.mu.Unlock()
s.idToken = token
}
func (s *UnifiedMockSession) GetAccessToken() string {
s.mu.RLock()
defer s.mu.RUnlock()
return s.accessToken
}
func (s *UnifiedMockSession) SetAccessToken(token string) {
s.mu.Lock()
defer s.mu.Unlock()
s.accessToken = token
}
func (s *UnifiedMockSession) GetRefreshToken() string {
s.mu.RLock()
defer s.mu.RUnlock()
return s.refreshToken
}
func (s *UnifiedMockSession) SetRefreshToken(token string) {
s.mu.Lock()
defer s.mu.Unlock()
s.refreshToken = token
}
func (s *UnifiedMockSession) GetEmail() string {
s.mu.RLock()
defer s.mu.RUnlock()
return s.email
}
func (s *UnifiedMockSession) SetEmail(email string) {
s.mu.Lock()
defer s.mu.Unlock()
s.email = email
}
func (s *UnifiedMockSession) GetCSRF() string {
s.mu.RLock()
defer s.mu.RUnlock()
return s.csrf
}
func (s *UnifiedMockSession) SetCSRF(csrf string) {
s.mu.Lock()
defer s.mu.Unlock()
s.csrf = csrf
}
func (s *UnifiedMockSession) GetNonce() string {
s.mu.RLock()
defer s.mu.RUnlock()
return s.nonce
}
func (s *UnifiedMockSession) SetNonce(nonce string) {
s.mu.Lock()
defer s.mu.Unlock()
s.nonce = nonce
}
func (s *UnifiedMockSession) GetCodeVerifier() string {
s.mu.RLock()
defer s.mu.RUnlock()
return s.codeVerifier
}
func (s *UnifiedMockSession) SetCodeVerifier(verifier string) {
s.mu.Lock()
defer s.mu.Unlock()
s.codeVerifier = verifier
}
func (s *UnifiedMockSession) GetIncomingPath() string {
s.mu.RLock()
defer s.mu.RUnlock()
return s.incomingPath
}
func (s *UnifiedMockSession) SetIncomingPath(path string) {
s.mu.Lock()
defer s.mu.Unlock()
s.incomingPath = path
}
func (s *UnifiedMockSession) GetRedirectCount() int {
s.mu.RLock()
defer s.mu.RUnlock()
return s.redirectCount
}
func (s *UnifiedMockSession) IncrementRedirectCount() {
s.mu.Lock()
defer s.mu.Unlock()
s.redirectCount++
}
func (s *UnifiedMockSession) ResetRedirectCount() {
s.mu.Lock()
defer s.mu.Unlock()
s.redirectCount = 0
}
func (s *UnifiedMockSession) Save(req *http.Request, rw http.ResponseWriter) error {
return nil
}
func (s *UnifiedMockSession) Clear(req *http.Request, rw http.ResponseWriter) error {
s.mu.Lock()
defer s.mu.Unlock()
s.authenticated = false
s.idToken = ""
s.accessToken = ""
s.refreshToken = ""
s.email = ""
s.csrf = ""
s.nonce = ""
s.codeVerifier = ""
s.incomingPath = ""
s.redirectCount = 0
return nil
}
func (s *UnifiedMockSession) MarkDirty() {}
func (s *UnifiedMockSession) IsDirty() bool {
return false
}
func (s *UnifiedMockSession) ReturnToPoolSafely() {}
// UnifiedMockTokenVerifier provides a standard mock token verifier
type UnifiedMockTokenVerifier struct {
ShouldFail bool
Error error
}
func NewUnifiedMockTokenVerifier() *UnifiedMockTokenVerifier {
return &UnifiedMockTokenVerifier{}
}
func (v *UnifiedMockTokenVerifier) VerifyToken(token string) error {
if v.ShouldFail {
if v.Error != nil {
return v.Error
}
return fmt.Errorf("mock verification failed")
}
return nil
}
// UnifiedMockTokenCache provides a standard mock token cache
type UnifiedMockTokenCache struct {
data map[string]map[string]interface{}
mu sync.RWMutex
}
func NewUnifiedMockTokenCache() *UnifiedMockTokenCache {
return &UnifiedMockTokenCache{
data: make(map[string]map[string]interface{}),
}
}
func (c *UnifiedMockTokenCache) Get(key string) (map[string]interface{}, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
value, exists := c.data[key]
return value, exists
}
func (c *UnifiedMockTokenCache) Set(key string, claims map[string]interface{}, ttl time.Duration) {
c.mu.Lock()
defer c.mu.Unlock()
c.data[key] = claims
}
func (c *UnifiedMockTokenCache) Delete(key string) {
c.mu.Lock()
defer c.mu.Unlock()
delete(c.data, key)
}
func (c *UnifiedMockTokenCache) SetMaxSize(size int) {}
func (c *UnifiedMockTokenCache) Size() int {
c.mu.RLock()
defer c.mu.RUnlock()
return len(c.data)
}
func (c *UnifiedMockTokenCache) Clear() {
c.mu.Lock()
defer c.mu.Unlock()
c.data = make(map[string]map[string]interface{})
}
func (c *UnifiedMockTokenCache) Cleanup() {}
func (c *UnifiedMockTokenCache) Close() {}
func (c *UnifiedMockTokenCache) GetStats() map[string]interface{} {
return map[string]interface{}{
"size": c.Size(),
}
}
// UnifiedMockHTTPClient provides a mock HTTP client for tests
type UnifiedMockHTTPClient struct {
Responses map[string]*http.Response
Errors map[string]error
mu sync.RWMutex
}
func NewUnifiedMockHTTPClient() *UnifiedMockHTTPClient {
return &UnifiedMockHTTPClient{
Responses: make(map[string]*http.Response),
Errors: make(map[string]error),
}
}
func (c *UnifiedMockHTTPClient) Do(req *http.Request) (*http.Response, error) {
c.mu.RLock()
defer c.mu.RUnlock()
url := req.URL.String()
if err, exists := c.Errors[url]; exists {
return nil, err
}
if resp, exists := c.Responses[url]; exists {
return resp, nil
}
// Default response
return &http.Response{
StatusCode: 200,
Body: http.NoBody,
Header: make(http.Header),
}, nil
}
func (c *UnifiedMockHTTPClient) SetResponse(url string, response *http.Response) {
c.mu.Lock()
defer c.mu.Unlock()
c.Responses[url] = response
}
func (c *UnifiedMockHTTPClient) SetError(url string, err error) {
c.mu.Lock()
defer c.mu.Unlock()
c.Errors[url] = err
}
// TestSuite provides a unified test setup and teardown
type TestSuite struct {
Logger *UnifiedMockLogger
Session *UnifiedMockSession
TokenVerifier *UnifiedMockTokenVerifier
TokenCache *UnifiedMockTokenCache
HTTPClient *UnifiedMockHTTPClient
}
func NewTestSuite() *TestSuite {
return &TestSuite{
Logger: NewUnifiedMockLogger(),
Session: NewUnifiedMockSession(),
TokenVerifier: NewUnifiedMockTokenVerifier(),
TokenCache: NewUnifiedMockTokenCache(),
HTTPClient: NewUnifiedMockHTTPClient(),
}
}
func (ts *TestSuite) Setup() {
// Common test setup
ts.Logger.Clear()
_ = ts.Session.Clear(nil, nil) // Safe to ignore: test helper function
ts.TokenCache.Clear()
ts.TokenVerifier.ShouldFail = false
ts.TokenVerifier.Error = nil
}
func (ts *TestSuite) Teardown() {
// Common test teardown
ts.TokenCache.Close()
}
+140
View File
@@ -0,0 +1,140 @@
package testutil
import (
"time"
"github.com/lukaszraczylo/traefikoidc/internal/testutil/fixtures"
"github.com/lukaszraczylo/traefikoidc/internal/testutil/mocks"
"github.com/lukaszraczylo/traefikoidc/internal/testutil/servers"
)
// Re-export types for easier access from main package tests
type (
// Mocks
JWKCacheMock = mocks.JWKCache
TokenExchangerMock = mocks.TokenExchanger
TokenVerifierMock = mocks.TokenVerifier
JWTVerifierMock = mocks.JWTVerifier
SessionManagerMock = mocks.SessionManager
CacheMock = mocks.Cache
TokenCacheMock = mocks.TokenCache
BlacklistMock = mocks.Blacklist
HTTPClientMock = mocks.HTTPClient
RoundTripperMock = mocks.RoundTripper
LoggerMock = mocks.Logger
// Mock types
JWKSet = mocks.JWKSet
JWK = mocks.JWK
MockTokenResponse = mocks.TokenResponse
MockSessionData = mocks.SessionData
IntrospectionResp = mocks.IntrospectionResponse
// Fixtures
TokenFixture = fixtures.TokenFixture
// Servers
OIDCServer = servers.OIDCServer
OIDCServerConfig = servers.OIDCServerConfig
OIDCError = servers.OIDCError
)
// NewJWKCacheMock creates a new JWK cache mock
func NewJWKCacheMock() *mocks.JWKCache {
return new(mocks.JWKCache)
}
// NewTokenExchangerMock creates a new token exchanger mock
func NewTokenExchangerMock() *mocks.TokenExchanger {
return new(mocks.TokenExchanger)
}
// NewTokenVerifierMock creates a new token verifier mock
func NewTokenVerifierMock() *mocks.TokenVerifier {
return new(mocks.TokenVerifier)
}
// NewJWTVerifierMock creates a new JWT verifier mock
func NewJWTVerifierMock() *mocks.JWTVerifier {
return new(mocks.JWTVerifier)
}
// NewSessionManagerMock creates a new session manager mock
func NewSessionManagerMock() *mocks.SessionManager {
return new(mocks.SessionManager)
}
// NewCacheMock creates a new cache mock
func NewCacheMock() *mocks.Cache {
return new(mocks.Cache)
}
// NewTokenCacheMock creates a new token cache mock
func NewTokenCacheMock() *mocks.TokenCache {
return new(mocks.TokenCache)
}
// NewBlacklistMock creates a new blacklist mock
func NewBlacklistMock() *mocks.Blacklist {
return new(mocks.Blacklist)
}
// NewHTTPClientMock creates a new HTTP client mock
func NewHTTPClientMock() *mocks.HTTPClient {
return new(mocks.HTTPClient)
}
// NewRoundTripperMock creates a new round tripper mock
func NewRoundTripperMock() *mocks.RoundTripper {
return new(mocks.RoundTripper)
}
// NewLoggerMock creates a new logger mock
func NewLoggerMock() *mocks.Logger {
return new(mocks.Logger)
}
// NewTokenFixture creates a new token fixture
func NewTokenFixture() (*fixtures.TokenFixture, error) {
return fixtures.NewTokenFixture()
}
// NewOIDCServer creates a new mock OIDC server
func NewOIDCServer(config *servers.OIDCServerConfig) *servers.OIDCServer {
return servers.NewOIDCServer(config)
}
// DefaultServerConfig returns a default server configuration
func DefaultServerConfig() *servers.OIDCServerConfig {
return servers.DefaultConfig()
}
// GoogleServerConfig returns a Google-like server configuration
func GoogleServerConfig() *servers.OIDCServerConfig {
return servers.GoogleConfig()
}
// AzureServerConfig returns an Azure AD-like server configuration
func AzureServerConfig() *servers.OIDCServerConfig {
return servers.AzureConfig()
}
// Auth0ServerConfig returns an Auth0-like server configuration
func Auth0ServerConfig() *servers.OIDCServerConfig {
return servers.Auth0Config()
}
// KeycloakServerConfig returns a Keycloak-like server configuration
func KeycloakServerConfig() *servers.OIDCServerConfig {
return servers.KeycloakConfig()
}
// SlowServerConfig returns a configuration with delays
func SlowServerConfig(delay time.Duration) *servers.OIDCServerConfig {
return servers.SlowServerConfig(delay)
}
// RateLimitedServerConfig returns a rate-limited configuration
func RateLimitedServerConfig(afterN int) *servers.OIDCServerConfig {
return servers.RateLimitedConfig(afterN)
}
+330
View File
@@ -0,0 +1,330 @@
package fixtures
import (
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"math/big"
"time"
)
// TokenFixture provides JWT token generation for tests
type TokenFixture struct {
RSAPrivateKey *rsa.PrivateKey
RSAPublicKey *rsa.PublicKey
ECPrivateKey *ecdsa.PrivateKey
ECPublicKey *ecdsa.PublicKey
KeyID string
Issuer string
Audience string
ClockSkew time.Duration
}
// NewTokenFixture creates a new token fixture with generated keys
func NewTokenFixture() (*TokenFixture, error) {
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, err
}
ecKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return nil, err
}
return &TokenFixture{
RSAPrivateKey: rsaKey,
RSAPublicKey: &rsaKey.PublicKey,
ECPrivateKey: ecKey,
ECPublicKey: &ecKey.PublicKey,
KeyID: "test-key-id",
Issuer: "https://test-issuer.com",
Audience: "test-client-id",
ClockSkew: 2 * time.Minute,
}, nil
}
// DefaultClaims returns standard JWT claims
func (f *TokenFixture) DefaultClaims() map[string]interface{} {
now := time.Now()
return map[string]interface{}{
"iss": f.Issuer,
"aud": f.Audience,
"sub": "test-subject",
"email": "user@example.com",
"exp": now.Add(1 * time.Hour).Unix(),
"iat": now.Add(-f.ClockSkew).Unix(),
"nbf": now.Add(-f.ClockSkew).Unix(),
"nonce": "test-nonce",
"jti": generateJTI(),
}
}
// ValidToken creates a valid JWT token with optional claim overrides
func (f *TokenFixture) ValidToken(claimOverrides map[string]interface{}) (string, error) {
claims := f.DefaultClaims()
for k, v := range claimOverrides {
claims[k] = v
}
return f.createJWT(claims, "RS256", f.KeyID)
}
// ExpiredToken creates an expired JWT token
func (f *TokenFixture) ExpiredToken() (string, error) {
claims := f.DefaultClaims()
claims["exp"] = time.Now().Add(-1 * time.Hour).Unix()
return f.createJWT(claims, "RS256", f.KeyID)
}
// NotYetValidToken creates a token that's not valid yet
func (f *TokenFixture) NotYetValidToken() (string, error) {
claims := f.DefaultClaims()
claims["nbf"] = time.Now().Add(1 * time.Hour).Unix()
return f.createJWT(claims, "RS256", f.KeyID)
}
// TokenWithSkew creates a token with a specific time offset
func (f *TokenFixture) TokenWithSkew(skew time.Duration) (string, error) {
claims := f.DefaultClaims()
claims["exp"] = time.Now().Add(skew).Unix()
return f.createJWT(claims, "RS256", f.KeyID)
}
// TokenWithRoles creates a token with specific roles
func (f *TokenFixture) TokenWithRoles(roles []string) (string, error) {
claims := f.DefaultClaims()
claims["roles"] = roles
return f.createJWT(claims, "RS256", f.KeyID)
}
// TokenWithGroups creates a token with specific groups
func (f *TokenFixture) TokenWithGroups(groups []string) (string, error) {
claims := f.DefaultClaims()
claims["groups"] = groups
return f.createJWT(claims, "RS256", f.KeyID)
}
// TokenWithEmail creates a token with a specific email
func (f *TokenFixture) TokenWithEmail(email string) (string, error) {
claims := f.DefaultClaims()
claims["email"] = email
return f.createJWT(claims, "RS256", f.KeyID)
}
// TokenWithAudience creates a token with a specific audience
func (f *TokenFixture) TokenWithAudience(audience string) (string, error) {
claims := f.DefaultClaims()
claims["aud"] = audience
return f.createJWT(claims, "RS256", f.KeyID)
}
// TokenWithIssuer creates a token with a specific issuer
func (f *TokenFixture) TokenWithIssuer(issuer string) (string, error) {
claims := f.DefaultClaims()
claims["iss"] = issuer
return f.createJWT(claims, "RS256", f.KeyID)
}
// TokenMissingClaim creates a token missing specified claims
func (f *TokenFixture) TokenMissingClaim(missingClaims ...string) (string, error) {
claims := f.DefaultClaims()
for _, claim := range missingClaims {
delete(claims, claim)
}
return f.createJWT(claims, "RS256", f.KeyID)
}
// TokenWithCustomClaims creates a token with custom claims
func (f *TokenFixture) TokenWithCustomClaims(customClaims map[string]interface{}) (string, error) {
claims := f.DefaultClaims()
for k, v := range customClaims {
claims[k] = v
}
return f.createJWT(claims, "RS256", f.KeyID)
}
// MalformedToken returns an invalid JWT string
func (f *TokenFixture) MalformedToken() string {
return "not.a.valid.jwt"
}
// EmptyToken returns an empty string
func (f *TokenFixture) EmptyToken() string {
return ""
}
// TokenWithWrongSignature creates a token signed with a different key
func (f *TokenFixture) TokenWithWrongSignature() (string, error) {
wrongKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return "", err
}
claims := f.DefaultClaims()
return createJWTWithKey(claims, "RS256", f.KeyID, wrongKey)
}
// TokenWithWrongAlgorithm creates a token with mismatched algorithm
func (f *TokenFixture) TokenWithWrongAlgorithm() (string, error) {
claims := f.DefaultClaims()
// Create token claiming RS256 but we'll return it as-is
// This simulates algorithm confusion attacks
return f.createJWT(claims, "none", f.KeyID)
}
// ECToken creates a token signed with EC key
func (f *TokenFixture) ECToken() (string, error) {
claims := f.DefaultClaims()
return f.createECJWT(claims, "ES256", f.KeyID)
}
// GetJWKS returns a JWKS containing the test public key
func (f *TokenFixture) GetJWKS() map[string]interface{} {
return map[string]interface{}{
"keys": []map[string]interface{}{
{
"kty": "RSA",
"kid": f.KeyID,
"use": "sig",
"alg": "RS256",
"n": base64.RawURLEncoding.EncodeToString(f.RSAPublicKey.N.Bytes()),
"e": base64.RawURLEncoding.EncodeToString(bigIntToBytes(big.NewInt(int64(f.RSAPublicKey.E)))),
},
},
}
}
// GetJWKSBytes returns JWKS as JSON bytes
func (f *TokenFixture) GetJWKSBytes() ([]byte, error) {
return json.Marshal(f.GetJWKS())
}
// createJWT creates a JWT with the fixture's RSA key
func (f *TokenFixture) createJWT(claims map[string]interface{}, alg, kid string) (string, error) {
return createJWTWithKey(claims, alg, kid, f.RSAPrivateKey)
}
// createECJWT creates a JWT with the fixture's EC key
func (f *TokenFixture) createECJWT(claims map[string]interface{}, alg, kid string) (string, error) {
return createECJWTWithKey(claims, alg, kid, f.ECPrivateKey)
}
// Helper functions
func generateJTI() string {
b := make([]byte, 16)
_, _ = rand.Read(b) // #nosec G104 - test fixture, crypto strength not critical
return base64.RawURLEncoding.EncodeToString(b)
}
func bigIntToBytes(i *big.Int) []byte {
return i.Bytes()
}
func createJWTWithKey(claims map[string]interface{}, alg, kid string, key *rsa.PrivateKey) (string, error) {
header := map[string]interface{}{
"alg": alg,
"typ": "JWT",
"kid": kid,
}
headerBytes, err := json.Marshal(header)
if err != nil {
return "", err
}
claimsBytes, err := json.Marshal(claims)
if err != nil {
return "", err
}
headerB64 := base64.RawURLEncoding.EncodeToString(headerBytes)
claimsB64 := base64.RawURLEncoding.EncodeToString(claimsBytes)
signingInput := headerB64 + "." + claimsB64
// For "none" algorithm, return without signature
if alg == "none" {
return signingInput + ".", nil
}
// Sign with RSA-SHA256
signature, err := signRS256([]byte(signingInput), key)
if err != nil {
return "", err
}
signatureB64 := base64.RawURLEncoding.EncodeToString(signature)
return signingInput + "." + signatureB64, nil
}
func createECJWTWithKey(claims map[string]interface{}, alg, kid string, key *ecdsa.PrivateKey) (string, error) {
header := map[string]interface{}{
"alg": alg,
"typ": "JWT",
"kid": kid,
}
headerBytes, err := json.Marshal(header)
if err != nil {
return "", err
}
claimsBytes, err := json.Marshal(claims)
if err != nil {
return "", err
}
headerB64 := base64.RawURLEncoding.EncodeToString(headerBytes)
claimsB64 := base64.RawURLEncoding.EncodeToString(claimsBytes)
signingInput := headerB64 + "." + claimsB64
// Sign with ECDSA-SHA256
signature, err := signES256([]byte(signingInput), key)
if err != nil {
return "", err
}
signatureB64 := base64.RawURLEncoding.EncodeToString(signature)
return signingInput + "." + signatureB64, nil
}
func signRS256(data []byte, key *rsa.PrivateKey) ([]byte, error) {
h := hashSHA256(data)
return rsa.SignPKCS1v15(rand.Reader, key, crypto.SHA256, h)
}
func signES256(data []byte, key *ecdsa.PrivateKey) ([]byte, error) {
h := hashSHA256(data)
r, s, err := ecdsa.Sign(rand.Reader, key, h)
if err != nil {
return nil, err
}
// Encode r and s as fixed-size byte arrays
curveBits := key.Curve.Params().BitSize
keyBytes := curveBits / 8
if curveBits%8 > 0 {
keyBytes++
}
rBytes := r.Bytes()
sBytes := s.Bytes()
signature := make([]byte, 2*keyBytes)
copy(signature[keyBytes-len(rBytes):keyBytes], rBytes)
copy(signature[2*keyBytes-len(sBytes):], sBytes)
return signature, nil
}
func hashSHA256(data []byte) []byte {
h := sha256.Sum256(data)
return h[:]
}

Some files were not shown because too many files have changed in this diff Show More