mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-07 22:53:58 +00:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| c474bbafd6 | |||
| 9126c74723 | |||
| a750c4f5b9 | |||
| 56051779ee |
@@ -18,6 +18,6 @@ jobs:
|
||||
pr-checks:
|
||||
uses: lukaszraczylo/shared-actions/.github/workflows/go-pr.yaml@main
|
||||
with:
|
||||
go-version: "1.24"
|
||||
go-version: "1.24.11"
|
||||
coverage-threshold: 70
|
||||
secrets: inherit
|
||||
|
||||
@@ -17,5 +17,5 @@ jobs:
|
||||
release:
|
||||
uses: lukaszraczylo/shared-actions/.github/workflows/go-release.yaml@main
|
||||
with:
|
||||
go-version: "1.24"
|
||||
go-version: "1.24.11"
|
||||
secrets: inherit
|
||||
|
||||
+2
-1
@@ -1,2 +1,3 @@
|
||||
docker/
|
||||
.claude/
|
||||
.claude/*.out
|
||||
*.test
|
||||
|
||||
+1
-1
@@ -7,7 +7,7 @@ builds:
|
||||
|
||||
# Create source archive for GitHub releases
|
||||
archives:
|
||||
- format: tar.gz
|
||||
- formats: [tar.gz]
|
||||
name_template: "{{ .ProjectName }}_v{{ .Version }}_source"
|
||||
files:
|
||||
- "*.go"
|
||||
|
||||
+118
@@ -77,6 +77,7 @@ testData:
|
||||
# Custom claim names for Auth0 and other providers with namespaced claims
|
||||
roleClaimName: roles # JWT claim name for extracting user roles (default: "roles")
|
||||
groupClaimName: groups # JWT claim name for extracting user groups (default: "groups")
|
||||
userIdentifierClaim: email # JWT claim for user identification (default: "email", alternatives: "sub", "oid", "upn", "preferred_username")
|
||||
|
||||
# ⚠️ CRITICAL for TLS termination scenarios (AWS ALB, Cloud Load Balancers, etc.)
|
||||
# When NOT specified in config: defaults to FALSE (Go zero value)
|
||||
@@ -120,6 +121,8 @@ testData:
|
||||
allowOpaqueTokens: false # Enable opaque (non-JWT) access token support via RFC 7662 introspection
|
||||
requireTokenIntrospection: false # Force introspection for opaque tokens (requires introspection endpoint)
|
||||
disableReplayDetection: false # Disable JTI replay detection for multi-replica deployments (default: false)
|
||||
allowPrivateIPAddresses: false # Allow private IP addresses in provider URLs for internal networks (default: false)
|
||||
minimalHeaders: false # Reduce forwarded headers to prevent 431 errors (default: false)
|
||||
|
||||
# Security Headers Configuration (enabled by default with 'default' profile)
|
||||
securityHeaders:
|
||||
@@ -266,6 +269,8 @@ testDataWithRedis:
|
||||
# allowedRolesAndGroups: # Corresponds to 'Token Claim Name' in Keycloak mappers
|
||||
# - admin
|
||||
# - editor
|
||||
# # For internal Keycloak deployments with private IPs (Docker/Kubernetes internal):
|
||||
# # allowPrivateIPAddresses: true # Enable for private IP addresses like 192.168.x.x, 10.x.x.x
|
||||
# # Ensure Keycloak client mappers add 'email', 'roles', 'groups' etc. to the ID Token.
|
||||
# # See README.md "Provider Configuration Recommendations" for Keycloak.
|
||||
|
||||
@@ -287,6 +292,26 @@ testDataWithRedis:
|
||||
# - "AppRoleName"
|
||||
# # See README.md "Provider Configuration Recommendations" for Azure AD.
|
||||
|
||||
# --- Azure AD Users Without Email Example (Issue #95) ---
|
||||
# testDataAzureADNoEmail:
|
||||
# providerURL: https://login.microsoftonline.com/your-tenant-id/v2.0
|
||||
# clientID: your-azure-ad-client-id
|
||||
# clientSecret: your-azure-ad-client-secret
|
||||
# callbackURL: /oauth2/callback
|
||||
# sessionEncryptionKey: "a-very-secure-key-at-least-32-bytes-long-for-azure"
|
||||
# # Use 'sub' claim instead of 'email' for user identification
|
||||
# userIdentifierClaim: sub # or "oid", "upn", "preferred_username"
|
||||
# overrideScopes: true # Remove email scope if not needed
|
||||
# scopes:
|
||||
# - openid
|
||||
# - profile
|
||||
# - groups # For group-based access control
|
||||
# # When using non-email identifiers, allowedUsers matches against the claim value
|
||||
# allowedUsers:
|
||||
# - "abc12345-6789-0abc-def0-123456789abc" # Azure AD user object ID (sub or oid claim)
|
||||
# # NOTE: allowedUserDomains is ignored when userIdentifierClaim is not "email"
|
||||
# # See: https://github.com/lukaszraczylo/traefikoidc/issues/95
|
||||
|
||||
# --- Google Workspace / Google Cloud Identity Example ---
|
||||
# testDataGoogle:
|
||||
# providerURL: https://accounts.google.com # Standard Google OIDC endpoint
|
||||
@@ -605,6 +630,38 @@ configuration:
|
||||
items:
|
||||
type: string
|
||||
|
||||
userIdentifierClaim:
|
||||
type: string
|
||||
description: |
|
||||
Specifies the JWT claim to use as the user identifier for authentication and authorization.
|
||||
|
||||
This allows authentication for users without email addresses, such as Azure AD service
|
||||
accounts or organizational accounts that don't have email attributes configured.
|
||||
|
||||
When set to a non-email claim (e.g., "sub", "oid", "upn"):
|
||||
- AllowedUsers will match against this claim value instead of email
|
||||
- AllowedUserDomains validation is skipped (domains only apply to email addresses)
|
||||
- The session stores this identifier as the user's identity
|
||||
- If the configured claim is missing, falls back to "sub" (required by OIDC spec)
|
||||
|
||||
Common values by provider:
|
||||
- Default: "email" (standard email-based identification)
|
||||
- Azure AD: "sub", "oid" (object ID), "upn" (User Principal Name), "preferred_username"
|
||||
- Generic OIDC: "sub" (always present per OIDC specification)
|
||||
- Keycloak: "sub", "preferred_username"
|
||||
|
||||
Example for Azure AD users without email:
|
||||
```yaml
|
||||
userIdentifierClaim: sub
|
||||
allowedUsers:
|
||||
- "abc123-user-object-id"
|
||||
- "xyz789-another-user-id"
|
||||
```
|
||||
|
||||
Default: "email"
|
||||
See: https://github.com/lukaszraczylo/traefikoidc/issues/95
|
||||
required: false
|
||||
|
||||
revocationURL:
|
||||
type: string
|
||||
description: |
|
||||
@@ -903,6 +960,67 @@ configuration:
|
||||
Default: false (replay detection enabled)
|
||||
required: false
|
||||
|
||||
allowPrivateIPAddresses:
|
||||
type: boolean
|
||||
description: |
|
||||
Allow private IP addresses in OIDC provider URLs for internal network deployments.
|
||||
|
||||
By default, the plugin blocks URLs containing private IP address ranges
|
||||
(10.x.x.x, 172.16-31.x.x, 192.168.x.x) to prevent SSRF attacks and ensure
|
||||
OIDC providers are publicly accessible.
|
||||
|
||||
Enable this option when:
|
||||
- Your OIDC provider (e.g., Keycloak) runs on an internal network with private IPs
|
||||
- You don't have DNS resolution available for internal services
|
||||
- Your entire stack runs in a Docker network or Kubernetes cluster with private addressing
|
||||
|
||||
When enabled, the plugin will accept provider URLs like:
|
||||
- https://192.168.1.100:8443/auth/realms/your-realm
|
||||
- https://10.0.0.50:8080/realms/master
|
||||
- https://172.16.0.10/auth
|
||||
|
||||
Security Warning:
|
||||
Enabling this option reduces SSRF protection. Only use in trusted network
|
||||
environments where the OIDC provider is known and controlled. Loopback
|
||||
addresses (127.0.0.1, localhost, ::1) remain blocked even with this option enabled.
|
||||
|
||||
Default: false (private IPs are blocked for security)
|
||||
See: https://github.com/lukaszraczylo/traefikoidc/issues/97
|
||||
required: false
|
||||
|
||||
minimalHeaders:
|
||||
type: boolean
|
||||
description: |
|
||||
Reduce forwarded headers to prevent "431 Request Header Fields Too Large" errors.
|
||||
|
||||
When enabled, the middleware only forwards the X-Forwarded-User header and skips
|
||||
the larger authentication headers that can cause downstream services to reject
|
||||
requests due to header size limits (typically 8KB).
|
||||
|
||||
Headers when disabled (default):
|
||||
- X-Forwarded-User: User's email address (always set)
|
||||
- X-Auth-Request-Redirect: Original request URI
|
||||
- X-Auth-Request-User: User's email address
|
||||
- X-Auth-Request-Token: Full ID token (can be very large with many claims)
|
||||
- X-User-Groups: Comma-separated user groups (if configured)
|
||||
- X-User-Roles: Comma-separated user roles (if configured)
|
||||
|
||||
Headers when enabled:
|
||||
- X-Forwarded-User: User's email address (always set)
|
||||
- X-User-Groups: Comma-separated user groups (if configured, still forwarded)
|
||||
- X-User-Roles: Comma-separated user roles (if configured, still forwarded)
|
||||
- Custom templated headers (still processed)
|
||||
|
||||
Use this option when:
|
||||
- Downstream services return "431 Request Header Fields Too Large" errors
|
||||
- Your ID tokens are large (many claims, long group lists)
|
||||
- You don't need the full ID token forwarded to backend services
|
||||
- You want to reduce request overhead
|
||||
|
||||
Default: false (all headers forwarded for backward compatibility)
|
||||
See: https://github.com/lukaszraczylo/traefikoidc/issues/64
|
||||
required: false
|
||||
|
||||
headers:
|
||||
type: array
|
||||
description: |
|
||||
|
||||
-286
@@ -1,286 +0,0 @@
|
||||
# CI/CD Setup Guide
|
||||
|
||||
## 📋 Overview
|
||||
|
||||
This repository now has a comprehensive CI/CD pipeline that runs **20+ parallel checks** on every pull request to ensure code quality, security, and reliability.
|
||||
|
||||
## 🎯 What Was Added
|
||||
|
||||
### GitHub Actions Workflow
|
||||
- **`.github/workflows/pr-validation.yml`** - Main CI/CD pipeline (single file, all parallel)
|
||||
|
||||
### Configuration Files
|
||||
- **`.golangci.yml`** - Linter configuration with 30+ enabled checks
|
||||
- **`.github/dependabot.yml`** - Automated dependency updates
|
||||
- **`.github/CODEOWNERS`** - Automatic PR reviewer assignment
|
||||
- **`.github/PULL_REQUEST_TEMPLATE.md`** - Standardized PR descriptions
|
||||
- **`.github/workflows/README.md`** - Detailed workflow documentation
|
||||
- **`.github/workflows/.gitattributes`** - Consistent line endings
|
||||
|
||||
## ✅ What Gets Tested (All in Parallel)
|
||||
|
||||
### Code Quality (3 checks)
|
||||
- **Format & Basic Checks** - gofmt, go vet, go mod
|
||||
- **golangci-lint** - 30+ linters including style, complexity, bugs
|
||||
- **Staticcheck** - Advanced static analysis
|
||||
|
||||
### Security (3 checks)
|
||||
- **Gosec** - Security vulnerability scanning with SARIF reports
|
||||
- **Govulncheck** - Go vulnerability database scanning
|
||||
- **CodeQL** - GitHub's semantic code analysis
|
||||
|
||||
### Testing (9 test suites)
|
||||
- **Race Detector** - Concurrent access bugs
|
||||
- **Coverage** - 75% threshold with PR comments
|
||||
- **Memory Leaks** - Goroutine and memory leak detection
|
||||
- **Integration Tests** - Full integration suite
|
||||
- **Regression Tests** - Prevent old bugs from returning
|
||||
- **Security Edge Cases** - Security-specific scenarios
|
||||
- **Session Tests** - Session management
|
||||
- **Token Tests** - Token validation
|
||||
- **CSRF Tests** - CSRF protection
|
||||
|
||||
### Provider Testing (9 providers in parallel)
|
||||
- Google, Azure AD, Auth0, Okta, Keycloak, AWS Cognito, GitLab, GitHub, Generic
|
||||
|
||||
### Performance & Build (3 checks)
|
||||
- **Benchmarks** - Performance regression detection
|
||||
- **Multi-platform Build** - 4 combinations (linux/darwin × amd64/arm64)
|
||||
- **Go Version Compatibility** - Go 1.23 & 1.24
|
||||
|
||||
## 🚀 Quick Start
|
||||
|
||||
### 1. Push to GitHub
|
||||
```bash
|
||||
git add .github .golangci.yml CI_SETUP.md
|
||||
git commit -m "Add comprehensive CI/CD pipeline"
|
||||
git push origin main
|
||||
```
|
||||
|
||||
### 2. Create a Test PR
|
||||
```bash
|
||||
# Create a feature branch
|
||||
git checkout -b feature/test-ci
|
||||
echo "# Test" >> test.md
|
||||
git add test.md
|
||||
git commit -m "Test CI pipeline"
|
||||
git push origin feature/test-ci
|
||||
|
||||
# Create PR on GitHub
|
||||
# Watch all 20+ checks run in parallel! ⚡
|
||||
```
|
||||
|
||||
### 3. Monitor Results
|
||||
- Go to Actions tab: `https://github.com/{owner}/{repo}/actions`
|
||||
- Click on latest workflow run
|
||||
- See all parallel checks in action
|
||||
- Review coverage comment on PR
|
||||
|
||||
## 📊 Key Features
|
||||
|
||||
### ⚡ Maximum Speed
|
||||
- **Parallel execution** - All checks run simultaneously
|
||||
- **Smart caching** - Go modules and build cache
|
||||
- **Optimized order** - Quick checks first for fast feedback
|
||||
- **Expected runtime**: 5-10 minutes for full suite
|
||||
|
||||
### 🔒 Security First
|
||||
- **3 security scanners** - gosec, govulncheck, CodeQL
|
||||
- **SARIF integration** - Results in GitHub Security tab
|
||||
- **Dependency scanning** - Automated with Dependabot
|
||||
- **Security edge case tests**
|
||||
|
||||
### 📈 Coverage Tracking
|
||||
- **Automatic PR comments** with coverage stats
|
||||
- **Per-package breakdown** included
|
||||
- **75% threshold** enforced (configurable)
|
||||
- **Codecov integration** ready (optional)
|
||||
|
||||
### 🎨 Developer Experience
|
||||
- **Clear PR template** guides contributors
|
||||
- **Auto code owners** assignment
|
||||
- **Detailed error messages** for failures
|
||||
- **Benchmark tracking** for performance
|
||||
|
||||
## 🛠️ Local Development
|
||||
|
||||
### Install Required Tools
|
||||
```bash
|
||||
# golangci-lint (comprehensive linting)
|
||||
go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
|
||||
|
||||
# staticcheck (static analysis)
|
||||
go install honnef.co/go/tools/cmd/staticcheck@latest
|
||||
|
||||
# gosec (security scanning)
|
||||
go install github.com/securego/gosec/v2/cmd/gosec@latest
|
||||
|
||||
# govulncheck (vulnerability scanning)
|
||||
go install golang.org/x/vuln/cmd/govulncheck@latest
|
||||
```
|
||||
|
||||
### Run Checks Locally
|
||||
```bash
|
||||
# Quick validation (before committing)
|
||||
gofmt -s -w . # Format code
|
||||
go vet ./... # Basic checks
|
||||
go mod tidy # Clean dependencies
|
||||
|
||||
# Linting
|
||||
golangci-lint run # Full lint suite
|
||||
staticcheck ./... # Static analysis
|
||||
|
||||
# Testing
|
||||
go test -race -timeout=15m ./... # Tests with race detector
|
||||
go test -coverprofile=coverage.out ./... # Coverage
|
||||
go tool cover -func=coverage.out # View coverage
|
||||
|
||||
# Security
|
||||
gosec ./... # Security scan
|
||||
govulncheck ./... # Vulnerability check
|
||||
|
||||
# Benchmarks
|
||||
go test -bench=. -benchmem ./... # Performance tests
|
||||
```
|
||||
|
||||
### Pre-commit Checklist
|
||||
```bash
|
||||
# Run this before every commit
|
||||
gofmt -s -w . && \
|
||||
go mod tidy && \
|
||||
golangci-lint run && \
|
||||
go test -race -short ./... && \
|
||||
echo "✅ Ready to commit!"
|
||||
```
|
||||
|
||||
## 📝 Configuration
|
||||
|
||||
### Adjust Coverage Threshold
|
||||
Edit `.github/workflows/pr-validation.yml`:
|
||||
```yaml
|
||||
THRESHOLD=75 # Change to desired percentage
|
||||
```
|
||||
|
||||
### Modify Linter Rules
|
||||
Edit `.golangci.yml`:
|
||||
```yaml
|
||||
linters:
|
||||
enable:
|
||||
- newlinter # Add new linters here
|
||||
```
|
||||
|
||||
### Update Go Version
|
||||
Edit `.github/workflows/pr-validation.yml`:
|
||||
```yaml
|
||||
go-version: '1.24' # Update version
|
||||
```
|
||||
|
||||
## 🐛 Troubleshooting
|
||||
|
||||
### Coverage Below Threshold
|
||||
```bash
|
||||
# See uncovered lines in browser
|
||||
go test -coverprofile=coverage.out ./...
|
||||
go tool cover -html=coverage.out
|
||||
```
|
||||
|
||||
### Race Condition Found
|
||||
```bash
|
||||
# Run specific test with race detector
|
||||
go test -race -v -run=TestName ./...
|
||||
```
|
||||
|
||||
### Linter Errors
|
||||
```bash
|
||||
# See detailed lint errors
|
||||
golangci-lint run -v
|
||||
|
||||
# Auto-fix some issues
|
||||
golangci-lint run --fix
|
||||
```
|
||||
|
||||
### Provider Test Fails
|
||||
```bash
|
||||
# Test specific provider
|
||||
go test -v -run='.*Azure.*' ./internal/providers/
|
||||
```
|
||||
|
||||
## 📈 Metrics & Monitoring
|
||||
|
||||
### GitHub Actions Dashboard
|
||||
- View all runs: `Actions` tab
|
||||
- Filter by workflow, branch, status
|
||||
- Download logs and artifacts
|
||||
|
||||
### Status Badge
|
||||
Add to README.md:
|
||||
```markdown
|
||||
[](https://github.com/lukaszraczylo/traefikoidc/actions/workflows/pr-validation.yml)
|
||||
```
|
||||
|
||||
### Notifications
|
||||
- Configure in: Settings → Notifications
|
||||
- Email alerts for workflow failures
|
||||
- Slack/Discord webhooks supported
|
||||
|
||||
## 🔄 Continuous Improvement
|
||||
|
||||
### Dependabot Updates
|
||||
- Automatic weekly dependency checks (Mondays 9 AM)
|
||||
- Security updates prioritized
|
||||
- Groups patch updates together
|
||||
|
||||
### Code Owners
|
||||
- Auto-assigns reviewers based on file paths
|
||||
- Ensures expertise reviews changes
|
||||
- Speeds up PR review process
|
||||
|
||||
## 📚 Additional Resources
|
||||
|
||||
- [Workflow Documentation](.github/workflows/README.md)
|
||||
- [golangci-lint Rules](.golangci.yml)
|
||||
- [PR Template](.github/PULL_REQUEST_TEMPLATE.md)
|
||||
- [Dependabot Config](.github/dependabot.yml)
|
||||
|
||||
## 🎉 Benefits
|
||||
|
||||
### For Contributors
|
||||
- Clear expectations via PR template
|
||||
- Fast feedback (5-10 min)
|
||||
- Comprehensive local tooling
|
||||
- Detailed error messages
|
||||
|
||||
### For Maintainers
|
||||
- Automated code review
|
||||
- Security scanning
|
||||
- Performance tracking
|
||||
- Quality gates enforcement
|
||||
|
||||
### For Users
|
||||
- Higher code quality
|
||||
- Fewer bugs in production
|
||||
- Better security
|
||||
- Consistent performance
|
||||
|
||||
## 🚦 Success Criteria
|
||||
|
||||
All PRs must pass:
|
||||
- ✅ All 20+ parallel checks
|
||||
- ✅ 75% test coverage minimum
|
||||
- ✅ Zero security vulnerabilities
|
||||
- ✅ No race conditions
|
||||
- ✅ No memory leaks
|
||||
- ✅ All providers tested
|
||||
- ✅ Builds on all platforms
|
||||
|
||||
## 💡 Tips
|
||||
|
||||
1. **Run checks locally** before pushing to save CI time
|
||||
2. **Watch for PR comments** - coverage stats posted automatically
|
||||
3. **Check Security tab** for gosec/CodeQL findings
|
||||
4. **Review benchmark results** in artifacts
|
||||
5. **Use draft PRs** for work-in-progress to skip some checks
|
||||
|
||||
---
|
||||
|
||||
**Ready to go!** 🚀 Push your changes and create a PR to see it in action.
|
||||
@@ -124,6 +124,7 @@ The middleware supports the following configuration options:
|
||||
| `allowedRolesAndGroups` | Restricts access to users with specific roles or groups | none | `["admin", "developer"]` |
|
||||
| `roleClaimName` | JWT claim name for extracting user roles (supports namespaced claims for Auth0) | `"roles"` | `"https://myapp.com/roles"`, `"user_roles"` |
|
||||
| `groupClaimName` | JWT claim name for extracting user groups (supports namespaced claims for Auth0) | `"groups"` | `"https://myapp.com/groups"`, `"user_groups"` |
|
||||
| `userIdentifierClaim` | JWT claim to use as user identifier (for users without email, e.g., Azure AD service accounts) | `"email"` | `"sub"`, `"oid"`, `"upn"`, `"preferred_username"` |
|
||||
| `revocationURL` | The endpoint for revoking tokens | auto-discovered | `https://accounts.google.com/revoke` |
|
||||
| `oidcEndSessionURL` | The provider's end session endpoint | auto-discovered | `https://accounts.google.com/logout` |
|
||||
| `enablePKCE` | Enables PKCE (Proof Key for Code Exchange) for authorization code flow | `false` | `true`, `false` |
|
||||
@@ -138,6 +139,8 @@ The middleware supports the following configuration options:
|
||||
| `headers` | Custom HTTP headers with templates that can access OIDC claims and tokens | none | See "Templated Headers" section |
|
||||
| `securityHeaders` | Configure security headers including CSP, HSTS, CORS, and custom headers | enabled with default profile | See "Security Headers Configuration" section |
|
||||
| `disableReplayDetection` | Disable JTI-based replay attack detection for multi-replica deployments | `false` | `true` |
|
||||
| `allowPrivateIPAddresses` | Allow private IP addresses in provider URLs (for internal networks with Keycloak, etc.) | `false` | `true` |
|
||||
| `minimalHeaders` | Reduce forwarded headers to prevent "431 Request Header Fields Too Large" errors | `false` | `true` |
|
||||
| `redis` | Redis cache configuration for distributed deployments | disabled | See "Redis Cache" section |
|
||||
|
||||
> **⚠️ IMPORTANT - TLS Termination at Load Balancer:**
|
||||
@@ -1241,6 +1244,45 @@ spec:
|
||||
- "AppRoleName" # Application role names
|
||||
```
|
||||
|
||||
### Azure AD Configuration (Users Without Email)
|
||||
|
||||
For Azure AD users without email addresses (service accounts, organizational accounts without mail attributes):
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-azure-no-email
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://login.microsoftonline.com/your-tenant-id/v2.0
|
||||
clientID: your-azure-ad-client-id
|
||||
clientSecret: your-azure-ad-client-secret
|
||||
sessionEncryptionKey: your-secure-encryption-key-min-32-chars
|
||||
callbackURL: /oauth2/callback
|
||||
logoutURL: /oauth2/logout
|
||||
|
||||
# Use 'sub' instead of 'email' for user identification
|
||||
userIdentifierClaim: sub # Can also use: "oid", "upn", "preferred_username"
|
||||
|
||||
overrideScopes: true # Optional: Don't request email scope if not needed
|
||||
scopes:
|
||||
- openid
|
||||
- profile
|
||||
- groups
|
||||
|
||||
# When using non-email identifiers, allowedUsers matches against the claim value
|
||||
allowedUsers:
|
||||
- "abc12345-6789-0abc-def0-123456789abc" # Azure AD user object ID
|
||||
- "def67890-1234-5678-90ab-cdef12345678"
|
||||
|
||||
# NOTE: allowedUserDomains is ignored when userIdentifierClaim is not "email"
|
||||
```
|
||||
|
||||
> **Note**: When `userIdentifierClaim` is set to a non-email claim (like `sub`, `oid`, or `upn`), the `allowedUserDomains` configuration is ignored since domain-based validation only applies to email addresses. Use `allowedUsers` with the actual claim values instead.
|
||||
|
||||
### Auth0 Configuration
|
||||
|
||||
```yaml
|
||||
@@ -1327,8 +1369,12 @@ spec:
|
||||
- admin
|
||||
- editor
|
||||
# Ensure Keycloak client mappers add necessary claims to ID Token
|
||||
# For internal Keycloak deployments with private IPs (e.g., Docker network):
|
||||
# allowPrivateIPAddresses: true
|
||||
```
|
||||
|
||||
> **Internal Network Deployment**: If your Keycloak runs on an internal network with private IP addresses (e.g., `192.168.x.x`, `10.x.x.x`, `172.16-31.x.x`) and you don't have DNS resolution available, set `allowPrivateIPAddresses: true` to allow the plugin to connect to your Keycloak instance. See [Issue #97](https://github.com/lukaszraczylo/traefikoidc/issues/97) for details.
|
||||
|
||||
### AWS Cognito Configuration
|
||||
|
||||
```yaml
|
||||
@@ -1629,12 +1675,39 @@ headers:
|
||||
|
||||
When a user is authenticated, the middleware sets the following headers for downstream services:
|
||||
|
||||
- `X-Forwarded-User`: The user's email address
|
||||
- `X-Forwarded-User`: The user's email address (always set)
|
||||
- `X-User-Groups`: Comma-separated list of user groups (if available)
|
||||
- `X-User-Roles`: Comma-separated list of user roles (if available)
|
||||
- `X-Auth-Request-Redirect`: The original request URI
|
||||
- `X-Auth-Request-User`: The user's email address
|
||||
- `X-Auth-Request-Token`: The user's access token
|
||||
- `X-Auth-Request-Token`: The user's ID token (can be large)
|
||||
|
||||
#### Minimal Headers Mode
|
||||
|
||||
If your downstream services return **"431 Request Header Fields Too Large"** errors, you can enable minimal headers mode to reduce header overhead:
|
||||
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
my-auth:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
minimalHeaders: true
|
||||
# ... other config
|
||||
```
|
||||
|
||||
When `minimalHeaders: true` is set:
|
||||
- **Only forwards**: `X-Forwarded-User`
|
||||
- **Skips**: `X-Auth-Request-Token` (the full ID token - often the largest header), `X-Auth-Request-User`, `X-Auth-Request-Redirect`
|
||||
- **Still forwards**: `X-User-Groups` and `X-User-Roles` (if configured)
|
||||
- **Still processes**: Custom templated headers
|
||||
|
||||
This is particularly useful when:
|
||||
- Your ID tokens are large (many claims, long group lists)
|
||||
- Downstream services have limited header buffer sizes (default 8KB in many servers)
|
||||
- You don't need the full token forwarded to backend services
|
||||
|
||||
See [GitHub Issue #64](https://github.com/lukaszraczylo/traefikoidc/issues/64) for details.
|
||||
|
||||
### Security Headers
|
||||
|
||||
@@ -1862,6 +1935,15 @@ logLevel: debug
|
||||
- No refresh tokens (re-authentication required on expiry)
|
||||
- Use only for GitHub API access, not user authentication
|
||||
|
||||
15. **Environment variable names containing "API" cause plugin failure** ([Issue #98](https://github.com/lukaszraczylo/traefikoidc/issues/98)):
|
||||
- When using environment variable syntax like `${OIDC_ENCRYPTION_SECRET_API}` in Traefik configuration, the plugin fails with "invalid handler type: \<nil\>" error
|
||||
- This is a **Traefik-side issue**, not a plugin bug. Traefik uses reserved environment variables starting with `TRAEFIK_API_*` for its internal API configuration, and the "API" substring in user-defined variable names may interfere with Traefik's environment variable processing
|
||||
- **Workaround**: Avoid using "API" as a substring in environment variable names. Use alternatives like:
|
||||
- `${OIDC_ENCRYPTION_SECRET_SVC}` instead of `${OIDC_ENCRYPTION_SECRET_API}`
|
||||
- `${OIDC_ENCRYPTION_SECRET_SERVICE}`
|
||||
- `${OIDC_ENCRYPTION_SECRET_BACKEND}`
|
||||
- Any name that doesn't contain the literal substring "API"
|
||||
|
||||
### Provider Warnings and Recommendations
|
||||
|
||||
The middleware includes built-in warnings for provider-specific limitations. Check your logs for important notices about:
|
||||
|
||||
+1375
File diff suppressed because it is too large
Load Diff
@@ -1,931 +0,0 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// TestConfigAudienceValidation tests the Config.Validate() method for the audience field
|
||||
func TestConfigAudienceValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
audience string
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "Empty audience is valid for backward compatibility",
|
||||
audience: "",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Valid HTTPS URL audience Auth0 format",
|
||||
audience: "https://api.example.com",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Valid identifier audience",
|
||||
audience: "my-api",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Valid Azure AD Application ID URI format",
|
||||
audience: "api://12345-guid-67890",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Valid Auth0 API identifier",
|
||||
audience: "https://my-company.auth0.com/api/v2/",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "HTTP URL audience should fail",
|
||||
audience: "http://api.example.com",
|
||||
wantErr: true,
|
||||
errContains: "must use HTTPS",
|
||||
},
|
||||
{
|
||||
name: "Audience with wildcard should fail",
|
||||
audience: "https://api.*.example.com",
|
||||
wantErr: true,
|
||||
errContains: "must not contain wildcards",
|
||||
},
|
||||
{
|
||||
name: "Audience with single asterisk should fail",
|
||||
audience: "*",
|
||||
wantErr: true,
|
||||
errContains: "must not contain wildcards",
|
||||
},
|
||||
{
|
||||
name: "Audience over 256 characters should fail",
|
||||
audience: strings.Repeat("a", 257),
|
||||
wantErr: true,
|
||||
errContains: "must not exceed 256 characters",
|
||||
},
|
||||
{
|
||||
name: "Audience with newline should fail",
|
||||
audience: "my-api\ninjection",
|
||||
wantErr: true,
|
||||
errContains: "contains invalid characters",
|
||||
},
|
||||
{
|
||||
name: "Audience with carriage return should fail",
|
||||
audience: "my-api\rinjection",
|
||||
wantErr: true,
|
||||
errContains: "contains invalid characters",
|
||||
},
|
||||
{
|
||||
name: "Audience with tab should fail",
|
||||
audience: "my-api\tinjection",
|
||||
wantErr: true,
|
||||
errContains: "contains invalid characters",
|
||||
},
|
||||
{
|
||||
name: "Valid audience exactly 256 characters",
|
||||
audience: strings.Repeat("a", 256),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Valid simple identifier",
|
||||
audience: "my-service-api",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Valid URN format",
|
||||
audience: "urn:myservice:api:v1",
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config := CreateConfig()
|
||||
config.ProviderURL = "https://provider.example.com"
|
||||
config.ClientID = "test-client-id"
|
||||
config.ClientSecret = "test-client-secret"
|
||||
config.CallbackURL = "/callback"
|
||||
config.SessionEncryptionKey = strings.Repeat("a", MinSessionEncryptionKeyLength)
|
||||
config.Audience = tt.audience
|
||||
|
||||
err := config.Validate()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if err != nil && tt.errContains != "" && !strings.Contains(err.Error(), tt.errContains) {
|
||||
t.Errorf("Error message should contain %q, got: %v", tt.errContains, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestJWTAudienceVerification tests JWT verification with custom audience values
|
||||
func TestJWTAudienceVerification(t *testing.T) {
|
||||
// Create cleanup helper
|
||||
tc := newTestCleanup(t)
|
||||
|
||||
// Generate RSA key for signing JWTs
|
||||
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate RSA key: %v", err)
|
||||
}
|
||||
rsaPublicKey := &rsaPrivateKey.PublicKey
|
||||
|
||||
// Create JWK
|
||||
jwk := JWK{
|
||||
Kty: "RSA",
|
||||
Kid: "test-key-id",
|
||||
Alg: "RS256",
|
||||
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}),
|
||||
}
|
||||
jwks := &JWKSet{
|
||||
Keys: []JWK{jwk},
|
||||
}
|
||||
|
||||
mockJWKCache := &MockJWKCache{
|
||||
JWKS: jwks,
|
||||
Err: nil,
|
||||
}
|
||||
|
||||
logger := NewLogger("debug")
|
||||
tokenBlacklist := tc.addCache(NewCache())
|
||||
tokenCache := tc.addTokenCache(NewTokenCache())
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
configAudience string
|
||||
tokenAudience interface{}
|
||||
wantErr bool
|
||||
errContains string
|
||||
skipReplayCheck bool
|
||||
}{
|
||||
{
|
||||
name: "JWT with string aud matching configured audience",
|
||||
configAudience: "https://api.example.com",
|
||||
tokenAudience: "https://api.example.com",
|
||||
wantErr: false,
|
||||
skipReplayCheck: true,
|
||||
},
|
||||
{
|
||||
name: "JWT with array aud containing configured audience",
|
||||
configAudience: "https://api.example.com",
|
||||
tokenAudience: []interface{}{"https://other.com", "https://api.example.com", "https://another.com"},
|
||||
wantErr: false,
|
||||
skipReplayCheck: true,
|
||||
},
|
||||
{
|
||||
name: "JWT with string aud NOT matching configured audience",
|
||||
configAudience: "https://api.example.com",
|
||||
tokenAudience: "https://wrong-api.example.com",
|
||||
wantErr: true,
|
||||
errContains: "invalid audience",
|
||||
skipReplayCheck: true,
|
||||
},
|
||||
{
|
||||
name: "JWT with array aud NOT containing configured audience",
|
||||
configAudience: "https://api.example.com",
|
||||
tokenAudience: []interface{}{"https://other.com", "https://another.com"},
|
||||
wantErr: true,
|
||||
errContains: "invalid audience",
|
||||
skipReplayCheck: true,
|
||||
},
|
||||
{
|
||||
name: "JWT with clientID as aud when no custom audience configured",
|
||||
configAudience: "",
|
||||
tokenAudience: "test-client-id",
|
||||
wantErr: false,
|
||||
skipReplayCheck: true,
|
||||
},
|
||||
{
|
||||
name: "JWT with empty string aud",
|
||||
configAudience: "https://api.example.com",
|
||||
tokenAudience: "",
|
||||
wantErr: true,
|
||||
errContains: "invalid audience",
|
||||
skipReplayCheck: true,
|
||||
},
|
||||
{
|
||||
name: "Azure AD Application ID URI format",
|
||||
configAudience: "api://12345-app-id",
|
||||
tokenAudience: "api://12345-app-id",
|
||||
wantErr: false,
|
||||
skipReplayCheck: true,
|
||||
},
|
||||
{
|
||||
name: "Auth0 custom API audience",
|
||||
configAudience: "https://mycompany.com/api",
|
||||
tokenAudience: "https://mycompany.com/api",
|
||||
wantErr: false,
|
||||
skipReplayCheck: true,
|
||||
},
|
||||
{
|
||||
name: "Token confusion attack - audience for different service",
|
||||
configAudience: "https://service-a.example.com",
|
||||
tokenAudience: "https://service-b.example.com",
|
||||
wantErr: true,
|
||||
errContains: "invalid audience",
|
||||
skipReplayCheck: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create TraefikOidc instance
|
||||
tOidc := &TraefikOidc{
|
||||
issuerURL: "https://test-issuer.com",
|
||||
clientID: "test-client-id",
|
||||
clientSecret: "test-client-secret",
|
||||
jwkCache: mockJWKCache,
|
||||
jwksURL: "https://test-jwks-url.com",
|
||||
tokenBlacklist: tokenBlacklist,
|
||||
tokenCache: tokenCache,
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
||||
logger: logger,
|
||||
httpClient: &http.Client{},
|
||||
}
|
||||
|
||||
// Set up the token verifier and JWT verifier
|
||||
tOidc.jwtVerifier = tOidc
|
||||
tOidc.tokenVerifier = tOidc
|
||||
|
||||
// Determine the expected audience for validation
|
||||
expectedAudience := tt.configAudience
|
||||
if expectedAudience == "" {
|
||||
expectedAudience = tOidc.clientID
|
||||
}
|
||||
|
||||
// Set the audience field on the tOidc instance
|
||||
tOidc.audience = expectedAudience
|
||||
|
||||
// Create JWT with specified audience
|
||||
jti := generateRandomString(16)
|
||||
if tt.skipReplayCheck {
|
||||
// Use a unique JTI for each test to avoid replay detection
|
||||
jti = fmt.Sprintf("test-%s-%s", tt.name, jti)
|
||||
}
|
||||
|
||||
jwt, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": tt.tokenAudience,
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
||||
"sub": "test-subject",
|
||||
"email": "user@example.com",
|
||||
"jti": jti,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test JWT: %v", err)
|
||||
}
|
||||
|
||||
// Verify the token
|
||||
err = tOidc.VerifyToken(jwt)
|
||||
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("VerifyToken() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if err != nil && tt.errContains != "" && !strings.Contains(err.Error(), tt.errContains) {
|
||||
t.Errorf("Error message should contain %q, got: %v", tt.errContains, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestJWTAudienceBackwardCompatibility tests that existing behavior is preserved
|
||||
// when the Audience field is not set
|
||||
func TestJWTAudienceBackwardCompatibility(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
// Test with no custom audience configured - should use clientID
|
||||
jwt, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id", // Should match clientID
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
||||
"sub": "test-subject",
|
||||
"email": "user@example.com",
|
||||
"jti": generateRandomString(16),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test JWT: %v", err)
|
||||
}
|
||||
|
||||
err = ts.tOidc.VerifyToken(jwt)
|
||||
if err != nil {
|
||||
t.Errorf("Backward compatibility broken: VerifyToken() error = %v, expected nil", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAudienceIntegrationAuth0Scenario tests Auth0-specific use case
|
||||
func TestAudienceIntegrationAuth0Scenario(t *testing.T) {
|
||||
// Create cleanup helper
|
||||
tc := newTestCleanup(t)
|
||||
|
||||
// Simulate Auth0 scenario: custom audience for API access
|
||||
config := CreateConfig()
|
||||
config.ProviderURL = "https://mycompany.auth0.com"
|
||||
config.ClientID = "auth0-client-id"
|
||||
config.ClientSecret = "auth0-client-secret"
|
||||
config.CallbackURL = "/callback"
|
||||
config.SessionEncryptionKey = strings.Repeat("a", MinSessionEncryptionKeyLength)
|
||||
config.Audience = "https://api.mycompany.com" // Custom API audience
|
||||
|
||||
// Validate config
|
||||
if err := config.Validate(); err != nil {
|
||||
t.Fatalf("Auth0 config validation failed: %v", err)
|
||||
}
|
||||
|
||||
// Generate test keys
|
||||
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate RSA key: %v", err)
|
||||
}
|
||||
rsaPublicKey := &rsaPrivateKey.PublicKey
|
||||
|
||||
jwk := JWK{
|
||||
Kty: "RSA",
|
||||
Kid: "auth0-key-id",
|
||||
Alg: "RS256",
|
||||
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}),
|
||||
}
|
||||
jwks := &JWKSet{
|
||||
Keys: []JWK{jwk},
|
||||
}
|
||||
|
||||
mockJWKCache := &MockJWKCache{
|
||||
JWKS: jwks,
|
||||
Err: nil,
|
||||
}
|
||||
|
||||
logger := NewLogger("debug")
|
||||
tokenBlacklist := tc.addCache(NewCache())
|
||||
tokenCache := tc.addTokenCache(NewTokenCache())
|
||||
|
||||
tOidc := &TraefikOidc{
|
||||
issuerURL: config.ProviderURL,
|
||||
clientID: config.ClientID,
|
||||
clientSecret: config.ClientSecret,
|
||||
audience: config.Audience, // Set audience from config
|
||||
jwkCache: mockJWKCache,
|
||||
jwksURL: "https://mycompany.auth0.com/.well-known/jwks.json",
|
||||
tokenBlacklist: tokenBlacklist,
|
||||
tokenCache: tokenCache,
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
||||
logger: logger,
|
||||
httpClient: &http.Client{},
|
||||
}
|
||||
|
||||
// Default audience to clientID if not specified
|
||||
if tOidc.audience == "" {
|
||||
tOidc.audience = tOidc.clientID
|
||||
}
|
||||
|
||||
tOidc.jwtVerifier = tOidc
|
||||
tOidc.tokenVerifier = tOidc
|
||||
|
||||
t.Run("Valid Auth0 API access token with custom audience", func(t *testing.T) {
|
||||
jwt, err := createTestJWT(rsaPrivateKey, "RS256", "auth0-key-id", map[string]interface{}{
|
||||
"iss": config.ProviderURL,
|
||||
"aud": config.Audience, // Matches configured audience
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
||||
"sub": "auth0|123456",
|
||||
"email": "user@example.com",
|
||||
"jti": generateRandomString(16),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create Auth0 JWT: %v", err)
|
||||
}
|
||||
|
||||
err = tOidc.VerifyToken(jwt)
|
||||
if err != nil {
|
||||
t.Errorf("Auth0 token verification failed: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Auth0 ACCESS token with clientID instead of API audience should fail", func(t *testing.T) {
|
||||
jwt, err := createTestJWT(rsaPrivateKey, "RS256", "auth0-key-id", map[string]interface{}{
|
||||
"iss": config.ProviderURL,
|
||||
"aud": config.ClientID, // Using clientID instead of API audience
|
||||
"scope": "openid profile email", // Mark as access token
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
||||
"sub": "auth0|123456",
|
||||
"email": "user@example.com",
|
||||
"jti": generateRandomString(16),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create Auth0 JWT: %v", err)
|
||||
}
|
||||
|
||||
err = tOidc.VerifyToken(jwt)
|
||||
if err == nil {
|
||||
t.Error("Auth0 access token with wrong audience should have been rejected")
|
||||
} else if !strings.Contains(err.Error(), "invalid audience") {
|
||||
t.Errorf("Expected 'invalid audience' error, got: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestAudienceIntegrationAzureADScenario tests Azure AD-specific use case
|
||||
func TestAudienceIntegrationAzureADScenario(t *testing.T) {
|
||||
// Create cleanup helper
|
||||
tc := newTestCleanup(t)
|
||||
|
||||
// Simulate Azure AD scenario: Application ID URI format
|
||||
config := CreateConfig()
|
||||
config.ProviderURL = "https://login.microsoftonline.com/tenant-id/v2.0"
|
||||
config.ClientID = "azure-client-id"
|
||||
config.ClientSecret = "azure-client-secret"
|
||||
config.CallbackURL = "/callback"
|
||||
config.SessionEncryptionKey = strings.Repeat("a", MinSessionEncryptionKeyLength)
|
||||
config.Audience = "api://12345-abcd-6789-efgh" // Azure AD Application ID URI
|
||||
|
||||
// Validate config
|
||||
if err := config.Validate(); err != nil {
|
||||
t.Fatalf("Azure AD config validation failed: %v", err)
|
||||
}
|
||||
|
||||
// Generate test keys
|
||||
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate RSA key: %v", err)
|
||||
}
|
||||
rsaPublicKey := &rsaPrivateKey.PublicKey
|
||||
|
||||
jwk := JWK{
|
||||
Kty: "RSA",
|
||||
Kid: "azure-key-id",
|
||||
Alg: "RS256",
|
||||
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}),
|
||||
}
|
||||
jwks := &JWKSet{
|
||||
Keys: []JWK{jwk},
|
||||
}
|
||||
|
||||
mockJWKCache := &MockJWKCache{
|
||||
JWKS: jwks,
|
||||
Err: nil,
|
||||
}
|
||||
|
||||
logger := NewLogger("debug")
|
||||
tokenBlacklist := tc.addCache(NewCache())
|
||||
tokenCache := tc.addTokenCache(NewTokenCache())
|
||||
|
||||
tOidc := &TraefikOidc{
|
||||
issuerURL: config.ProviderURL,
|
||||
clientID: config.ClientID,
|
||||
clientSecret: config.ClientSecret,
|
||||
audience: config.Audience, // Set audience from config
|
||||
jwkCache: mockJWKCache,
|
||||
jwksURL: config.ProviderURL + "/.well-known/jwks.json",
|
||||
tokenBlacklist: tokenBlacklist,
|
||||
tokenCache: tokenCache,
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
||||
logger: logger,
|
||||
httpClient: &http.Client{},
|
||||
}
|
||||
|
||||
// Default audience to clientID if not specified
|
||||
if tOidc.audience == "" {
|
||||
tOidc.audience = tOidc.clientID
|
||||
}
|
||||
|
||||
tOidc.jwtVerifier = tOidc
|
||||
tOidc.tokenVerifier = tOidc
|
||||
|
||||
t.Run("Valid Azure AD token with Application ID URI audience", func(t *testing.T) {
|
||||
jwt, err := createTestJWT(rsaPrivateKey, "RS256", "azure-key-id", map[string]interface{}{
|
||||
"iss": config.ProviderURL,
|
||||
"aud": config.Audience, // Matches Application ID URI
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
||||
"sub": "azure-user-id",
|
||||
"email": "user@example.com",
|
||||
"oid": "object-id-12345",
|
||||
"tid": "tenant-id",
|
||||
"jti": generateRandomString(16),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create Azure AD JWT: %v", err)
|
||||
}
|
||||
|
||||
err = tOidc.VerifyToken(jwt)
|
||||
if err != nil {
|
||||
t.Errorf("Azure AD token verification failed: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Azure AD token with multiple audiences including correct one", func(t *testing.T) {
|
||||
jwt, err := createTestJWT(rsaPrivateKey, "RS256", "azure-key-id", map[string]interface{}{
|
||||
"iss": config.ProviderURL,
|
||||
"aud": []interface{}{config.ClientID, config.Audience, "https://graph.microsoft.com"},
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
||||
"sub": "azure-user-id",
|
||||
"email": "user@example.com",
|
||||
"oid": "object-id-12345",
|
||||
"tid": "tenant-id",
|
||||
"jti": generateRandomString(16),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create Azure AD JWT: %v", err)
|
||||
}
|
||||
|
||||
err = tOidc.VerifyToken(jwt)
|
||||
if err != nil {
|
||||
t.Errorf("Azure AD token with multiple audiences verification failed: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestAudienceSecurityTokenConfusionAttack tests security against token confusion attacks
|
||||
func TestAudienceSecurityTokenConfusionAttack(t *testing.T) {
|
||||
// Create cleanup helper
|
||||
tc := newTestCleanup(t)
|
||||
|
||||
// Generate test keys
|
||||
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate RSA key: %v", err)
|
||||
}
|
||||
rsaPublicKey := &rsaPrivateKey.PublicKey
|
||||
|
||||
jwk := JWK{
|
||||
Kty: "RSA",
|
||||
Kid: "test-key-id",
|
||||
Alg: "RS256",
|
||||
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}),
|
||||
}
|
||||
jwks := &JWKSet{
|
||||
Keys: []JWK{jwk},
|
||||
}
|
||||
|
||||
mockJWKCache := &MockJWKCache{
|
||||
JWKS: jwks,
|
||||
Err: nil,
|
||||
}
|
||||
|
||||
logger := NewLogger("debug")
|
||||
tokenBlacklist := tc.addCache(NewCache())
|
||||
tokenCache := tc.addTokenCache(NewTokenCache())
|
||||
|
||||
// Service A configuration
|
||||
serviceA := &TraefikOidc{
|
||||
issuerURL: "https://auth.example.com",
|
||||
clientID: "service-a-client-id",
|
||||
clientSecret: "service-a-secret",
|
||||
audience: "service-a-client-id", // Service A uses its clientID as audience
|
||||
jwkCache: mockJWKCache,
|
||||
jwksURL: "https://auth.example.com/.well-known/jwks.json",
|
||||
tokenBlacklist: tokenBlacklist,
|
||||
tokenCache: tokenCache,
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
||||
logger: logger,
|
||||
httpClient: &http.Client{},
|
||||
}
|
||||
serviceA.jwtVerifier = serviceA
|
||||
serviceA.tokenVerifier = serviceA
|
||||
|
||||
t.Run("Token confusion - Try to use service B token on service A", func(t *testing.T) {
|
||||
// Create a token intended for service B
|
||||
serviceBToken, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://auth.example.com",
|
||||
"aud": "https://service-b.example.com", // For service B
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
||||
"sub": "attacker@example.com",
|
||||
"email": "attacker@example.com",
|
||||
"jti": generateRandomString(16),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create service B token: %v", err)
|
||||
}
|
||||
|
||||
// Try to verify the service B token on service A
|
||||
err = serviceA.VerifyToken(serviceBToken)
|
||||
switch {
|
||||
case err == nil:
|
||||
t.Error("SECURITY VULNERABILITY: Token confusion attack succeeded - service B token was accepted by service A")
|
||||
case !strings.Contains(err.Error(), "invalid audience"):
|
||||
t.Errorf("Expected 'invalid audience' error for token confusion, got: %v", err)
|
||||
default:
|
||||
t.Logf("Token confusion attack correctly prevented: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestAudienceSecurityWildcardInjection tests that wildcards are rejected
|
||||
func TestAudienceSecurityWildcardInjection(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
audience string
|
||||
}{
|
||||
{
|
||||
name: "Single asterisk",
|
||||
audience: "*",
|
||||
},
|
||||
{
|
||||
name: "Wildcard in URL",
|
||||
audience: "https://*.example.com",
|
||||
},
|
||||
{
|
||||
name: "Wildcard in path",
|
||||
audience: "https://api.example.com/*",
|
||||
},
|
||||
{
|
||||
name: "Multiple wildcards",
|
||||
audience: "https://*.*.example.com",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config := CreateConfig()
|
||||
config.ProviderURL = "https://provider.example.com"
|
||||
config.ClientID = "test-client-id"
|
||||
config.ClientSecret = "test-client-secret"
|
||||
config.CallbackURL = "/callback"
|
||||
config.SessionEncryptionKey = strings.Repeat("a", MinSessionEncryptionKeyLength)
|
||||
config.Audience = tt.audience
|
||||
|
||||
err := config.Validate()
|
||||
if err == nil {
|
||||
t.Errorf("SECURITY VULNERABILITY: Wildcard audience %q was not rejected", tt.audience)
|
||||
} else if !strings.Contains(err.Error(), "must not contain wildcards") {
|
||||
t.Errorf("Expected 'must not contain wildcards' error, got: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAudienceSecurityInjectionAttempts tests various injection attempts
|
||||
func TestAudienceSecurityInjectionAttempts(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
audience string
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "Newline injection",
|
||||
audience: "api.example.com\nmalicious.com",
|
||||
errContains: "contains invalid characters",
|
||||
},
|
||||
{
|
||||
name: "Carriage return injection",
|
||||
audience: "api.example.com\rmalicious.com",
|
||||
errContains: "contains invalid characters",
|
||||
},
|
||||
{
|
||||
name: "Tab injection",
|
||||
audience: "api.example.com\tmalicious.com",
|
||||
errContains: "contains invalid characters",
|
||||
},
|
||||
{
|
||||
name: "Null byte injection",
|
||||
audience: "api.example.com\x00malicious.com",
|
||||
errContains: "contains invalid characters",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config := CreateConfig()
|
||||
config.ProviderURL = "https://provider.example.com"
|
||||
config.ClientID = "test-client-id"
|
||||
config.ClientSecret = "test-client-secret"
|
||||
config.CallbackURL = "/callback"
|
||||
config.SessionEncryptionKey = strings.Repeat("a", MinSessionEncryptionKeyLength)
|
||||
config.Audience = tt.audience
|
||||
|
||||
err := config.Validate()
|
||||
if err == nil {
|
||||
t.Errorf("SECURITY VULNERABILITY: Injection attempt with %q was not rejected", tt.name)
|
||||
} else if !strings.Contains(err.Error(), tt.errContains) {
|
||||
t.Errorf("Expected error containing %q, got: %v", tt.errContains, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAudienceWithReplayProtection tests that replay protection works correctly with custom audiences
|
||||
func TestAudienceWithReplayProtection(t *testing.T) {
|
||||
// Create cleanup helper
|
||||
tc := newTestCleanup(t)
|
||||
|
||||
// Generate test keys
|
||||
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate RSA key: %v", err)
|
||||
}
|
||||
rsaPublicKey := &rsaPrivateKey.PublicKey
|
||||
|
||||
jwk := JWK{
|
||||
Kty: "RSA",
|
||||
Kid: "test-key-id",
|
||||
Alg: "RS256",
|
||||
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}),
|
||||
}
|
||||
jwks := &JWKSet{
|
||||
Keys: []JWK{jwk},
|
||||
}
|
||||
|
||||
mockJWKCache := &MockJWKCache{
|
||||
JWKS: jwks,
|
||||
Err: nil,
|
||||
}
|
||||
|
||||
logger := NewLogger("debug")
|
||||
tokenBlacklist := tc.addCache(NewCache())
|
||||
tokenCache := tc.addTokenCache(NewTokenCache())
|
||||
|
||||
tOidc := &TraefikOidc{
|
||||
issuerURL: "https://auth.example.com",
|
||||
clientID: "test-client-id",
|
||||
clientSecret: "test-client-secret",
|
||||
jwkCache: mockJWKCache,
|
||||
jwksURL: "https://auth.example.com/.well-known/jwks.json",
|
||||
tokenBlacklist: tokenBlacklist,
|
||||
tokenCache: tokenCache,
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
||||
logger: logger,
|
||||
httpClient: &http.Client{},
|
||||
}
|
||||
tOidc.jwtVerifier = tOidc
|
||||
tOidc.tokenVerifier = tOidc
|
||||
|
||||
// Create a token with custom audience and fixed JTI
|
||||
fixedJTI := "replay-test-jti-" + generateRandomString(8)
|
||||
customAudience := "https://api.example.com"
|
||||
|
||||
// Set the audience field to match what we expect
|
||||
tOidc.audience = customAudience
|
||||
|
||||
jwt, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://auth.example.com",
|
||||
"aud": customAudience,
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
||||
"sub": "test-user",
|
||||
"email": "user@example.com",
|
||||
"jti": fixedJTI,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create JWT: %v", err)
|
||||
}
|
||||
|
||||
// First verification should succeed
|
||||
err = tOidc.VerifyToken(jwt)
|
||||
if err != nil {
|
||||
t.Fatalf("First verification failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify that the JTI was blacklisted
|
||||
if blacklisted, exists := tOidc.tokenBlacklist.Get(fixedJTI); !exists || blacklisted == nil {
|
||||
t.Logf("Note: JTI was not added to blacklist (may be due to test token prefix)")
|
||||
} else {
|
||||
t.Logf("Replay protection verified: JTI %s is correctly blacklisted", fixedJTI)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAudienceEndToEndScenario tests a complete end-to-end scenario with middleware
|
||||
func TestAudienceEndToEndScenario(t *testing.T) {
|
||||
// Create cleanup helper
|
||||
tc := newTestCleanup(t)
|
||||
|
||||
// Create a test next handler
|
||||
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("Authenticated with custom audience"))
|
||||
})
|
||||
|
||||
// Generate test keys
|
||||
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate RSA key: %v", err)
|
||||
}
|
||||
rsaPublicKey := &rsaPrivateKey.PublicKey
|
||||
|
||||
jwk := JWK{
|
||||
Kty: "RSA",
|
||||
Kid: "test-key-id",
|
||||
Alg: "RS256",
|
||||
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}),
|
||||
}
|
||||
jwks := &JWKSet{
|
||||
Keys: []JWK{jwk},
|
||||
}
|
||||
|
||||
mockJWKCache := &MockJWKCache{
|
||||
JWKS: jwks,
|
||||
Err: nil,
|
||||
}
|
||||
|
||||
logger := NewLogger("debug")
|
||||
sm, err := NewSessionManager(strings.Repeat("a", MinSessionEncryptionKeyLength), false, "", "", 0, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session manager: %v", err)
|
||||
}
|
||||
|
||||
tokenBlacklist := tc.addCache(NewCache())
|
||||
tokenCache := tc.addTokenCache(NewTokenCache())
|
||||
|
||||
customAudience := "https://api.company.com"
|
||||
|
||||
tOidc := &TraefikOidc{
|
||||
next: nextHandler,
|
||||
name: "test",
|
||||
redirURLPath: "/callback",
|
||||
logoutURLPath: "/callback/logout",
|
||||
issuerURL: "https://auth.company.com",
|
||||
clientID: "test-client-id",
|
||||
clientSecret: "test-client-secret",
|
||||
audience: customAudience, // Set custom audience
|
||||
jwkCache: mockJWKCache,
|
||||
jwksURL: "https://auth.company.com/.well-known/jwks.json",
|
||||
tokenBlacklist: tokenBlacklist,
|
||||
tokenCache: tokenCache,
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
||||
logger: logger,
|
||||
allowedUserDomains: map[string]struct{}{"company.com": {}},
|
||||
excludedURLs: map[string]struct{}{},
|
||||
httpClient: &http.Client{},
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sm,
|
||||
extractClaimsFunc: extractClaims,
|
||||
}
|
||||
tOidc.jwtVerifier = tOidc
|
||||
tOidc.tokenVerifier = tOidc
|
||||
close(tOidc.initComplete)
|
||||
|
||||
t.Run("End-to-end with correct custom audience", func(t *testing.T) {
|
||||
// Create a valid token with the custom audience
|
||||
validJWT, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://auth.company.com",
|
||||
"aud": customAudience,
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
||||
"sub": "user-123",
|
||||
"email": "user@company.com",
|
||||
"jti": generateRandomString(16),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create valid JWT: %v", err)
|
||||
}
|
||||
|
||||
// Create a request with authenticated session
|
||||
req := httptest.NewRequest("GET", "/protected", nil)
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
req.Header.Set("X-Forwarded-Host", "company.com")
|
||||
|
||||
// Create session with token
|
||||
resp := httptest.NewRecorder()
|
||||
session, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
if err := session.SetAuthenticated(true); err != nil {
|
||||
t.Fatalf("Failed to set authenticated: %v", err)
|
||||
}
|
||||
session.SetEmail("user@company.com")
|
||||
session.SetIDToken(validJWT)
|
||||
session.SetAccessToken(validJWT)
|
||||
|
||||
if err := session.Save(req, resp); err != nil {
|
||||
t.Fatalf("Failed to save session: %v", err)
|
||||
}
|
||||
|
||||
// Get cookies and add them to a new request
|
||||
cookies := resp.Result().Cookies()
|
||||
req = httptest.NewRequest("GET", "/protected", nil)
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
req.Header.Set("X-Forwarded-Host", "company.com")
|
||||
for _, cookie := range cookies {
|
||||
req.AddCookie(cookie)
|
||||
}
|
||||
|
||||
resp = httptest.NewRecorder()
|
||||
tOidc.ServeHTTP(resp, req)
|
||||
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d. Body: %s", resp.Code, resp.Body.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1,409 +0,0 @@
|
||||
// Package auth provides authentication-related functionality for the OIDC middleware.
|
||||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// ScopeFilter interface for filtering OAuth scopes based on provider capabilities
|
||||
type ScopeFilter interface {
|
||||
FilterSupportedScopes(requestedScopes, supportedScopes []string, providerURL string) []string
|
||||
}
|
||||
|
||||
// Handler provides core authentication functionality for OIDC flows
|
||||
type Handler struct {
|
||||
logger Logger
|
||||
enablePKCE bool
|
||||
isGoogleProv func() bool
|
||||
isAzureProv func() bool
|
||||
clientID string
|
||||
authURL string
|
||||
issuerURL string
|
||||
scopes []string
|
||||
overrideScopes bool
|
||||
scopeFilter ScopeFilter // NEW
|
||||
scopesSupported []string // NEW - from provider metadata
|
||||
}
|
||||
|
||||
// Logger interface for dependency injection
|
||||
type Logger interface {
|
||||
Debugf(format string, args ...interface{})
|
||||
Errorf(format string, args ...interface{})
|
||||
}
|
||||
|
||||
// NewAuthHandler creates a new Handler instance
|
||||
func NewAuthHandler(logger Logger, enablePKCE bool, isGoogleProv, isAzureProv func() bool,
|
||||
clientID, authURL, issuerURL string, scopes []string, overrideScopes bool,
|
||||
scopeFilter ScopeFilter, scopesSupported []string) *Handler {
|
||||
return &Handler{
|
||||
logger: logger,
|
||||
enablePKCE: enablePKCE,
|
||||
isGoogleProv: isGoogleProv,
|
||||
isAzureProv: isAzureProv,
|
||||
clientID: clientID,
|
||||
authURL: authURL,
|
||||
issuerURL: issuerURL,
|
||||
scopes: scopes,
|
||||
overrideScopes: overrideScopes,
|
||||
scopeFilter: scopeFilter, // NEW
|
||||
scopesSupported: scopesSupported, // NEW
|
||||
}
|
||||
}
|
||||
|
||||
// InitiateAuthentication initiates the OIDC authentication flow.
|
||||
// It generates CSRF tokens, nonce, PKCE parameters (if enabled), clears the session,
|
||||
// stores authentication state, and redirects the user to the OIDC provider.
|
||||
func (h *Handler) InitiateAuthentication(rw http.ResponseWriter, req *http.Request,
|
||||
session SessionData, redirectURL string,
|
||||
generateNonce, generateCodeVerifier, deriveCodeChallenge func() (string, error)) {
|
||||
h.logger.Debugf("Initiating new OIDC authentication flow for request: %s", req.URL.RequestURI())
|
||||
|
||||
const maxRedirects = 5
|
||||
redirectCount := session.GetRedirectCount()
|
||||
if redirectCount >= maxRedirects {
|
||||
h.logger.Errorf("Maximum redirect limit (%d) exceeded, possible redirect loop detected", maxRedirects)
|
||||
session.ResetRedirectCount()
|
||||
http.Error(rw, "Authentication failed: Too many redirects", http.StatusLoopDetected)
|
||||
return
|
||||
}
|
||||
|
||||
session.IncrementRedirectCount()
|
||||
|
||||
csrfToken := uuid.NewString()
|
||||
nonce, err := generateNonce()
|
||||
if err != nil {
|
||||
h.logger.Errorf("Failed to generate nonce: %v", err)
|
||||
http.Error(rw, "Failed to generate nonce", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Generate PKCE code verifier and challenge if PKCE is enabled
|
||||
var codeVerifier, codeChallenge string
|
||||
if h.enablePKCE {
|
||||
codeVerifier, err = generateCodeVerifier()
|
||||
if err != nil {
|
||||
h.logger.Errorf("Failed to generate code verifier: %v", err)
|
||||
http.Error(rw, "Failed to generate code verifier", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
codeChallenge, err = deriveCodeChallenge()
|
||||
if err != nil {
|
||||
h.logger.Errorf("Failed to generate code challenge: %v", err)
|
||||
http.Error(rw, "Failed to generate code challenge", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
h.logger.Debugf("PKCE enabled, generated code challenge")
|
||||
}
|
||||
|
||||
session.SetAuthenticated(false)
|
||||
session.SetEmail("")
|
||||
session.SetAccessToken("")
|
||||
session.SetRefreshToken("")
|
||||
session.SetIDToken("")
|
||||
session.SetNonce("")
|
||||
session.SetCodeVerifier("")
|
||||
|
||||
session.SetCSRF(csrfToken)
|
||||
session.SetNonce(nonce)
|
||||
if h.enablePKCE {
|
||||
session.SetCodeVerifier(codeVerifier)
|
||||
}
|
||||
session.SetIncomingPath(req.URL.RequestURI())
|
||||
h.logger.Debugf("Storing incoming path: %s", req.URL.RequestURI())
|
||||
|
||||
session.MarkDirty()
|
||||
|
||||
if err := session.Save(req, rw); err != nil {
|
||||
h.logger.Errorf("Failed to save session before redirecting to provider: %v", err)
|
||||
http.Error(rw, "Failed to save session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Debugf("Session saved before redirect. CSRF: %s, Nonce: %s",
|
||||
csrfToken, nonce)
|
||||
|
||||
authURL := h.BuildAuthURL(redirectURL, csrfToken, nonce, codeChallenge)
|
||||
h.logger.Debugf("Redirecting user to OIDC provider: %s", authURL)
|
||||
|
||||
http.Redirect(rw, req, authURL, http.StatusFound)
|
||||
}
|
||||
|
||||
// BuildAuthURL constructs the OIDC provider authorization URL.
|
||||
// It builds the URL with all necessary parameters including client_id, scopes,
|
||||
// PKCE parameters, and provider-specific parameters for Google and Azure.
|
||||
func (h *Handler) BuildAuthURL(redirectURL, state, nonce, codeChallenge string) string {
|
||||
params := url.Values{}
|
||||
params.Set("client_id", h.clientID)
|
||||
params.Set("response_type", "code")
|
||||
params.Set("redirect_uri", redirectURL)
|
||||
params.Set("state", state)
|
||||
params.Set("nonce", nonce)
|
||||
|
||||
if h.enablePKCE && codeChallenge != "" {
|
||||
params.Set("code_challenge", codeChallenge)
|
||||
params.Set("code_challenge_method", "S256")
|
||||
}
|
||||
|
||||
scopes := make([]string, len(h.scopes))
|
||||
copy(scopes, h.scopes)
|
||||
|
||||
// Apply discovery-based scope filtering if available
|
||||
if h.scopeFilter != nil && len(h.scopesSupported) > 0 {
|
||||
scopes = h.scopeFilter.FilterSupportedScopes(scopes, h.scopesSupported, h.issuerURL)
|
||||
h.logger.Debugf("AuthHandler.BuildAuthURL: After discovery filtering: %v", scopes)
|
||||
}
|
||||
|
||||
// Apply provider-specific modifications
|
||||
scopes, params = h.applyProviderSpecificConfig(scopes, params)
|
||||
|
||||
// Final filtering pass to remove anything the provider doesn't support
|
||||
if h.scopeFilter != nil && len(h.scopesSupported) > 0 {
|
||||
scopes = h.scopeFilter.FilterSupportedScopes(scopes, h.scopesSupported, h.issuerURL)
|
||||
h.logger.Debugf("AuthHandler.BuildAuthURL: After final filtering: %v", scopes)
|
||||
}
|
||||
|
||||
if len(scopes) > 0 {
|
||||
finalScopeString := strings.Join(scopes, " ")
|
||||
params.Set("scope", finalScopeString)
|
||||
h.logger.Debugf("AuthHandler.BuildAuthURL: Final scope string being sent to OIDC provider: %s", finalScopeString)
|
||||
}
|
||||
|
||||
return h.buildURLWithParams(h.authURL, params)
|
||||
}
|
||||
|
||||
// applyProviderSpecificConfig applies provider-specific scope and parameter modifications
|
||||
func (h *Handler) applyProviderSpecificConfig(scopes []string, params url.Values) ([]string, url.Values) {
|
||||
switch {
|
||||
case h.isGoogleProv():
|
||||
return h.applyGoogleConfig(scopes, params)
|
||||
case h.isAzureProv():
|
||||
return h.applyAzureConfig(scopes, params)
|
||||
default:
|
||||
return h.applyStandardProviderConfig(scopes, params)
|
||||
}
|
||||
}
|
||||
|
||||
// applyGoogleConfig applies Google-specific configuration
|
||||
func (h *Handler) applyGoogleConfig(scopes []string, params url.Values) ([]string, url.Values) {
|
||||
// Google: Remove offline_access if present, add access_type=offline
|
||||
filteredScopes := make([]string, 0, len(scopes))
|
||||
for _, scope := range scopes {
|
||||
if scope != "offline_access" {
|
||||
filteredScopes = append(filteredScopes, scope)
|
||||
}
|
||||
}
|
||||
params.Set("access_type", "offline")
|
||||
h.logger.Debugf("Google OIDC provider detected, added access_type=offline")
|
||||
params.Set("prompt", "consent")
|
||||
h.logger.Debugf("Google OIDC provider detected, added prompt=consent to ensure refresh tokens")
|
||||
return filteredScopes, params
|
||||
}
|
||||
|
||||
// applyAzureConfig applies Azure AD-specific configuration
|
||||
func (h *Handler) applyAzureConfig(scopes []string, params url.Values) ([]string, url.Values) {
|
||||
params.Set("response_mode", "query")
|
||||
h.logger.Debugf("Azure AD provider detected, added response_mode=query")
|
||||
|
||||
if h.shouldAddOfflineAccess(scopes) {
|
||||
scopes = append(scopes, "offline_access")
|
||||
h.logger.Debugf("Azure AD provider: Added offline_access scope (overrideScopes: %t, user scopes count: %d)",
|
||||
h.overrideScopes, len(h.scopes))
|
||||
} else {
|
||||
h.logger.Debugf("Azure AD provider: User is overriding scopes (count: %d), offline_access not automatically added.",
|
||||
len(h.scopes))
|
||||
}
|
||||
return scopes, params
|
||||
}
|
||||
|
||||
// applyStandardProviderConfig applies configuration for standard OIDC providers
|
||||
func (h *Handler) applyStandardProviderConfig(scopes []string, params url.Values) ([]string, url.Values) {
|
||||
if h.shouldAddOfflineAccess(scopes) {
|
||||
scopes = append(scopes, "offline_access")
|
||||
h.logger.Debugf("Standard provider: Added offline_access scope (overrideScopes: %t, user scopes count: %d)",
|
||||
h.overrideScopes, len(h.scopes))
|
||||
} else {
|
||||
h.logger.Debugf("Standard provider: User is overriding scopes (count: %d), offline_access not automatically added.",
|
||||
len(h.scopes))
|
||||
}
|
||||
return scopes, params
|
||||
}
|
||||
|
||||
// shouldAddOfflineAccess determines if offline_access scope should be added
|
||||
func (h *Handler) shouldAddOfflineAccess(scopes []string) bool {
|
||||
if h.overrideScopes && len(h.scopes) > 0 {
|
||||
return false
|
||||
}
|
||||
for _, scope := range scopes {
|
||||
if scope == "offline_access" {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// buildURLWithParams constructs a URL by combining a base URL with query parameters.
|
||||
// It handles both relative and absolute URLs, validates URL security,
|
||||
// and properly encodes query parameters.
|
||||
func (h *Handler) buildURLWithParams(baseURL string, params url.Values) string {
|
||||
if baseURL != "" {
|
||||
if strings.HasPrefix(baseURL, "http://") || strings.HasPrefix(baseURL, "https://") {
|
||||
if err := h.validateURL(baseURL); err != nil {
|
||||
h.logger.Errorf("URL validation failed for %s: %v", baseURL, err)
|
||||
return ""
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
|
||||
issuerURLParsed, err := url.Parse(h.issuerURL)
|
||||
if err != nil {
|
||||
h.logger.Errorf("Could not parse issuerURL: %s. Error: %v", h.issuerURL, err)
|
||||
return ""
|
||||
}
|
||||
|
||||
baseURLParsed, err := url.Parse(baseURL)
|
||||
if err != nil {
|
||||
h.logger.Errorf("Could not parse baseURL: %s. Error: %v", baseURL, err)
|
||||
return ""
|
||||
}
|
||||
|
||||
resolvedURL := issuerURLParsed.ResolveReference(baseURLParsed)
|
||||
|
||||
if err := h.validateURL(resolvedURL.String()); err != nil {
|
||||
h.logger.Errorf("Resolved URL validation failed for %s: %v", resolvedURL.String(), err)
|
||||
return ""
|
||||
}
|
||||
|
||||
resolvedURL.RawQuery = params.Encode()
|
||||
return resolvedURL.String()
|
||||
}
|
||||
|
||||
u, err := url.Parse(baseURL)
|
||||
if err != nil {
|
||||
h.logger.Errorf("Could not parse absolute baseURL: %s. Error: %v", baseURL, err)
|
||||
return ""
|
||||
}
|
||||
|
||||
if err := h.validateParsedURL(u); err != nil {
|
||||
h.logger.Errorf("Parsed URL validation failed for %s: %v", baseURL, err)
|
||||
return ""
|
||||
}
|
||||
|
||||
u.RawQuery = params.Encode()
|
||||
return u.String()
|
||||
}
|
||||
|
||||
// validateURL performs security validation on URLs to prevent SSRF attacks.
|
||||
// It checks for allowed schemes, validates hosts, and prevents access to private networks.
|
||||
func (h *Handler) validateURL(urlStr string) error {
|
||||
if urlStr == "" {
|
||||
return fmt.Errorf("empty URL")
|
||||
}
|
||||
|
||||
u, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid URL format: %w", err)
|
||||
}
|
||||
|
||||
return h.validateParsedURL(u)
|
||||
}
|
||||
|
||||
// validateParsedURL validates a parsed URL structure for security.
|
||||
// It checks schemes, hosts, and paths to prevent malicious URLs.
|
||||
func (h *Handler) validateParsedURL(u *url.URL) error {
|
||||
allowedSchemes := map[string]bool{
|
||||
"https": true,
|
||||
"http": true,
|
||||
}
|
||||
|
||||
if !allowedSchemes[u.Scheme] {
|
||||
return fmt.Errorf("disallowed URL scheme: %s", u.Scheme)
|
||||
}
|
||||
|
||||
if u.Scheme == "http" {
|
||||
h.logger.Debugf("Warning: Using HTTP scheme for URL: %s", u.String())
|
||||
}
|
||||
|
||||
if u.Host == "" {
|
||||
return fmt.Errorf("missing host in URL")
|
||||
}
|
||||
|
||||
if err := h.validateHost(u.Host); err != nil {
|
||||
return fmt.Errorf("invalid host: %w", err)
|
||||
}
|
||||
|
||||
if strings.Contains(u.Path, "..") {
|
||||
return fmt.Errorf("path traversal detected in URL path")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateHost validates a hostname for security and reachability.
|
||||
// It prevents access to private networks and localhost addresses.
|
||||
func (h *Handler) validateHost(host string) error {
|
||||
if host == "" {
|
||||
return fmt.Errorf("empty host")
|
||||
}
|
||||
|
||||
// Strip port if present
|
||||
if strings.Contains(host, ":") {
|
||||
var err error
|
||||
host, _, err = net.SplitHostPort(host)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid host:port format: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Check for localhost variations
|
||||
localhostVariations := []string{
|
||||
"localhost", "127.0.0.1", "::1", "0.0.0.0",
|
||||
}
|
||||
for _, localhost := range localhostVariations {
|
||||
if strings.EqualFold(host, localhost) {
|
||||
return fmt.Errorf("localhost access not allowed: %s", host)
|
||||
}
|
||||
}
|
||||
|
||||
// Try to parse as IP address
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
if ip.IsLoopback() {
|
||||
return fmt.Errorf("loopback IP not allowed: %s", host)
|
||||
}
|
||||
if ip.IsPrivate() {
|
||||
return fmt.Errorf("private IP not allowed: %s", host)
|
||||
}
|
||||
if ip.IsLinkLocalUnicast() {
|
||||
return fmt.Errorf("link-local IP not allowed: %s", host)
|
||||
}
|
||||
if ip.IsMulticast() {
|
||||
return fmt.Errorf("multicast IP not allowed: %s", host)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SessionData interface for dependency injection
|
||||
type SessionData interface {
|
||||
GetRedirectCount() int
|
||||
ResetRedirectCount()
|
||||
IncrementRedirectCount()
|
||||
SetAuthenticated(bool)
|
||||
SetEmail(string)
|
||||
SetAccessToken(string)
|
||||
SetRefreshToken(string)
|
||||
SetIDToken(string)
|
||||
SetNonce(string)
|
||||
SetCodeVerifier(string)
|
||||
SetCSRF(string)
|
||||
SetIncomingPath(string)
|
||||
MarkDirty()
|
||||
Save(req *http.Request, rw http.ResponseWriter) error
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,562 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestAuthHandler_validateURL tests URL validation functionality
|
||||
func TestAuthHandler_validateURL(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
url string
|
||||
wantErr bool
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "Valid HTTPS URL",
|
||||
url: "https://example.com/auth",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Valid HTTP URL",
|
||||
url: "http://example.com/auth",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Empty URL",
|
||||
url: "",
|
||||
wantErr: true,
|
||||
errMsg: "empty URL",
|
||||
},
|
||||
{
|
||||
name: "Invalid URL format",
|
||||
url: "not-a-url",
|
||||
wantErr: true,
|
||||
errMsg: "disallowed URL scheme",
|
||||
},
|
||||
{
|
||||
name: "Disallowed scheme - javascript",
|
||||
url: "javascript:alert('xss')",
|
||||
wantErr: true,
|
||||
errMsg: "disallowed URL scheme",
|
||||
},
|
||||
{
|
||||
name: "Disallowed scheme - data",
|
||||
url: "data:text/html,<script>alert('xss')</script>",
|
||||
wantErr: true,
|
||||
errMsg: "disallowed URL scheme",
|
||||
},
|
||||
{
|
||||
name: "Disallowed scheme - file",
|
||||
url: "file:///etc/passwd",
|
||||
wantErr: true,
|
||||
errMsg: "disallowed URL scheme",
|
||||
},
|
||||
{
|
||||
name: "Disallowed scheme - ftp",
|
||||
url: "ftp://example.com/file",
|
||||
wantErr: true,
|
||||
errMsg: "disallowed URL scheme",
|
||||
},
|
||||
{
|
||||
name: "Missing host",
|
||||
url: "https:///path",
|
||||
wantErr: true,
|
||||
errMsg: "missing host",
|
||||
},
|
||||
{
|
||||
name: "Path traversal attempt",
|
||||
url: "https://example.com/../../../etc/passwd",
|
||||
wantErr: true,
|
||||
errMsg: "path traversal detected",
|
||||
},
|
||||
{
|
||||
name: "Path traversal in middle",
|
||||
url: "https://example.com/path/../sensitive/file",
|
||||
wantErr: true,
|
||||
errMsg: "path traversal detected",
|
||||
},
|
||||
{
|
||||
name: "Localhost attempt",
|
||||
url: "https://localhost/auth",
|
||||
wantErr: true,
|
||||
errMsg: "localhost access not allowed",
|
||||
},
|
||||
{
|
||||
name: "127.0.0.1 attempt",
|
||||
url: "https://127.0.0.1/auth",
|
||||
wantErr: true,
|
||||
errMsg: "localhost access not allowed",
|
||||
},
|
||||
{
|
||||
name: "IPv6 localhost attempt",
|
||||
url: "https://[::1]/auth",
|
||||
wantErr: true,
|
||||
errMsg: "invalid host:port format",
|
||||
},
|
||||
{
|
||||
name: "0.0.0.0 attempt",
|
||||
url: "https://0.0.0.0/auth",
|
||||
wantErr: true,
|
||||
errMsg: "localhost access not allowed",
|
||||
},
|
||||
{
|
||||
name: "Private IP - 192.168.x.x",
|
||||
url: "https://192.168.1.1/auth",
|
||||
wantErr: true,
|
||||
errMsg: "private IP not allowed",
|
||||
},
|
||||
{
|
||||
name: "Private IP - 10.x.x.x",
|
||||
url: "https://10.0.0.1/auth",
|
||||
wantErr: true,
|
||||
errMsg: "private IP not allowed",
|
||||
},
|
||||
{
|
||||
name: "Private IP - 172.16.x.x",
|
||||
url: "https://172.16.0.1/auth",
|
||||
wantErr: true,
|
||||
errMsg: "private IP not allowed",
|
||||
},
|
||||
{
|
||||
name: "Link-local IP",
|
||||
url: "https://169.254.1.1/auth",
|
||||
wantErr: true,
|
||||
errMsg: "link-local IP not allowed",
|
||||
},
|
||||
{
|
||||
name: "Multicast IP",
|
||||
url: "https://224.0.0.1/auth",
|
||||
wantErr: true,
|
||||
errMsg: "multicast IP not allowed",
|
||||
},
|
||||
{
|
||||
name: "Valid public IP",
|
||||
url: "https://8.8.8.8/auth",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Valid domain with port",
|
||||
url: "https://example.com:8443/auth",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "localhost with case variation",
|
||||
url: "https://LOCALHOST/auth",
|
||||
wantErr: true,
|
||||
errMsg: "localhost access not allowed",
|
||||
},
|
||||
{
|
||||
name: "Invalid host:port format",
|
||||
url: "https://example.com:notanumber/auth",
|
||||
wantErr: true,
|
||||
errMsg: "invalid URL format",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := handler.validateURL(tt.url)
|
||||
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Errorf("validateURL() expected error but got none")
|
||||
return
|
||||
}
|
||||
if !strings.Contains(err.Error(), tt.errMsg) {
|
||||
t.Errorf("validateURL() error = %v, expected error containing %v", err, tt.errMsg)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("validateURL() unexpected error = %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthHandler_validateHost tests host validation specifically
|
||||
func TestAuthHandler_validateHost(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
host string
|
||||
wantErr bool
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "Valid hostname",
|
||||
host: "example.com",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Valid hostname with subdomain",
|
||||
host: "api.example.com",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Valid hostname with port",
|
||||
host: "example.com:8080",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Empty host",
|
||||
host: "",
|
||||
wantErr: true,
|
||||
errMsg: "empty host",
|
||||
},
|
||||
{
|
||||
name: "localhost",
|
||||
host: "localhost",
|
||||
wantErr: true,
|
||||
errMsg: "localhost access not allowed",
|
||||
},
|
||||
{
|
||||
name: "LOCALHOST (case insensitive)",
|
||||
host: "LOCALHOST",
|
||||
wantErr: true,
|
||||
errMsg: "localhost access not allowed",
|
||||
},
|
||||
{
|
||||
name: "localhost with port",
|
||||
host: "localhost:8080",
|
||||
wantErr: true,
|
||||
errMsg: "localhost access not allowed",
|
||||
},
|
||||
{
|
||||
name: "127.0.0.1",
|
||||
host: "127.0.0.1",
|
||||
wantErr: true,
|
||||
errMsg: "localhost access not allowed",
|
||||
},
|
||||
{
|
||||
name: "127.0.0.1 with port",
|
||||
host: "127.0.0.1:8080",
|
||||
wantErr: true,
|
||||
errMsg: "localhost access not allowed",
|
||||
},
|
||||
{
|
||||
name: "IPv6 localhost",
|
||||
host: "::1",
|
||||
wantErr: true,
|
||||
errMsg: "invalid host:port format",
|
||||
},
|
||||
{
|
||||
name: "0.0.0.0",
|
||||
host: "0.0.0.0",
|
||||
wantErr: true,
|
||||
errMsg: "localhost access not allowed",
|
||||
},
|
||||
{
|
||||
name: "Private IP 192.168.1.1",
|
||||
host: "192.168.1.1",
|
||||
wantErr: true,
|
||||
errMsg: "private IP not allowed",
|
||||
},
|
||||
{
|
||||
name: "Private IP 10.0.0.1",
|
||||
host: "10.0.0.1",
|
||||
wantErr: true,
|
||||
errMsg: "private IP not allowed",
|
||||
},
|
||||
{
|
||||
name: "Private IP 172.16.0.1",
|
||||
host: "172.16.0.1",
|
||||
wantErr: true,
|
||||
errMsg: "private IP not allowed",
|
||||
},
|
||||
{
|
||||
name: "Public IP 8.8.8.8",
|
||||
host: "8.8.8.8",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Link-local IP",
|
||||
host: "169.254.1.1",
|
||||
wantErr: true,
|
||||
errMsg: "link-local IP not allowed",
|
||||
},
|
||||
{
|
||||
name: "Multicast IP",
|
||||
host: "224.0.0.1",
|
||||
wantErr: true,
|
||||
errMsg: "multicast IP not allowed",
|
||||
},
|
||||
{
|
||||
name: "Invalid host:port format",
|
||||
host: "example.com::",
|
||||
wantErr: true,
|
||||
errMsg: "invalid host:port format",
|
||||
},
|
||||
{
|
||||
name: "Valid international domain",
|
||||
host: "example.org",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Valid ccTLD",
|
||||
host: "example.co.uk",
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := handler.validateHost(tt.host)
|
||||
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Errorf("validateHost() expected error but got none")
|
||||
return
|
||||
}
|
||||
if !strings.Contains(err.Error(), tt.errMsg) {
|
||||
t.Errorf("validateHost() error = %v, expected error containing %v", err, tt.errMsg)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("validateHost() unexpected error = %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthHandler_buildURLWithParams tests URL building with parameters
|
||||
func TestAuthHandler_buildURLWithParams(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
baseURL string
|
||||
params url.Values
|
||||
expected string
|
||||
expectEmpty bool
|
||||
}{
|
||||
{
|
||||
name: "Absolute HTTPS URL",
|
||||
baseURL: "https://provider.com/auth",
|
||||
params: url.Values{
|
||||
"client_id": []string{"test-client"},
|
||||
"response_type": []string{"code"},
|
||||
},
|
||||
expected: "https://provider.com/auth?client_id=test-client&response_type=code",
|
||||
},
|
||||
{
|
||||
name: "Absolute HTTP URL",
|
||||
baseURL: "http://provider.com/auth",
|
||||
params: url.Values{
|
||||
"state": []string{"test-state"},
|
||||
},
|
||||
expected: "http://provider.com/auth?state=test-state",
|
||||
},
|
||||
{
|
||||
name: "Relative URL resolved against issuer",
|
||||
baseURL: "/oauth2/authorize",
|
||||
params: url.Values{
|
||||
"scope": []string{"openid"},
|
||||
},
|
||||
expected: "https://example.com/oauth2/authorize?scope=openid",
|
||||
},
|
||||
{
|
||||
name: "Root relative URL",
|
||||
baseURL: "/auth",
|
||||
params: url.Values{
|
||||
"nonce": []string{"test-nonce"},
|
||||
},
|
||||
expected: "https://example.com/auth?nonce=test-nonce",
|
||||
},
|
||||
{
|
||||
name: "Invalid absolute URL",
|
||||
baseURL: "https://localhost/auth",
|
||||
params: url.Values{},
|
||||
expectEmpty: true, // Should return empty string due to validation failure
|
||||
},
|
||||
{
|
||||
name: "Invalid relative URL when resolved",
|
||||
baseURL: "/auth",
|
||||
params: url.Values{},
|
||||
expected: "", // Should be empty because issuer validation would be tested separately
|
||||
expectEmpty: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := handler.buildURLWithParams(tt.baseURL, tt.params)
|
||||
|
||||
if tt.expectEmpty {
|
||||
if result != "" {
|
||||
t.Errorf("buildURLWithParams() expected empty string, got %v", result)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// For relative URLs, we expect them to be resolved against the issuer URL
|
||||
if !strings.HasPrefix(tt.baseURL, "http") {
|
||||
// Verify it starts with the issuer URL
|
||||
if !strings.HasPrefix(result, handler.issuerURL) {
|
||||
t.Errorf("buildURLWithParams() relative URL not resolved against issuer URL. Got %v", result)
|
||||
}
|
||||
}
|
||||
|
||||
// Parse the result to verify parameters
|
||||
parsedURL, err := url.Parse(result)
|
||||
if err != nil {
|
||||
t.Fatalf("buildURLWithParams() produced invalid URL: %v", err)
|
||||
}
|
||||
|
||||
// Verify all expected parameters are present
|
||||
resultParams := parsedURL.Query()
|
||||
for key, expectedValues := range tt.params {
|
||||
actualValues := resultParams[key]
|
||||
if len(actualValues) != len(expectedValues) {
|
||||
t.Errorf("Parameter %s: expected %d values, got %d", key, len(expectedValues), len(actualValues))
|
||||
continue
|
||||
}
|
||||
for i, expectedValue := range expectedValues {
|
||||
if actualValues[i] != expectedValue {
|
||||
t.Errorf("Parameter %s[%d]: expected %v, got %v", key, i, expectedValue, actualValues[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthHandler_buildURLWithParams_ParameterEncoding tests proper parameter encoding
|
||||
func TestAuthHandler_buildURLWithParams_ParameterEncoding(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
|
||||
|
||||
// Test special characters that need encoding
|
||||
params := url.Values{
|
||||
"redirect_uri": []string{"https://example.com/callback?test=value&other=data"},
|
||||
"state": []string{"state with spaces and & special chars"},
|
||||
"scope": []string{"openid profile email"},
|
||||
"special": []string{"value+with+plus&ersand=equals"},
|
||||
}
|
||||
|
||||
result := handler.buildURLWithParams("https://provider.com/auth", params)
|
||||
|
||||
parsedURL, err := url.Parse(result)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse result URL: %v", err)
|
||||
}
|
||||
|
||||
// Verify parameters are correctly encoded/decoded
|
||||
resultParams := parsedURL.Query()
|
||||
|
||||
expectedParams := map[string]string{
|
||||
"redirect_uri": "https://example.com/callback?test=value&other=data",
|
||||
"state": "state with spaces and & special chars",
|
||||
"scope": "openid profile email",
|
||||
"special": "value+with+plus&ersand=equals",
|
||||
}
|
||||
|
||||
for key, expectedValue := range expectedParams {
|
||||
actualValue := resultParams.Get(key)
|
||||
if actualValue != expectedValue {
|
||||
t.Errorf("Parameter %s: expected %v, got %v", key, expectedValue, actualValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthHandler_validateParsedURL tests validateParsedURL method
|
||||
func TestAuthHandler_validateParsedURL(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
url string
|
||||
wantErr bool
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "Valid HTTPS URL",
|
||||
url: "https://example.com/path",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Valid HTTP URL with warning",
|
||||
url: "http://example.com/path",
|
||||
wantErr: false, // Should not error but should log warning
|
||||
},
|
||||
{
|
||||
name: "Invalid scheme",
|
||||
url: "javascript:alert('xss')",
|
||||
wantErr: true,
|
||||
errMsg: "disallowed URL scheme",
|
||||
},
|
||||
{
|
||||
name: "Missing host",
|
||||
url: "https:///path",
|
||||
wantErr: true,
|
||||
errMsg: "missing host",
|
||||
},
|
||||
{
|
||||
name: "Path traversal",
|
||||
url: "https://example.com/path/../../../etc",
|
||||
wantErr: true,
|
||||
errMsg: "path traversal detected",
|
||||
},
|
||||
{
|
||||
name: "Invalid host (private IP)",
|
||||
url: "https://192.168.1.1/path",
|
||||
wantErr: true,
|
||||
errMsg: "invalid host",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
parsedURL, err := url.Parse(tt.url)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse test URL: %v", err)
|
||||
}
|
||||
|
||||
err = handler.validateParsedURL(parsedURL)
|
||||
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Errorf("validateParsedURL() expected error but got none")
|
||||
return
|
||||
}
|
||||
if !strings.Contains(err.Error(), tt.errMsg) {
|
||||
t.Errorf("validateParsedURL() error = %v, expected error containing %v", err, tt.errMsg)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("validateParsedURL() unexpected error = %v", err)
|
||||
}
|
||||
|
||||
// Check for HTTP warning in debug logs
|
||||
if parsedURL.Scheme == "http" && len(logger.debugMessages) > 0 {
|
||||
found := false
|
||||
for _, msg := range logger.debugMessages {
|
||||
if strings.Contains(msg, "Warning: Using HTTP scheme") {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("Expected HTTP scheme warning in debug logs")
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,428 +0,0 @@
|
||||
// Package traefikoidc provides OIDC authentication middleware for Traefik.
|
||||
// This file contains tests for Auth0-specific audience validation scenarios.
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestAuth0Scenario1WithCustomAudience tests Auth0 scenario 1:
|
||||
// - Custom audience configured in plugin
|
||||
// - Authorize endpoint called WITH audience parameter
|
||||
// - ID token: aud = client_id
|
||||
// - Access token: aud = [userinfo, custom_audience]
|
||||
// Expected: Both tokens validate correctly
|
||||
func TestAuth0Scenario1WithCustomAudience(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
customAudience := "https://my-api.example.com"
|
||||
ts.tOidc.audience = customAudience
|
||||
|
||||
// Create ID token with aud = client_id (OIDC standard)
|
||||
idToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id", // ID token always has client_id
|
||||
"nonce": "test-nonce-scenario1", // ID tokens have nonce per OIDC spec
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
"email": "test@example.com",
|
||||
"jti": "id-token-jti",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create ID token: %v", err)
|
||||
}
|
||||
|
||||
// Create access token with aud = [userinfo, custom_audience]
|
||||
accessToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": []interface{}{
|
||||
"https://test-issuer.com/userinfo",
|
||||
customAudience, // Custom API audience
|
||||
},
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
"scope": "openid profile email read:data", // Access tokens have scope
|
||||
"jti": "access-token-jti",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create access token: %v", err)
|
||||
}
|
||||
|
||||
// Verify ID token validates against client_id
|
||||
cleanupReplayCache()
|
||||
initReplayCache()
|
||||
err = ts.tOidc.VerifyToken(idToken)
|
||||
if err != nil {
|
||||
t.Errorf("ID token validation failed (should validate against client_id): %v", err)
|
||||
}
|
||||
|
||||
// Verify access token validates against custom audience
|
||||
cleanupReplayCache()
|
||||
initReplayCache()
|
||||
err = ts.tOidc.VerifyToken(accessToken)
|
||||
if err != nil {
|
||||
t.Errorf("Access token validation failed (should validate against custom audience): %v", err)
|
||||
}
|
||||
|
||||
// Verify buildAuthURL includes audience parameter (URL-encoded)
|
||||
authURL := ts.tOidc.buildAuthURL("https://example.com/callback", "state", "nonce", "")
|
||||
if !strings.Contains(authURL, "audience=") {
|
||||
t.Errorf("Auth URL should contain audience parameter when custom audience is configured, got: %s", authURL)
|
||||
}
|
||||
// Verify the audience is properly URL-encoded (contains %3A for :, %2F for /)
|
||||
if !strings.Contains(authURL, "audience=https%3A%2F%2Fmy-api.example.com") {
|
||||
t.Errorf("Auth URL should contain URL-encoded custom audience, got: %s", authURL)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuth0Scenario2DefaultAudience tests Auth0 scenario 2:
|
||||
// - No custom audience configured (defaults to client_id)
|
||||
// - Authorize endpoint called WITHOUT audience parameter
|
||||
// - ID token: aud = client_id
|
||||
// - Access token: aud = [userinfo, default_audience] (no client_id)
|
||||
// Expected: ID token validates, access token falls back to ID token validation
|
||||
func TestAuth0Scenario2DefaultAudience(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
// No custom audience - defaults to client_id
|
||||
ts.tOidc.audience = ts.tOidc.clientID
|
||||
|
||||
// Create ID token with aud = client_id
|
||||
idToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"nonce": "test-nonce-scenario2", // ID tokens have nonce per OIDC spec
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
"email": "test@example.com",
|
||||
"jti": "id-token-jti-2",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create ID token: %v", err)
|
||||
}
|
||||
|
||||
// Create access token with aud = [userinfo, some_default_audience]
|
||||
// This represents Auth0's default audience behavior
|
||||
accessToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": []interface{}{
|
||||
"https://test-issuer.com/userinfo",
|
||||
"https://test-issuer.com/api/v2/", // Default Auth0 Management API
|
||||
},
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
"scope": "openid profile email",
|
||||
"jti": "access-token-jti-2",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create access token: %v", err)
|
||||
}
|
||||
|
||||
// Verify ID token validates
|
||||
cleanupReplayCache()
|
||||
initReplayCache()
|
||||
err = ts.tOidc.VerifyToken(idToken)
|
||||
if err != nil {
|
||||
t.Errorf("ID token validation failed: %v", err)
|
||||
}
|
||||
|
||||
// Access token won't have client_id in aud, so it will fail validation
|
||||
// This is expected for scenario 2 - the session validation relies on ID token
|
||||
cleanupReplayCache()
|
||||
initReplayCache()
|
||||
err = ts.tOidc.VerifyToken(accessToken)
|
||||
if err == nil {
|
||||
t.Logf("Access token validation passed (unexpected but OK if client_id is in aud array)")
|
||||
} else {
|
||||
// Expected failure - access token doesn't have client_id in aud
|
||||
t.Logf("Access token validation failed as expected (aud doesn't contain client_id): %v", err)
|
||||
}
|
||||
|
||||
// Verify buildAuthURL does NOT include audience parameter (since audience == client_id)
|
||||
authURL := ts.tOidc.buildAuthURL("https://example.com/callback", "state", "nonce", "")
|
||||
if strings.Contains(authURL, "audience=") {
|
||||
t.Errorf("Auth URL should NOT contain audience parameter when audience equals client_id, got: %s", authURL)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuth0Scenario3OpaqueAccessToken tests Auth0 scenario 3:
|
||||
// - No custom audience configured
|
||||
// - No default audience in Auth0
|
||||
// - ID token: aud = client_id
|
||||
// - Access token: opaque (not JWT)
|
||||
// Expected: ID token validates, opaque access token is accepted
|
||||
func TestAuth0Scenario3OpaqueAccessToken(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
// Enable opaque tokens for this scenario (Option C requirement)
|
||||
ts.tOidc.allowOpaqueTokens = true
|
||||
|
||||
// No custom audience
|
||||
ts.tOidc.audience = ts.tOidc.clientID
|
||||
|
||||
// Create ID token
|
||||
idToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"nonce": "test-nonce-scenario3", // ID tokens have nonce per OIDC spec
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
"email": "test@example.com",
|
||||
"jti": "id-token-jti-3",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create ID token: %v", err)
|
||||
}
|
||||
|
||||
// Opaque access token (not a JWT - just a random string)
|
||||
opaqueAccessToken := "opaque_access_token_random_string_12345"
|
||||
|
||||
// Verify ID token validates
|
||||
cleanupReplayCache()
|
||||
initReplayCache()
|
||||
err = ts.tOidc.VerifyToken(idToken)
|
||||
if err != nil {
|
||||
t.Errorf("ID token validation failed: %v", err)
|
||||
}
|
||||
|
||||
// Opaque access token should fail JWT validation (expected)
|
||||
err = ts.tOidc.VerifyToken(opaqueAccessToken)
|
||||
if err == nil {
|
||||
t.Error("Opaque access token should fail JWT validation")
|
||||
} else {
|
||||
t.Logf("Opaque access token correctly rejected by JWT validator: %v", err)
|
||||
}
|
||||
|
||||
// Test that validateStandardTokens handles opaque tokens correctly
|
||||
// by falling back to ID token validation
|
||||
req := httptest.NewRequest("GET", "https://example.com/test", nil)
|
||||
|
||||
session, err := ts.tOidc.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
session.SetAuthenticated(true)
|
||||
session.SetAccessToken(opaqueAccessToken)
|
||||
session.SetIDToken(idToken)
|
||||
|
||||
authenticated, needsRefresh, expired := ts.tOidc.validateStandardTokens(session)
|
||||
if !authenticated || needsRefresh || expired {
|
||||
t.Errorf("Session with opaque access token and valid ID token should be authenticated. Got: auth=%v, refresh=%v, expired=%v",
|
||||
authenticated, needsRefresh, expired)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuth0AudienceArrayValidation tests that audience validation
|
||||
// correctly handles array audiences (common in Auth0)
|
||||
func TestAuth0AudienceArrayValidation(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
customAudience := "https://my-api.example.com"
|
||||
ts.tOidc.audience = customAudience
|
||||
|
||||
// Access token with audience as array containing our custom audience
|
||||
accessToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": []interface{}{
|
||||
"https://test-issuer.com/userinfo",
|
||||
customAudience,
|
||||
"https://another-api.example.com",
|
||||
},
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
"scope": "openid profile email read:data write:data",
|
||||
"jti": "array-aud-token-jti",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create access token: %v", err)
|
||||
}
|
||||
|
||||
// Should validate successfully - custom audience is in the array
|
||||
cleanupReplayCache()
|
||||
initReplayCache()
|
||||
err = ts.tOidc.VerifyToken(accessToken)
|
||||
if err != nil {
|
||||
t.Errorf("Access token with audience array should validate when custom audience is present: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuth0MismatchedAudience tests that tokens with wrong audience fail validation
|
||||
func TestAuth0MismatchedAudience(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
customAudience := "https://my-api.example.com"
|
||||
ts.tOidc.audience = customAudience
|
||||
|
||||
// Access token with WRONG audience
|
||||
accessToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": []interface{}{
|
||||
"https://test-issuer.com/userinfo",
|
||||
"https://different-api.example.com", // Wrong audience
|
||||
},
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
"scope": "openid profile email",
|
||||
"jti": "wrong-aud-token-jti",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create access token: %v", err)
|
||||
}
|
||||
|
||||
// Should fail validation - audience doesn't match
|
||||
cleanupReplayCache()
|
||||
initReplayCache()
|
||||
err = ts.tOidc.VerifyToken(accessToken)
|
||||
if err == nil {
|
||||
t.Error("Access token with wrong audience should fail validation")
|
||||
} else if !strings.Contains(err.Error(), "invalid audience") {
|
||||
t.Errorf("Expected 'invalid audience' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuth0Scenario2StrictMode tests strict audience validation mode:
|
||||
// - Scenario 2 (access token with wrong audience) should be REJECTED
|
||||
// - strictAudienceValidation=true prevents fallback to ID token
|
||||
// - This addresses Allan's security concerns about audience bypass
|
||||
func TestAuth0Scenario2StrictMode(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
// Enable strict mode to prevent Scenario 2 bypass (Option C)
|
||||
ts.tOidc.strictAudienceValidation = true
|
||||
|
||||
// Configure custom audience
|
||||
customAudience := "https://my-api.example.com"
|
||||
ts.tOidc.audience = customAudience
|
||||
|
||||
// Create ID token with aud = client_id (valid)
|
||||
idToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"nonce": "test-nonce-strict",
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
"email": "test@example.com",
|
||||
"jti": "id-token-strict-jti",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create ID token: %v", err)
|
||||
}
|
||||
|
||||
// Create access token with WRONG audience (doesn't include custom audience)
|
||||
accessToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": []interface{}{
|
||||
"https://test-issuer.com/userinfo",
|
||||
"https://wrong-api.example.com", // Wrong audience - not our custom audience
|
||||
},
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
"scope": "openid profile email",
|
||||
"jti": "access-token-strict-jti",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create access token: %v", err)
|
||||
}
|
||||
|
||||
// Test session validation with wrong access token and valid ID token
|
||||
req := httptest.NewRequest("GET", "https://example.com/test", nil)
|
||||
session, err := ts.tOidc.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
session.SetAuthenticated(true)
|
||||
session.SetAccessToken(accessToken)
|
||||
session.SetIDToken(idToken)
|
||||
session.SetRefreshToken("test-refresh-token") // Add refresh token so it can attempt refresh
|
||||
|
||||
// In strict mode, this should FAIL (no fallback to ID token)
|
||||
authenticated, needsRefresh, expired := ts.tOidc.validateStandardTokens(session)
|
||||
if authenticated {
|
||||
t.Errorf("Strict mode: Session with wrong access token audience should be rejected, but got authenticated=true")
|
||||
}
|
||||
if !needsRefresh {
|
||||
t.Errorf("Strict mode: Should signal refresh needed after rejection, got needsRefresh=%v", needsRefresh)
|
||||
}
|
||||
if expired {
|
||||
t.Errorf("Strict mode: Should not mark as expired (should try refresh first), got expired=%v", expired)
|
||||
}
|
||||
|
||||
t.Logf("✓ Strict mode correctly rejected Scenario 2 (access token audience mismatch)")
|
||||
}
|
||||
|
||||
// TestIDTokenAlwaysValidatesAgainstClientID verifies that ID tokens
|
||||
// are ALWAYS validated against client_id, regardless of configured audience
|
||||
func TestIDTokenAlwaysValidatesAgainstClientID(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
// Configure a custom audience different from client_id
|
||||
customAudience := "https://my-api.example.com"
|
||||
ts.tOidc.audience = customAudience
|
||||
|
||||
// Create ID token with aud = client_id (per OIDC spec)
|
||||
idToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id", // ID token MUST have client_id
|
||||
"nonce": "test-nonce-123", // ID tokens have nonce for replay protection
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
"email": "test@example.com",
|
||||
"jti": "id-token-client-id-jti",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create ID token: %v", err)
|
||||
}
|
||||
|
||||
// Should validate successfully - ID tokens are checked against client_id
|
||||
cleanupReplayCache()
|
||||
initReplayCache()
|
||||
err = ts.tOidc.VerifyToken(idToken)
|
||||
if err != nil {
|
||||
t.Errorf("ID token should validate against client_id even when custom audience is configured: %v", err)
|
||||
}
|
||||
|
||||
// Create ID token with WRONG audience (should fail)
|
||||
wrongIDToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": customAudience, // WRONG - should be client_id
|
||||
"nonce": "test-nonce-wrong-456", // ID token has nonce, so it will be detected as ID token
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
"email": "test@example.com",
|
||||
"jti": "wrong-id-token-jti",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create wrong ID token: %v", err)
|
||||
}
|
||||
|
||||
// Should fail - ID tokens must have client_id as audience
|
||||
cleanupReplayCache()
|
||||
initReplayCache()
|
||||
err = ts.tOidc.VerifyToken(wrongIDToken)
|
||||
if err == nil {
|
||||
t.Error("ID token with custom audience (not client_id) should fail validation")
|
||||
}
|
||||
}
|
||||
+19
-13
@@ -8,10 +8,6 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// ============================================================================
|
||||
// AUTHENTICATION FLOW
|
||||
// ============================================================================
|
||||
|
||||
// validateRedirectCount checks if redirect limit is exceeded and handles the error
|
||||
func (t *TraefikOidc) validateRedirectCount(session *SessionData, rw http.ResponseWriter, req *http.Request) error {
|
||||
const maxRedirects = 5
|
||||
@@ -223,15 +219,25 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
return
|
||||
}
|
||||
|
||||
email, _ := claims["email"].(string)
|
||||
if email == "" {
|
||||
t.logger.Errorf("Email claim missing or empty in token during callback")
|
||||
t.sendErrorResponse(rw, req, "Authentication failed: Email missing in token", http.StatusInternalServerError)
|
||||
return
|
||||
// Extract user identifier from the configured claim (defaults to "email" for backward compatibility)
|
||||
userIdentifier, _ := claims[t.userIdentifierClaim].(string)
|
||||
if userIdentifier == "" {
|
||||
// Try "sub" as fallback since it's required by OIDC spec
|
||||
if t.userIdentifierClaim != "sub" {
|
||||
userIdentifier, _ = claims["sub"].(string)
|
||||
}
|
||||
if userIdentifier == "" {
|
||||
t.logger.Errorf("User identifier claim '%s' missing or empty in token during callback", t.userIdentifierClaim)
|
||||
t.sendErrorResponse(rw, req, "Authentication failed: User identifier missing in token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
t.logger.Debugf("Configured claim '%s' not found, using 'sub' claim as fallback", t.userIdentifierClaim)
|
||||
}
|
||||
if !t.isAllowedDomain(email) {
|
||||
t.logger.Errorf("Disallowed email domain during callback: %s", email)
|
||||
t.sendErrorResponse(rw, req, "Authentication failed: Email domain not allowed", http.StatusForbidden)
|
||||
|
||||
// Validate user authorization
|
||||
if !t.isAllowedUser(userIdentifier) {
|
||||
t.logger.Errorf("User not authorized during callback: %s", userIdentifier)
|
||||
t.sendErrorResponse(rw, req, "Authentication failed: User not authorized", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -240,7 +246,7 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
t.sendErrorResponse(rw, req, "Failed to update session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
session.SetEmail(email)
|
||||
session.SetEmail(userIdentifier) // SetEmail stores the user identifier (email or other claim)
|
||||
session.SetIDToken(tokenResponse.IDToken)
|
||||
session.SetAccessToken(tokenResponse.AccessToken)
|
||||
session.SetRefreshToken(tokenResponse.RefreshToken)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -787,6 +787,7 @@ func (mm *TaskMemoryMonitor) ForceGC() (before, after TaskMemoryStats, err error
|
||||
}
|
||||
|
||||
if mm.logger != nil {
|
||||
// #nosec G115 -- heap allocation bytes fit in int64 for practical purposes
|
||||
freed := int64(before.HeapAlloc) - int64(after.HeapAlloc)
|
||||
mm.logger.Infof("Forced GC: freed %d bytes (%.2f MB)", freed, float64(freed)/(1024*1024))
|
||||
}
|
||||
|
||||
@@ -0,0 +1,241 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// =============================================================================
|
||||
// UNIVERSAL CACHE BENCHMARKS
|
||||
// =============================================================================
|
||||
|
||||
func BenchmarkCacheSet(b *testing.B) {
|
||||
cache := NewUniversalCache(createTestCacheConfig())
|
||||
defer cache.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := 0
|
||||
for pb.Next() {
|
||||
cache.Set(fmt.Sprintf("key%d", i), fmt.Sprintf("value%d", i), 1*time.Hour)
|
||||
i++
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkCacheGet(b *testing.B) {
|
||||
cache := NewUniversalCache(createTestCacheConfig())
|
||||
defer cache.Close()
|
||||
|
||||
for i := 0; i < 1000; i++ {
|
||||
cache.Set(fmt.Sprintf("key%d", i), fmt.Sprintf("value%d", i), 1*time.Hour)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := 0
|
||||
for pb.Next() {
|
||||
cache.Get(fmt.Sprintf("key%d", i%1000))
|
||||
i++
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkCacheSetGet(b *testing.B) {
|
||||
cache := NewUniversalCache(createTestCacheConfig())
|
||||
defer cache.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := 0
|
||||
for pb.Next() {
|
||||
key := fmt.Sprintf("key%d", i)
|
||||
cache.Set(key, fmt.Sprintf("value%d", i), 1*time.Hour)
|
||||
cache.Get(key)
|
||||
i++
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkCacheLRUEviction(b *testing.B) {
|
||||
config := createTestCacheConfig()
|
||||
config.MaxSize = 100
|
||||
cache := NewUniversalCache(config)
|
||||
defer cache.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
cache.Set(fmt.Sprintf("key%d", i), fmt.Sprintf("value%d", i), 1*time.Hour)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkCacheConcurrent(b *testing.B) {
|
||||
cache := NewUniversalCache(createTestCacheConfig())
|
||||
defer cache.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := 0
|
||||
for pb.Next() {
|
||||
switch i % 3 {
|
||||
case 0:
|
||||
cache.Set(fmt.Sprintf("key%d", i), fmt.Sprintf("value%d", i), 1*time.Hour)
|
||||
case 1:
|
||||
cache.Get(fmt.Sprintf("key%d", i))
|
||||
case 2:
|
||||
cache.Delete(fmt.Sprintf("key%d", i))
|
||||
}
|
||||
i++
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// CACHE MANAGER BENCHMARKS
|
||||
// =============================================================================
|
||||
|
||||
func BenchmarkCacheInterfaceWrapper_Set(b *testing.B) {
|
||||
t := &testing.T{}
|
||||
cm := getTestCacheManager(t)
|
||||
cache := cm.GetSharedTokenBlacklist()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
cache.Set("benchmark-key", "benchmark-value", time.Hour)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkCacheInterfaceWrapper_Get(b *testing.B) {
|
||||
t := &testing.T{}
|
||||
cm := getTestCacheManager(t)
|
||||
cache := cm.GetSharedTokenBlacklist()
|
||||
|
||||
cache.Set("benchmark-key", "benchmark-value", time.Hour)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
cache.Get("benchmark-key")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkCacheInterfaceWrapper_Delete(b *testing.B) {
|
||||
t := &testing.T{}
|
||||
cm := getTestCacheManager(t)
|
||||
cache := cm.GetSharedTokenBlacklist()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
b.StopTimer()
|
||||
key := fmt.Sprintf("benchmark-key-%d", i)
|
||||
cache.Set(key, "value", time.Hour)
|
||||
b.StartTimer()
|
||||
|
||||
cache.Delete(key)
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// CACHE COMPATIBILITY BENCHMARKS
|
||||
// =============================================================================
|
||||
|
||||
func BenchmarkNewBoundedCache(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
NewBoundedCache(1000)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkNewOptimizedCache(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
NewOptimizedCache()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkLRUStrategy_EstimateSize(b *testing.B) {
|
||||
strategy := NewLRUStrategy(1000)
|
||||
item := "test-item"
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
strategy.EstimateSize(item)
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// SHARDED CACHE BENCHMARKS
|
||||
// =============================================================================
|
||||
|
||||
func BenchmarkShardedCache(b *testing.B) {
|
||||
b.Run("Set", func(b *testing.B) {
|
||||
cache := NewShardedCache(64, 100000)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
cache.Set(fmt.Sprintf("key-%d", i), i, 5*time.Minute)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("Get", func(b *testing.B) {
|
||||
cache := NewShardedCache(64, 100000)
|
||||
for i := 0; i < 10000; i++ {
|
||||
cache.Set(fmt.Sprintf("key-%d", i), i, 5*time.Minute)
|
||||
}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
cache.Get(fmt.Sprintf("key-%d", i%10000))
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("ParallelSetGet", func(b *testing.B) {
|
||||
cache := NewShardedCache(64, 100000)
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := 0
|
||||
for pb.Next() {
|
||||
key := fmt.Sprintf("key-%d", i)
|
||||
cache.Set(key, i, 5*time.Minute)
|
||||
cache.Get(key)
|
||||
i++
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkShardedVsGlobalMutex compares sharded cache with global mutex approach
|
||||
func BenchmarkShardedVsGlobalMutex(b *testing.B) {
|
||||
b.Run("ShardedCache64", func(b *testing.B) {
|
||||
cache := NewShardedCache(64, 100000)
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := 0
|
||||
for pb.Next() {
|
||||
key := fmt.Sprintf("jti-%d", i%10000)
|
||||
if !cache.Exists(key) {
|
||||
cache.Set(key, true, 5*time.Minute)
|
||||
}
|
||||
i++
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
b.Run("GlobalMutexCache", func(b *testing.B) {
|
||||
var mu sync.RWMutex
|
||||
data := make(map[string]bool)
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := 0
|
||||
for pb.Next() {
|
||||
key := fmt.Sprintf("jti-%d", i%10000)
|
||||
|
||||
mu.RLock()
|
||||
_, exists := data[key]
|
||||
mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
mu.Lock()
|
||||
data[key] = true
|
||||
mu.Unlock()
|
||||
}
|
||||
i++
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
@@ -1,369 +0,0 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestNewBoundedCache tests creation of bounded cache
|
||||
func TestNewBoundedCache(t *testing.T) {
|
||||
maxSize := 500
|
||||
cache := NewBoundedCache(maxSize)
|
||||
|
||||
if cache == nil {
|
||||
t.Fatal("Expected cache to be created, got nil")
|
||||
}
|
||||
|
||||
// Verify we can use basic operations
|
||||
cache.Set("test-key", "test-value", time.Hour)
|
||||
value, found := cache.Get("test-key")
|
||||
if !found {
|
||||
t.Error("Expected key to be found in cache")
|
||||
}
|
||||
if value != "test-value" {
|
||||
t.Errorf("Expected 'test-value', got %v", value)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDefaultUnifiedCacheConfig tests default configuration
|
||||
func TestDefaultUnifiedCacheConfig(t *testing.T) {
|
||||
config := DefaultUnifiedCacheConfig()
|
||||
|
||||
if config.Type != CacheTypeGeneral {
|
||||
t.Errorf("Expected CacheTypeGeneral, got %v", config.Type)
|
||||
}
|
||||
|
||||
if config.MaxSize != 500 {
|
||||
t.Errorf("Expected MaxSize 500, got %d", config.MaxSize)
|
||||
}
|
||||
|
||||
if config.MaxMemoryBytes != 64*1024*1024 {
|
||||
t.Errorf("Expected MaxMemoryBytes 64MB, got %d", config.MaxMemoryBytes)
|
||||
}
|
||||
|
||||
if config.CleanupInterval != 2*time.Minute {
|
||||
t.Errorf("Expected CleanupInterval 2 minutes, got %v", config.CleanupInterval)
|
||||
}
|
||||
|
||||
if config.Logger == nil {
|
||||
t.Error("Expected Logger to be set")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewUnifiedCache tests unified cache creation
|
||||
func TestNewUnifiedCache(t *testing.T) {
|
||||
config := DefaultUnifiedCacheConfig()
|
||||
cache := NewUnifiedCache(config)
|
||||
|
||||
if cache == nil {
|
||||
t.Fatal("Expected cache to be created, got nil")
|
||||
}
|
||||
|
||||
if cache.UniversalCache == nil {
|
||||
t.Error("Expected UniversalCache to be set")
|
||||
}
|
||||
|
||||
// Test basic operations
|
||||
cache.Set("test-key", "test-value", time.Hour)
|
||||
value, found := cache.Get("test-key")
|
||||
if !found {
|
||||
t.Error("Expected key to be found in cache")
|
||||
}
|
||||
if value != "test-value" {
|
||||
t.Errorf("Expected 'test-value', got %v", value)
|
||||
}
|
||||
}
|
||||
|
||||
// TestUnifiedCache_SetMaxSize tests SetMaxSize method
|
||||
func TestUnifiedCache_SetMaxSize(t *testing.T) {
|
||||
config := DefaultUnifiedCacheConfig()
|
||||
cache := NewUnifiedCache(config)
|
||||
|
||||
// Test setting max size
|
||||
newSize := 1000
|
||||
cache.SetMaxSize(newSize)
|
||||
|
||||
// We can't easily verify the size was set without exposing internal fields,
|
||||
// but we can ensure the method doesn't panic
|
||||
}
|
||||
|
||||
// TestNewCacheAdapter tests cache adapter creation
|
||||
func TestNewCacheAdapter(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cache interface{}
|
||||
expectNil bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "UniversalCache",
|
||||
cache: NewUniversalCache(DefaultUnifiedCacheConfig()),
|
||||
expectNil: false,
|
||||
description: "Should create adapter for UniversalCache",
|
||||
},
|
||||
{
|
||||
name: "UnifiedCache",
|
||||
cache: NewUnifiedCache(DefaultUnifiedCacheConfig()),
|
||||
expectNil: false,
|
||||
description: "Should create adapter for UnifiedCache",
|
||||
},
|
||||
{
|
||||
name: "Invalid cache type",
|
||||
cache: "not-a-cache",
|
||||
expectNil: true,
|
||||
description: "Should return nil for invalid cache type",
|
||||
},
|
||||
{
|
||||
name: "Nil cache",
|
||||
cache: nil,
|
||||
expectNil: true,
|
||||
description: "Should return nil for nil cache",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
adapter := NewCacheAdapter(tt.cache)
|
||||
|
||||
if tt.expectNil {
|
||||
if adapter != nil {
|
||||
t.Errorf("Expected nil adapter, got %v", adapter)
|
||||
}
|
||||
} else {
|
||||
if adapter == nil {
|
||||
t.Error("Expected non-nil adapter")
|
||||
}
|
||||
// Test basic operations
|
||||
adapter.Set("test", "value", time.Hour)
|
||||
value, found := adapter.Get("test")
|
||||
if !found {
|
||||
t.Error("Expected key to be found")
|
||||
}
|
||||
if value != "value" {
|
||||
t.Errorf("Expected 'value', got %v", value)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewOptimizedCache tests optimized cache creation
|
||||
func TestNewOptimizedCache(t *testing.T) {
|
||||
cache := NewOptimizedCache()
|
||||
|
||||
if cache == nil {
|
||||
t.Fatal("Expected cache to be created, got nil")
|
||||
}
|
||||
|
||||
// Verify it works with basic operations
|
||||
cache.Set("test-key", "test-value", time.Hour)
|
||||
value, found := cache.Get("test-key")
|
||||
if !found {
|
||||
t.Error("Expected key to be found in cache")
|
||||
}
|
||||
if value != "test-value" {
|
||||
t.Errorf("Expected 'test-value', got %v", value)
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewLRUStrategy tests LRU strategy creation
|
||||
func TestNewLRUStrategy(t *testing.T) {
|
||||
maxSize := 100
|
||||
strategy := NewLRUStrategy(maxSize)
|
||||
|
||||
if strategy == nil {
|
||||
t.Fatal("Expected strategy to be created, got nil")
|
||||
}
|
||||
|
||||
lruStrategy, ok := strategy.(*LRUStrategy)
|
||||
if !ok {
|
||||
t.Fatal("Expected LRUStrategy type")
|
||||
}
|
||||
|
||||
if lruStrategy.maxSize != maxSize {
|
||||
t.Errorf("Expected maxSize %d, got %d", maxSize, lruStrategy.maxSize)
|
||||
}
|
||||
|
||||
if lruStrategy.order == nil {
|
||||
t.Error("Expected order list to be initialized")
|
||||
}
|
||||
|
||||
if lruStrategy.elements == nil {
|
||||
t.Error("Expected elements map to be initialized")
|
||||
}
|
||||
}
|
||||
|
||||
// TestLRUStrategy_Name tests strategy name
|
||||
func TestLRUStrategy_Name(t *testing.T) {
|
||||
strategy := NewLRUStrategy(100)
|
||||
|
||||
name := strategy.Name()
|
||||
if name != "LRU" {
|
||||
t.Errorf("Expected 'LRU', got %s", name)
|
||||
}
|
||||
}
|
||||
|
||||
// TestLRUStrategy_ShouldEvict tests eviction logic
|
||||
func TestLRUStrategy_ShouldEvict(t *testing.T) {
|
||||
strategy := NewLRUStrategy(100)
|
||||
|
||||
// LRU strategy always returns false for ShouldEvict
|
||||
result := strategy.ShouldEvict("test-item", time.Now())
|
||||
if result != false {
|
||||
t.Error("Expected ShouldEvict to return false")
|
||||
}
|
||||
}
|
||||
|
||||
// TestLRUStrategy_OnAccess tests access callback
|
||||
func TestLRUStrategy_OnAccess(t *testing.T) {
|
||||
strategy := NewLRUStrategy(100)
|
||||
|
||||
// OnAccess should not panic
|
||||
strategy.OnAccess("test-key", "test-value")
|
||||
}
|
||||
|
||||
// TestLRUStrategy_OnRemove tests removal callback
|
||||
func TestLRUStrategy_OnRemove(t *testing.T) {
|
||||
strategy := NewLRUStrategy(100)
|
||||
|
||||
// OnRemove should not panic
|
||||
strategy.OnRemove("test-key")
|
||||
}
|
||||
|
||||
// TestLRUStrategy_EstimateSize tests size estimation
|
||||
func TestLRUStrategy_EstimateSize(t *testing.T) {
|
||||
strategy := NewLRUStrategy(100)
|
||||
|
||||
size := strategy.EstimateSize("test-item")
|
||||
if size != 64 {
|
||||
t.Errorf("Expected size 64, got %d", size)
|
||||
}
|
||||
}
|
||||
|
||||
// TestLRUStrategy_GetEvictionCandidate tests eviction candidate retrieval
|
||||
func TestLRUStrategy_GetEvictionCandidate(t *testing.T) {
|
||||
strategy := NewLRUStrategy(100)
|
||||
|
||||
key, found := strategy.GetEvictionCandidate()
|
||||
if found {
|
||||
t.Error("Expected no eviction candidate to be found")
|
||||
}
|
||||
if key != "" {
|
||||
t.Errorf("Expected empty key, got %s", key)
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewOptimizedCacheWithConfig tests optimized cache with custom config
|
||||
func TestNewOptimizedCacheWithConfig(t *testing.T) {
|
||||
config := UniversalCacheConfig{
|
||||
Type: CacheTypeGeneral,
|
||||
MaxSize: 1000,
|
||||
MaxMemoryBytes: 128 * 1024 * 1024,
|
||||
EnableMetrics: true,
|
||||
Logger: GetSingletonNoOpLogger(),
|
||||
}
|
||||
|
||||
cache := NewOptimizedCacheWithConfig(config)
|
||||
|
||||
if cache == nil {
|
||||
t.Fatal("Expected cache to be created, got nil")
|
||||
}
|
||||
|
||||
// Verify it works with basic operations
|
||||
cache.Set("test-key", "test-value", time.Hour)
|
||||
value, found := cache.Get("test-key")
|
||||
if !found {
|
||||
t.Error("Expected key to be found in cache")
|
||||
}
|
||||
if value != "test-value" {
|
||||
t.Errorf("Expected 'test-value', got %v", value)
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewFixedMetadataCache tests fixed metadata cache creation
|
||||
func TestNewFixedMetadataCache(t *testing.T) {
|
||||
cache := NewFixedMetadataCache()
|
||||
|
||||
if cache == nil {
|
||||
t.Fatal("Expected cache to be created, got nil")
|
||||
}
|
||||
|
||||
// Verify it works with proper metadata operations
|
||||
metadata := &ProviderMetadata{
|
||||
Issuer: "https://example.com",
|
||||
AuthURL: "https://example.com/auth",
|
||||
TokenURL: "https://example.com/token",
|
||||
JWKSURL: "https://example.com/jwks",
|
||||
}
|
||||
|
||||
err := cache.Set("test-provider", metadata, time.Hour)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error setting metadata: %v", err)
|
||||
}
|
||||
|
||||
// Test that the cache was created (basic verification)
|
||||
// Note: We can't easily test Get without more complex setup
|
||||
}
|
||||
|
||||
// TestNewDoublyLinkedList tests doubly linked list creation
|
||||
func TestNewDoublyLinkedList(t *testing.T) {
|
||||
list := NewDoublyLinkedList()
|
||||
|
||||
if list == nil {
|
||||
t.Fatal("Expected list to be created, got nil")
|
||||
}
|
||||
|
||||
// Test it's a proper list structure
|
||||
if list.Len() != 0 {
|
||||
t.Error("Expected empty list initially")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDoublyLinkedList_PopFront tests front element removal
|
||||
func TestDoublyLinkedList_PopFront(t *testing.T) {
|
||||
list := NewDoublyLinkedList()
|
||||
|
||||
// Test popping from empty list
|
||||
element := list.PopFront()
|
||||
if element != nil {
|
||||
t.Error("Expected nil when popping from empty list")
|
||||
}
|
||||
|
||||
// Add an element and test popping
|
||||
added := list.PushBack("test-value")
|
||||
if added == nil {
|
||||
t.Fatal("Expected element to be added")
|
||||
}
|
||||
|
||||
popped := list.PopFront()
|
||||
if popped == nil {
|
||||
t.Error("Expected element to be popped")
|
||||
}
|
||||
|
||||
if list.Len() != 0 {
|
||||
t.Error("Expected list to be empty after popping")
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark tests for performance
|
||||
func BenchmarkNewBoundedCache(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
NewBoundedCache(1000)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkNewOptimizedCache(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
NewOptimizedCache()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkLRUStrategy_EstimateSize(b *testing.B) {
|
||||
strategy := NewLRUStrategy(1000)
|
||||
item := "test-item"
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
strategy.EstimateSize(item)
|
||||
}
|
||||
}
|
||||
@@ -1,314 +0,0 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Helper function to ensure we have a working cache manager for tests
|
||||
func getTestCacheManager(t *testing.T) *CacheManager {
|
||||
cm := GetGlobalCacheManager(&sync.WaitGroup{})
|
||||
if cm == nil {
|
||||
t.Fatal("Failed to get cache manager")
|
||||
}
|
||||
if cm.manager == nil {
|
||||
t.Fatal("Cache manager has nil internal manager")
|
||||
}
|
||||
return cm
|
||||
}
|
||||
|
||||
// TestCacheManager_Close tests cache manager close functionality
|
||||
func TestCacheManager_Close(t *testing.T) {
|
||||
// Get a fresh cache manager
|
||||
wg := &sync.WaitGroup{}
|
||||
cm := GetGlobalCacheManager(wg)
|
||||
|
||||
if cm == nil {
|
||||
t.Fatal("Expected cache manager to be created")
|
||||
}
|
||||
|
||||
// Test closing the cache manager
|
||||
err := cm.Close()
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error closing cache manager: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCleanupGlobalCacheManager tests global cleanup
|
||||
func TestCleanupGlobalCacheManager(t *testing.T) {
|
||||
// Test cleanup when no instance exists (should not error)
|
||||
originalInstance := globalCacheManagerInstance
|
||||
globalCacheManagerInstance = nil
|
||||
err := CleanupGlobalCacheManager()
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error during cleanup of nil instance: %v", err)
|
||||
}
|
||||
|
||||
// Restore original instance
|
||||
globalCacheManagerInstance = originalInstance
|
||||
}
|
||||
|
||||
// TestCacheInterfaceWrapper_Delete tests delete functionality
|
||||
func TestCacheInterfaceWrapper_Delete(t *testing.T) {
|
||||
cm := getTestCacheManager(t)
|
||||
cache := cm.GetSharedTokenBlacklist()
|
||||
|
||||
// Add an item
|
||||
cache.Set("test-key", "test-value", time.Hour)
|
||||
|
||||
// Verify it exists
|
||||
value, found := cache.Get("test-key")
|
||||
if !found {
|
||||
t.Fatal("Expected key to be found after setting")
|
||||
}
|
||||
if value != "test-value" {
|
||||
t.Errorf("Expected 'test-value', got %v", value)
|
||||
}
|
||||
|
||||
// Delete it
|
||||
cache.Delete("test-key")
|
||||
|
||||
// Verify it's gone
|
||||
_, found = cache.Get("test-key")
|
||||
if found {
|
||||
t.Error("Expected key to be deleted")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCacheInterfaceWrapper_Size tests size functionality
|
||||
func TestCacheInterfaceWrapper_Size(t *testing.T) {
|
||||
cm := getTestCacheManager(t)
|
||||
cache := cm.GetSharedTokenBlacklist()
|
||||
|
||||
// Clear cache first
|
||||
cache.Clear()
|
||||
|
||||
// Check initial size
|
||||
initialSize := cache.Size()
|
||||
if initialSize != 0 {
|
||||
t.Errorf("Expected initial size 0, got %d", initialSize)
|
||||
}
|
||||
|
||||
// Add some items
|
||||
cache.Set("key1", "value1", time.Hour)
|
||||
cache.Set("key2", "value2", time.Hour)
|
||||
|
||||
// Check size increased
|
||||
newSize := cache.Size()
|
||||
if newSize != 2 {
|
||||
t.Errorf("Expected size 2, got %d", newSize)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCacheInterfaceWrapper_Clear tests clear functionality
|
||||
func TestCacheInterfaceWrapper_Clear(t *testing.T) {
|
||||
cm := getTestCacheManager(t)
|
||||
cache := cm.GetSharedTokenBlacklist()
|
||||
|
||||
// Add some items
|
||||
cache.Set("key1", "value1", time.Hour)
|
||||
cache.Set("key2", "value2", time.Hour)
|
||||
|
||||
// Verify items exist
|
||||
size := cache.Size()
|
||||
if size != 2 {
|
||||
t.Errorf("Expected 2 items before clear, got %d", size)
|
||||
}
|
||||
|
||||
// Clear all
|
||||
cache.Clear()
|
||||
|
||||
// Verify cache is empty
|
||||
size = cache.Size()
|
||||
if size != 0 {
|
||||
t.Errorf("Expected 0 items after clear, got %d", size)
|
||||
}
|
||||
|
||||
// Verify specific items are gone
|
||||
_, found := cache.Get("key1")
|
||||
if found {
|
||||
t.Error("Expected key1 to be cleared")
|
||||
}
|
||||
|
||||
_, found = cache.Get("key2")
|
||||
if found {
|
||||
t.Error("Expected key2 to be cleared")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCacheInterfaceWrapper_Close tests wrapper close functionality
|
||||
func TestCacheInterfaceWrapper_Close(t *testing.T) {
|
||||
cm := getTestCacheManager(t)
|
||||
cache := cm.GetSharedTokenBlacklist()
|
||||
|
||||
// Test close - should not panic
|
||||
wrapper, ok := cache.(*CacheInterfaceWrapper)
|
||||
if !ok {
|
||||
t.Fatal("Expected CacheInterfaceWrapper")
|
||||
}
|
||||
|
||||
wrapper.Close() // Should not panic
|
||||
|
||||
// Test close with nil cache
|
||||
nilWrapper := &CacheInterfaceWrapper{cache: nil}
|
||||
nilWrapper.Close() // Should not panic
|
||||
}
|
||||
|
||||
// TestCacheInterfaceWrapper_GetStats tests stats functionality
|
||||
func TestCacheInterfaceWrapper_GetStats(t *testing.T) {
|
||||
cm := getTestCacheManager(t)
|
||||
cache := cm.GetSharedTokenBlacklist()
|
||||
|
||||
wrapper, ok := cache.(*CacheInterfaceWrapper)
|
||||
if !ok {
|
||||
t.Fatal("Expected CacheInterfaceWrapper")
|
||||
}
|
||||
|
||||
// Get stats
|
||||
stats := wrapper.GetStats()
|
||||
if stats == nil {
|
||||
t.Error("Expected non-nil stats")
|
||||
}
|
||||
|
||||
// Stats should be accessible (len() never returns negative values)
|
||||
// Just verify it's accessible by checking it's not nil (already done above)
|
||||
}
|
||||
|
||||
// TestCacheInterfaceWrapper_Cleanup tests cleanup functionality
|
||||
func TestCacheInterfaceWrapper_Cleanup(t *testing.T) {
|
||||
cm := getTestCacheManager(t)
|
||||
cache := cm.GetSharedTokenBlacklist()
|
||||
|
||||
// Add an item that will expire quickly
|
||||
cache.Set("expire-key", "expire-value", time.Millisecond)
|
||||
|
||||
// Wait for expiration
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Trigger cleanup
|
||||
cache.Cleanup()
|
||||
|
||||
// Item should be cleaned up
|
||||
_, found := cache.Get("expire-key")
|
||||
if found {
|
||||
t.Error("Expected expired key to be cleaned up")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCacheInterfaceWrapper_SetMaxSize tests max size setting
|
||||
func TestCacheInterfaceWrapper_SetMaxSize(t *testing.T) {
|
||||
cm := getTestCacheManager(t)
|
||||
cache := cm.GetSharedTokenBlacklist()
|
||||
|
||||
// Test setting max size (should not panic)
|
||||
cache.SetMaxSize(1000)
|
||||
|
||||
// We can't easily verify the size was set without exposing internals,
|
||||
// but we can ensure the method doesn't panic
|
||||
}
|
||||
|
||||
// TestGetSharedCaches tests getting shared cache instances
|
||||
func TestGetSharedCaches(t *testing.T) {
|
||||
cm := getTestCacheManager(t)
|
||||
|
||||
// Test getting shared token blacklist
|
||||
blacklist := cm.GetSharedTokenBlacklist()
|
||||
if blacklist == nil {
|
||||
t.Error("Expected non-nil token blacklist")
|
||||
}
|
||||
|
||||
// Test getting shared token cache
|
||||
tokenCache := cm.GetSharedTokenCache()
|
||||
if tokenCache == nil {
|
||||
t.Error("Expected non-nil token cache")
|
||||
}
|
||||
|
||||
// Test getting shared metadata cache
|
||||
metadataCache := cm.GetSharedMetadataCache()
|
||||
if metadataCache == nil {
|
||||
t.Error("Expected non-nil metadata cache")
|
||||
}
|
||||
|
||||
// Test getting shared JWK cache
|
||||
jwkCache := cm.GetSharedJWKCache()
|
||||
if jwkCache == nil {
|
||||
t.Error("Expected non-nil JWK cache")
|
||||
}
|
||||
}
|
||||
|
||||
// TestConcurrentCacheAccess tests thread safety
|
||||
func TestConcurrentCacheAccess(t *testing.T) {
|
||||
cm := getTestCacheManager(t)
|
||||
cache := cm.GetSharedTokenBlacklist()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
goroutines := 10
|
||||
iterations := 10
|
||||
|
||||
// Concurrent operations
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < iterations; j++ {
|
||||
key := fmt.Sprintf("key-%d-%d", id, j)
|
||||
value := fmt.Sprintf("value-%d-%d", id, j)
|
||||
|
||||
cache.Set(key, value, time.Hour)
|
||||
|
||||
retrieved, found := cache.Get(key)
|
||||
if found && retrieved != value {
|
||||
t.Errorf("Concurrent access failed: expected %s, got %v", value, retrieved)
|
||||
}
|
||||
|
||||
cache.Delete(key)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// Benchmark tests for performance
|
||||
func BenchmarkCacheInterfaceWrapper_Set(b *testing.B) {
|
||||
t := &testing.T{}
|
||||
cm := getTestCacheManager(t)
|
||||
cache := cm.GetSharedTokenBlacklist()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
cache.Set("benchmark-key", "benchmark-value", time.Hour)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkCacheInterfaceWrapper_Get(b *testing.B) {
|
||||
t := &testing.T{}
|
||||
cm := getTestCacheManager(t)
|
||||
cache := cm.GetSharedTokenBlacklist()
|
||||
|
||||
// Pre-populate cache
|
||||
cache.Set("benchmark-key", "benchmark-value", time.Hour)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
cache.Get("benchmark-key")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkCacheInterfaceWrapper_Delete(b *testing.B) {
|
||||
t := &testing.T{}
|
||||
cm := getTestCacheManager(t)
|
||||
cache := cm.GetSharedTokenBlacklist()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
b.StopTimer()
|
||||
key := fmt.Sprintf("benchmark-key-%d", i)
|
||||
cache.Set(key, "value", time.Hour)
|
||||
b.StartTimer()
|
||||
|
||||
cache.Delete(key)
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,319 +0,0 @@
|
||||
// Package circuit_breaker provides circuit breaker implementation for resilience
|
||||
package circuit_breaker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// CircuitBreakerState represents the current state of a circuit breaker.
|
||||
// The circuit breaker pattern prevents cascading failures by monitoring
|
||||
// error rates and temporarily blocking requests to failing services.
|
||||
type CircuitBreakerState int
|
||||
|
||||
// Circuit breaker states following the standard pattern:
|
||||
// Closed: Normal operation, requests flow through
|
||||
// Open: Circuit is tripped, requests are blocked
|
||||
// HalfOpen: Testing state, limited requests allowed to test recovery
|
||||
const (
|
||||
// CircuitBreakerClosed allows all requests through (normal operation)
|
||||
CircuitBreakerClosed CircuitBreakerState = iota
|
||||
// CircuitBreakerOpen blocks all requests (service is failing)
|
||||
CircuitBreakerOpen
|
||||
// CircuitBreakerHalfOpen allows limited requests to test service recovery
|
||||
CircuitBreakerHalfOpen
|
||||
)
|
||||
|
||||
// String returns a string representation of the circuit breaker state
|
||||
func (s CircuitBreakerState) String() string {
|
||||
switch s {
|
||||
case CircuitBreakerClosed:
|
||||
return "closed"
|
||||
case CircuitBreakerOpen:
|
||||
return "open"
|
||||
case CircuitBreakerHalfOpen:
|
||||
return "half-open"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// Logger interface for dependency injection
|
||||
type Logger interface {
|
||||
Infof(format string, args ...interface{})
|
||||
Errorf(format string, args ...interface{})
|
||||
Debugf(format string, args ...interface{})
|
||||
}
|
||||
|
||||
// BaseRecoveryMechanism interface for common functionality
|
||||
type BaseRecoveryMechanism interface {
|
||||
RecordRequest()
|
||||
RecordSuccess()
|
||||
RecordFailure()
|
||||
GetBaseMetrics() map[string]interface{}
|
||||
LogInfo(format string, args ...interface{})
|
||||
LogError(format string, args ...interface{})
|
||||
LogDebug(format string, args ...interface{})
|
||||
}
|
||||
|
||||
// CircuitBreaker implements the circuit breaker pattern for external service calls.
|
||||
// It monitors failure rates and automatically opens the circuit when failures
|
||||
// exceed the threshold, preventing further requests until the service recovers.
|
||||
type CircuitBreaker struct {
|
||||
// baseRecovery provides common functionality
|
||||
baseRecovery BaseRecoveryMechanism
|
||||
// maxFailures is the threshold for opening the circuit
|
||||
maxFailures int
|
||||
// timeout is how long to wait before allowing requests in half-open state
|
||||
timeout time.Duration
|
||||
// resetTimeout is how long to wait before transitioning from open to half-open
|
||||
resetTimeout time.Duration
|
||||
// state tracks the current circuit breaker state
|
||||
state CircuitBreakerState
|
||||
// failures counts consecutive failures
|
||||
failures int64
|
||||
// lastFailureTime records when the last failure occurred
|
||||
lastFailureTime time.Time
|
||||
// mutex protects shared state
|
||||
mutex sync.RWMutex
|
||||
// logger for debugging and monitoring
|
||||
logger Logger
|
||||
}
|
||||
|
||||
// CircuitBreakerConfig holds configuration parameters for circuit breakers.
|
||||
// These settings control when the circuit opens and how it recovers.
|
||||
type CircuitBreakerConfig struct {
|
||||
// MaxFailures is the number of failures before opening the circuit
|
||||
MaxFailures int `json:"max_failures"`
|
||||
// Timeout is how long to wait before trying to recover (open -> half-open)
|
||||
Timeout time.Duration `json:"timeout"`
|
||||
// ResetTimeout is how long to wait before fully closing the circuit
|
||||
ResetTimeout time.Duration `json:"reset_timeout"`
|
||||
}
|
||||
|
||||
// DefaultCircuitBreakerConfig returns sensible default configuration for circuit breakers.
|
||||
// Configured for typical web service scenarios with moderate tolerance for failures.
|
||||
func DefaultCircuitBreakerConfig() CircuitBreakerConfig {
|
||||
return CircuitBreakerConfig{
|
||||
MaxFailures: 2,
|
||||
Timeout: 60 * time.Second,
|
||||
ResetTimeout: 30 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// NewCircuitBreaker creates a new circuit breaker with the specified configuration.
|
||||
// The circuit breaker starts in the closed state, allowing all requests through.
|
||||
func NewCircuitBreaker(config CircuitBreakerConfig, logger Logger, baseRecovery BaseRecoveryMechanism) *CircuitBreaker {
|
||||
return &CircuitBreaker{
|
||||
baseRecovery: baseRecovery,
|
||||
maxFailures: config.MaxFailures,
|
||||
timeout: config.Timeout,
|
||||
resetTimeout: config.ResetTimeout,
|
||||
state: CircuitBreakerClosed,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// ExecuteWithContext executes a function through the circuit breaker with context.
|
||||
// It checks if requests are allowed, executes the function, and updates the circuit state
|
||||
// based on the result. Implements the ErrorRecoveryMechanism interface.
|
||||
func (cb *CircuitBreaker) ExecuteWithContext(ctx context.Context, fn func() error) error {
|
||||
if cb.baseRecovery != nil {
|
||||
cb.baseRecovery.RecordRequest()
|
||||
}
|
||||
|
||||
if !cb.allowRequest() {
|
||||
return fmt.Errorf("circuit breaker is open")
|
||||
}
|
||||
|
||||
err := fn()
|
||||
if err != nil {
|
||||
cb.recordFailure()
|
||||
if cb.baseRecovery != nil {
|
||||
cb.baseRecovery.RecordFailure()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
cb.recordSuccess()
|
||||
if cb.baseRecovery != nil {
|
||||
cb.baseRecovery.RecordSuccess()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Execute executes a function through the circuit breaker without context.
|
||||
// This is provided for backward compatibility with existing code.
|
||||
func (cb *CircuitBreaker) Execute(fn func() error) error {
|
||||
return cb.ExecuteWithContext(context.Background(), fn)
|
||||
}
|
||||
|
||||
// allowRequest determines whether to allow a request based on the circuit state.
|
||||
// Handles state transitions from open to half-open based on timeout.
|
||||
func (cb *CircuitBreaker) allowRequest() bool {
|
||||
cb.mutex.Lock()
|
||||
defer cb.mutex.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
|
||||
switch cb.state {
|
||||
case CircuitBreakerClosed:
|
||||
return true
|
||||
|
||||
case CircuitBreakerOpen:
|
||||
if now.Sub(cb.lastFailureTime) > cb.timeout {
|
||||
cb.state = CircuitBreakerHalfOpen
|
||||
if cb.logger != nil {
|
||||
cb.logger.Infof("Circuit breaker transitioning to half-open state")
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
|
||||
case CircuitBreakerHalfOpen:
|
||||
return true
|
||||
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// recordFailure records a failure and potentially opens the circuit.
|
||||
// Updates failure count and triggers state transitions when thresholds are exceeded.
|
||||
func (cb *CircuitBreaker) recordFailure() {
|
||||
cb.mutex.Lock()
|
||||
defer cb.mutex.Unlock()
|
||||
|
||||
cb.failures++
|
||||
cb.lastFailureTime = time.Now()
|
||||
|
||||
switch cb.state {
|
||||
case CircuitBreakerClosed:
|
||||
if cb.failures >= int64(cb.maxFailures) {
|
||||
cb.state = CircuitBreakerOpen
|
||||
if cb.baseRecovery != nil {
|
||||
cb.baseRecovery.LogError("Circuit breaker opened after %d failures", cb.failures)
|
||||
}
|
||||
}
|
||||
|
||||
case CircuitBreakerHalfOpen:
|
||||
cb.state = CircuitBreakerOpen
|
||||
if cb.baseRecovery != nil {
|
||||
cb.baseRecovery.LogError("Circuit breaker returned to open state after failure in half-open")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// recordSuccess records a successful request and potentially closes the circuit.
|
||||
// Resets failure count and transitions from half-open to closed state on success.
|
||||
func (cb *CircuitBreaker) recordSuccess() {
|
||||
cb.mutex.Lock()
|
||||
defer cb.mutex.Unlock()
|
||||
|
||||
switch cb.state {
|
||||
case CircuitBreakerHalfOpen:
|
||||
cb.failures = 0
|
||||
cb.state = CircuitBreakerClosed
|
||||
if cb.baseRecovery != nil {
|
||||
cb.baseRecovery.LogInfo("Circuit breaker closed after successful request in half-open state")
|
||||
}
|
||||
|
||||
case CircuitBreakerClosed:
|
||||
cb.failures = 0
|
||||
}
|
||||
}
|
||||
|
||||
// GetState returns the current state of the circuit breaker.
|
||||
// Thread-safe method for monitoring circuit breaker status.
|
||||
func (cb *CircuitBreaker) GetState() CircuitBreakerState {
|
||||
cb.mutex.RLock()
|
||||
defer cb.mutex.RUnlock()
|
||||
return cb.state
|
||||
}
|
||||
|
||||
// Reset resets the circuit breaker to its initial closed state.
|
||||
// Clears failure count and state, effectively recovering from any open state.
|
||||
func (cb *CircuitBreaker) Reset() {
|
||||
cb.mutex.Lock()
|
||||
defer cb.mutex.Unlock()
|
||||
|
||||
cb.state = CircuitBreakerClosed
|
||||
atomic.StoreInt64(&cb.failures, 0)
|
||||
if cb.baseRecovery != nil {
|
||||
cb.baseRecovery.LogInfo("Circuit breaker has been reset")
|
||||
}
|
||||
}
|
||||
|
||||
// IsAvailable returns whether the circuit breaker is currently allowing requests.
|
||||
// This provides a quick way to check if the service is available.
|
||||
func (cb *CircuitBreaker) IsAvailable() bool {
|
||||
return cb.allowRequest()
|
||||
}
|
||||
|
||||
// GetMetrics returns comprehensive metrics about the circuit breaker.
|
||||
// Includes state information, failure counts, configuration, and base metrics.
|
||||
func (cb *CircuitBreaker) GetMetrics() map[string]interface{} {
|
||||
cb.mutex.RLock()
|
||||
state := cb.state
|
||||
failures := cb.failures
|
||||
lastFailureTime := cb.lastFailureTime
|
||||
cb.mutex.RUnlock()
|
||||
|
||||
var metrics map[string]interface{}
|
||||
if cb.baseRecovery != nil {
|
||||
metrics = cb.baseRecovery.GetBaseMetrics()
|
||||
} else {
|
||||
metrics = make(map[string]interface{})
|
||||
}
|
||||
|
||||
metrics["state"] = state.String()
|
||||
metrics["current_failures"] = failures
|
||||
metrics["max_failures"] = cb.maxFailures
|
||||
metrics["timeout"] = cb.timeout.String()
|
||||
metrics["reset_timeout"] = cb.resetTimeout.String()
|
||||
|
||||
if !lastFailureTime.IsZero() {
|
||||
metrics["last_failure_time"] = lastFailureTime
|
||||
metrics["time_since_last_failure"] = time.Since(lastFailureTime).String()
|
||||
}
|
||||
|
||||
return metrics
|
||||
}
|
||||
|
||||
// GetFailureCount returns the current failure count
|
||||
func (cb *CircuitBreaker) GetFailureCount() int64 {
|
||||
cb.mutex.RLock()
|
||||
defer cb.mutex.RUnlock()
|
||||
return cb.failures
|
||||
}
|
||||
|
||||
// GetLastFailureTime returns the time of the last failure
|
||||
func (cb *CircuitBreaker) GetLastFailureTime() time.Time {
|
||||
cb.mutex.RLock()
|
||||
defer cb.mutex.RUnlock()
|
||||
return cb.lastFailureTime
|
||||
}
|
||||
|
||||
// IsOpen returns true if the circuit breaker is in open state
|
||||
func (cb *CircuitBreaker) IsOpen() bool {
|
||||
cb.mutex.RLock()
|
||||
defer cb.mutex.RUnlock()
|
||||
return cb.state == CircuitBreakerOpen
|
||||
}
|
||||
|
||||
// IsClosed returns true if the circuit breaker is in closed state
|
||||
func (cb *CircuitBreaker) IsClosed() bool {
|
||||
cb.mutex.RLock()
|
||||
defer cb.mutex.RUnlock()
|
||||
return cb.state == CircuitBreakerClosed
|
||||
}
|
||||
|
||||
// IsHalfOpen returns true if the circuit breaker is in half-open state
|
||||
func (cb *CircuitBreaker) IsHalfOpen() bool {
|
||||
cb.mutex.RLock()
|
||||
defer cb.mutex.RUnlock()
|
||||
return cb.state == CircuitBreakerHalfOpen
|
||||
}
|
||||
@@ -1,981 +0,0 @@
|
||||
package circuit_breaker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Mock implementations for testing
|
||||
type mockLogger struct {
|
||||
infoLogs []string
|
||||
errorLogs []string
|
||||
debugLogs []string
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func (m *mockLogger) Infof(format string, args ...interface{}) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.infoLogs = append(m.infoLogs, fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
func (m *mockLogger) Errorf(format string, args ...interface{}) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.errorLogs = append(m.errorLogs, fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
func (m *mockLogger) Debugf(format string, args ...interface{}) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.debugLogs = append(m.debugLogs, fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
func (m *mockLogger) getInfoLogs() []string {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
result := make([]string, len(m.infoLogs))
|
||||
copy(result, m.infoLogs)
|
||||
return result
|
||||
}
|
||||
|
||||
//lint:ignore U1000 May be needed for future error log verification tests
|
||||
func (m *mockLogger) getErrorLogs() []string {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
result := make([]string, len(m.errorLogs))
|
||||
copy(result, m.errorLogs)
|
||||
return result
|
||||
}
|
||||
|
||||
//lint:ignore U1000 May be needed for future test isolation
|
||||
func (m *mockLogger) reset() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.infoLogs = nil
|
||||
m.errorLogs = nil
|
||||
m.debugLogs = nil
|
||||
}
|
||||
|
||||
type mockBaseRecoveryMechanism struct {
|
||||
requestCount int64
|
||||
successCount int64
|
||||
failureCount int64
|
||||
infoLogs []string
|
||||
errorLogs []string
|
||||
debugLogs []string
|
||||
baseMetrics map[string]interface{}
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func newMockBaseRecovery() *mockBaseRecoveryMechanism {
|
||||
return &mockBaseRecoveryMechanism{
|
||||
baseMetrics: make(map[string]interface{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockBaseRecoveryMechanism) RecordRequest() {
|
||||
atomic.AddInt64(&m.requestCount, 1)
|
||||
}
|
||||
|
||||
func (m *mockBaseRecoveryMechanism) RecordSuccess() {
|
||||
atomic.AddInt64(&m.successCount, 1)
|
||||
}
|
||||
|
||||
func (m *mockBaseRecoveryMechanism) RecordFailure() {
|
||||
atomic.AddInt64(&m.failureCount, 1)
|
||||
}
|
||||
|
||||
func (m *mockBaseRecoveryMechanism) GetBaseMetrics() map[string]interface{} {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
result := make(map[string]interface{})
|
||||
for k, v := range m.baseMetrics {
|
||||
result[k] = v
|
||||
}
|
||||
result["total_requests"] = atomic.LoadInt64(&m.requestCount)
|
||||
result["total_successes"] = atomic.LoadInt64(&m.successCount)
|
||||
result["total_failures"] = atomic.LoadInt64(&m.failureCount)
|
||||
return result
|
||||
}
|
||||
|
||||
func (m *mockBaseRecoveryMechanism) LogInfo(format string, args ...interface{}) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.infoLogs = append(m.infoLogs, fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
func (m *mockBaseRecoveryMechanism) LogError(format string, args ...interface{}) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.errorLogs = append(m.errorLogs, fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
func (m *mockBaseRecoveryMechanism) LogDebug(format string, args ...interface{}) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.debugLogs = append(m.debugLogs, fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
func (m *mockBaseRecoveryMechanism) getRequestCount() int64 {
|
||||
return atomic.LoadInt64(&m.requestCount)
|
||||
}
|
||||
|
||||
func (m *mockBaseRecoveryMechanism) getSuccessCount() int64 {
|
||||
return atomic.LoadInt64(&m.successCount)
|
||||
}
|
||||
|
||||
func (m *mockBaseRecoveryMechanism) getFailureCount() int64 {
|
||||
return atomic.LoadInt64(&m.failureCount)
|
||||
}
|
||||
|
||||
func (m *mockBaseRecoveryMechanism) getInfoLogs() []string {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
result := make([]string, len(m.infoLogs))
|
||||
copy(result, m.infoLogs)
|
||||
return result
|
||||
}
|
||||
|
||||
func (m *mockBaseRecoveryMechanism) getErrorLogs() []string {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
result := make([]string, len(m.errorLogs))
|
||||
copy(result, m.errorLogs)
|
||||
return result
|
||||
}
|
||||
|
||||
func TestCircuitBreakerState_String(t *testing.T) {
|
||||
tests := []struct {
|
||||
state CircuitBreakerState
|
||||
expected string
|
||||
}{
|
||||
{CircuitBreakerClosed, "closed"},
|
||||
{CircuitBreakerOpen, "open"},
|
||||
{CircuitBreakerHalfOpen, "half-open"},
|
||||
{CircuitBreakerState(999), "unknown"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.expected, func(t *testing.T) {
|
||||
result := tt.state.String()
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected %s, got %s", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultCircuitBreakerConfig(t *testing.T) {
|
||||
config := DefaultCircuitBreakerConfig()
|
||||
|
||||
if config.MaxFailures != 2 {
|
||||
t.Errorf("Expected MaxFailures to be 2, got %d", config.MaxFailures)
|
||||
}
|
||||
|
||||
if config.Timeout != 60*time.Second {
|
||||
t.Errorf("Expected Timeout to be 60s, got %v", config.Timeout)
|
||||
}
|
||||
|
||||
if config.ResetTimeout != 30*time.Second {
|
||||
t.Errorf("Expected ResetTimeout to be 30s, got %v", config.ResetTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewCircuitBreaker(t *testing.T) {
|
||||
config := CircuitBreakerConfig{
|
||||
MaxFailures: 3,
|
||||
Timeout: 30 * time.Second,
|
||||
ResetTimeout: 15 * time.Second,
|
||||
}
|
||||
logger := &mockLogger{}
|
||||
baseRecovery := newMockBaseRecovery()
|
||||
|
||||
cb := NewCircuitBreaker(config, logger, baseRecovery)
|
||||
|
||||
if cb == nil {
|
||||
t.Fatal("NewCircuitBreaker returned nil")
|
||||
}
|
||||
|
||||
if cb.maxFailures != 3 {
|
||||
t.Errorf("Expected maxFailures to be 3, got %d", cb.maxFailures)
|
||||
}
|
||||
|
||||
if cb.timeout != 30*time.Second {
|
||||
t.Errorf("Expected timeout to be 30s, got %v", cb.timeout)
|
||||
}
|
||||
|
||||
if cb.resetTimeout != 15*time.Second {
|
||||
t.Errorf("Expected resetTimeout to be 15s, got %v", cb.resetTimeout)
|
||||
}
|
||||
|
||||
if cb.state != CircuitBreakerClosed {
|
||||
t.Errorf("Expected initial state to be Closed, got %v", cb.state)
|
||||
}
|
||||
|
||||
if cb.logger != logger {
|
||||
t.Error("Expected logger to be set")
|
||||
}
|
||||
|
||||
if cb.baseRecovery != baseRecovery {
|
||||
t.Error("Expected baseRecovery to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_ExecuteWithContext_Success(t *testing.T) {
|
||||
config := CircuitBreakerConfig{
|
||||
MaxFailures: 2,
|
||||
Timeout: time.Second,
|
||||
ResetTimeout: time.Second,
|
||||
}
|
||||
logger := &mockLogger{}
|
||||
baseRecovery := newMockBaseRecovery()
|
||||
cb := NewCircuitBreaker(config, logger, baseRecovery)
|
||||
|
||||
callCount := 0
|
||||
testFunc := func() error {
|
||||
callCount++
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
err := cb.ExecuteWithContext(ctx, testFunc)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if callCount != 1 {
|
||||
t.Errorf("Expected function to be called once, got %d", callCount)
|
||||
}
|
||||
|
||||
if cb.GetState() != CircuitBreakerClosed {
|
||||
t.Errorf("Expected state to remain Closed, got %v", cb.GetState())
|
||||
}
|
||||
|
||||
if baseRecovery.getRequestCount() != 1 {
|
||||
t.Errorf("Expected 1 request recorded, got %d", baseRecovery.getRequestCount())
|
||||
}
|
||||
|
||||
if baseRecovery.getSuccessCount() != 1 {
|
||||
t.Errorf("Expected 1 success recorded, got %d", baseRecovery.getSuccessCount())
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_ExecuteWithContext_Failure(t *testing.T) {
|
||||
config := CircuitBreakerConfig{
|
||||
MaxFailures: 2,
|
||||
Timeout: time.Second,
|
||||
ResetTimeout: time.Second,
|
||||
}
|
||||
logger := &mockLogger{}
|
||||
baseRecovery := newMockBaseRecovery()
|
||||
cb := NewCircuitBreaker(config, logger, baseRecovery)
|
||||
|
||||
testError := fmt.Errorf("test error")
|
||||
testFunc := func() error {
|
||||
return testError
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
err := cb.ExecuteWithContext(ctx, testFunc)
|
||||
|
||||
if err != testError {
|
||||
t.Errorf("Expected test error, got %v", err)
|
||||
}
|
||||
|
||||
if cb.GetState() != CircuitBreakerClosed {
|
||||
t.Errorf("Expected state to remain Closed after single failure, got %v", cb.GetState())
|
||||
}
|
||||
|
||||
if baseRecovery.getRequestCount() != 1 {
|
||||
t.Errorf("Expected 1 request recorded, got %d", baseRecovery.getRequestCount())
|
||||
}
|
||||
|
||||
if baseRecovery.getFailureCount() != 1 {
|
||||
t.Errorf("Expected 1 failure recorded, got %d", baseRecovery.getFailureCount())
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_Execute(t *testing.T) {
|
||||
config := CircuitBreakerConfig{
|
||||
MaxFailures: 1,
|
||||
Timeout: time.Second,
|
||||
ResetTimeout: time.Second,
|
||||
}
|
||||
logger := &mockLogger{}
|
||||
baseRecovery := newMockBaseRecovery()
|
||||
cb := NewCircuitBreaker(config, logger, baseRecovery)
|
||||
|
||||
callCount := 0
|
||||
testFunc := func() error {
|
||||
callCount++
|
||||
return nil
|
||||
}
|
||||
|
||||
err := cb.Execute(testFunc)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if callCount != 1 {
|
||||
t.Errorf("Expected function to be called once, got %d", callCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_OpenAfterMaxFailures(t *testing.T) {
|
||||
config := CircuitBreakerConfig{
|
||||
MaxFailures: 2,
|
||||
Timeout: time.Second,
|
||||
ResetTimeout: time.Second,
|
||||
}
|
||||
logger := &mockLogger{}
|
||||
baseRecovery := newMockBaseRecovery()
|
||||
cb := NewCircuitBreaker(config, logger, baseRecovery)
|
||||
|
||||
testError := fmt.Errorf("test error")
|
||||
testFunc := func() error {
|
||||
return testError
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// First failure
|
||||
err := cb.ExecuteWithContext(ctx, testFunc)
|
||||
if err != testError {
|
||||
t.Errorf("Expected test error on first failure, got %v", err)
|
||||
}
|
||||
if cb.GetState() != CircuitBreakerClosed {
|
||||
t.Errorf("Expected state to remain Closed after first failure, got %v", cb.GetState())
|
||||
}
|
||||
|
||||
// Second failure - should open circuit
|
||||
err = cb.ExecuteWithContext(ctx, testFunc)
|
||||
if err != testError {
|
||||
t.Errorf("Expected test error on second failure, got %v", err)
|
||||
}
|
||||
if cb.GetState() != CircuitBreakerOpen {
|
||||
t.Errorf("Expected state to be Open after max failures, got %v", cb.GetState())
|
||||
}
|
||||
|
||||
// Third attempt - should be blocked
|
||||
callCount := 0
|
||||
blockedFunc := func() error {
|
||||
callCount++
|
||||
return nil
|
||||
}
|
||||
err = cb.ExecuteWithContext(ctx, blockedFunc)
|
||||
if err == nil {
|
||||
t.Error("Expected error when circuit is open")
|
||||
}
|
||||
if callCount != 0 {
|
||||
t.Errorf("Expected function not to be called when circuit is open, got %d calls", callCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_HalfOpenTransition(t *testing.T) {
|
||||
config := CircuitBreakerConfig{
|
||||
MaxFailures: 1,
|
||||
Timeout: 10 * time.Millisecond, // Very short for testing
|
||||
ResetTimeout: time.Second,
|
||||
}
|
||||
logger := &mockLogger{}
|
||||
baseRecovery := newMockBaseRecovery()
|
||||
cb := NewCircuitBreaker(config, logger, baseRecovery)
|
||||
|
||||
// Trigger circuit opening
|
||||
testError := fmt.Errorf("test error")
|
||||
err := cb.ExecuteWithContext(context.Background(), func() error {
|
||||
return testError
|
||||
})
|
||||
if err != testError {
|
||||
t.Errorf("Expected test error, got %v", err)
|
||||
}
|
||||
if cb.GetState() != CircuitBreakerOpen {
|
||||
t.Errorf("Expected state to be Open, got %v", cb.GetState())
|
||||
}
|
||||
|
||||
// Wait for timeout
|
||||
time.Sleep(15 * time.Millisecond)
|
||||
|
||||
// Next request should transition to half-open
|
||||
callCount := 0
|
||||
testFunc := func() error {
|
||||
callCount++
|
||||
return nil
|
||||
}
|
||||
|
||||
err = cb.ExecuteWithContext(context.Background(), testFunc)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error in half-open state, got %v", err)
|
||||
}
|
||||
if callCount != 1 {
|
||||
t.Errorf("Expected function to be called in half-open state, got %d calls", callCount)
|
||||
}
|
||||
if cb.GetState() != CircuitBreakerClosed {
|
||||
t.Errorf("Expected state to be Closed after successful half-open request, got %v", cb.GetState())
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_HalfOpenFailureReturnsToOpen(t *testing.T) {
|
||||
config := CircuitBreakerConfig{
|
||||
MaxFailures: 1,
|
||||
Timeout: 10 * time.Millisecond,
|
||||
ResetTimeout: time.Second,
|
||||
}
|
||||
logger := &mockLogger{}
|
||||
baseRecovery := newMockBaseRecovery()
|
||||
cb := NewCircuitBreaker(config, logger, baseRecovery)
|
||||
|
||||
// Trigger circuit opening
|
||||
testError := fmt.Errorf("test error")
|
||||
_ = cb.ExecuteWithContext(context.Background(), func() error {
|
||||
return testError
|
||||
})
|
||||
if cb.GetState() != CircuitBreakerOpen {
|
||||
t.Errorf("Expected state to be Open, got %v", cb.GetState())
|
||||
}
|
||||
|
||||
// Wait for timeout to allow half-open transition
|
||||
time.Sleep(15 * time.Millisecond)
|
||||
|
||||
// First call should transition to half-open, but we'll force it by checking allowRequest
|
||||
if !cb.allowRequest() {
|
||||
t.Error("Expected allowRequest to return true after timeout")
|
||||
}
|
||||
if cb.GetState() != CircuitBreakerHalfOpen {
|
||||
t.Errorf("Expected state to be HalfOpen, got %v", cb.GetState())
|
||||
}
|
||||
|
||||
// Failure in half-open should return to open
|
||||
err := cb.ExecuteWithContext(context.Background(), func() error {
|
||||
return testError
|
||||
})
|
||||
if err != testError {
|
||||
t.Errorf("Expected test error, got %v", err)
|
||||
}
|
||||
if cb.GetState() != CircuitBreakerOpen {
|
||||
t.Errorf("Expected state to return to Open after half-open failure, got %v", cb.GetState())
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_Reset(t *testing.T) {
|
||||
config := CircuitBreakerConfig{
|
||||
MaxFailures: 1,
|
||||
Timeout: time.Second,
|
||||
ResetTimeout: time.Second,
|
||||
}
|
||||
logger := &mockLogger{}
|
||||
baseRecovery := newMockBaseRecovery()
|
||||
cb := NewCircuitBreaker(config, logger, baseRecovery)
|
||||
|
||||
// Trigger circuit opening
|
||||
testError := fmt.Errorf("test error")
|
||||
_ = cb.ExecuteWithContext(context.Background(), func() error {
|
||||
return testError
|
||||
})
|
||||
if cb.GetState() != CircuitBreakerOpen {
|
||||
t.Errorf("Expected state to be Open, got %v", cb.GetState())
|
||||
}
|
||||
|
||||
// Reset circuit
|
||||
cb.Reset()
|
||||
|
||||
if cb.GetState() != CircuitBreakerClosed {
|
||||
t.Errorf("Expected state to be Closed after reset, got %v", cb.GetState())
|
||||
}
|
||||
|
||||
if cb.GetFailureCount() != 0 {
|
||||
t.Errorf("Expected failure count to be 0 after reset, got %d", cb.GetFailureCount())
|
||||
}
|
||||
|
||||
// Should allow requests again
|
||||
callCount := 0
|
||||
err := cb.ExecuteWithContext(context.Background(), func() error {
|
||||
callCount++
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error after reset, got %v", err)
|
||||
}
|
||||
if callCount != 1 {
|
||||
t.Errorf("Expected function to be called after reset, got %d calls", callCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_IsAvailable(t *testing.T) {
|
||||
config := CircuitBreakerConfig{
|
||||
MaxFailures: 1,
|
||||
Timeout: 10 * time.Millisecond,
|
||||
ResetTimeout: time.Second,
|
||||
}
|
||||
logger := &mockLogger{}
|
||||
baseRecovery := newMockBaseRecovery()
|
||||
cb := NewCircuitBreaker(config, logger, baseRecovery)
|
||||
|
||||
// Initially available
|
||||
if !cb.IsAvailable() {
|
||||
t.Error("Expected circuit breaker to be available initially")
|
||||
}
|
||||
|
||||
// Trigger opening
|
||||
testError := fmt.Errorf("test error")
|
||||
cb.ExecuteWithContext(context.Background(), func() error {
|
||||
return testError
|
||||
})
|
||||
|
||||
// Should not be available when open
|
||||
if cb.IsAvailable() {
|
||||
t.Error("Expected circuit breaker to be unavailable when open")
|
||||
}
|
||||
|
||||
// Wait for timeout
|
||||
time.Sleep(15 * time.Millisecond)
|
||||
|
||||
// Should be available again after timeout (half-open)
|
||||
if !cb.IsAvailable() {
|
||||
t.Error("Expected circuit breaker to be available after timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_StateCheckers(t *testing.T) {
|
||||
config := CircuitBreakerConfig{
|
||||
MaxFailures: 1,
|
||||
Timeout: 10 * time.Millisecond,
|
||||
ResetTimeout: time.Second,
|
||||
}
|
||||
logger := &mockLogger{}
|
||||
baseRecovery := newMockBaseRecovery()
|
||||
cb := NewCircuitBreaker(config, logger, baseRecovery)
|
||||
|
||||
// Initially closed
|
||||
if !cb.IsClosed() {
|
||||
t.Error("Expected circuit breaker to be closed initially")
|
||||
}
|
||||
if cb.IsOpen() {
|
||||
t.Error("Expected circuit breaker not to be open initially")
|
||||
}
|
||||
if cb.IsHalfOpen() {
|
||||
t.Error("Expected circuit breaker not to be half-open initially")
|
||||
}
|
||||
|
||||
// Trigger opening
|
||||
testError := fmt.Errorf("test error")
|
||||
cb.ExecuteWithContext(context.Background(), func() error {
|
||||
return testError
|
||||
})
|
||||
|
||||
// Should be open
|
||||
if cb.IsClosed() {
|
||||
t.Error("Expected circuit breaker not to be closed when open")
|
||||
}
|
||||
if !cb.IsOpen() {
|
||||
t.Error("Expected circuit breaker to be open")
|
||||
}
|
||||
if cb.IsHalfOpen() {
|
||||
t.Error("Expected circuit breaker not to be half-open when open")
|
||||
}
|
||||
|
||||
// Wait for timeout and trigger half-open
|
||||
time.Sleep(15 * time.Millisecond)
|
||||
cb.allowRequest() // This will transition to half-open
|
||||
|
||||
// Should be half-open
|
||||
if cb.IsClosed() {
|
||||
t.Error("Expected circuit breaker not to be closed when half-open")
|
||||
}
|
||||
if cb.IsOpen() {
|
||||
t.Error("Expected circuit breaker not to be open when half-open")
|
||||
}
|
||||
if !cb.IsHalfOpen() {
|
||||
t.Error("Expected circuit breaker to be half-open")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_GetMetrics(t *testing.T) {
|
||||
config := CircuitBreakerConfig{
|
||||
MaxFailures: 2,
|
||||
Timeout: 30 * time.Second,
|
||||
ResetTimeout: 15 * time.Second,
|
||||
}
|
||||
logger := &mockLogger{}
|
||||
baseRecovery := newMockBaseRecovery()
|
||||
baseRecovery.baseMetrics["custom_metric"] = "custom_value"
|
||||
cb := NewCircuitBreaker(config, logger, baseRecovery)
|
||||
|
||||
// Record some activity
|
||||
testError := fmt.Errorf("test error")
|
||||
cb.ExecuteWithContext(context.Background(), func() error {
|
||||
return testError
|
||||
})
|
||||
|
||||
metrics := cb.GetMetrics()
|
||||
|
||||
// Check circuit breaker specific metrics
|
||||
if metrics["state"] != "closed" {
|
||||
t.Errorf("Expected state to be 'closed', got %v", metrics["state"])
|
||||
}
|
||||
|
||||
if metrics["current_failures"] != int64(1) {
|
||||
t.Errorf("Expected current_failures to be 1, got %v", metrics["current_failures"])
|
||||
}
|
||||
|
||||
if metrics["max_failures"] != 2 {
|
||||
t.Errorf("Expected max_failures to be 2, got %v", metrics["max_failures"])
|
||||
}
|
||||
|
||||
if metrics["timeout"] != "30s" {
|
||||
t.Errorf("Expected timeout to be '30s', got %v", metrics["timeout"])
|
||||
}
|
||||
|
||||
if metrics["reset_timeout"] != "15s" {
|
||||
t.Errorf("Expected reset_timeout to be '15s', got %v", metrics["reset_timeout"])
|
||||
}
|
||||
|
||||
// Check base metrics are included
|
||||
if metrics["total_requests"] != int64(1) {
|
||||
t.Errorf("Expected total_requests to be 1, got %v", metrics["total_requests"])
|
||||
}
|
||||
|
||||
if metrics["custom_metric"] != "custom_value" {
|
||||
t.Errorf("Expected custom_metric to be 'custom_value', got %v", metrics["custom_metric"])
|
||||
}
|
||||
|
||||
// Check failure time metrics
|
||||
if _, exists := metrics["last_failure_time"]; !exists {
|
||||
t.Error("Expected last_failure_time to exist")
|
||||
}
|
||||
|
||||
if _, exists := metrics["time_since_last_failure"]; !exists {
|
||||
t.Error("Expected time_since_last_failure to exist")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_GetMetrics_NoBaseRecovery(t *testing.T) {
|
||||
config := DefaultCircuitBreakerConfig()
|
||||
logger := &mockLogger{}
|
||||
cb := NewCircuitBreaker(config, logger, nil)
|
||||
|
||||
metrics := cb.GetMetrics()
|
||||
|
||||
// Should still have circuit breaker metrics
|
||||
if metrics["state"] != "closed" {
|
||||
t.Errorf("Expected state to be 'closed', got %v", metrics["state"])
|
||||
}
|
||||
|
||||
if metrics["max_failures"] != 2 {
|
||||
t.Errorf("Expected max_failures to be 2, got %v", metrics["max_failures"])
|
||||
}
|
||||
|
||||
// Should not have base metrics
|
||||
if _, exists := metrics["total_requests"]; exists {
|
||||
t.Error("Expected total_requests not to exist without base recovery")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_GetLastFailureTime(t *testing.T) {
|
||||
config := DefaultCircuitBreakerConfig()
|
||||
logger := &mockLogger{}
|
||||
baseRecovery := newMockBaseRecovery()
|
||||
cb := NewCircuitBreaker(config, logger, baseRecovery)
|
||||
|
||||
// Initially should be zero
|
||||
if !cb.GetLastFailureTime().IsZero() {
|
||||
t.Error("Expected last failure time to be zero initially")
|
||||
}
|
||||
|
||||
// Record a failure
|
||||
before := time.Now()
|
||||
testError := fmt.Errorf("test error")
|
||||
cb.ExecuteWithContext(context.Background(), func() error {
|
||||
return testError
|
||||
})
|
||||
after := time.Now()
|
||||
|
||||
lastFailure := cb.GetLastFailureTime()
|
||||
if lastFailure.IsZero() {
|
||||
t.Error("Expected last failure time to be set after failure")
|
||||
}
|
||||
|
||||
if lastFailure.Before(before) || lastFailure.After(after) {
|
||||
t.Errorf("Expected last failure time to be between %v and %v, got %v",
|
||||
before, after, lastFailure)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_ExecuteWithoutBaseRecovery(t *testing.T) {
|
||||
config := DefaultCircuitBreakerConfig()
|
||||
logger := &mockLogger{}
|
||||
cb := NewCircuitBreaker(config, logger, nil)
|
||||
|
||||
callCount := 0
|
||||
testFunc := func() error {
|
||||
callCount++
|
||||
return nil
|
||||
}
|
||||
|
||||
err := cb.ExecuteWithContext(context.Background(), testFunc)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if callCount != 1 {
|
||||
t.Errorf("Expected function to be called once, got %d", callCount)
|
||||
}
|
||||
|
||||
// Should work fine without base recovery
|
||||
if cb.GetState() != CircuitBreakerClosed {
|
||||
t.Errorf("Expected state to be Closed, got %v", cb.GetState())
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_ConcurrentAccess(t *testing.T) {
|
||||
config := CircuitBreakerConfig{
|
||||
MaxFailures: 10, // Higher threshold for concurrent test
|
||||
Timeout: 100 * time.Millisecond,
|
||||
ResetTimeout: 50 * time.Millisecond,
|
||||
}
|
||||
logger := &mockLogger{}
|
||||
baseRecovery := newMockBaseRecovery()
|
||||
cb := NewCircuitBreaker(config, logger, baseRecovery)
|
||||
|
||||
const numGoroutines = 10
|
||||
const numOperations = 50
|
||||
|
||||
var wg sync.WaitGroup
|
||||
successCount := int64(0)
|
||||
errorCount := int64(0)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < numOperations; j++ {
|
||||
err := cb.ExecuteWithContext(context.Background(), func() error {
|
||||
// Simulate some failures
|
||||
if j%10 == 9 { // Every 10th operation fails
|
||||
return fmt.Errorf("simulated error")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
atomic.AddInt64(&errorCount, 1)
|
||||
} else {
|
||||
atomic.AddInt64(&successCount, 1)
|
||||
}
|
||||
|
||||
// Intermittently check state and metrics
|
||||
if j%5 == 0 {
|
||||
cb.GetState()
|
||||
cb.GetMetrics()
|
||||
cb.IsAvailable()
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Verify we got both successes and errors
|
||||
finalSuccessCount := atomic.LoadInt64(&successCount)
|
||||
finalErrorCount := atomic.LoadInt64(&errorCount)
|
||||
|
||||
if finalSuccessCount == 0 {
|
||||
t.Error("Expected some successful operations")
|
||||
}
|
||||
|
||||
if finalErrorCount == 0 {
|
||||
t.Error("Expected some failed operations")
|
||||
}
|
||||
|
||||
totalOperations := finalSuccessCount + finalErrorCount
|
||||
expectedMax := int64(numGoroutines * numOperations)
|
||||
|
||||
if totalOperations > expectedMax {
|
||||
t.Errorf("Expected at most %d operations, got %d", expectedMax, totalOperations)
|
||||
}
|
||||
|
||||
t.Logf("Concurrent test completed: %d successes, %d errors, final state: %v",
|
||||
finalSuccessCount, finalErrorCount, cb.GetState())
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_StateTransitionLogging(t *testing.T) {
|
||||
config := CircuitBreakerConfig{
|
||||
MaxFailures: 1,
|
||||
Timeout: 10 * time.Millisecond,
|
||||
ResetTimeout: time.Second,
|
||||
}
|
||||
logger := &mockLogger{}
|
||||
baseRecovery := newMockBaseRecovery()
|
||||
cb := NewCircuitBreaker(config, logger, baseRecovery)
|
||||
|
||||
// Trigger circuit opening
|
||||
testError := fmt.Errorf("test error")
|
||||
cb.ExecuteWithContext(context.Background(), func() error {
|
||||
return testError
|
||||
})
|
||||
|
||||
// Check that error was logged when circuit opened
|
||||
errorLogs := baseRecovery.getErrorLogs()
|
||||
if len(errorLogs) == 0 {
|
||||
t.Error("Expected error log when circuit breaker opened")
|
||||
} else {
|
||||
if !contains(errorLogs, "Circuit breaker opened after") {
|
||||
t.Errorf("Expected circuit opening log, got %v", errorLogs)
|
||||
}
|
||||
}
|
||||
|
||||
// Wait and trigger half-open
|
||||
time.Sleep(15 * time.Millisecond)
|
||||
|
||||
// Successful request should close circuit and log
|
||||
cb.ExecuteWithContext(context.Background(), func() error {
|
||||
return nil
|
||||
})
|
||||
|
||||
// Check that success was logged when circuit closed
|
||||
infoLogs := baseRecovery.getInfoLogs()
|
||||
if len(infoLogs) == 0 {
|
||||
t.Error("Expected info log when circuit breaker closed")
|
||||
} else {
|
||||
if !contains(infoLogs, "Circuit breaker closed after successful request") {
|
||||
t.Errorf("Expected circuit closing log, got %v", infoLogs)
|
||||
}
|
||||
}
|
||||
|
||||
// Reset should also be logged
|
||||
cb.Reset()
|
||||
infoLogs = baseRecovery.getInfoLogs()
|
||||
if !contains(infoLogs, "Circuit breaker has been reset") {
|
||||
t.Errorf("Expected reset log, got %v", infoLogs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_LoggerTransitionLogging(t *testing.T) {
|
||||
config := CircuitBreakerConfig{
|
||||
MaxFailures: 1,
|
||||
Timeout: 10 * time.Millisecond,
|
||||
ResetTimeout: time.Second,
|
||||
}
|
||||
logger := &mockLogger{}
|
||||
baseRecovery := newMockBaseRecovery()
|
||||
cb := NewCircuitBreaker(config, logger, baseRecovery)
|
||||
|
||||
// Wait for timeout and check half-open transition logging
|
||||
testError := fmt.Errorf("test error")
|
||||
cb.ExecuteWithContext(context.Background(), func() error {
|
||||
return testError
|
||||
})
|
||||
|
||||
// Wait for timeout
|
||||
time.Sleep(15 * time.Millisecond)
|
||||
|
||||
// Next allowRequest call should log transition to half-open
|
||||
cb.allowRequest()
|
||||
|
||||
infoLogs := logger.getInfoLogs()
|
||||
if len(infoLogs) == 0 {
|
||||
t.Error("Expected info log for half-open transition")
|
||||
} else {
|
||||
if !contains(infoLogs, "Circuit breaker transitioning to half-open state") {
|
||||
t.Errorf("Expected half-open transition log, got %v", infoLogs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to check if a slice contains a string with substring
|
||||
func contains(slice []string, substr string) bool {
|
||||
for _, s := range slice {
|
||||
if len(s) >= len(substr) && s[:len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Benchmark tests
|
||||
func BenchmarkCircuitBreaker_ExecuteWithContext_Success(b *testing.B) {
|
||||
config := DefaultCircuitBreakerConfig()
|
||||
logger := &mockLogger{}
|
||||
baseRecovery := newMockBaseRecovery()
|
||||
cb := NewCircuitBreaker(config, logger, baseRecovery)
|
||||
|
||||
testFunc := func() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
cb.ExecuteWithContext(ctx, testFunc)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkCircuitBreaker_ExecuteWithContext_Failure(b *testing.B) {
|
||||
config := CircuitBreakerConfig{
|
||||
MaxFailures: 1000, // High threshold to avoid opening during benchmark
|
||||
Timeout: time.Second,
|
||||
ResetTimeout: time.Second,
|
||||
}
|
||||
logger := &mockLogger{}
|
||||
baseRecovery := newMockBaseRecovery()
|
||||
cb := NewCircuitBreaker(config, logger, baseRecovery)
|
||||
|
||||
testError := fmt.Errorf("test error")
|
||||
testFunc := func() error {
|
||||
return testError
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
cb.ExecuteWithContext(ctx, testFunc)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkCircuitBreaker_GetState(b *testing.B) {
|
||||
config := DefaultCircuitBreakerConfig()
|
||||
logger := &mockLogger{}
|
||||
baseRecovery := newMockBaseRecovery()
|
||||
cb := NewCircuitBreaker(config, logger, baseRecovery)
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
cb.GetState()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkCircuitBreaker_GetMetrics(b *testing.B) {
|
||||
config := DefaultCircuitBreakerConfig()
|
||||
logger := &mockLogger{}
|
||||
baseRecovery := newMockBaseRecovery()
|
||||
cb := NewCircuitBreaker(config, logger, baseRecovery)
|
||||
|
||||
// Add some activity
|
||||
for i := 0; i < 100; i++ {
|
||||
if i%2 == 0 {
|
||||
cb.ExecuteWithContext(context.Background(), func() error { return nil })
|
||||
} else {
|
||||
cb.ExecuteWithContext(context.Background(), func() error { return fmt.Errorf("error") })
|
||||
}
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
cb.GetMetrics()
|
||||
}
|
||||
}
|
||||
@@ -1,258 +0,0 @@
|
||||
// Package config provides backward compatibility for legacy configuration
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/compat"
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/features"
|
||||
)
|
||||
|
||||
// LegacyAdapter provides backward compatibility for old Config struct
|
||||
type LegacyAdapter struct {
|
||||
unified *UnifiedConfig
|
||||
adapter *compat.ConfigAdapter
|
||||
}
|
||||
|
||||
// NewLegacyAdapter creates a new legacy adapter from unified config
|
||||
func NewLegacyAdapter(unified *UnifiedConfig) *LegacyAdapter {
|
||||
adapter := compat.NewConfigAdapter(unified)
|
||||
|
||||
// Register getters for commonly used fields
|
||||
adapter.RegisterGetter("ProviderURL", func() interface{} {
|
||||
return unified.Provider.IssuerURL
|
||||
})
|
||||
adapter.RegisterGetter("ClientID", func() interface{} {
|
||||
return unified.Provider.ClientID
|
||||
})
|
||||
adapter.RegisterGetter("ClientSecret", func() interface{} {
|
||||
return unified.Provider.ClientSecret
|
||||
})
|
||||
adapter.RegisterGetter("CallbackURL", func() interface{} {
|
||||
return unified.Provider.RedirectURL
|
||||
})
|
||||
adapter.RegisterGetter("LogoutURL", func() interface{} {
|
||||
return unified.Provider.LogoutURL
|
||||
})
|
||||
adapter.RegisterGetter("PostLogoutRedirectURI", func() interface{} {
|
||||
return unified.Provider.PostLogoutRedirectURI
|
||||
})
|
||||
adapter.RegisterGetter("SessionEncryptionKey", func() interface{} {
|
||||
return unified.Session.EncryptionKey
|
||||
})
|
||||
adapter.RegisterGetter("ForceHTTPS", func() interface{} {
|
||||
return unified.Security.ForceHTTPS
|
||||
})
|
||||
adapter.RegisterGetter("LogLevel", func() interface{} {
|
||||
return unified.Logging.Level
|
||||
})
|
||||
adapter.RegisterGetter("Scopes", func() interface{} {
|
||||
return unified.Provider.Scopes
|
||||
})
|
||||
adapter.RegisterGetter("OverrideScopes", func() interface{} {
|
||||
return unified.Provider.OverrideScopes
|
||||
})
|
||||
adapter.RegisterGetter("AllowedUsers", func() interface{} {
|
||||
return unified.Security.AllowedUsers
|
||||
})
|
||||
adapter.RegisterGetter("AllowedUserDomains", func() interface{} {
|
||||
return unified.Security.AllowedUserDomains
|
||||
})
|
||||
adapter.RegisterGetter("AllowedRolesAndGroups", func() interface{} {
|
||||
return unified.Security.AllowedRolesAndGroups
|
||||
})
|
||||
adapter.RegisterGetter("ExcludedURLs", func() interface{} {
|
||||
return unified.Security.ExcludedURLs
|
||||
})
|
||||
adapter.RegisterGetter("EnablePKCE", func() interface{} {
|
||||
return unified.Security.EnablePKCE
|
||||
})
|
||||
adapter.RegisterGetter("RateLimit", func() interface{} {
|
||||
return unified.RateLimit.RequestsPerSecond
|
||||
})
|
||||
adapter.RegisterGetter("RefreshGracePeriodSeconds", func() interface{} {
|
||||
return int(unified.Token.RefreshGracePeriod.Seconds())
|
||||
})
|
||||
adapter.RegisterGetter("CookieDomain", func() interface{} {
|
||||
return unified.Session.Domain
|
||||
})
|
||||
adapter.RegisterGetter("SecurityHeaders", func() interface{} {
|
||||
return unified.Security.Headers
|
||||
})
|
||||
|
||||
return &LegacyAdapter{
|
||||
unified: unified,
|
||||
adapter: adapter,
|
||||
}
|
||||
}
|
||||
|
||||
// ToOldConfig converts unified config to old Config struct format
|
||||
func (la *LegacyAdapter) ToOldConfig() *Config {
|
||||
// Use feature flags to determine behavior
|
||||
if !features.IsUnifiedConfigEnabled() {
|
||||
// Return existing Config if unified config not enabled
|
||||
return CreateConfig()
|
||||
}
|
||||
|
||||
cfg := &Config{
|
||||
ProviderURL: la.unified.Provider.IssuerURL,
|
||||
ClientID: la.unified.Provider.ClientID,
|
||||
ClientSecret: la.unified.Provider.ClientSecret,
|
||||
CallbackURL: la.unified.Provider.RedirectURL,
|
||||
LogoutURL: la.unified.Provider.LogoutURL,
|
||||
PostLogoutRedirectURI: la.unified.Provider.PostLogoutRedirectURI,
|
||||
SessionEncryptionKey: la.unified.Session.EncryptionKey,
|
||||
ForceHTTPS: la.unified.Security.ForceHTTPS,
|
||||
LogLevel: la.unified.Logging.Level,
|
||||
Scopes: la.unified.Provider.Scopes,
|
||||
OverrideScopes: la.unified.Provider.OverrideScopes,
|
||||
AllowedUsers: la.unified.Security.AllowedUsers,
|
||||
AllowedUserDomains: la.unified.Security.AllowedUserDomains,
|
||||
AllowedRolesAndGroups: la.unified.Security.AllowedRolesAndGroups,
|
||||
ExcludedURLs: la.unified.Security.ExcludedURLs,
|
||||
EnablePKCE: la.unified.Security.EnablePKCE,
|
||||
RateLimit: la.unified.RateLimit.RequestsPerSecond,
|
||||
RefreshGracePeriodSeconds: int(la.unified.Token.RefreshGracePeriod.Seconds()),
|
||||
Headers: la.convertHeaders(),
|
||||
CookieDomain: la.unified.Session.Domain,
|
||||
SecurityHeaders: la.unified.Security.Headers,
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
// convertHeaders converts unified header config to old format
|
||||
func (la *LegacyAdapter) convertHeaders() []HeaderConfig {
|
||||
headers := make([]HeaderConfig, 0)
|
||||
|
||||
for name, value := range la.unified.Middleware.CustomHeaders {
|
||||
headers = append(headers, HeaderConfig{
|
||||
Name: name,
|
||||
Value: value,
|
||||
})
|
||||
}
|
||||
|
||||
return headers
|
||||
}
|
||||
|
||||
// FromOldConfig creates unified config from old Config struct
|
||||
func FromOldConfig(old *Config) *UnifiedConfig {
|
||||
unified := NewUnifiedConfig()
|
||||
|
||||
// Map provider settings
|
||||
unified.Provider.IssuerURL = old.ProviderURL
|
||||
unified.Provider.ClientID = old.ClientID
|
||||
unified.Provider.ClientSecret = old.ClientSecret
|
||||
unified.Provider.RedirectURL = old.CallbackURL
|
||||
unified.Provider.LogoutURL = old.LogoutURL
|
||||
unified.Provider.PostLogoutRedirectURI = old.PostLogoutRedirectURI
|
||||
unified.Provider.Scopes = old.Scopes
|
||||
unified.Provider.OverrideScopes = old.OverrideScopes
|
||||
|
||||
// Map session settings
|
||||
unified.Session.EncryptionKey = old.SessionEncryptionKey
|
||||
unified.Session.Domain = old.CookieDomain
|
||||
|
||||
// Map security settings
|
||||
unified.Security.ForceHTTPS = old.ForceHTTPS
|
||||
unified.Security.EnablePKCE = old.EnablePKCE
|
||||
unified.Security.AllowedUsers = old.AllowedUsers
|
||||
unified.Security.AllowedUserDomains = old.AllowedUserDomains
|
||||
unified.Security.AllowedRolesAndGroups = old.AllowedRolesAndGroups
|
||||
unified.Security.ExcludedURLs = old.ExcludedURLs
|
||||
unified.Security.Headers = old.SecurityHeaders
|
||||
|
||||
// Map rate limiting
|
||||
unified.RateLimit.RequestsPerSecond = old.RateLimit
|
||||
unified.RateLimit.Enabled = old.RateLimit > 0
|
||||
|
||||
// Map token settings
|
||||
unified.Token.RefreshGracePeriod = timeSecondsToDuration(old.RefreshGracePeriodSeconds)
|
||||
|
||||
// Map logging
|
||||
unified.Logging.Level = old.LogLevel
|
||||
|
||||
// Map custom headers
|
||||
if len(old.Headers) > 0 {
|
||||
unified.Middleware.CustomHeaders = make(map[string]string)
|
||||
for _, header := range old.Headers {
|
||||
unified.Middleware.CustomHeaders[header.Name] = header.Value
|
||||
}
|
||||
}
|
||||
|
||||
// Store original config in legacy field for reference
|
||||
unified.Legacy["original"] = old
|
||||
|
||||
return unified
|
||||
}
|
||||
|
||||
// timeSecondsToDuration converts seconds to time.Duration
|
||||
func timeSecondsToDuration(seconds int) time.Duration {
|
||||
return time.Duration(seconds) * time.Second
|
||||
}
|
||||
|
||||
// GetConfigInterface returns appropriate config based on feature flag
|
||||
func GetConfigInterface() interface{} {
|
||||
if features.IsUnifiedConfigEnabled() {
|
||||
return NewUnifiedConfig()
|
||||
}
|
||||
return CreateConfig()
|
||||
}
|
||||
|
||||
// ValidateConfig validates config based on feature flag
|
||||
func ValidateConfig(cfg interface{}) error {
|
||||
if features.IsUnifiedConfigEnabled() {
|
||||
if unified, ok := cfg.(*UnifiedConfig); ok {
|
||||
return unified.Validate()
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to old validation if available
|
||||
if old, ok := cfg.(*Config); ok {
|
||||
return old.Validate()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Add Validate method to old Config for compatibility
|
||||
func (c *Config) Validate() error {
|
||||
var errors ValidationErrors
|
||||
|
||||
// Basic validation for old config
|
||||
if c.ProviderURL == "" {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "ProviderURL",
|
||||
Message: "provider URL is required",
|
||||
})
|
||||
}
|
||||
|
||||
if c.ClientID == "" {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "ClientID",
|
||||
Message: "client ID is required",
|
||||
})
|
||||
}
|
||||
|
||||
if c.ClientSecret == "" && !c.EnablePKCE {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "ClientSecret",
|
||||
Message: "client secret is required (or enable PKCE)",
|
||||
})
|
||||
}
|
||||
|
||||
if c.SessionEncryptionKey != "" && len(c.SessionEncryptionKey) < minEncryptionKeyLength {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "SessionEncryptionKey",
|
||||
Message: fmt.Sprintf("encryption key must be at least %d characters", minEncryptionKeyLength),
|
||||
Value: len(c.SessionEncryptionKey),
|
||||
})
|
||||
}
|
||||
|
||||
if len(errors) > 0 {
|
||||
return errors
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,363 +0,0 @@
|
||||
//go:build !yaegi
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/features"
|
||||
)
|
||||
|
||||
// NewLegacyAdapter Tests
|
||||
func TestNewLegacyAdapter(t *testing.T) {
|
||||
unified := NewUnifiedConfig()
|
||||
unified.Provider.IssuerURL = "https://provider.example.com"
|
||||
unified.Provider.ClientID = "test-client"
|
||||
unified.Provider.ClientSecret = "test-secret"
|
||||
|
||||
adapter := NewLegacyAdapter(unified)
|
||||
|
||||
if adapter == nil {
|
||||
t.Fatal("Expected NewLegacyAdapter to return non-nil")
|
||||
}
|
||||
|
||||
if adapter.unified != unified {
|
||||
t.Error("Expected adapter to reference the unified config")
|
||||
}
|
||||
|
||||
if adapter.adapter == nil {
|
||||
t.Error("Expected internal adapter to be initialized")
|
||||
}
|
||||
}
|
||||
|
||||
// ToOldConfig Tests
|
||||
func TestLegacyAdapter_ToOldConfig(t *testing.T) {
|
||||
unified := NewUnifiedConfig()
|
||||
unified.Provider.IssuerURL = "https://issuer.example.com"
|
||||
unified.Provider.ClientID = "client-123"
|
||||
unified.Provider.ClientSecret = "secret-456"
|
||||
unified.Provider.RedirectURL = "https://app.example.com/callback"
|
||||
unified.Provider.LogoutURL = "/logout"
|
||||
unified.Provider.PostLogoutRedirectURI = "https://app.example.com"
|
||||
unified.Provider.Scopes = []string{"openid", "profile"}
|
||||
unified.Provider.OverrideScopes = true
|
||||
unified.Session.EncryptionKey = "test-encryption-key-32-chars!!"
|
||||
unified.Session.Domain = "example.com"
|
||||
unified.Security.ForceHTTPS = true
|
||||
unified.Security.EnablePKCE = true
|
||||
unified.Security.AllowedUsers = []string{"user@example.com"}
|
||||
unified.Security.AllowedUserDomains = []string{"example.com"}
|
||||
unified.Security.AllowedRolesAndGroups = []string{"admin"}
|
||||
unified.Security.ExcludedURLs = []string{"/health"}
|
||||
unified.RateLimit.RequestsPerSecond = 100
|
||||
unified.Logging.Level = "debug"
|
||||
unified.Middleware.CustomHeaders = map[string]string{
|
||||
"X-Header-1": "value1",
|
||||
"X-Header-2": "value2",
|
||||
}
|
||||
|
||||
adapter := NewLegacyAdapter(unified)
|
||||
oldConfig := adapter.ToOldConfig()
|
||||
|
||||
if oldConfig == nil {
|
||||
t.Fatal("Expected ToOldConfig to return non-nil")
|
||||
}
|
||||
|
||||
// ToOldConfig behavior depends on feature flag
|
||||
if !features.IsUnifiedConfigEnabled() {
|
||||
// When feature is disabled, returns default config
|
||||
if oldConfig.ProviderURL == "" {
|
||||
t.Log("Feature flag disabled - ToOldConfig returns default config")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// When feature is enabled, verify all fields were correctly mapped
|
||||
if oldConfig.ProviderURL != unified.Provider.IssuerURL {
|
||||
t.Errorf("Expected ProviderURL '%s', got '%s'", unified.Provider.IssuerURL, oldConfig.ProviderURL)
|
||||
}
|
||||
|
||||
if oldConfig.ClientID != unified.Provider.ClientID {
|
||||
t.Errorf("Expected ClientID '%s', got '%s'", unified.Provider.ClientID, oldConfig.ClientID)
|
||||
}
|
||||
|
||||
if oldConfig.ClientSecret != unified.Provider.ClientSecret {
|
||||
t.Errorf("Expected ClientSecret '%s', got '%s'", unified.Provider.ClientSecret, oldConfig.ClientSecret)
|
||||
}
|
||||
|
||||
if oldConfig.CallbackURL != unified.Provider.RedirectURL {
|
||||
t.Error("Expected CallbackURL to match RedirectURL")
|
||||
}
|
||||
|
||||
if oldConfig.LogoutURL != unified.Provider.LogoutURL {
|
||||
t.Error("Expected LogoutURL to match")
|
||||
}
|
||||
|
||||
if oldConfig.ForceHTTPS != unified.Security.ForceHTTPS {
|
||||
t.Error("Expected ForceHTTPS to match")
|
||||
}
|
||||
|
||||
if oldConfig.EnablePKCE != unified.Security.EnablePKCE {
|
||||
t.Error("Expected EnablePKCE to match")
|
||||
}
|
||||
|
||||
if oldConfig.RateLimit != unified.RateLimit.RequestsPerSecond {
|
||||
t.Errorf("Expected RateLimit %d, got %d", unified.RateLimit.RequestsPerSecond, oldConfig.RateLimit)
|
||||
}
|
||||
|
||||
if len(oldConfig.Headers) != 2 {
|
||||
t.Errorf("Expected 2 headers, got %d", len(oldConfig.Headers))
|
||||
}
|
||||
}
|
||||
|
||||
// convertHeaders Tests
|
||||
func TestLegacyAdapter_convertHeaders(t *testing.T) {
|
||||
unified := NewUnifiedConfig()
|
||||
unified.Middleware.CustomHeaders = map[string]string{
|
||||
"X-Custom-Header-1": "value1",
|
||||
"X-Custom-Header-2": "value2",
|
||||
"X-Custom-Header-3": "value3",
|
||||
}
|
||||
|
||||
adapter := NewLegacyAdapter(unified)
|
||||
headers := adapter.convertHeaders()
|
||||
|
||||
if len(headers) != 3 {
|
||||
t.Errorf("Expected 3 headers, got %d", len(headers))
|
||||
}
|
||||
|
||||
// Check that headers were converted
|
||||
headerMap := make(map[string]string)
|
||||
for _, h := range headers {
|
||||
headerMap[h.Name] = h.Value
|
||||
}
|
||||
|
||||
if headerMap["X-Custom-Header-1"] != "value1" {
|
||||
t.Error("Expected X-Custom-Header-1 to have value 'value1'")
|
||||
}
|
||||
|
||||
if headerMap["X-Custom-Header-2"] != "value2" {
|
||||
t.Error("Expected X-Custom-Header-2 to have value 'value2'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLegacyAdapter_convertHeaders_Empty(t *testing.T) {
|
||||
unified := NewUnifiedConfig()
|
||||
// No custom headers
|
||||
|
||||
adapter := NewLegacyAdapter(unified)
|
||||
headers := adapter.convertHeaders()
|
||||
|
||||
if len(headers) != 0 {
|
||||
t.Errorf("Expected 0 headers, got %d", len(headers))
|
||||
}
|
||||
}
|
||||
|
||||
// GetConfigInterface Tests
|
||||
func TestGetConfigInterface(t *testing.T) {
|
||||
cfg := GetConfigInterface()
|
||||
|
||||
if cfg == nil {
|
||||
t.Fatal("Expected GetConfigInterface to return non-nil")
|
||||
}
|
||||
|
||||
// Should return either UnifiedConfig or Config depending on feature flag
|
||||
_, isUnified := cfg.(*UnifiedConfig)
|
||||
_, isOld := cfg.(*Config)
|
||||
|
||||
if !isUnified && !isOld {
|
||||
t.Error("Expected either *UnifiedConfig or *Config")
|
||||
}
|
||||
|
||||
// Verify consistency with feature flag
|
||||
if features.IsUnifiedConfigEnabled() {
|
||||
if !isUnified {
|
||||
t.Error("Expected *UnifiedConfig when unified config is enabled")
|
||||
}
|
||||
} else {
|
||||
if !isOld {
|
||||
t.Error("Expected *Config when unified config is disabled")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateConfig Tests
|
||||
func TestValidateConfig_UnifiedConfig(t *testing.T) {
|
||||
unified := NewUnifiedConfig()
|
||||
unified.Provider.IssuerURL = "https://provider.example.com"
|
||||
unified.Provider.ClientID = "client-id"
|
||||
unified.Provider.ClientSecret = "client-secret"
|
||||
unified.Session.EncryptionKey = "encryption-key-32-characters!!"
|
||||
|
||||
err := ValidateConfig(unified)
|
||||
// Should succeed regardless of feature flag since we're passing the right type
|
||||
if err != nil {
|
||||
t.Errorf("Expected valid unified config to pass validation, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateConfig_OldConfig(t *testing.T) {
|
||||
old := CreateConfig()
|
||||
old.ProviderURL = "https://provider.example.com"
|
||||
old.ClientID = "client-id"
|
||||
old.ClientSecret = "client-secret"
|
||||
old.SessionEncryptionKey = "encryption-key-32-characters!!"
|
||||
|
||||
err := ValidateConfig(old)
|
||||
if err != nil {
|
||||
t.Errorf("Expected valid old config to pass validation, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateConfig_InvalidType(t *testing.T) {
|
||||
// Pass something that's not a config
|
||||
err := ValidateConfig("not a config")
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil for unknown type, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Config.Validate Tests
|
||||
func TestConfig_Validate_Valid(t *testing.T) {
|
||||
cfg := CreateConfig()
|
||||
cfg.ProviderURL = "https://provider.example.com"
|
||||
cfg.ClientID = "client-id"
|
||||
cfg.ClientSecret = "client-secret"
|
||||
cfg.SessionEncryptionKey = "encryption-key-32-characters!!"
|
||||
|
||||
err := cfg.Validate()
|
||||
if err != nil {
|
||||
t.Errorf("Expected valid config to pass, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_Validate_MissingProviderURL(t *testing.T) {
|
||||
cfg := CreateConfig()
|
||||
cfg.ClientID = "client-id"
|
||||
cfg.ClientSecret = "client-secret"
|
||||
|
||||
err := cfg.Validate()
|
||||
if err == nil {
|
||||
t.Error("Expected error for missing ProviderURL")
|
||||
}
|
||||
|
||||
// Check if it's a ValidationErrors type
|
||||
if verrs, ok := err.(ValidationErrors); ok {
|
||||
found := false
|
||||
for _, verr := range verrs {
|
||||
if verr.Field == "ProviderURL" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("Expected ProviderURL validation error")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_Validate_MissingClientID(t *testing.T) {
|
||||
cfg := CreateConfig()
|
||||
cfg.ProviderURL = "https://provider.example.com"
|
||||
cfg.ClientSecret = "client-secret"
|
||||
|
||||
err := cfg.Validate()
|
||||
if err == nil {
|
||||
t.Error("Expected error for missing ClientID")
|
||||
}
|
||||
|
||||
if verrs, ok := err.(ValidationErrors); ok {
|
||||
found := false
|
||||
for _, verr := range verrs {
|
||||
if verr.Field == "ClientID" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("Expected ClientID validation error")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_Validate_MissingClientSecret_NoPKCE(t *testing.T) {
|
||||
cfg := CreateConfig()
|
||||
cfg.ProviderURL = "https://provider.example.com"
|
||||
cfg.ClientID = "client-id"
|
||||
cfg.EnablePKCE = false
|
||||
|
||||
err := cfg.Validate()
|
||||
if err == nil {
|
||||
t.Error("Expected error for missing ClientSecret without PKCE")
|
||||
}
|
||||
|
||||
if verrs, ok := err.(ValidationErrors); ok {
|
||||
found := false
|
||||
for _, verr := range verrs {
|
||||
if verr.Field == "ClientSecret" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("Expected ClientSecret validation error")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_Validate_MissingClientSecret_WithPKCE(t *testing.T) {
|
||||
cfg := CreateConfig()
|
||||
cfg.ProviderURL = "https://provider.example.com"
|
||||
cfg.ClientID = "client-id"
|
||||
cfg.EnablePKCE = true // PKCE enabled, so ClientSecret not required
|
||||
|
||||
err := cfg.Validate()
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error with PKCE enabled and no ClientSecret, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_Validate_ShortEncryptionKey(t *testing.T) {
|
||||
cfg := CreateConfig()
|
||||
cfg.ProviderURL = "https://provider.example.com"
|
||||
cfg.ClientID = "client-id"
|
||||
cfg.ClientSecret = "client-secret"
|
||||
cfg.SessionEncryptionKey = "short" // Too short
|
||||
|
||||
err := cfg.Validate()
|
||||
if err == nil {
|
||||
t.Error("Expected error for short encryption key")
|
||||
}
|
||||
|
||||
if verrs, ok := err.(ValidationErrors); ok {
|
||||
found := false
|
||||
for _, verr := range verrs {
|
||||
if verr.Field == "SessionEncryptionKey" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("Expected SessionEncryptionKey validation error")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_Validate_MultipleErrors(t *testing.T) {
|
||||
cfg := CreateConfig()
|
||||
// Missing ProviderURL, ClientID, and ClientSecret
|
||||
|
||||
err := cfg.Validate()
|
||||
if err == nil {
|
||||
t.Fatal("Expected validation errors")
|
||||
}
|
||||
|
||||
verrs, ok := err.(ValidationErrors)
|
||||
if !ok {
|
||||
t.Fatal("Expected ValidationErrors type")
|
||||
}
|
||||
|
||||
if len(verrs) < 2 {
|
||||
t.Errorf("Expected at least 2 validation errors, got %d", len(verrs))
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,276 +0,0 @@
|
||||
// Package config provides default values and initialization for unified configuration
|
||||
package config
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// NewUnifiedConfig creates a new unified configuration with sensible defaults
|
||||
func NewUnifiedConfig() *UnifiedConfig {
|
||||
return &UnifiedConfig{
|
||||
Provider: DefaultProviderConfig(),
|
||||
Session: DefaultSessionConfig(),
|
||||
Token: DefaultTokenConfig(),
|
||||
Redis: *DefaultRedisConfig(), // Using existing DefaultRedisConfig
|
||||
Security: DefaultSecurityConfig(),
|
||||
Middleware: DefaultMiddlewareConfig(),
|
||||
Cache: DefaultCacheConfig(),
|
||||
RateLimit: DefaultRateLimitConfig(),
|
||||
Logging: DefaultLoggingConfig(),
|
||||
Metrics: DefaultMetricsConfig(),
|
||||
Health: DefaultHealthConfig(),
|
||||
Transport: DefaultTransportConfig(),
|
||||
Pool: DefaultPoolConfig(),
|
||||
Circuit: DefaultCircuitConfig(),
|
||||
Legacy: make(map[string]interface{}),
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultProviderConfig returns default provider configuration
|
||||
func DefaultProviderConfig() ProviderConfig {
|
||||
return ProviderConfig{
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
OverrideScopes: false,
|
||||
CustomClaims: make(map[string]string),
|
||||
JWKCachePeriod: 24 * time.Hour,
|
||||
MetadataCacheTTL: 24 * time.Hour,
|
||||
Discovery: true,
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultSessionConfig returns default session configuration
|
||||
func DefaultSessionConfig() SessionConfig {
|
||||
return SessionConfig{
|
||||
Name: "oidc_session",
|
||||
MaxAge: 86400, // 24 hours
|
||||
ChunkSize: 4000, // Safe size for cookies
|
||||
MaxChunks: 5,
|
||||
Path: "/",
|
||||
Secure: true,
|
||||
HttpOnly: true,
|
||||
SameSite: "Lax",
|
||||
StorageType: "cookie",
|
||||
CleanupInterval: 1 * time.Hour,
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultTokenConfig returns default token configuration
|
||||
func DefaultTokenConfig() TokenConfig {
|
||||
return TokenConfig{
|
||||
AccessTokenTTL: 1 * time.Hour,
|
||||
RefreshTokenTTL: 24 * time.Hour,
|
||||
RefreshGracePeriod: 60 * time.Second,
|
||||
ValidationMode: "jwt",
|
||||
CacheEnabled: true,
|
||||
CacheTTL: 5 * time.Minute,
|
||||
CacheNegativeTTL: 30 * time.Second,
|
||||
ValidateSignature: true,
|
||||
ValidateExpiry: true,
|
||||
ValidateAudience: true,
|
||||
ValidateIssuer: true,
|
||||
RequiredClaims: []string{"sub", "iat", "exp"},
|
||||
ClockSkew: 5 * time.Minute,
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultSecurityConfig returns default security configuration
|
||||
func DefaultSecurityConfig() SecurityConfig {
|
||||
return SecurityConfig{
|
||||
ForceHTTPS: true,
|
||||
EnablePKCE: true,
|
||||
AllowedUsers: []string{},
|
||||
AllowedUserDomains: []string{},
|
||||
AllowedRolesAndGroups: []string{},
|
||||
ExcludedURLs: []string{
|
||||
"/favicon.ico",
|
||||
"/robots.txt",
|
||||
"/health",
|
||||
"/.well-known/",
|
||||
"/metrics",
|
||||
"/ping",
|
||||
"/static/",
|
||||
"/assets/",
|
||||
"/js/",
|
||||
"/css/",
|
||||
"/images/",
|
||||
"/fonts/",
|
||||
},
|
||||
Headers: createDefaultSecurityConfig(),
|
||||
CSRFProtection: true,
|
||||
CSRFTokenName: "csrf_token",
|
||||
CSRFTokenTTL: 1 * time.Hour,
|
||||
MaxLoginAttempts: 5,
|
||||
LockoutDuration: 15 * time.Minute,
|
||||
RequireMFA: false,
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultMiddlewareConfig returns default middleware configuration
|
||||
func DefaultMiddlewareConfig() MiddlewareConfig {
|
||||
return MiddlewareConfig{
|
||||
Priority: 1000,
|
||||
SkipPaths: []string{},
|
||||
RequirePaths: []string{},
|
||||
PassthroughMode: false,
|
||||
MaxRequestSize: 10 * 1024 * 1024, // 10MB
|
||||
RequestTimeout: 30 * time.Second,
|
||||
IdleTimeout: 90 * time.Second,
|
||||
CustomHeaders: make(map[string]string),
|
||||
RemoveHeaders: []string{},
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultCacheConfig returns default cache configuration
|
||||
func DefaultCacheConfig() CacheConfig {
|
||||
return CacheConfig{
|
||||
Enabled: true,
|
||||
Type: "memory",
|
||||
DefaultTTL: 5 * time.Minute,
|
||||
MaxEntries: 10000,
|
||||
MaxEntrySize: 1024 * 1024, // 1MB
|
||||
EvictionPolicy: "lru",
|
||||
CleanupInterval: 10 * time.Minute,
|
||||
Namespace: "traefikoidc",
|
||||
Compression: false,
|
||||
Serialization: "json",
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultRateLimitConfig returns default rate limiting configuration
|
||||
func DefaultRateLimitConfig() RateLimitConfig {
|
||||
return RateLimitConfig{
|
||||
Enabled: false,
|
||||
RequestsPerSecond: 10,
|
||||
Burst: 20,
|
||||
StorageType: "memory",
|
||||
WindowDuration: 1 * time.Minute,
|
||||
KeyType: "ip",
|
||||
CustomKeyFunc: "",
|
||||
WhitelistIPs: []string{},
|
||||
WhitelistUsers: []string{},
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultLoggingConfig returns default logging configuration
|
||||
func DefaultLoggingConfig() LoggingConfig {
|
||||
return LoggingConfig{
|
||||
Level: "info",
|
||||
Format: "json",
|
||||
Output: "stdout",
|
||||
FilePath: "",
|
||||
FilterSensitive: true,
|
||||
MaskFields: []string{
|
||||
"password",
|
||||
"secret",
|
||||
"token",
|
||||
"key",
|
||||
"authorization",
|
||||
"cookie",
|
||||
},
|
||||
BufferSize: 8192,
|
||||
FlushInterval: 5 * time.Second,
|
||||
AuditEnabled: false,
|
||||
AuditEvents: []string{
|
||||
"login",
|
||||
"logout",
|
||||
"token_refresh",
|
||||
"auth_failure",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultMetricsConfig returns default metrics configuration
|
||||
func DefaultMetricsConfig() MetricsConfig {
|
||||
return MetricsConfig{
|
||||
Enabled: false,
|
||||
Provider: "prometheus",
|
||||
Endpoint: "/metrics",
|
||||
Namespace: "traefikoidc",
|
||||
Subsystem: "middleware",
|
||||
CollectInterval: 10 * time.Second,
|
||||
Histograms: true,
|
||||
Labels: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultHealthConfig returns default health check configuration
|
||||
func DefaultHealthConfig() HealthConfig {
|
||||
return HealthConfig{
|
||||
Enabled: true,
|
||||
Path: "/health",
|
||||
CheckInterval: 30 * time.Second,
|
||||
Timeout: 5 * time.Second,
|
||||
CheckProvider: true,
|
||||
CheckRedis: true,
|
||||
CheckCache: true,
|
||||
MaxLatency: 1 * time.Second,
|
||||
MinMemory: 100 * 1024 * 1024, // 100MB
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultTransportConfig returns default HTTP transport configuration
|
||||
func DefaultTransportConfig() TransportConfig {
|
||||
return TransportConfig{
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
MaxConnsPerHost: 0, // No limit
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
ResponseHeaderTimeout: 10 * time.Second,
|
||||
DisableKeepAlives: false,
|
||||
DisableCompression: false,
|
||||
TLSInsecureSkipVerify: false,
|
||||
TLSMinVersion: "TLS1.2",
|
||||
TLSCipherSuites: []string{},
|
||||
ProxyURL: "",
|
||||
NoProxy: []string{},
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultPoolConfig returns default connection pool configuration
|
||||
func DefaultPoolConfig() PoolConfig {
|
||||
return PoolConfig{
|
||||
Enabled: true,
|
||||
Size: 10,
|
||||
MinSize: 2,
|
||||
MaxSize: 50,
|
||||
MaxAge: 30 * time.Minute,
|
||||
IdleTimeout: 5 * time.Minute,
|
||||
WaitTimeout: 5 * time.Second,
|
||||
HealthCheckInterval: 30 * time.Second,
|
||||
MaxRetries: 3,
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultCircuitConfig returns default circuit breaker configuration
|
||||
func DefaultCircuitConfig() CircuitConfig {
|
||||
return CircuitConfig{
|
||||
Enabled: true,
|
||||
MaxRequests: 100,
|
||||
Interval: 10 * time.Second,
|
||||
Timeout: 60 * time.Second,
|
||||
ConsecutiveFailures: 5,
|
||||
FailureRatio: 0.5,
|
||||
OnOpen: "reject",
|
||||
OnHalfOpen: "passthrough",
|
||||
MetricsEnabled: true,
|
||||
LogStateChanges: true,
|
||||
}
|
||||
}
|
||||
|
||||
// MergeWithDefaults merges a partial configuration with defaults
|
||||
func MergeWithDefaults(partial *UnifiedConfig) *UnifiedConfig {
|
||||
if partial == nil {
|
||||
return NewUnifiedConfig()
|
||||
}
|
||||
|
||||
// Ensure Legacy field is initialized
|
||||
if partial.Legacy == nil {
|
||||
partial.Legacy = make(map[string]interface{})
|
||||
}
|
||||
|
||||
// TODO: Implement deep merge logic with defaults
|
||||
// For now, just return the partial config
|
||||
return partial
|
||||
}
|
||||
@@ -1,396 +0,0 @@
|
||||
// Package config provides configuration loading and merging logic
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/features"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// ConfigLoader handles loading configuration from various sources
|
||||
type ConfigLoader struct {
|
||||
migrator *ConfigMigrator
|
||||
envPrefix string
|
||||
configPaths []string
|
||||
}
|
||||
|
||||
// NewConfigLoader creates a new configuration loader
|
||||
func NewConfigLoader() *ConfigLoader {
|
||||
return &ConfigLoader{
|
||||
migrator: NewConfigMigrator(),
|
||||
envPrefix: "TRAEFIKOIDC_",
|
||||
configPaths: getDefaultConfigPaths(),
|
||||
}
|
||||
}
|
||||
|
||||
// getDefaultConfigPaths returns default configuration file paths to check
|
||||
func getDefaultConfigPaths() []string {
|
||||
return []string{
|
||||
"traefik-oidc.yaml",
|
||||
"traefik-oidc.yml",
|
||||
"traefik-oidc.json",
|
||||
"config.yaml",
|
||||
"config.yml",
|
||||
"config.json",
|
||||
"/etc/traefik-oidc/config.yaml",
|
||||
"/etc/traefik-oidc/config.json",
|
||||
}
|
||||
}
|
||||
|
||||
// Load loads configuration from all available sources
|
||||
func (l *ConfigLoader) Load() (*UnifiedConfig, error) {
|
||||
// Start with defaults
|
||||
config := NewUnifiedConfig()
|
||||
|
||||
// Try to load from file
|
||||
if fileConfig, err := l.LoadFromFile(); err == nil && fileConfig != nil {
|
||||
config = l.mergeConfigs(config, fileConfig)
|
||||
}
|
||||
|
||||
// Load from environment variables
|
||||
l.LoadFromEnv(config)
|
||||
|
||||
// Validate the final configuration
|
||||
if err := config.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("configuration validation failed: %w", err)
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// LoadFromFile loads configuration from a file
|
||||
func (l *ConfigLoader) LoadFromFile(paths ...string) (*UnifiedConfig, error) {
|
||||
// Use provided paths or default paths
|
||||
searchPaths := paths
|
||||
if len(searchPaths) == 0 {
|
||||
searchPaths = l.configPaths
|
||||
}
|
||||
|
||||
// Check for config file in environment variable
|
||||
if envPath := os.Getenv(l.envPrefix + "CONFIG_FILE"); envPath != "" {
|
||||
searchPaths = append([]string{envPath}, searchPaths...)
|
||||
}
|
||||
|
||||
// Try each path
|
||||
for _, path := range searchPaths {
|
||||
if _, err := os.Stat(path); err == nil {
|
||||
return l.loadFile(path)
|
||||
}
|
||||
}
|
||||
|
||||
// No config file found, not an error (use defaults)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// loadFile loads a specific configuration file
|
||||
func (l *ConfigLoader) loadFile(path string) (*UnifiedConfig, error) {
|
||||
// Clean and validate path to prevent traversal attacks
|
||||
cleanPath := filepath.Clean(path)
|
||||
|
||||
// Check for path traversal attempts
|
||||
if strings.Contains(cleanPath, "..") {
|
||||
return nil, fmt.Errorf("invalid config path: potential path traversal detected in %s", path)
|
||||
}
|
||||
|
||||
// Ensure the path is within expected directories (current dir or subdirs)
|
||||
absPath, err := filepath.Abs(cleanPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to resolve absolute path for %s: %w", path, err)
|
||||
}
|
||||
|
||||
// Read the file with validated path
|
||||
data, err := os.ReadFile(absPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read config file %s: %w", absPath, err)
|
||||
}
|
||||
|
||||
// Check if unified config is enabled
|
||||
if features.IsUnifiedConfigEnabled() {
|
||||
// Use migrator to handle any version
|
||||
config, warnings, err := l.migrator.Migrate(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to migrate config from %s: %w", path, err)
|
||||
}
|
||||
|
||||
// Log warnings
|
||||
for _, warning := range warnings {
|
||||
// In production, use proper logging
|
||||
fmt.Printf("Config Warning (%s): %s\n", path, warning)
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// Legacy path: load old config and convert
|
||||
ext := strings.ToLower(filepath.Ext(path))
|
||||
var oldConfig Config
|
||||
|
||||
switch ext {
|
||||
case ".json":
|
||||
if err := json.Unmarshal(data, &oldConfig); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse JSON config: %w", err)
|
||||
}
|
||||
case ".yaml", ".yml":
|
||||
if err := yaml.Unmarshal(data, &oldConfig); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse YAML config: %w", err)
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported config file extension: %s", ext)
|
||||
}
|
||||
|
||||
return FromOldConfig(&oldConfig), nil
|
||||
}
|
||||
|
||||
// LoadFromEnv loads configuration from environment variables
|
||||
func (l *ConfigLoader) LoadFromEnv(config *UnifiedConfig) {
|
||||
// Provider configuration
|
||||
l.loadEnvString(&config.Provider.IssuerURL, "PROVIDER_ISSUER_URL", "PROVIDER_URL")
|
||||
l.loadEnvString(&config.Provider.ClientID, "PROVIDER_CLIENT_ID", "CLIENT_ID")
|
||||
l.loadEnvString(&config.Provider.ClientSecret, "PROVIDER_CLIENT_SECRET", "CLIENT_SECRET")
|
||||
l.loadEnvString(&config.Provider.RedirectURL, "PROVIDER_REDIRECT_URL", "CALLBACK_URL")
|
||||
l.loadEnvString(&config.Provider.LogoutURL, "PROVIDER_LOGOUT_URL", "LOGOUT_URL")
|
||||
l.loadEnvString(&config.Provider.PostLogoutRedirectURI, "PROVIDER_POST_LOGOUT_URI", "POST_LOGOUT_REDIRECT_URI")
|
||||
l.loadEnvStringSlice(&config.Provider.Scopes, "PROVIDER_SCOPES", "SCOPES")
|
||||
l.loadEnvBool(&config.Provider.OverrideScopes, "PROVIDER_OVERRIDE_SCOPES", "OVERRIDE_SCOPES")
|
||||
|
||||
// Session configuration
|
||||
l.loadEnvString(&config.Session.Name, "SESSION_NAME")
|
||||
l.loadEnvInt(&config.Session.MaxAge, "SESSION_MAX_AGE")
|
||||
l.loadEnvString(&config.Session.Secret, "SESSION_SECRET")
|
||||
l.loadEnvString(&config.Session.EncryptionKey, "SESSION_ENCRYPTION_KEY")
|
||||
l.loadEnvString(&config.Session.Domain, "SESSION_DOMAIN", "COOKIE_DOMAIN")
|
||||
l.loadEnvBool(&config.Session.Secure, "SESSION_SECURE")
|
||||
l.loadEnvBool(&config.Session.HttpOnly, "SESSION_HTTP_ONLY")
|
||||
l.loadEnvString(&config.Session.SameSite, "SESSION_SAME_SITE")
|
||||
|
||||
// Security configuration
|
||||
l.loadEnvBool(&config.Security.ForceHTTPS, "SECURITY_FORCE_HTTPS", "FORCE_HTTPS")
|
||||
l.loadEnvBool(&config.Security.EnablePKCE, "SECURITY_ENABLE_PKCE", "ENABLE_PKCE")
|
||||
l.loadEnvStringSlice(&config.Security.AllowedUsers, "SECURITY_ALLOWED_USERS", "ALLOWED_USERS")
|
||||
l.loadEnvStringSlice(&config.Security.AllowedUserDomains, "SECURITY_ALLOWED_DOMAINS", "ALLOWED_USER_DOMAINS")
|
||||
l.loadEnvStringSlice(&config.Security.AllowedRolesAndGroups, "SECURITY_ALLOWED_ROLES", "ALLOWED_ROLES_AND_GROUPS")
|
||||
l.loadEnvStringSlice(&config.Security.ExcludedURLs, "SECURITY_EXCLUDED_URLS", "EXCLUDED_URLS")
|
||||
|
||||
// Cache configuration
|
||||
l.loadEnvBool(&config.Cache.Enabled, "CACHE_ENABLED")
|
||||
l.loadEnvString(&config.Cache.Type, "CACHE_TYPE")
|
||||
l.loadEnvInt(&config.Cache.MaxEntries, "CACHE_MAX_ENTRIES")
|
||||
// MaxEntrySize is int64, skip for now
|
||||
|
||||
// Rate limiting
|
||||
l.loadEnvBool(&config.RateLimit.Enabled, "RATELIMIT_ENABLED")
|
||||
l.loadEnvInt(&config.RateLimit.RequestsPerSecond, "RATELIMIT_RPS", "RATE_LIMIT")
|
||||
l.loadEnvInt(&config.RateLimit.Burst, "RATELIMIT_BURST")
|
||||
|
||||
// Logging
|
||||
l.loadEnvString(&config.Logging.Level, "LOGGING_LEVEL", "LOG_LEVEL")
|
||||
l.loadEnvString(&config.Logging.Format, "LOGGING_FORMAT")
|
||||
l.loadEnvString(&config.Logging.Output, "LOGGING_OUTPUT")
|
||||
|
||||
// Redis configuration (already handled by its own LoadFromEnv)
|
||||
config.Redis.LoadFromEnv()
|
||||
|
||||
// Feature flags
|
||||
features.GetManager().LoadFromEnv()
|
||||
}
|
||||
|
||||
// Helper methods for environment variable loading
|
||||
|
||||
func (l *ConfigLoader) loadEnvString(target *string, keys ...string) {
|
||||
for _, key := range keys {
|
||||
if value := os.Getenv(l.envPrefix + key); value != "" {
|
||||
*target = value
|
||||
return
|
||||
}
|
||||
// Try without prefix
|
||||
if value := os.Getenv(key); value != "" {
|
||||
*target = value
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *ConfigLoader) loadEnvBool(target *bool, keys ...string) {
|
||||
for _, key := range keys {
|
||||
if value := os.Getenv(l.envPrefix + key); value != "" {
|
||||
*target = strings.ToLower(value) == "true" || value == "1"
|
||||
return
|
||||
}
|
||||
// Try without prefix
|
||||
if value := os.Getenv(key); value != "" {
|
||||
*target = strings.ToLower(value) == "true" || value == "1"
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *ConfigLoader) loadEnvInt(target *int, keys ...string) {
|
||||
for _, key := range keys {
|
||||
if value := os.Getenv(l.envPrefix + key); value != "" {
|
||||
var i int
|
||||
if _, err := fmt.Sscanf(value, "%d", &i); err == nil {
|
||||
*target = i
|
||||
return
|
||||
}
|
||||
}
|
||||
// Try without prefix
|
||||
if value := os.Getenv(key); value != "" {
|
||||
var i int
|
||||
if _, err := fmt.Sscanf(value, "%d", &i); err == nil {
|
||||
*target = i
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *ConfigLoader) loadEnvStringSlice(target *[]string, keys ...string) {
|
||||
for _, key := range keys {
|
||||
if value := os.Getenv(l.envPrefix + key); value != "" {
|
||||
*target = splitAndTrim(value)
|
||||
return
|
||||
}
|
||||
// Try without prefix
|
||||
if value := os.Getenv(key); value != "" {
|
||||
*target = splitAndTrim(value)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func splitAndTrim(s string) []string {
|
||||
parts := strings.Split(s, ",")
|
||||
result := make([]string, 0, len(parts))
|
||||
for _, part := range parts {
|
||||
if trimmed := strings.TrimSpace(part); trimmed != "" {
|
||||
result = append(result, trimmed)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// mergeConfigs merges two configurations, with source overriding target
|
||||
func (l *ConfigLoader) mergeConfigs(target, source *UnifiedConfig) *UnifiedConfig {
|
||||
if source == nil {
|
||||
return target
|
||||
}
|
||||
if target == nil {
|
||||
return source
|
||||
}
|
||||
|
||||
// Use reflection for deep merge
|
||||
l.mergeStructs(reflect.ValueOf(target).Elem(), reflect.ValueOf(source).Elem())
|
||||
|
||||
return target
|
||||
}
|
||||
|
||||
// mergeStructs recursively merges two structs
|
||||
func (l *ConfigLoader) mergeStructs(target, source reflect.Value) {
|
||||
for i := 0; i < source.NumField(); i++ {
|
||||
sourceField := source.Field(i)
|
||||
targetField := target.Field(i)
|
||||
|
||||
// Skip if source field is zero value
|
||||
if isZeroValue(sourceField) {
|
||||
continue
|
||||
}
|
||||
|
||||
switch sourceField.Kind() {
|
||||
case reflect.Struct:
|
||||
// Recursively merge structs
|
||||
l.mergeStructs(targetField, sourceField)
|
||||
case reflect.Slice:
|
||||
// Replace slice if source has values
|
||||
if sourceField.Len() > 0 {
|
||||
targetField.Set(sourceField)
|
||||
}
|
||||
case reflect.Map:
|
||||
// Merge maps
|
||||
if !sourceField.IsNil() {
|
||||
if targetField.IsNil() {
|
||||
targetField.Set(reflect.MakeMap(sourceField.Type()))
|
||||
}
|
||||
for _, key := range sourceField.MapKeys() {
|
||||
targetField.SetMapIndex(key, sourceField.MapIndex(key))
|
||||
}
|
||||
}
|
||||
default:
|
||||
// Replace value
|
||||
targetField.Set(sourceField)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// isZeroValue checks if a reflect.Value is a zero value
|
||||
func isZeroValue(v reflect.Value) bool {
|
||||
switch v.Kind() {
|
||||
case reflect.Ptr, reflect.Interface:
|
||||
return v.IsNil()
|
||||
case reflect.Slice, reflect.Map:
|
||||
return v.IsNil() || v.Len() == 0
|
||||
case reflect.Struct:
|
||||
// Check if all fields are zero
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
if !isZeroValue(v.Field(i)) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
default:
|
||||
zero := reflect.Zero(v.Type())
|
||||
return reflect.DeepEqual(v.Interface(), zero.Interface())
|
||||
}
|
||||
}
|
||||
|
||||
// SaveToFile saves the configuration to a file
|
||||
func (l *ConfigLoader) SaveToFile(config *UnifiedConfig, path string) error {
|
||||
// Clean and validate path to prevent traversal attacks
|
||||
cleanPath := filepath.Clean(path)
|
||||
|
||||
// Check for path traversal attempts
|
||||
if strings.Contains(cleanPath, "..") {
|
||||
return fmt.Errorf("invalid config path: potential path traversal detected in %s", path)
|
||||
}
|
||||
|
||||
// Ensure the path is within expected directories
|
||||
absPath, err := filepath.Abs(cleanPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to resolve absolute path for %s: %w", path, err)
|
||||
}
|
||||
|
||||
ext := strings.ToLower(filepath.Ext(absPath))
|
||||
|
||||
var data []byte
|
||||
|
||||
switch ext {
|
||||
case ".json":
|
||||
data, err = json.MarshalIndent(config, "", " ")
|
||||
case ".yaml", ".yml":
|
||||
data, err = yaml.Marshal(config)
|
||||
default:
|
||||
return fmt.Errorf("unsupported file extension: %s", ext)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal config: %w", err)
|
||||
}
|
||||
|
||||
// Create directory if it doesn't exist with secure permissions
|
||||
dir := filepath.Dir(absPath)
|
||||
if err := os.MkdirAll(dir, 0700); err != nil {
|
||||
return fmt.Errorf("failed to create directory %s: %w", dir, err)
|
||||
}
|
||||
|
||||
// Write file with secure permissions (owner read/write only)
|
||||
if err := os.WriteFile(absPath, data, 0600); err != nil {
|
||||
return fmt.Errorf("failed to write config file %s: %w", absPath, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,832 +0,0 @@
|
||||
//go:build !yaegi
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestConfigLoader tests the config loader functionality
|
||||
func TestConfigLoader(t *testing.T) {
|
||||
loader := NewConfigLoader()
|
||||
|
||||
if loader == nil {
|
||||
t.Fatal("NewConfigLoader should not return nil")
|
||||
}
|
||||
|
||||
if loader.migrator == nil {
|
||||
t.Error("ConfigLoader should have a migrator")
|
||||
}
|
||||
|
||||
if loader.envPrefix != "TRAEFIKOIDC_" {
|
||||
t.Errorf("Expected envPrefix to be 'TRAEFIKOIDC_', got %s", loader.envPrefix)
|
||||
}
|
||||
|
||||
if len(loader.configPaths) == 0 {
|
||||
t.Error("ConfigLoader should have default config paths")
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoadFromEnv tests loading configuration from environment variables
|
||||
func TestLoadFromEnv(t *testing.T) {
|
||||
// Set up test environment variables
|
||||
testEnvVars := map[string]string{
|
||||
"TRAEFIKOIDC_PROVIDER_ISSUER_URL": "https://test.example.com",
|
||||
"TRAEFIKOIDC_PROVIDER_CLIENT_ID": "test-client-id",
|
||||
"TRAEFIKOIDC_PROVIDER_CLIENT_SECRET": "test-secret",
|
||||
"TRAEFIKOIDC_SESSION_ENCRYPTION_KEY": "32-character-encryption-key-12345",
|
||||
"TRAEFIKOIDC_SESSION_CHUNKED": "true",
|
||||
"TRAEFIKOIDC_REDIS_ENABLED": "true",
|
||||
"TRAEFIKOIDC_REDIS_ADDR": "redis.example.com:6379",
|
||||
"TRAEFIKOIDC_SECURITY_FORCE_HTTPS": "true",
|
||||
"TRAEFIKOIDC_CACHE_ENABLED": "true",
|
||||
"TRAEFIKOIDC_CACHE_TYPE": "redis",
|
||||
"TRAEFIKOIDC_RATELIMIT_ENABLED": "true",
|
||||
"TRAEFIKOIDC_RATELIMIT_RPS": "100",
|
||||
}
|
||||
|
||||
// Set environment variables
|
||||
for key, value := range testEnvVars {
|
||||
os.Setenv(key, value)
|
||||
defer os.Unsetenv(key)
|
||||
}
|
||||
|
||||
loader := NewConfigLoader()
|
||||
config := &UnifiedConfig{}
|
||||
loader.LoadFromEnv(config)
|
||||
|
||||
// Verify values were loaded
|
||||
if config.Provider.IssuerURL != "https://test.example.com" {
|
||||
t.Errorf("Expected IssuerURL to be 'https://test.example.com', got %s", config.Provider.IssuerURL)
|
||||
}
|
||||
if config.Provider.ClientID != "test-client-id" {
|
||||
t.Errorf("Expected ClientID to be 'test-client-id', got %s", config.Provider.ClientID)
|
||||
}
|
||||
if config.Provider.ClientSecret != "test-secret" {
|
||||
t.Errorf("Expected ClientSecret to be 'test-secret', got %s", config.Provider.ClientSecret)
|
||||
}
|
||||
if config.Session.EncryptionKey != "32-character-encryption-key-12345" {
|
||||
t.Errorf("Expected EncryptionKey to be set, got %s", config.Session.EncryptionKey)
|
||||
}
|
||||
if !config.Security.ForceHTTPS {
|
||||
t.Error("Expected ForceHTTPS to be true")
|
||||
}
|
||||
if !config.Cache.Enabled {
|
||||
t.Error("Expected Cache to be enabled")
|
||||
}
|
||||
if config.Cache.Type != "redis" {
|
||||
t.Errorf("Expected Cache.Type to be 'redis', got %s", config.Cache.Type)
|
||||
}
|
||||
if !config.RateLimit.Enabled {
|
||||
t.Error("Expected RateLimit to be enabled")
|
||||
}
|
||||
if config.RateLimit.RequestsPerSecond != 100 {
|
||||
t.Errorf("Expected RequestsPerSecond to be 100, got %d", config.RateLimit.RequestsPerSecond)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSaveToFile tests saving configuration to files
|
||||
func TestSaveToFile(t *testing.T) {
|
||||
// Create a temporary directory for test files
|
||||
tmpDir, err := os.MkdirTemp("", "config-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
loader := NewConfigLoader()
|
||||
config := &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://auth.example.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "secret",
|
||||
},
|
||||
Session: SessionConfig{
|
||||
EncryptionKey: "32-character-encryption-key-12345",
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
filename string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "save as JSON",
|
||||
filename: "config.json",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "save as YAML",
|
||||
filename: "config.yaml",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "save as YML",
|
||||
filename: "config.yml",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "unsupported extension",
|
||||
filename: "config.txt",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "path traversal attempt",
|
||||
filename: "../../../etc/config.json",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
filePath := filepath.Join(tmpDir, tt.filename)
|
||||
err := loader.SaveToFile(config, filePath)
|
||||
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Error("Expected error but got none")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify file was created with correct permissions
|
||||
info, err := os.Stat(filePath)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to stat saved file: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Check file permissions (should be 0600)
|
||||
mode := info.Mode().Perm()
|
||||
if mode != 0600 {
|
||||
t.Errorf("Expected file permissions 0600, got %o", mode)
|
||||
}
|
||||
|
||||
// Verify content can be read back
|
||||
data, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to read saved file: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify secrets are redacted
|
||||
content := string(data)
|
||||
if strings.Contains(content, "secret") && !strings.Contains(content, "[REDACTED]") {
|
||||
t.Error("Secrets should be redacted in saved file")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoadFile tests loading configuration from files
|
||||
func TestLoadFile(t *testing.T) {
|
||||
// Create a temporary directory for test files
|
||||
tmpDir, err := os.MkdirTemp("", "config-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
// Test data - using old config format since unified config is not enabled by default
|
||||
jsonConfig := `{
|
||||
"providerURL": "https://auth.example.com",
|
||||
"clientID": "test-client",
|
||||
"clientSecret": "secret",
|
||||
"sessionEncryptionKey": "32-character-encryption-key-12345"
|
||||
}`
|
||||
|
||||
yamlConfig := `
|
||||
providerurl: https://auth.example.com
|
||||
clientid: test-client
|
||||
clientsecret: secret
|
||||
sessionencryptionkey: 32-character-encryption-key-12345
|
||||
`
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
filename string
|
||||
content string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "load JSON config",
|
||||
filename: "config.json",
|
||||
content: jsonConfig,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "load YAML config",
|
||||
filename: "config.yaml",
|
||||
content: yamlConfig,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "path traversal attempt",
|
||||
filename: "../../../etc/passwd",
|
||||
content: "",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "non-existent file",
|
||||
filename: "does-not-exist.json",
|
||||
content: "",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
loader := NewConfigLoader()
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var filePath string
|
||||
if tt.content != "" {
|
||||
filePath = filepath.Join(tmpDir, tt.filename)
|
||||
err := os.WriteFile(filePath, []byte(tt.content), 0600)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write test file: %v", err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
filePath = tt.filename
|
||||
}
|
||||
|
||||
config, err := loader.loadFile(filePath)
|
||||
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Error("Expected error but got none")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) && !strings.Contains(err.Error(), "no such file") {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Verify loaded config
|
||||
if config == nil {
|
||||
t.Error("Expected config to be loaded")
|
||||
return
|
||||
}
|
||||
|
||||
if config.Provider.IssuerURL != "https://auth.example.com" {
|
||||
t.Errorf("Expected IssuerURL to be 'https://auth.example.com', got %s", config.Provider.IssuerURL)
|
||||
}
|
||||
if config.Provider.ClientID != "test-client" {
|
||||
t.Errorf("Expected ClientID to be 'test-client', got %s", config.Provider.ClientID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ====================================================================================
|
||||
// Tests for untested functions (0% coverage)
|
||||
// ====================================================================================
|
||||
|
||||
// TestConfigLoader_Load tests the full Load pipeline
|
||||
func TestConfigLoader_Load(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "config-load-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
// Create a test config file
|
||||
configPath := filepath.Join(tmpDir, "traefik-oidc.json")
|
||||
configData := `{
|
||||
"providerURL": "https://auth.example.com",
|
||||
"clientID": "test-client",
|
||||
"clientSecret": "test-secret",
|
||||
"sessionEncryptionKey": "32-character-encryption-key-12345"
|
||||
}`
|
||||
err = os.WriteFile(configPath, []byte(configData), 0600)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write test config file: %v", err)
|
||||
}
|
||||
|
||||
// Change to temp directory so loader can find the config
|
||||
oldDir, _ := os.Getwd()
|
||||
os.Chdir(tmpDir)
|
||||
defer os.Chdir(oldDir)
|
||||
|
||||
// Set some environment variables to test merging
|
||||
os.Setenv("TRAEFIKOIDC_SECURITY_FORCE_HTTPS", "true")
|
||||
defer os.Unsetenv("TRAEFIKOIDC_SECURITY_FORCE_HTTPS")
|
||||
|
||||
loader := NewConfigLoader()
|
||||
config, err := loader.Load()
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Load() failed: %v", err)
|
||||
}
|
||||
|
||||
if config == nil {
|
||||
t.Fatal("Load() returned nil config")
|
||||
}
|
||||
|
||||
// Verify file was loaded
|
||||
if config.Provider.IssuerURL != "https://auth.example.com" {
|
||||
t.Errorf("Expected IssuerURL from file, got %s", config.Provider.IssuerURL)
|
||||
}
|
||||
|
||||
// Verify env vars were loaded
|
||||
if !config.Security.ForceHTTPS {
|
||||
t.Error("Expected ForceHTTPS from env var to be true")
|
||||
}
|
||||
}
|
||||
|
||||
// TestConfigLoader_LoadFromFile tests the LoadFromFile function
|
||||
func TestConfigLoader_LoadFromFile(t *testing.T) {
|
||||
t.Run("NoConfigFile", func(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "config-nofile-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
oldDir, _ := os.Getwd()
|
||||
os.Chdir(tmpDir)
|
||||
defer os.Chdir(oldDir)
|
||||
|
||||
loader := NewConfigLoader()
|
||||
config, err := loader.LoadFromFile()
|
||||
|
||||
// Should not error when no config file found
|
||||
if err != nil {
|
||||
t.Errorf("LoadFromFile() should not error when no file found: %v", err)
|
||||
}
|
||||
|
||||
// Should return nil config
|
||||
if config != nil {
|
||||
t.Error("LoadFromFile() should return nil config when no file found")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("LoadFromEnvPath", func(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "config-envpath-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
// Create config file
|
||||
configPath := filepath.Join(tmpDir, "custom-config.json")
|
||||
configData := `{
|
||||
"providerURL": "https://custom.example.com",
|
||||
"clientID": "custom-client"
|
||||
}`
|
||||
err = os.WriteFile(configPath, []byte(configData), 0600)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write test config: %v", err)
|
||||
}
|
||||
|
||||
// Set env variable pointing to config
|
||||
os.Setenv("TRAEFIKOIDC_CONFIG_FILE", configPath)
|
||||
defer os.Unsetenv("TRAEFIKOIDC_CONFIG_FILE")
|
||||
|
||||
loader := NewConfigLoader()
|
||||
config, err := loader.LoadFromFile()
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("LoadFromFile() failed: %v", err)
|
||||
}
|
||||
|
||||
if config == nil {
|
||||
t.Fatal("LoadFromFile() returned nil config")
|
||||
}
|
||||
|
||||
if config.Provider.IssuerURL != "https://custom.example.com" {
|
||||
t.Errorf("Expected IssuerURL 'https://custom.example.com', got %s", config.Provider.IssuerURL)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("LoadWithProvidedPaths", func(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "config-provided-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
// Create config file
|
||||
configPath := filepath.Join(tmpDir, "specific.json")
|
||||
configData := `{
|
||||
"providerURL": "https://specific.example.com",
|
||||
"clientID": "specific-client"
|
||||
}`
|
||||
err = os.WriteFile(configPath, []byte(configData), 0600)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write test config: %v", err)
|
||||
}
|
||||
|
||||
loader := NewConfigLoader()
|
||||
config, err := loader.LoadFromFile(configPath)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("LoadFromFile() with path failed: %v", err)
|
||||
}
|
||||
|
||||
if config == nil {
|
||||
t.Fatal("LoadFromFile() returned nil config")
|
||||
}
|
||||
|
||||
if config.Provider.IssuerURL != "https://specific.example.com" {
|
||||
t.Errorf("Expected IssuerURL 'https://specific.example.com', got %s", config.Provider.IssuerURL)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestSplitAndTrim tests the splitAndTrim helper function
|
||||
func TestSplitAndTrim(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
name: "Simple comma-separated",
|
||||
input: "a,b,c",
|
||||
expected: []string{"a", "b", "c"},
|
||||
},
|
||||
{
|
||||
name: "With spaces",
|
||||
input: "a, b , c",
|
||||
expected: []string{"a", "b", "c"},
|
||||
},
|
||||
{
|
||||
name: "Empty strings filtered out",
|
||||
input: "a,,b, ,c",
|
||||
expected: []string{"a", "b", "c"},
|
||||
},
|
||||
{
|
||||
name: "Leading and trailing spaces",
|
||||
input: " a , b , c ",
|
||||
expected: []string{"a", "b", "c"},
|
||||
},
|
||||
{
|
||||
name: "Single value",
|
||||
input: "single",
|
||||
expected: []string{"single"},
|
||||
},
|
||||
{
|
||||
name: "Empty string",
|
||||
input: "",
|
||||
expected: []string{},
|
||||
},
|
||||
{
|
||||
name: "Only commas and spaces",
|
||||
input: " , , , ",
|
||||
expected: []string{},
|
||||
},
|
||||
{
|
||||
name: "Complex real-world example",
|
||||
input: "openid, profile, email, groups",
|
||||
expected: []string{"openid", "profile", "email", "groups"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := splitAndTrim(tt.input)
|
||||
|
||||
if len(result) != len(tt.expected) {
|
||||
t.Errorf("Expected %d items, got %d: %v", len(tt.expected), len(result), result)
|
||||
return
|
||||
}
|
||||
|
||||
for i, expected := range tt.expected {
|
||||
if result[i] != expected {
|
||||
t.Errorf("At index %d: expected %q, got %q", i, expected, result[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestConfigLoader_MergeConfigs tests the mergeConfigs function
|
||||
func TestConfigLoader_MergeConfigs(t *testing.T) {
|
||||
loader := NewConfigLoader()
|
||||
|
||||
t.Run("MergeNilSource", func(t *testing.T) {
|
||||
target := &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://target.example.com",
|
||||
},
|
||||
}
|
||||
|
||||
result := loader.mergeConfigs(target, nil)
|
||||
|
||||
if result != target {
|
||||
t.Error("mergeConfigs should return target when source is nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("MergeNilTarget", func(t *testing.T) {
|
||||
source := &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://source.example.com",
|
||||
},
|
||||
}
|
||||
|
||||
result := loader.mergeConfigs(nil, source)
|
||||
|
||||
if result != source {
|
||||
t.Error("mergeConfigs should return source when target is nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("MergeSimpleFields", func(t *testing.T) {
|
||||
target := &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://target.example.com",
|
||||
ClientID: "",
|
||||
},
|
||||
}
|
||||
|
||||
source := &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://source.example.com",
|
||||
ClientID: "source-client",
|
||||
},
|
||||
}
|
||||
|
||||
result := loader.mergeConfigs(target, source)
|
||||
|
||||
if result.Provider.IssuerURL != "https://source.example.com" {
|
||||
t.Errorf("Expected IssuerURL to be overridden, got %s", result.Provider.IssuerURL)
|
||||
}
|
||||
|
||||
if result.Provider.ClientID != "source-client" {
|
||||
t.Errorf("Expected ClientID to be set, got %s", result.Provider.ClientID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("MergeSlices", func(t *testing.T) {
|
||||
target := &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
Scopes: []string{"openid", "profile"},
|
||||
},
|
||||
}
|
||||
|
||||
source := &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
Scopes: []string{"email", "groups"},
|
||||
},
|
||||
}
|
||||
|
||||
result := loader.mergeConfigs(target, source)
|
||||
|
||||
// Source slice should replace target slice
|
||||
if len(result.Provider.Scopes) != 2 {
|
||||
t.Errorf("Expected 2 scopes, got %d", len(result.Provider.Scopes))
|
||||
}
|
||||
|
||||
if result.Provider.Scopes[0] != "email" {
|
||||
t.Errorf("Expected first scope 'email', got %s", result.Provider.Scopes[0])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("MergeMaps", func(t *testing.T) {
|
||||
target := &UnifiedConfig{
|
||||
Middleware: MiddlewareConfig{
|
||||
CustomHeaders: map[string]string{
|
||||
"X-Target-Header": "target-value",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
source := &UnifiedConfig{
|
||||
Middleware: MiddlewareConfig{
|
||||
CustomHeaders: map[string]string{
|
||||
"X-Source-Header": "source-value",
|
||||
"X-Target-Header": "overridden-value",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := loader.mergeConfigs(target, source)
|
||||
|
||||
if len(result.Middleware.CustomHeaders) != 2 {
|
||||
t.Errorf("Expected 2 headers, got %d", len(result.Middleware.CustomHeaders))
|
||||
}
|
||||
|
||||
if result.Middleware.CustomHeaders["X-Target-Header"] != "overridden-value" {
|
||||
t.Errorf("Expected X-Target-Header to be overridden")
|
||||
}
|
||||
|
||||
if result.Middleware.CustomHeaders["X-Source-Header"] != "source-value" {
|
||||
t.Errorf("Expected X-Source-Header to be added")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestConfigLoader_MergeStructs tests the mergeStructs function indirectly
|
||||
func TestConfigLoader_MergeStructs(t *testing.T) {
|
||||
loader := NewConfigLoader()
|
||||
|
||||
t.Run("NestedStructMerge", func(t *testing.T) {
|
||||
target := &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://target.example.com",
|
||||
ClientID: "target-client",
|
||||
},
|
||||
Session: SessionConfig{
|
||||
Name: "target-session",
|
||||
MaxAge: 3600,
|
||||
},
|
||||
}
|
||||
|
||||
source := &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
ClientID: "source-client",
|
||||
ClientSecret: "source-secret",
|
||||
},
|
||||
Session: SessionConfig{
|
||||
MaxAge: 7200,
|
||||
},
|
||||
}
|
||||
|
||||
result := loader.mergeConfigs(target, source)
|
||||
|
||||
// Provider.IssuerURL should remain (zero value in source)
|
||||
if result.Provider.IssuerURL != "https://target.example.com" {
|
||||
t.Errorf("Expected IssuerURL to remain, got %s", result.Provider.IssuerURL)
|
||||
}
|
||||
|
||||
// Provider.ClientID should be overridden
|
||||
if result.Provider.ClientID != "source-client" {
|
||||
t.Errorf("Expected ClientID to be overridden, got %s", result.Provider.ClientID)
|
||||
}
|
||||
|
||||
// Provider.ClientSecret should be added
|
||||
if result.Provider.ClientSecret != "source-secret" {
|
||||
t.Errorf("Expected ClientSecret to be added, got %s", result.Provider.ClientSecret)
|
||||
}
|
||||
|
||||
// Session.Name should remain (zero value in source)
|
||||
if result.Session.Name != "target-session" {
|
||||
t.Errorf("Expected Session.Name to remain, got %s", result.Session.Name)
|
||||
}
|
||||
|
||||
// Session.MaxAge should be overridden
|
||||
if result.Session.MaxAge != 7200 {
|
||||
t.Errorf("Expected Session.MaxAge to be overridden, got %d", result.Session.MaxAge)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestIsZeroValue tests the isZeroValue helper function
|
||||
func TestIsZeroValue(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
value interface{}
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "Zero string",
|
||||
value: "",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Non-zero string",
|
||||
value: "hello",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Zero int",
|
||||
value: 0,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Non-zero int",
|
||||
value: 42,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Zero bool",
|
||||
value: false,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Non-zero bool",
|
||||
value: true,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Nil pointer",
|
||||
value: (*string)(nil),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Non-nil pointer",
|
||||
value: stringPtr("test"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Nil slice",
|
||||
value: ([]string)(nil),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Empty slice",
|
||||
value: []string{},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Non-empty slice",
|
||||
value: []string{"a"},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Nil map",
|
||||
value: (map[string]string)(nil),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Empty map",
|
||||
value: map[string]string{},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Non-empty map",
|
||||
value: map[string]string{"key": "value"},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
v := reflect.ValueOf(tt.value)
|
||||
result := isZeroValue(v)
|
||||
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected isZeroValue to be %v, got %v", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsZeroValue_Struct tests isZeroValue with struct types
|
||||
func TestIsZeroValue_Struct(t *testing.T) {
|
||||
type TestStruct struct {
|
||||
Field1 string
|
||||
Field2 int
|
||||
}
|
||||
|
||||
t.Run("Zero struct", func(t *testing.T) {
|
||||
s := TestStruct{}
|
||||
v := reflect.ValueOf(s)
|
||||
result := isZeroValue(v)
|
||||
|
||||
if !result {
|
||||
t.Error("Expected zero struct to return true")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Non-zero struct - Field1 set", func(t *testing.T) {
|
||||
s := TestStruct{Field1: "test"}
|
||||
v := reflect.ValueOf(s)
|
||||
result := isZeroValue(v)
|
||||
|
||||
if result {
|
||||
t.Error("Expected non-zero struct to return false")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Non-zero struct - Field2 set", func(t *testing.T) {
|
||||
s := TestStruct{Field2: 42}
|
||||
v := reflect.ValueOf(s)
|
||||
result := isZeroValue(v)
|
||||
|
||||
if result {
|
||||
t.Error("Expected non-zero struct to return false")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Non-zero struct - Both fields set", func(t *testing.T) {
|
||||
s := TestStruct{Field1: "test", Field2: 42}
|
||||
v := reflect.ValueOf(s)
|
||||
result := isZeroValue(v)
|
||||
|
||||
if result {
|
||||
t.Error("Expected non-zero struct to return false")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Helper function for pointer tests
|
||||
func stringPtr(s string) *string {
|
||||
return &s
|
||||
}
|
||||
@@ -1,169 +0,0 @@
|
||||
// Package config provides unified configuration management for the OIDC middleware
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
// REDACTED is the placeholder value for sensitive information
|
||||
const REDACTED = "[REDACTED]"
|
||||
|
||||
// MarshalJSON implements custom JSON marshalling to redact sensitive fields
|
||||
func (c UnifiedConfig) MarshalJSON() ([]byte, error) {
|
||||
// Create an alias to avoid recursion
|
||||
type Alias UnifiedConfig
|
||||
|
||||
// Create a copy with redacted sensitive fields
|
||||
copy := (Alias)(c)
|
||||
|
||||
// Redact provider secrets
|
||||
if copy.Provider.ClientSecret != "" {
|
||||
copy.Provider.ClientSecret = REDACTED
|
||||
}
|
||||
|
||||
// Redact session secrets
|
||||
if copy.Session.Secret != "" {
|
||||
copy.Session.Secret = REDACTED
|
||||
}
|
||||
if copy.Session.EncryptionKey != "" {
|
||||
copy.Session.EncryptionKey = REDACTED
|
||||
}
|
||||
if copy.Session.SigningKey != "" {
|
||||
copy.Session.SigningKey = REDACTED
|
||||
}
|
||||
|
||||
// Redact Redis passwords
|
||||
if copy.Redis.Password != "" {
|
||||
copy.Redis.Password = REDACTED
|
||||
}
|
||||
if copy.Redis.SentinelPassword != "" {
|
||||
copy.Redis.SentinelPassword = REDACTED
|
||||
}
|
||||
|
||||
return json.Marshal(copy)
|
||||
}
|
||||
|
||||
// MarshalJSON for ProviderConfig to redact sensitive fields
|
||||
func (p ProviderConfig) MarshalJSON() ([]byte, error) {
|
||||
type Alias ProviderConfig
|
||||
copy := (Alias)(p)
|
||||
|
||||
if copy.ClientSecret != "" {
|
||||
copy.ClientSecret = REDACTED
|
||||
}
|
||||
|
||||
return json.Marshal(copy)
|
||||
}
|
||||
|
||||
// MarshalJSON for SessionConfig to redact sensitive fields
|
||||
func (s SessionConfig) MarshalJSON() ([]byte, error) {
|
||||
type Alias SessionConfig
|
||||
copy := (Alias)(s)
|
||||
|
||||
if copy.Secret != "" {
|
||||
copy.Secret = REDACTED
|
||||
}
|
||||
if copy.EncryptionKey != "" {
|
||||
copy.EncryptionKey = REDACTED
|
||||
}
|
||||
if copy.SigningKey != "" {
|
||||
copy.SigningKey = REDACTED
|
||||
}
|
||||
|
||||
return json.Marshal(copy)
|
||||
}
|
||||
|
||||
// MarshalJSON for RedisConfig to redact sensitive fields
|
||||
func (r RedisConfig) MarshalJSON() ([]byte, error) {
|
||||
type Alias RedisConfig
|
||||
copy := (Alias)(r)
|
||||
|
||||
if copy.Password != "" {
|
||||
copy.Password = REDACTED
|
||||
}
|
||||
if copy.SentinelPassword != "" {
|
||||
copy.SentinelPassword = REDACTED
|
||||
}
|
||||
|
||||
return json.Marshal(copy)
|
||||
}
|
||||
|
||||
// MarshalYAML implements custom YAML marshalling to redact sensitive fields
|
||||
func (c UnifiedConfig) MarshalYAML() (interface{}, error) {
|
||||
// Create an alias to avoid recursion
|
||||
type Alias UnifiedConfig
|
||||
|
||||
// Create a copy with redacted sensitive fields
|
||||
copy := (Alias)(c)
|
||||
|
||||
// Redact provider secrets
|
||||
if copy.Provider.ClientSecret != "" {
|
||||
copy.Provider.ClientSecret = REDACTED
|
||||
}
|
||||
|
||||
// Redact session secrets
|
||||
if copy.Session.Secret != "" {
|
||||
copy.Session.Secret = REDACTED
|
||||
}
|
||||
if copy.Session.EncryptionKey != "" {
|
||||
copy.Session.EncryptionKey = REDACTED
|
||||
}
|
||||
if copy.Session.SigningKey != "" {
|
||||
copy.Session.SigningKey = REDACTED
|
||||
}
|
||||
|
||||
// Redact Redis passwords
|
||||
if copy.Redis.Password != "" {
|
||||
copy.Redis.Password = REDACTED
|
||||
}
|
||||
if copy.Redis.SentinelPassword != "" {
|
||||
copy.Redis.SentinelPassword = REDACTED
|
||||
}
|
||||
|
||||
return copy, nil
|
||||
}
|
||||
|
||||
// MarshalYAML for ProviderConfig to redact sensitive fields
|
||||
func (p ProviderConfig) MarshalYAML() (interface{}, error) {
|
||||
type Alias ProviderConfig
|
||||
copy := (Alias)(p)
|
||||
|
||||
if copy.ClientSecret != "" {
|
||||
copy.ClientSecret = REDACTED
|
||||
}
|
||||
|
||||
return copy, nil
|
||||
}
|
||||
|
||||
// MarshalYAML for SessionConfig to redact sensitive fields
|
||||
func (s SessionConfig) MarshalYAML() (interface{}, error) {
|
||||
type Alias SessionConfig
|
||||
copy := (Alias)(s)
|
||||
|
||||
if copy.Secret != "" {
|
||||
copy.Secret = REDACTED
|
||||
}
|
||||
if copy.EncryptionKey != "" {
|
||||
copy.EncryptionKey = REDACTED
|
||||
}
|
||||
if copy.SigningKey != "" {
|
||||
copy.SigningKey = REDACTED
|
||||
}
|
||||
|
||||
return copy, nil
|
||||
}
|
||||
|
||||
// MarshalYAML for RedisConfig to redact sensitive fields
|
||||
func (r RedisConfig) MarshalYAML() (interface{}, error) {
|
||||
type Alias RedisConfig
|
||||
copy := (Alias)(r)
|
||||
|
||||
if copy.Password != "" {
|
||||
copy.Password = REDACTED
|
||||
}
|
||||
if copy.SentinelPassword != "" {
|
||||
copy.SentinelPassword = REDACTED
|
||||
}
|
||||
|
||||
return copy, nil
|
||||
}
|
||||
@@ -1,407 +0,0 @@
|
||||
// Package config provides configuration migration from old to new format
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/compat"
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/features"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// ConfigVersion represents the version of a configuration format
|
||||
type ConfigVersion string
|
||||
|
||||
const (
|
||||
// VersionLegacy represents the original config format
|
||||
VersionLegacy ConfigVersion = "legacy"
|
||||
|
||||
// VersionUnified represents the new unified config format
|
||||
VersionUnified ConfigVersion = "unified"
|
||||
|
||||
// CurrentVersion is the current config version
|
||||
CurrentVersion ConfigVersion = VersionUnified
|
||||
)
|
||||
|
||||
// ConfigMigrator handles migration between config versions
|
||||
type ConfigMigrator struct {
|
||||
compatLayer *compat.CompatibilityLayer
|
||||
migrations map[ConfigVersion]MigrationFunc
|
||||
}
|
||||
|
||||
// MigrationFunc defines a function that migrates configuration
|
||||
type MigrationFunc func(data map[string]interface{}) (*UnifiedConfig, error)
|
||||
|
||||
// NewConfigMigrator creates a new configuration migrator
|
||||
func NewConfigMigrator() *ConfigMigrator {
|
||||
m := &ConfigMigrator{
|
||||
compatLayer: compat.GetLayer(),
|
||||
migrations: make(map[ConfigVersion]MigrationFunc),
|
||||
}
|
||||
|
||||
// Register migration functions
|
||||
m.migrations[VersionLegacy] = m.migrateLegacyToUnified
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// DetectVersion detects the version of a configuration
|
||||
func (m *ConfigMigrator) DetectVersion(data []byte) ConfigVersion {
|
||||
var testMap map[string]interface{}
|
||||
|
||||
// Try JSON first
|
||||
if err := json.Unmarshal(data, &testMap); err != nil {
|
||||
// Try YAML
|
||||
if err := yaml.Unmarshal(data, &testMap); err != nil {
|
||||
return VersionLegacy // Default to legacy if can't parse
|
||||
}
|
||||
}
|
||||
|
||||
// Check for unified config markers
|
||||
if _, hasProvider := testMap["provider"]; hasProvider {
|
||||
if _, hasSession := testMap["session"]; hasSession {
|
||||
return VersionUnified
|
||||
}
|
||||
}
|
||||
|
||||
// Check for legacy config markers
|
||||
if _, hasProviderURL := testMap["providerUrl"]; hasProviderURL {
|
||||
return VersionLegacy
|
||||
}
|
||||
if _, hasProviderURL := testMap["ProviderURL"]; hasProviderURL {
|
||||
return VersionLegacy
|
||||
}
|
||||
|
||||
return VersionLegacy
|
||||
}
|
||||
|
||||
// Migrate migrates configuration data to the current version
|
||||
func (m *ConfigMigrator) Migrate(data []byte) (*UnifiedConfig, []string, error) {
|
||||
warnings := []string{}
|
||||
|
||||
// Detect version
|
||||
version := m.DetectVersion(data)
|
||||
|
||||
// If already current version, just unmarshal
|
||||
if version == CurrentVersion {
|
||||
var config UnifiedConfig
|
||||
if err := json.Unmarshal(data, &config); err != nil {
|
||||
// Try YAML
|
||||
if err := yaml.Unmarshal(data, &config); err != nil {
|
||||
return nil, warnings, fmt.Errorf("failed to unmarshal unified config: %w", err)
|
||||
}
|
||||
}
|
||||
return &config, warnings, nil
|
||||
}
|
||||
|
||||
// Parse to generic map
|
||||
var configMap map[string]interface{}
|
||||
if err := json.Unmarshal(data, &configMap); err != nil {
|
||||
// Try YAML
|
||||
if err := yaml.Unmarshal(data, &configMap); err != nil {
|
||||
return nil, warnings, fmt.Errorf("failed to unmarshal config: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Apply migration
|
||||
migrationFunc, exists := m.migrations[version]
|
||||
if !exists {
|
||||
return nil, warnings, fmt.Errorf("no migration path from version %s", version)
|
||||
}
|
||||
|
||||
config, err := migrationFunc(configMap)
|
||||
if err != nil {
|
||||
return nil, warnings, fmt.Errorf("migration failed: %w", err)
|
||||
}
|
||||
|
||||
// Collect any deprecation warnings
|
||||
for key := range configMap {
|
||||
if warning, deprecated := m.compatLayer.CheckDeprecation(key); deprecated {
|
||||
warnings = append(warnings, warning)
|
||||
}
|
||||
}
|
||||
|
||||
return config, warnings, nil
|
||||
}
|
||||
|
||||
// migrateLegacyToUnified migrates legacy config to unified format
|
||||
func (m *ConfigMigrator) migrateLegacyToUnified(data map[string]interface{}) (*UnifiedConfig, error) {
|
||||
config := NewUnifiedConfig()
|
||||
|
||||
// Use compatibility layer for field mapping
|
||||
migratedMap, warnings := m.compatLayer.MigrateMap(data)
|
||||
|
||||
// Log warnings
|
||||
for _, warning := range warnings {
|
||||
// In production, these would be logged
|
||||
_ = warning
|
||||
}
|
||||
|
||||
// Map provider configuration
|
||||
if provider, ok := getNestedMap(migratedMap, "Provider"); ok {
|
||||
_ = mapToStruct(provider, &config.Provider)
|
||||
} else {
|
||||
// Direct field mapping for legacy format
|
||||
config.Provider.IssuerURL = getStringValue(data, "providerUrl", "ProviderURL")
|
||||
config.Provider.ClientID = getStringValue(data, "clientId", "ClientID")
|
||||
config.Provider.ClientSecret = getStringValue(data, "clientSecret", "ClientSecret")
|
||||
config.Provider.RedirectURL = getStringValue(data, "callbackUrl", "CallbackURL")
|
||||
config.Provider.LogoutURL = getStringValue(data, "logoutUrl", "LogoutURL")
|
||||
config.Provider.PostLogoutRedirectURI = getStringValue(data, "postLogoutRedirectUri", "PostLogoutRedirectURI")
|
||||
|
||||
if scopes := getArrayValue(data, "scopes", "Scopes"); scopes != nil {
|
||||
config.Provider.Scopes = scopes
|
||||
}
|
||||
config.Provider.OverrideScopes = getBoolValue(data, "overrideScopes", "OverrideScopes")
|
||||
}
|
||||
|
||||
// Map session configuration
|
||||
if session, ok := getNestedMap(migratedMap, "Session"); ok {
|
||||
_ = mapToStruct(session, &config.Session)
|
||||
} else {
|
||||
config.Session.EncryptionKey = getStringValue(data, "sessionEncryptionKey", "SessionEncryptionKey")
|
||||
config.Session.Domain = getStringValue(data, "cookieDomain", "CookieDomain")
|
||||
}
|
||||
|
||||
// Map security configuration
|
||||
if security, ok := getNestedMap(migratedMap, "Security"); ok {
|
||||
_ = mapToStruct(security, &config.Security)
|
||||
} else {
|
||||
config.Security.ForceHTTPS = getBoolValue(data, "forceHttps", "ForceHTTPS")
|
||||
config.Security.EnablePKCE = getBoolValue(data, "enablePkce", "EnablePKCE")
|
||||
|
||||
if users := getArrayValue(data, "allowedUsers", "AllowedUsers"); users != nil {
|
||||
config.Security.AllowedUsers = users
|
||||
}
|
||||
if domains := getArrayValue(data, "allowedUserDomains", "AllowedUserDomains"); domains != nil {
|
||||
config.Security.AllowedUserDomains = domains
|
||||
}
|
||||
if roles := getArrayValue(data, "allowedRolesAndGroups", "AllowedRolesAndGroups"); roles != nil {
|
||||
config.Security.AllowedRolesAndGroups = roles
|
||||
}
|
||||
if excluded := getArrayValue(data, "excludedUrls", "ExcludedURLs"); excluded != nil {
|
||||
config.Security.ExcludedURLs = excluded
|
||||
}
|
||||
|
||||
// Handle security headers
|
||||
if headers := data["securityHeaders"]; headers != nil {
|
||||
// Security headers might be in old format
|
||||
_ = mapToStruct(headers, &config.Security.Headers)
|
||||
}
|
||||
}
|
||||
|
||||
// Map rate limiting
|
||||
if rateLimit := getIntValue(data, "rateLimit", "RateLimit"); rateLimit > 0 {
|
||||
config.RateLimit.Enabled = true
|
||||
config.RateLimit.RequestsPerSecond = rateLimit
|
||||
config.RateLimit.Burst = rateLimit * 2 // Default burst to 2x rate
|
||||
}
|
||||
|
||||
// Map token configuration
|
||||
if refreshGrace := getIntValue(data, "refreshGracePeriodSeconds", "RefreshGracePeriodSeconds"); refreshGrace > 0 {
|
||||
config.Token.RefreshGracePeriod = time.Duration(refreshGrace) * time.Second
|
||||
}
|
||||
|
||||
// Map logging
|
||||
config.Logging.Level = strings.ToLower(getStringValue(data, "logLevel", "LogLevel"))
|
||||
if config.Logging.Level == "" {
|
||||
config.Logging.Level = "info"
|
||||
}
|
||||
|
||||
// Map custom headers
|
||||
if headers := data["headers"]; headers != nil {
|
||||
if headerList, ok := headers.([]interface{}); ok {
|
||||
config.Middleware.CustomHeaders = make(map[string]string)
|
||||
for _, h := range headerList {
|
||||
if headerMap, ok := h.(map[string]interface{}); ok {
|
||||
name := getStringFromInterface(headerMap["name"])
|
||||
value := getStringFromInterface(headerMap["value"])
|
||||
if name != "" {
|
||||
config.Middleware.CustomHeaders[name] = value
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Store original data for reference
|
||||
config.Legacy = data
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// MigrateFile migrates a configuration file
|
||||
func (m *ConfigMigrator) MigrateFile(filePath string) (*UnifiedConfig, error) {
|
||||
// Clean and validate path to prevent traversal attacks
|
||||
cleanPath := filepath.Clean(filePath)
|
||||
|
||||
// Check for path traversal attempts
|
||||
if strings.Contains(cleanPath, "..") {
|
||||
return nil, fmt.Errorf("invalid config path: potential path traversal detected in %s", filePath)
|
||||
}
|
||||
|
||||
// Ensure the path is within expected directories
|
||||
absPath, err := filepath.Abs(cleanPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to resolve absolute path for %s: %w", filePath, err)
|
||||
}
|
||||
|
||||
// Read the file with validated path
|
||||
data, err := os.ReadFile(absPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read config file: %w", err)
|
||||
}
|
||||
|
||||
config, warnings, err := m.Migrate(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Log warnings
|
||||
for _, warning := range warnings {
|
||||
fmt.Printf("Migration Warning: %s\n", warning)
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// AutoMigrate automatically migrates config based on feature flags
|
||||
func AutoMigrate(data interface{}) (*UnifiedConfig, error) {
|
||||
if !features.IsUnifiedConfigEnabled() {
|
||||
// Feature not enabled, return nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
migrator := NewConfigMigrator()
|
||||
|
||||
// Handle different input types
|
||||
switch v := data.(type) {
|
||||
case []byte:
|
||||
config, _, err := migrator.Migrate(v)
|
||||
return config, err
|
||||
case string:
|
||||
config, _, err := migrator.Migrate([]byte(v))
|
||||
return config, err
|
||||
case *Config:
|
||||
// Convert old config to unified
|
||||
return FromOldConfig(v), nil
|
||||
case *UnifiedConfig:
|
||||
// Already unified
|
||||
return v, nil
|
||||
case map[string]interface{}:
|
||||
// Convert map to JSON then migrate
|
||||
jsonData, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
config, _, err := migrator.Migrate(jsonData)
|
||||
return config, err
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported config type: %T", v)
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
func getNestedMap(m map[string]interface{}, key string) (map[string]interface{}, bool) {
|
||||
if val, exists := m[key]; exists {
|
||||
if mapped, ok := val.(map[string]interface{}); ok {
|
||||
return mapped, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func getStringValue(m map[string]interface{}, keys ...string) string {
|
||||
for _, key := range keys {
|
||||
if val, exists := m[key]; exists {
|
||||
return getStringFromInterface(val)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func getStringFromInterface(val interface{}) string {
|
||||
if val == nil {
|
||||
return ""
|
||||
}
|
||||
switch v := val.(type) {
|
||||
case string:
|
||||
return v
|
||||
case []byte:
|
||||
return string(v)
|
||||
default:
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
}
|
||||
|
||||
func getBoolValue(m map[string]interface{}, keys ...string) bool {
|
||||
for _, key := range keys {
|
||||
if val, exists := m[key]; exists {
|
||||
if b, ok := val.(bool); ok {
|
||||
return b
|
||||
}
|
||||
// Try string conversion
|
||||
if s, ok := val.(string); ok {
|
||||
return strings.ToLower(s) == "true"
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func getIntValue(m map[string]interface{}, keys ...string) int {
|
||||
for _, key := range keys {
|
||||
if val, exists := m[key]; exists {
|
||||
switch v := val.(type) {
|
||||
case int:
|
||||
return v
|
||||
case int64:
|
||||
return int(v)
|
||||
case float64:
|
||||
return int(v)
|
||||
case string:
|
||||
// Try to parse
|
||||
var i int
|
||||
if _, err := fmt.Sscanf(v, "%d", &i); err != nil {
|
||||
// If parsing fails, return default
|
||||
return 0
|
||||
}
|
||||
return i
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func getArrayValue(m map[string]interface{}, keys ...string) []string {
|
||||
for _, key := range keys {
|
||||
if val, exists := m[key]; exists {
|
||||
if arr, ok := val.([]interface{}); ok {
|
||||
result := make([]string, 0, len(arr))
|
||||
for _, item := range arr {
|
||||
result = append(result, getStringFromInterface(item))
|
||||
}
|
||||
return result
|
||||
}
|
||||
if strArr, ok := val.([]string); ok {
|
||||
return strArr
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func mapToStruct(m interface{}, target interface{}) error {
|
||||
// Simple mapping using JSON as intermediate
|
||||
data, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return json.Unmarshal(data, target)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,297 +0,0 @@
|
||||
// Package config provides configuration structures for the Traefik OIDC plugin.
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// RedisMode represents the Redis deployment mode
|
||||
type RedisMode string
|
||||
|
||||
const (
|
||||
// RedisModeStandalone represents a single Redis instance
|
||||
RedisModeStandalone RedisMode = "standalone"
|
||||
|
||||
// RedisModeCluster represents Redis cluster mode
|
||||
RedisModeCluster RedisMode = "cluster"
|
||||
|
||||
// RedisModeSentinel represents Redis sentinel mode
|
||||
RedisModeSentinel RedisMode = "sentinel"
|
||||
)
|
||||
|
||||
// RedisConfig holds Redis cache backend configuration
|
||||
type RedisConfig struct {
|
||||
// Enabled indicates if Redis backend should be used
|
||||
Enabled bool `json:"enabled,omitempty" yaml:"enabled,omitempty"`
|
||||
|
||||
// Mode specifies the Redis deployment mode
|
||||
Mode RedisMode `json:"mode,omitempty" yaml:"mode,omitempty"`
|
||||
|
||||
// === Standalone Configuration ===
|
||||
// Addr is the Redis server address (host:port)
|
||||
Addr string `json:"addr,omitempty" yaml:"addr,omitempty"`
|
||||
|
||||
// Password for Redis authentication
|
||||
Password string `json:"password,omitempty" yaml:"password,omitempty"`
|
||||
|
||||
// DB is the database number (0-15)
|
||||
DB int `json:"db,omitempty" yaml:"db,omitempty"`
|
||||
|
||||
// === Cluster Configuration ===
|
||||
// ClusterAddrs is the list of cluster node addresses
|
||||
ClusterAddrs []string `json:"clusterAddrs,omitempty" yaml:"clusterAddrs,omitempty"`
|
||||
|
||||
// === Sentinel Configuration ===
|
||||
// MasterName is the name of the master instance
|
||||
MasterName string `json:"masterName,omitempty" yaml:"masterName,omitempty"`
|
||||
|
||||
// SentinelAddrs is the list of sentinel addresses
|
||||
SentinelAddrs []string `json:"sentinelAddrs,omitempty" yaml:"sentinelAddrs,omitempty"`
|
||||
|
||||
// SentinelPassword is the password for sentinel authentication
|
||||
SentinelPassword string `json:"sentinelPassword,omitempty" yaml:"sentinelPassword,omitempty"`
|
||||
|
||||
// === Connection Pool Settings ===
|
||||
// PoolSize is the maximum number of socket connections
|
||||
PoolSize int `json:"poolSize,omitempty" yaml:"poolSize,omitempty"`
|
||||
|
||||
// MinIdleConns is the minimum number of idle connections
|
||||
MinIdleConns int `json:"minIdleConns,omitempty" yaml:"minIdleConns,omitempty"`
|
||||
|
||||
// MaxRetries is the maximum number of retries before giving up
|
||||
MaxRetries int `json:"maxRetries,omitempty" yaml:"maxRetries,omitempty"`
|
||||
|
||||
// === Timeouts ===
|
||||
// DialTimeout is the timeout for establishing new connections
|
||||
DialTimeout time.Duration `json:"dialTimeout,omitempty" yaml:"dialTimeout,omitempty"`
|
||||
|
||||
// ReadTimeout is the timeout for socket reads
|
||||
ReadTimeout time.Duration `json:"readTimeout,omitempty" yaml:"readTimeout,omitempty"`
|
||||
|
||||
// WriteTimeout is the timeout for socket writes
|
||||
WriteTimeout time.Duration `json:"writeTimeout,omitempty" yaml:"writeTimeout,omitempty"`
|
||||
|
||||
// PoolTimeout is the timeout for connection pool
|
||||
PoolTimeout time.Duration `json:"poolTimeout,omitempty" yaml:"poolTimeout,omitempty"`
|
||||
|
||||
// ConnMaxIdleTime is the maximum amount of time a connection may be idle
|
||||
ConnMaxIdleTime time.Duration `json:"connMaxIdleTime,omitempty" yaml:"connMaxIdleTime,omitempty"`
|
||||
|
||||
// ConnMaxLifetime is the maximum lifetime of a connection
|
||||
ConnMaxLifetime time.Duration `json:"connMaxLifetime,omitempty" yaml:"connMaxLifetime,omitempty"`
|
||||
|
||||
// === Key Management ===
|
||||
// KeyPrefix is the prefix for all Redis keys
|
||||
KeyPrefix string `json:"keyPrefix,omitempty" yaml:"keyPrefix,omitempty"`
|
||||
|
||||
// === TLS Configuration ===
|
||||
// TLSEnabled enables TLS for Redis connections
|
||||
TLSEnabled bool `json:"tlsEnabled,omitempty" yaml:"tlsEnabled,omitempty"`
|
||||
|
||||
// TLSInsecureSkipVerify skips TLS certificate verification
|
||||
TLSInsecureSkipVerify bool `json:"tlsInsecureSkipVerify,omitempty" yaml:"tlsInsecureSkipVerify,omitempty"`
|
||||
|
||||
// === Resilience Settings ===
|
||||
// EnableCircuitBreaker enables circuit breaker for Redis operations
|
||||
EnableCircuitBreaker bool `json:"enableCircuitBreaker,omitempty" yaml:"enableCircuitBreaker,omitempty"`
|
||||
|
||||
// CircuitBreakerMaxFailures is the number of failures before opening circuit
|
||||
CircuitBreakerMaxFailures int `json:"circuitBreakerMaxFailures,omitempty" yaml:"circuitBreakerMaxFailures,omitempty"`
|
||||
|
||||
// CircuitBreakerTimeout is how long the circuit stays open
|
||||
CircuitBreakerTimeout time.Duration `json:"circuitBreakerTimeout,omitempty" yaml:"circuitBreakerTimeout,omitempty"`
|
||||
|
||||
// EnableHealthCheck enables periodic health checks
|
||||
EnableHealthCheck bool `json:"enableHealthCheck,omitempty" yaml:"enableHealthCheck,omitempty"`
|
||||
|
||||
// HealthCheckInterval is how often to check Redis health
|
||||
HealthCheckInterval time.Duration `json:"healthCheckInterval,omitempty" yaml:"healthCheckInterval,omitempty"`
|
||||
}
|
||||
|
||||
// DefaultRedisConfig returns default Redis configuration
|
||||
func DefaultRedisConfig() *RedisConfig {
|
||||
return &RedisConfig{
|
||||
Enabled: false,
|
||||
Mode: RedisModeStandalone,
|
||||
Addr: "localhost:6379",
|
||||
DB: 0,
|
||||
PoolSize: 10,
|
||||
MinIdleConns: 2,
|
||||
MaxRetries: 3,
|
||||
DialTimeout: 5 * time.Second,
|
||||
ReadTimeout: 3 * time.Second,
|
||||
WriteTimeout: 3 * time.Second,
|
||||
PoolTimeout: 4 * time.Second,
|
||||
ConnMaxIdleTime: 5 * time.Minute,
|
||||
ConnMaxLifetime: 30 * time.Minute,
|
||||
KeyPrefix: "traefikoidc:",
|
||||
TLSEnabled: false,
|
||||
TLSInsecureSkipVerify: false,
|
||||
EnableCircuitBreaker: true,
|
||||
CircuitBreakerMaxFailures: 5,
|
||||
CircuitBreakerTimeout: 30 * time.Second,
|
||||
EnableHealthCheck: true,
|
||||
HealthCheckInterval: 30 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// LoadFromEnv loads Redis configuration from environment variables
|
||||
func (c *RedisConfig) LoadFromEnv() {
|
||||
// Enable Redis if environment variable is set
|
||||
if enabled := os.Getenv("REDIS_ENABLED"); enabled != "" {
|
||||
c.Enabled = strings.ToLower(enabled) == "true"
|
||||
}
|
||||
|
||||
// Mode
|
||||
if mode := os.Getenv("REDIS_MODE"); mode != "" {
|
||||
c.Mode = RedisMode(strings.ToLower(mode))
|
||||
}
|
||||
|
||||
// Standalone configuration
|
||||
if addr := os.Getenv("REDIS_ADDR"); addr != "" {
|
||||
c.Addr = addr
|
||||
}
|
||||
if password := os.Getenv("REDIS_PASSWORD"); password != "" {
|
||||
c.Password = password
|
||||
}
|
||||
if db := os.Getenv("REDIS_DB"); db != "" {
|
||||
if dbNum, err := strconv.Atoi(db); err == nil {
|
||||
c.DB = dbNum
|
||||
}
|
||||
}
|
||||
|
||||
// Cluster configuration
|
||||
if clusterAddrs := os.Getenv("REDIS_CLUSTER_ADDRS"); clusterAddrs != "" {
|
||||
c.ClusterAddrs = strings.Split(clusterAddrs, ",")
|
||||
for i := range c.ClusterAddrs {
|
||||
c.ClusterAddrs[i] = strings.TrimSpace(c.ClusterAddrs[i])
|
||||
}
|
||||
}
|
||||
|
||||
// Sentinel configuration
|
||||
if masterName := os.Getenv("REDIS_MASTER_NAME"); masterName != "" {
|
||||
c.MasterName = masterName
|
||||
}
|
||||
if sentinelAddrs := os.Getenv("REDIS_SENTINEL_ADDRS"); sentinelAddrs != "" {
|
||||
c.SentinelAddrs = strings.Split(sentinelAddrs, ",")
|
||||
for i := range c.SentinelAddrs {
|
||||
c.SentinelAddrs[i] = strings.TrimSpace(c.SentinelAddrs[i])
|
||||
}
|
||||
}
|
||||
if sentinelPassword := os.Getenv("REDIS_SENTINEL_PASSWORD"); sentinelPassword != "" {
|
||||
c.SentinelPassword = sentinelPassword
|
||||
}
|
||||
|
||||
// Connection pool settings
|
||||
if poolSize := os.Getenv("REDIS_POOL_SIZE"); poolSize != "" {
|
||||
if size, err := strconv.Atoi(poolSize); err == nil {
|
||||
c.PoolSize = size
|
||||
}
|
||||
}
|
||||
if minIdleConns := os.Getenv("REDIS_MIN_IDLE_CONNS"); minIdleConns != "" {
|
||||
if conns, err := strconv.Atoi(minIdleConns); err == nil {
|
||||
c.MinIdleConns = conns
|
||||
}
|
||||
}
|
||||
if maxRetries := os.Getenv("REDIS_MAX_RETRIES"); maxRetries != "" {
|
||||
if retries, err := strconv.Atoi(maxRetries); err == nil {
|
||||
c.MaxRetries = retries
|
||||
}
|
||||
}
|
||||
|
||||
// Timeouts
|
||||
if dialTimeout := os.Getenv("REDIS_DIAL_TIMEOUT"); dialTimeout != "" {
|
||||
if timeout, err := time.ParseDuration(dialTimeout); err == nil {
|
||||
c.DialTimeout = timeout
|
||||
}
|
||||
}
|
||||
if readTimeout := os.Getenv("REDIS_READ_TIMEOUT"); readTimeout != "" {
|
||||
if timeout, err := time.ParseDuration(readTimeout); err == nil {
|
||||
c.ReadTimeout = timeout
|
||||
}
|
||||
}
|
||||
if writeTimeout := os.Getenv("REDIS_WRITE_TIMEOUT"); writeTimeout != "" {
|
||||
if timeout, err := time.ParseDuration(writeTimeout); err == nil {
|
||||
c.WriteTimeout = timeout
|
||||
}
|
||||
}
|
||||
|
||||
// Key prefix
|
||||
if keyPrefix := os.Getenv("REDIS_KEY_PREFIX"); keyPrefix != "" {
|
||||
c.KeyPrefix = keyPrefix
|
||||
}
|
||||
|
||||
// TLS settings
|
||||
if tlsEnabled := os.Getenv("REDIS_TLS_ENABLED"); tlsEnabled != "" {
|
||||
c.TLSEnabled = strings.ToLower(tlsEnabled) == "true"
|
||||
}
|
||||
if tlsInsecure := os.Getenv("REDIS_TLS_INSECURE_SKIP_VERIFY"); tlsInsecure != "" {
|
||||
c.TLSInsecureSkipVerify = strings.ToLower(tlsInsecure) == "true"
|
||||
}
|
||||
|
||||
// Resilience settings
|
||||
if enableCB := os.Getenv("REDIS_ENABLE_CIRCUIT_BREAKER"); enableCB != "" {
|
||||
c.EnableCircuitBreaker = strings.ToLower(enableCB) == "true"
|
||||
}
|
||||
if cbMaxFailures := os.Getenv("REDIS_CIRCUIT_BREAKER_MAX_FAILURES"); cbMaxFailures != "" {
|
||||
if failures, err := strconv.Atoi(cbMaxFailures); err == nil {
|
||||
c.CircuitBreakerMaxFailures = failures
|
||||
}
|
||||
}
|
||||
if cbTimeout := os.Getenv("REDIS_CIRCUIT_BREAKER_TIMEOUT"); cbTimeout != "" {
|
||||
if timeout, err := time.ParseDuration(cbTimeout); err == nil {
|
||||
c.CircuitBreakerTimeout = timeout
|
||||
}
|
||||
}
|
||||
if enableHC := os.Getenv("REDIS_ENABLE_HEALTH_CHECK"); enableHC != "" {
|
||||
c.EnableHealthCheck = strings.ToLower(enableHC) == "true"
|
||||
}
|
||||
if hcInterval := os.Getenv("REDIS_HEALTH_CHECK_INTERVAL"); hcInterval != "" {
|
||||
if interval, err := time.ParseDuration(hcInterval); err == nil {
|
||||
c.HealthCheckInterval = interval
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Validate checks if the configuration is valid
|
||||
func (c *RedisConfig) Validate() error {
|
||||
if !c.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch c.Mode {
|
||||
case RedisModeStandalone:
|
||||
if c.Addr == "" {
|
||||
return &ConfigError{Field: "addr", Message: "Redis address is required for standalone mode"}
|
||||
}
|
||||
case RedisModeCluster:
|
||||
if len(c.ClusterAddrs) == 0 {
|
||||
return &ConfigError{Field: "clusterAddrs", Message: "At least one cluster address is required"}
|
||||
}
|
||||
case RedisModeSentinel:
|
||||
if c.MasterName == "" {
|
||||
return &ConfigError{Field: "masterName", Message: "Master name is required for sentinel mode"}
|
||||
}
|
||||
if len(c.SentinelAddrs) == 0 {
|
||||
return &ConfigError{Field: "sentinelAddrs", Message: "At least one sentinel address is required"}
|
||||
}
|
||||
default:
|
||||
return &ConfigError{Field: "mode", Message: "Invalid Redis mode"}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ConfigError represents a configuration validation error
|
||||
type ConfigError struct {
|
||||
Field string
|
||||
Message string
|
||||
}
|
||||
|
||||
// Error implements the error interface
|
||||
func (e *ConfigError) Error() string {
|
||||
return "redis config error: " + e.Field + ": " + e.Message
|
||||
}
|
||||
@@ -1,511 +0,0 @@
|
||||
// Package config provides configuration management for the OIDC middleware
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
minEncryptionKeyLength = 16
|
||||
ConstSessionTimeout = 86400
|
||||
)
|
||||
|
||||
//lint:ignore U1000 May be referenced for default exclusion patterns
|
||||
var defaultExcludedURLs = map[string]struct{}{
|
||||
"/favicon.ico": {},
|
||||
"/robots.txt": {},
|
||||
"/health": {},
|
||||
"/.well-known/": {},
|
||||
"/metrics": {},
|
||||
"/ping": {},
|
||||
"/api/": {},
|
||||
"/static/": {},
|
||||
"/assets/": {},
|
||||
"/js/": {},
|
||||
"/css/": {},
|
||||
"/images/": {},
|
||||
"/fonts/": {},
|
||||
}
|
||||
|
||||
// Settings manages configuration and initialization for the OIDC middleware
|
||||
type Settings struct {
|
||||
logger Logger
|
||||
}
|
||||
|
||||
// Logger interface for dependency injection
|
||||
type Logger interface {
|
||||
Debug(msg string)
|
||||
Debugf(format string, args ...interface{})
|
||||
Info(msg string)
|
||||
Infof(format string, args ...interface{})
|
||||
Error(msg string)
|
||||
Errorf(format string, args ...interface{})
|
||||
}
|
||||
|
||||
// Config represents the configuration for the OIDC middleware
|
||||
type Config struct {
|
||||
ProviderURL string `json:"providerUrl"`
|
||||
ClientID string `json:"clientId"`
|
||||
ClientSecret string `json:"clientSecret"`
|
||||
CallbackURL string `json:"callbackUrl"`
|
||||
LogoutURL string `json:"logoutUrl"`
|
||||
PostLogoutRedirectURI string `json:"postLogoutRedirectUri"`
|
||||
SessionEncryptionKey string `json:"sessionEncryptionKey"`
|
||||
ForceHTTPS bool `json:"forceHttps"`
|
||||
LogLevel string `json:"logLevel"`
|
||||
Scopes []string `json:"scopes"`
|
||||
OverrideScopes bool `json:"overrideScopes"`
|
||||
AllowedUsers []string `json:"allowedUsers"`
|
||||
AllowedUserDomains []string `json:"allowedUserDomains"`
|
||||
AllowedRolesAndGroups []string `json:"allowedRolesAndGroups"`
|
||||
ExcludedURLs []string `json:"excludedUrls"`
|
||||
EnablePKCE bool `json:"enablePkce"`
|
||||
RateLimit int `json:"rateLimit"`
|
||||
RefreshGracePeriodSeconds int `json:"refreshGracePeriodSeconds"`
|
||||
Headers []HeaderConfig `json:"headers"`
|
||||
HTTPClient *http.Client `json:"-"`
|
||||
CookieDomain string `json:"cookieDomain"`
|
||||
SecurityHeaders *SecurityHeadersConfig `json:"securityHeaders,omitempty"`
|
||||
|
||||
// Dynamic Client Registration (RFC 7591) configuration
|
||||
DynamicClientRegistration *DynamicClientRegistrationConfig `json:"dynamicClientRegistration,omitempty"`
|
||||
}
|
||||
|
||||
// DynamicClientRegistrationConfig configures OIDC Dynamic Client Registration (RFC 7591)
|
||||
type DynamicClientRegistrationConfig struct {
|
||||
// Enabled enables automatic client registration with the OIDC provider
|
||||
Enabled bool `json:"enabled"`
|
||||
|
||||
// InitialAccessToken is an optional bearer token for protected registration endpoints
|
||||
// Some providers require this token to authorize new client registrations
|
||||
InitialAccessToken string `json:"initialAccessToken,omitempty"`
|
||||
|
||||
// RegistrationEndpoint overrides the endpoint discovered from provider metadata
|
||||
// If empty, uses the registration_endpoint from .well-known/openid-configuration
|
||||
RegistrationEndpoint string `json:"registrationEndpoint,omitempty"`
|
||||
|
||||
// ClientMetadata contains the client metadata to register
|
||||
ClientMetadata *ClientRegistrationMetadata `json:"clientMetadata,omitempty"`
|
||||
|
||||
// PersistCredentials determines whether to save registered credentials to a file
|
||||
// This allows reusing the same client_id/client_secret across restarts
|
||||
PersistCredentials bool `json:"persistCredentials"`
|
||||
|
||||
// CredentialsFile is the path to store/load registered client credentials
|
||||
// Defaults to "/tmp/oidc-client-credentials.json" if not specified
|
||||
CredentialsFile string `json:"credentialsFile,omitempty"`
|
||||
}
|
||||
|
||||
// ClientRegistrationMetadata contains client metadata for dynamic registration (RFC 7591)
|
||||
type ClientRegistrationMetadata struct {
|
||||
// RedirectURIs is REQUIRED - array of redirect URIs for authorization
|
||||
RedirectURIs []string `json:"redirect_uris"`
|
||||
|
||||
// ResponseTypes specifies OAuth 2.0 response types (default: ["code"])
|
||||
ResponseTypes []string `json:"response_types,omitempty"`
|
||||
|
||||
// GrantTypes specifies OAuth 2.0 grant types (default: ["authorization_code"])
|
||||
GrantTypes []string `json:"grant_types,omitempty"`
|
||||
|
||||
// ApplicationType is either "web" (default) or "native"
|
||||
ApplicationType string `json:"application_type,omitempty"`
|
||||
|
||||
// Contacts is an array of email addresses for responsible parties
|
||||
Contacts []string `json:"contacts,omitempty"`
|
||||
|
||||
// ClientName is a human-readable name for the client
|
||||
ClientName string `json:"client_name,omitempty"`
|
||||
|
||||
// LogoURI is a URL pointing to a logo for the client
|
||||
LogoURI string `json:"logo_uri,omitempty"`
|
||||
|
||||
// ClientURI is a URL of the home page of the client
|
||||
ClientURI string `json:"client_uri,omitempty"`
|
||||
|
||||
// PolicyURI is a URL pointing to the client's privacy policy
|
||||
PolicyURI string `json:"policy_uri,omitempty"`
|
||||
|
||||
// TOSURI is a URL pointing to the client's terms of service
|
||||
TOSURI string `json:"tos_uri,omitempty"`
|
||||
|
||||
// JWKSURI is a URL for the client's JSON Web Key Set
|
||||
JWKSURI string `json:"jwks_uri,omitempty"`
|
||||
|
||||
// SubjectType is "pairwise" or "public" (provider-specific)
|
||||
SubjectType string `json:"subject_type,omitempty"`
|
||||
|
||||
// TokenEndpointAuthMethod specifies how the client authenticates at token endpoint
|
||||
// Values: "client_secret_basic", "client_secret_post", "client_secret_jwt", "private_key_jwt", "none"
|
||||
TokenEndpointAuthMethod string `json:"token_endpoint_auth_method,omitempty"`
|
||||
|
||||
// DefaultMaxAge is the default maximum authentication age in seconds
|
||||
DefaultMaxAge int `json:"default_max_age,omitempty"`
|
||||
|
||||
// RequireAuthTime specifies whether auth_time claim is required in ID token
|
||||
RequireAuthTime bool `json:"require_auth_time,omitempty"`
|
||||
|
||||
// DefaultACRValues specifies default ACR values
|
||||
DefaultACRValues []string `json:"default_acr_values,omitempty"`
|
||||
|
||||
// Scope is a space-separated list of scopes (alternative to config.Scopes)
|
||||
Scope string `json:"scope,omitempty"`
|
||||
}
|
||||
|
||||
// HeaderConfig represents header template configuration
|
||||
type HeaderConfig struct {
|
||||
Name string `json:"name"`
|
||||
Value string `json:"value"`
|
||||
}
|
||||
|
||||
// SecurityHeadersConfig configures security headers for the plugin
|
||||
type SecurityHeadersConfig struct {
|
||||
// Enable security headers (default: true)
|
||||
Enabled bool `json:"enabled"`
|
||||
|
||||
// Security profile: "default", "strict", "development", "api", or "custom"
|
||||
Profile string `json:"profile"`
|
||||
|
||||
// Content Security Policy
|
||||
ContentSecurityPolicy string `json:"contentSecurityPolicy,omitempty"`
|
||||
|
||||
// HSTS settings
|
||||
StrictTransportSecurity bool `json:"strictTransportSecurity"`
|
||||
StrictTransportSecurityMaxAge int `json:"strictTransportSecurityMaxAge"` // seconds
|
||||
StrictTransportSecuritySubdomains bool `json:"strictTransportSecuritySubdomains"`
|
||||
StrictTransportSecurityPreload bool `json:"strictTransportSecurityPreload"`
|
||||
|
||||
// Frame options: "DENY", "SAMEORIGIN", or "ALLOW-FROM uri"
|
||||
FrameOptions string `json:"frameOptions,omitempty"`
|
||||
|
||||
// Content type options (default: "nosniff")
|
||||
ContentTypeOptions string `json:"contentTypeOptions,omitempty"`
|
||||
|
||||
// XSS protection (default: "1; mode=block")
|
||||
XSSProtection string `json:"xssProtection,omitempty"`
|
||||
|
||||
// Referrer policy
|
||||
ReferrerPolicy string `json:"referrerPolicy,omitempty"`
|
||||
|
||||
// Permissions policy
|
||||
PermissionsPolicy string `json:"permissionsPolicy,omitempty"`
|
||||
|
||||
// Cross-origin settings
|
||||
CrossOriginEmbedderPolicy string `json:"crossOriginEmbedderPolicy,omitempty"`
|
||||
CrossOriginOpenerPolicy string `json:"crossOriginOpenerPolicy,omitempty"`
|
||||
CrossOriginResourcePolicy string `json:"crossOriginResourcePolicy,omitempty"`
|
||||
|
||||
// CORS settings
|
||||
CORSEnabled bool `json:"corsEnabled"`
|
||||
CORSAllowedOrigins []string `json:"corsAllowedOrigins,omitempty"`
|
||||
CORSAllowedMethods []string `json:"corsAllowedMethods,omitempty"`
|
||||
CORSAllowedHeaders []string `json:"corsAllowedHeaders,omitempty"`
|
||||
CORSAllowCredentials bool `json:"corsAllowCredentials"`
|
||||
CORSMaxAge int `json:"corsMaxAge"` // seconds
|
||||
|
||||
// Custom headers (in addition to standard security headers)
|
||||
CustomHeaders map[string]string `json:"customHeaders,omitempty"`
|
||||
|
||||
// Security features
|
||||
DisableServerHeader bool `json:"disableServerHeader"`
|
||||
DisablePoweredByHeader bool `json:"disablePoweredByHeader"`
|
||||
}
|
||||
|
||||
// NewSettings creates a new Settings instance
|
||||
func NewSettings(logger Logger) *Settings {
|
||||
return &Settings{
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateConfig creates a default configuration
|
||||
func CreateConfig() *Config {
|
||||
return &Config{
|
||||
LogLevel: "INFO",
|
||||
ForceHTTPS: true,
|
||||
EnablePKCE: true,
|
||||
RateLimit: 10,
|
||||
RefreshGracePeriodSeconds: 60,
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
Headers: []HeaderConfig{},
|
||||
SecurityHeaders: createDefaultSecurityConfig(),
|
||||
}
|
||||
}
|
||||
|
||||
// createDefaultSecurityConfig creates a default security headers configuration
|
||||
func createDefaultSecurityConfig() *SecurityHeadersConfig {
|
||||
return &SecurityHeadersConfig{
|
||||
Enabled: true,
|
||||
Profile: "default",
|
||||
|
||||
// Default security headers
|
||||
StrictTransportSecurity: true,
|
||||
StrictTransportSecurityMaxAge: 31536000, // 1 year
|
||||
StrictTransportSecuritySubdomains: true,
|
||||
StrictTransportSecurityPreload: true,
|
||||
|
||||
FrameOptions: "DENY",
|
||||
ContentTypeOptions: "nosniff",
|
||||
XSSProtection: "1; mode=block",
|
||||
ReferrerPolicy: "strict-origin-when-cross-origin",
|
||||
|
||||
// CORS disabled by default
|
||||
CORSEnabled: false,
|
||||
CORSAllowedMethods: []string{"GET", "POST", "OPTIONS"},
|
||||
CORSAllowedHeaders: []string{"Authorization", "Content-Type"},
|
||||
CORSAllowCredentials: false,
|
||||
CORSMaxAge: 86400, // 24 hours
|
||||
|
||||
// Security features
|
||||
DisableServerHeader: true,
|
||||
DisablePoweredByHeader: true,
|
||||
}
|
||||
}
|
||||
|
||||
// ToInternalSecurityConfig converts plugin SecurityHeadersConfig to internal security config
|
||||
func (c *SecurityHeadersConfig) ToInternalSecurityConfig() interface{} {
|
||||
if c == nil || !c.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create the internal security config structure
|
||||
config := map[string]interface{}{
|
||||
"DevelopmentMode": false,
|
||||
}
|
||||
|
||||
// Apply profile-based defaults
|
||||
switch strings.ToLower(c.Profile) {
|
||||
case "strict":
|
||||
applyStrictProfile(config)
|
||||
case "development":
|
||||
applyDevelopmentProfile(config)
|
||||
case "api":
|
||||
applyAPIProfile(config)
|
||||
case "custom":
|
||||
// No defaults, use only what's explicitly configured
|
||||
default: // "default"
|
||||
applyDefaultProfile(config)
|
||||
}
|
||||
|
||||
// Override with explicit configuration
|
||||
if c.ContentSecurityPolicy != "" {
|
||||
config["ContentSecurityPolicy"] = c.ContentSecurityPolicy
|
||||
}
|
||||
|
||||
// HSTS configuration
|
||||
if c.StrictTransportSecurity {
|
||||
config["StrictTransportSecurityMaxAge"] = c.StrictTransportSecurityMaxAge
|
||||
config["StrictTransportSecuritySubdomains"] = c.StrictTransportSecuritySubdomains
|
||||
config["StrictTransportSecurityPreload"] = c.StrictTransportSecurityPreload
|
||||
}
|
||||
|
||||
// Frame options
|
||||
if c.FrameOptions != "" {
|
||||
config["FrameOptions"] = c.FrameOptions
|
||||
}
|
||||
|
||||
// Content type and XSS protection
|
||||
if c.ContentTypeOptions != "" {
|
||||
config["ContentTypeOptions"] = c.ContentTypeOptions
|
||||
}
|
||||
if c.XSSProtection != "" {
|
||||
config["XSSProtection"] = c.XSSProtection
|
||||
}
|
||||
|
||||
// Referrer and permissions policies
|
||||
if c.ReferrerPolicy != "" {
|
||||
config["ReferrerPolicy"] = c.ReferrerPolicy
|
||||
}
|
||||
if c.PermissionsPolicy != "" {
|
||||
config["PermissionsPolicy"] = c.PermissionsPolicy
|
||||
}
|
||||
|
||||
// Cross-origin policies
|
||||
if c.CrossOriginEmbedderPolicy != "" {
|
||||
config["CrossOriginEmbedderPolicy"] = c.CrossOriginEmbedderPolicy
|
||||
}
|
||||
if c.CrossOriginOpenerPolicy != "" {
|
||||
config["CrossOriginOpenerPolicy"] = c.CrossOriginOpenerPolicy
|
||||
}
|
||||
if c.CrossOriginResourcePolicy != "" {
|
||||
config["CrossOriginResourcePolicy"] = c.CrossOriginResourcePolicy
|
||||
}
|
||||
|
||||
// CORS configuration
|
||||
config["CORSEnabled"] = c.CORSEnabled
|
||||
if len(c.CORSAllowedOrigins) > 0 {
|
||||
config["CORSAllowedOrigins"] = c.CORSAllowedOrigins
|
||||
}
|
||||
if len(c.CORSAllowedMethods) > 0 {
|
||||
config["CORSAllowedMethods"] = c.CORSAllowedMethods
|
||||
}
|
||||
if len(c.CORSAllowedHeaders) > 0 {
|
||||
config["CORSAllowedHeaders"] = c.CORSAllowedHeaders
|
||||
}
|
||||
config["CORSAllowCredentials"] = c.CORSAllowCredentials
|
||||
if c.CORSMaxAge > 0 {
|
||||
config["CORSMaxAge"] = c.CORSMaxAge
|
||||
}
|
||||
|
||||
// Custom headers
|
||||
if len(c.CustomHeaders) > 0 {
|
||||
config["CustomHeaders"] = c.CustomHeaders
|
||||
}
|
||||
|
||||
// Security features
|
||||
config["DisableServerHeader"] = c.DisableServerHeader
|
||||
config["DisablePoweredByHeader"] = c.DisablePoweredByHeader
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
// applyDefaultProfile applies default security settings
|
||||
func applyDefaultProfile(config map[string]interface{}) {
|
||||
config["ContentSecurityPolicy"] = "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'; img-src 'self' data: https:; font-src 'self' data:; connect-src 'self'; frame-ancestors 'none';"
|
||||
config["FrameOptions"] = "DENY"
|
||||
config["ContentTypeOptions"] = "nosniff"
|
||||
config["XSSProtection"] = "1; mode=block"
|
||||
config["ReferrerPolicy"] = "strict-origin-when-cross-origin"
|
||||
config["PermissionsPolicy"] = "geolocation=(), microphone=(), camera=(), payment=(), usb=()"
|
||||
config["CrossOriginEmbedderPolicy"] = "require-corp"
|
||||
config["CrossOriginOpenerPolicy"] = "same-origin"
|
||||
config["CrossOriginResourcePolicy"] = "same-origin"
|
||||
}
|
||||
|
||||
// applyStrictProfile applies strict security settings
|
||||
func applyStrictProfile(config map[string]interface{}) {
|
||||
config["ContentSecurityPolicy"] = "default-src 'none'; script-src 'self'; style-src 'self'; img-src 'self'; font-src 'self'; connect-src 'self'; frame-ancestors 'none'; base-uri 'self'; form-action 'self';"
|
||||
config["FrameOptions"] = "DENY"
|
||||
config["ContentTypeOptions"] = "nosniff"
|
||||
config["XSSProtection"] = "1; mode=block"
|
||||
config["ReferrerPolicy"] = "strict-origin-when-cross-origin"
|
||||
config["PermissionsPolicy"] = "geolocation=(), microphone=(), camera=(), payment=(), usb=(), magnetometer=(), gyroscope=(), speaker=()"
|
||||
config["CrossOriginEmbedderPolicy"] = "require-corp"
|
||||
config["CrossOriginOpenerPolicy"] = "same-origin"
|
||||
config["CrossOriginResourcePolicy"] = "same-site"
|
||||
}
|
||||
|
||||
// applyDevelopmentProfile applies development-friendly settings
|
||||
func applyDevelopmentProfile(config map[string]interface{}) {
|
||||
config["ContentSecurityPolicy"] = "default-src 'self' 'unsafe-inline' 'unsafe-eval'; img-src 'self' data: https: http:; connect-src 'self' ws: wss:;"
|
||||
config["FrameOptions"] = "SAMEORIGIN"
|
||||
config["ContentTypeOptions"] = "nosniff"
|
||||
config["XSSProtection"] = "1; mode=block"
|
||||
config["ReferrerPolicy"] = "strict-origin-when-cross-origin"
|
||||
config["CrossOriginOpenerPolicy"] = "unsafe-none"
|
||||
config["CrossOriginResourcePolicy"] = "cross-origin"
|
||||
config["DevelopmentMode"] = true
|
||||
}
|
||||
|
||||
// applyAPIProfile applies API-friendly settings
|
||||
func applyAPIProfile(config map[string]interface{}) {
|
||||
config["ContentSecurityPolicy"] = "default-src 'none'; frame-ancestors 'none';"
|
||||
config["FrameOptions"] = "DENY"
|
||||
config["ContentTypeOptions"] = "nosniff"
|
||||
config["XSSProtection"] = "1; mode=block"
|
||||
config["ReferrerPolicy"] = "strict-origin-when-cross-origin"
|
||||
config["CrossOriginResourcePolicy"] = "cross-origin"
|
||||
}
|
||||
|
||||
// GetSecurityHeadersApplier returns a function that applies security headers
|
||||
func (c *Config) GetSecurityHeadersApplier() func(http.ResponseWriter, *http.Request) {
|
||||
if c.SecurityHeaders == nil || !c.SecurityHeaders.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
// This would need to import the internal security package
|
||||
// For now, return a basic implementation
|
||||
return func(rw http.ResponseWriter, req *http.Request) {
|
||||
headers := rw.Header()
|
||||
|
||||
// Apply basic security headers based on configuration
|
||||
if c.SecurityHeaders.FrameOptions != "" {
|
||||
headers.Set("X-Frame-Options", c.SecurityHeaders.FrameOptions)
|
||||
}
|
||||
if c.SecurityHeaders.ContentTypeOptions != "" {
|
||||
headers.Set("X-Content-Type-Options", c.SecurityHeaders.ContentTypeOptions)
|
||||
}
|
||||
if c.SecurityHeaders.XSSProtection != "" {
|
||||
headers.Set("X-XSS-Protection", c.SecurityHeaders.XSSProtection)
|
||||
}
|
||||
if c.SecurityHeaders.ReferrerPolicy != "" {
|
||||
headers.Set("Referrer-Policy", c.SecurityHeaders.ReferrerPolicy)
|
||||
}
|
||||
if c.SecurityHeaders.ContentSecurityPolicy != "" {
|
||||
headers.Set("Content-Security-Policy", c.SecurityHeaders.ContentSecurityPolicy)
|
||||
}
|
||||
|
||||
// HSTS for HTTPS
|
||||
if (req.TLS != nil || req.Header.Get("X-Forwarded-Proto") == "https") && c.SecurityHeaders.StrictTransportSecurity {
|
||||
hstsValue := fmt.Sprintf("max-age=%d", c.SecurityHeaders.StrictTransportSecurityMaxAge)
|
||||
if c.SecurityHeaders.StrictTransportSecuritySubdomains {
|
||||
hstsValue += "; includeSubDomains"
|
||||
}
|
||||
if c.SecurityHeaders.StrictTransportSecurityPreload {
|
||||
hstsValue += "; preload"
|
||||
}
|
||||
headers.Set("Strict-Transport-Security", hstsValue)
|
||||
}
|
||||
|
||||
// CORS headers
|
||||
if c.SecurityHeaders.CORSEnabled {
|
||||
origin := req.Header.Get("Origin")
|
||||
if origin != "" && isOriginAllowed(origin, c.SecurityHeaders.CORSAllowedOrigins) {
|
||||
headers.Set("Access-Control-Allow-Origin", origin)
|
||||
}
|
||||
|
||||
if len(c.SecurityHeaders.CORSAllowedMethods) > 0 {
|
||||
headers.Set("Access-Control-Allow-Methods", strings.Join(c.SecurityHeaders.CORSAllowedMethods, ", "))
|
||||
}
|
||||
if len(c.SecurityHeaders.CORSAllowedHeaders) > 0 {
|
||||
headers.Set("Access-Control-Allow-Headers", strings.Join(c.SecurityHeaders.CORSAllowedHeaders, ", "))
|
||||
}
|
||||
if c.SecurityHeaders.CORSAllowCredentials {
|
||||
headers.Set("Access-Control-Allow-Credentials", "true")
|
||||
}
|
||||
if c.SecurityHeaders.CORSMaxAge > 0 {
|
||||
headers.Set("Access-Control-Max-Age", strconv.Itoa(c.SecurityHeaders.CORSMaxAge))
|
||||
}
|
||||
}
|
||||
|
||||
// Custom headers
|
||||
for name, value := range c.SecurityHeaders.CustomHeaders {
|
||||
headers.Set(name, value)
|
||||
}
|
||||
|
||||
// Remove server headers
|
||||
if c.SecurityHeaders.DisableServerHeader {
|
||||
headers.Del("Server")
|
||||
}
|
||||
if c.SecurityHeaders.DisablePoweredByHeader {
|
||||
headers.Del("X-Powered-By")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// isOriginAllowed checks if an origin is in the allowed list
|
||||
func isOriginAllowed(origin string, allowedOrigins []string) bool {
|
||||
for _, allowed := range allowedOrigins {
|
||||
if origin == allowed || allowed == "*" {
|
||||
return true
|
||||
}
|
||||
// Simple wildcard matching for subdomains
|
||||
if strings.Contains(allowed, "*") {
|
||||
if strings.HasPrefix(allowed, "https://*.") {
|
||||
domain := strings.TrimPrefix(allowed, "https://*.")
|
||||
if strings.HasSuffix(origin, "."+domain) || origin == "https://"+domain {
|
||||
return true
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(allowed, "http://*.") {
|
||||
domain := strings.TrimPrefix(allowed, "http://*.")
|
||||
if strings.HasSuffix(origin, "."+domain) || origin == "http://"+domain {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -1,287 +0,0 @@
|
||||
// Package config provides unified configuration management for the OIDC middleware
|
||||
package config
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// UnifiedConfig is the master configuration structure consolidating all config aspects
|
||||
// This replaces 45 duplicate config structs across the codebase
|
||||
type UnifiedConfig struct {
|
||||
// Core Configuration
|
||||
Provider ProviderConfig `json:"provider" yaml:"provider"`
|
||||
Session SessionConfig `json:"session" yaml:"session"`
|
||||
Token TokenConfig `json:"token" yaml:"token"`
|
||||
Redis RedisConfig `json:"redis" yaml:"redis"`
|
||||
Security SecurityConfig `json:"security" yaml:"security"`
|
||||
|
||||
// Middleware Configuration
|
||||
Middleware MiddlewareConfig `json:"middleware" yaml:"middleware"`
|
||||
Cache CacheConfig `json:"cache" yaml:"cache"`
|
||||
RateLimit RateLimitConfig `json:"rateLimit" yaml:"rateLimit"`
|
||||
|
||||
// Operational Configuration
|
||||
Logging LoggingConfig `json:"logging" yaml:"logging"`
|
||||
Metrics MetricsConfig `json:"metrics" yaml:"metrics"`
|
||||
Health HealthConfig `json:"health" yaml:"health"`
|
||||
|
||||
// Advanced Configuration
|
||||
Transport TransportConfig `json:"transport" yaml:"transport"`
|
||||
Pool PoolConfig `json:"pool" yaml:"pool"`
|
||||
Circuit CircuitConfig `json:"circuit" yaml:"circuit"`
|
||||
|
||||
// Compatibility field for migration
|
||||
Legacy map[string]interface{} `json:"-" yaml:"-"`
|
||||
}
|
||||
|
||||
// ProviderConfig contains OIDC provider settings
|
||||
type ProviderConfig struct {
|
||||
IssuerURL string `json:"issuerURL" yaml:"issuerURL"`
|
||||
ClientID string `json:"clientID" yaml:"clientID"`
|
||||
ClientSecret string `json:"clientSecret" yaml:"clientSecret"`
|
||||
RedirectURL string `json:"redirectURL" yaml:"redirectURL"`
|
||||
LogoutURL string `json:"logoutURL" yaml:"logoutURL"`
|
||||
PostLogoutRedirectURI string `json:"postLogoutRedirectURI" yaml:"postLogoutRedirectURI"`
|
||||
Scopes []string `json:"scopes" yaml:"scopes"`
|
||||
OverrideScopes bool `json:"overrideScopes" yaml:"overrideScopes"`
|
||||
CustomClaims map[string]string `json:"customClaims" yaml:"customClaims"`
|
||||
JWKCachePeriod time.Duration `json:"jwkCachePeriod" yaml:"jwkCachePeriod"`
|
||||
MetadataCacheTTL time.Duration `json:"metadataCacheTTL" yaml:"metadataCacheTTL"`
|
||||
Discovery bool `json:"discovery" yaml:"discovery"`
|
||||
|
||||
// Provider-specific endpoints
|
||||
AuthorizationEndpoint string `json:"authorizationEndpoint,omitempty" yaml:"authorizationEndpoint,omitempty"`
|
||||
TokenEndpoint string `json:"tokenEndpoint,omitempty" yaml:"tokenEndpoint,omitempty"`
|
||||
UserInfoEndpoint string `json:"userInfoEndpoint,omitempty" yaml:"userInfoEndpoint,omitempty"`
|
||||
JWKSEndpoint string `json:"jwksEndpoint,omitempty" yaml:"jwksEndpoint,omitempty"`
|
||||
IntrospectEndpoint string `json:"introspectEndpoint,omitempty" yaml:"introspectEndpoint,omitempty"`
|
||||
RevocationEndpoint string `json:"revocationEndpoint,omitempty" yaml:"revocationEndpoint,omitempty"`
|
||||
}
|
||||
|
||||
// SessionConfig contains session management settings
|
||||
type SessionConfig struct {
|
||||
Name string `json:"name" yaml:"name"`
|
||||
MaxAge int `json:"maxAge" yaml:"maxAge"`
|
||||
Secret string `json:"secret" yaml:"secret"`
|
||||
EncryptionKey string `json:"encryptionKey" yaml:"encryptionKey"`
|
||||
SigningKey string `json:"signingKey" yaml:"signingKey"`
|
||||
ChunkSize int `json:"chunkSize" yaml:"chunkSize"`
|
||||
MaxChunks int `json:"maxChunks" yaml:"maxChunks"`
|
||||
|
||||
// Cookie settings
|
||||
Domain string `json:"domain" yaml:"domain"`
|
||||
Path string `json:"path" yaml:"path"`
|
||||
Secure bool `json:"secure" yaml:"secure"`
|
||||
HttpOnly bool `json:"httpOnly" yaml:"httpOnly"`
|
||||
SameSite string `json:"sameSite" yaml:"sameSite"`
|
||||
CookiePrefix string `json:"cookiePrefix" yaml:"cookiePrefix"` // Prefix for cookie names (e.g., "_oidc_myapp_")
|
||||
|
||||
// Storage settings
|
||||
StorageType string `json:"storageType" yaml:"storageType"` // "memory", "redis", "cookie"
|
||||
CleanupInterval time.Duration `json:"cleanupInterval" yaml:"cleanupInterval"`
|
||||
}
|
||||
|
||||
// TokenConfig contains token handling settings
|
||||
type TokenConfig struct {
|
||||
AccessTokenTTL time.Duration `json:"accessTokenTTL" yaml:"accessTokenTTL"`
|
||||
RefreshTokenTTL time.Duration `json:"refreshTokenTTL" yaml:"refreshTokenTTL"`
|
||||
RefreshGracePeriod time.Duration `json:"refreshGracePeriod" yaml:"refreshGracePeriod"`
|
||||
ValidationMode string `json:"validationMode" yaml:"validationMode"` // "jwt", "introspect", "hybrid"
|
||||
IntrospectURL string `json:"introspectURL" yaml:"introspectURL"`
|
||||
|
||||
// Token caching
|
||||
CacheEnabled bool `json:"cacheEnabled" yaml:"cacheEnabled"`
|
||||
CacheTTL time.Duration `json:"cacheTTL" yaml:"cacheTTL"`
|
||||
CacheNegativeTTL time.Duration `json:"cacheNegativeTTL" yaml:"cacheNegativeTTL"`
|
||||
|
||||
// Token validation
|
||||
ValidateSignature bool `json:"validateSignature" yaml:"validateSignature"`
|
||||
ValidateExpiry bool `json:"validateExpiry" yaml:"validateExpiry"`
|
||||
ValidateAudience bool `json:"validateAudience" yaml:"validateAudience"`
|
||||
ValidateIssuer bool `json:"validateIssuer" yaml:"validateIssuer"`
|
||||
RequiredClaims []string `json:"requiredClaims" yaml:"requiredClaims"`
|
||||
ClockSkew time.Duration `json:"clockSkew" yaml:"clockSkew"`
|
||||
}
|
||||
|
||||
// SecurityConfig contains security-related settings
|
||||
type SecurityConfig struct {
|
||||
ForceHTTPS bool `json:"forceHTTPS" yaml:"forceHTTPS"`
|
||||
EnablePKCE bool `json:"enablePKCE" yaml:"enablePKCE"`
|
||||
AllowedUsers []string `json:"allowedUsers" yaml:"allowedUsers"`
|
||||
AllowedUserDomains []string `json:"allowedUserDomains" yaml:"allowedUserDomains"`
|
||||
AllowedRolesAndGroups []string `json:"allowedRolesAndGroups" yaml:"allowedRolesAndGroups"`
|
||||
ExcludedURLs []string `json:"excludedURLs" yaml:"excludedURLs"`
|
||||
Headers *SecurityHeadersConfig `json:"headers" yaml:"headers"`
|
||||
|
||||
// CSRF protection
|
||||
CSRFProtection bool `json:"csrfProtection" yaml:"csrfProtection"`
|
||||
CSRFTokenName string `json:"csrfTokenName" yaml:"csrfTokenName"`
|
||||
CSRFTokenTTL time.Duration `json:"csrfTokenTTL" yaml:"csrfTokenTTL"`
|
||||
|
||||
// Additional security
|
||||
MaxLoginAttempts int `json:"maxLoginAttempts" yaml:"maxLoginAttempts"`
|
||||
LockoutDuration time.Duration `json:"lockoutDuration" yaml:"lockoutDuration"`
|
||||
RequireMFA bool `json:"requireMFA" yaml:"requireMFA"`
|
||||
}
|
||||
|
||||
// MiddlewareConfig contains middleware-specific settings
|
||||
type MiddlewareConfig struct {
|
||||
Priority int `json:"priority" yaml:"priority"`
|
||||
SkipPaths []string `json:"skipPaths" yaml:"skipPaths"`
|
||||
RequirePaths []string `json:"requirePaths" yaml:"requirePaths"`
|
||||
PassthroughMode bool `json:"passthroughMode" yaml:"passthroughMode"`
|
||||
|
||||
// Request handling
|
||||
MaxRequestSize int64 `json:"maxRequestSize" yaml:"maxRequestSize"`
|
||||
RequestTimeout time.Duration `json:"requestTimeout" yaml:"requestTimeout"`
|
||||
IdleTimeout time.Duration `json:"idleTimeout" yaml:"idleTimeout"`
|
||||
|
||||
// Response handling
|
||||
CustomHeaders map[string]string `json:"customHeaders" yaml:"customHeaders"`
|
||||
RemoveHeaders []string `json:"removeHeaders" yaml:"removeHeaders"`
|
||||
}
|
||||
|
||||
// CacheConfig contains cache configuration
|
||||
type CacheConfig struct {
|
||||
Enabled bool `json:"enabled" yaml:"enabled"`
|
||||
Type string `json:"type" yaml:"type"` // "memory", "redis", "hybrid"
|
||||
DefaultTTL time.Duration `json:"defaultTTL" yaml:"defaultTTL"`
|
||||
MaxEntries int `json:"maxEntries" yaml:"maxEntries"`
|
||||
MaxEntrySize int64 `json:"maxEntrySize" yaml:"maxEntrySize"`
|
||||
EvictionPolicy string `json:"evictionPolicy" yaml:"evictionPolicy"` // "lru", "lfu", "fifo"
|
||||
|
||||
// Memory cache settings
|
||||
CleanupInterval time.Duration `json:"cleanupInterval" yaml:"cleanupInterval"`
|
||||
|
||||
// Distributed cache settings
|
||||
Namespace string `json:"namespace" yaml:"namespace"`
|
||||
Compression bool `json:"compression" yaml:"compression"`
|
||||
Serialization string `json:"serialization" yaml:"serialization"` // "json", "msgpack", "protobuf"
|
||||
}
|
||||
|
||||
// RateLimitConfig contains rate limiting configuration
|
||||
type RateLimitConfig struct {
|
||||
Enabled bool `json:"enabled" yaml:"enabled"`
|
||||
RequestsPerSecond int `json:"requestsPerSecond" yaml:"requestsPerSecond"`
|
||||
Burst int `json:"burst" yaml:"burst"`
|
||||
|
||||
// Rate limit storage
|
||||
StorageType string `json:"storageType" yaml:"storageType"` // "memory", "redis"
|
||||
WindowDuration time.Duration `json:"windowDuration" yaml:"windowDuration"`
|
||||
|
||||
// Rate limit keys
|
||||
KeyType string `json:"keyType" yaml:"keyType"` // "ip", "user", "token", "custom"
|
||||
CustomKeyFunc string `json:"customKeyFunc" yaml:"customKeyFunc"`
|
||||
|
||||
// Whitelisting
|
||||
WhitelistIPs []string `json:"whitelistIPs" yaml:"whitelistIPs"`
|
||||
WhitelistUsers []string `json:"whitelistUsers" yaml:"whitelistUsers"`
|
||||
}
|
||||
|
||||
// LoggingConfig contains logging configuration
|
||||
type LoggingConfig struct {
|
||||
Level string `json:"level" yaml:"level"` // "debug", "info", "warn", "error"
|
||||
Format string `json:"format" yaml:"format"` // "json", "text", "structured"
|
||||
Output string `json:"output" yaml:"output"` // "stdout", "stderr", "file"
|
||||
FilePath string `json:"filePath" yaml:"filePath"`
|
||||
|
||||
// Log filtering
|
||||
FilterSensitive bool `json:"filterSensitive" yaml:"filterSensitive"`
|
||||
MaskFields []string `json:"maskFields" yaml:"maskFields"`
|
||||
|
||||
// Performance
|
||||
BufferSize int `json:"bufferSize" yaml:"bufferSize"`
|
||||
FlushInterval time.Duration `json:"flushInterval" yaml:"flushInterval"`
|
||||
|
||||
// Audit logging
|
||||
AuditEnabled bool `json:"auditEnabled" yaml:"auditEnabled"`
|
||||
AuditEvents []string `json:"auditEvents" yaml:"auditEvents"`
|
||||
}
|
||||
|
||||
// MetricsConfig contains metrics collection configuration
|
||||
type MetricsConfig struct {
|
||||
Enabled bool `json:"enabled" yaml:"enabled"`
|
||||
Provider string `json:"provider" yaml:"provider"` // "prometheus", "statsd", "otlp"
|
||||
Endpoint string `json:"endpoint" yaml:"endpoint"`
|
||||
Namespace string `json:"namespace" yaml:"namespace"`
|
||||
Subsystem string `json:"subsystem" yaml:"subsystem"`
|
||||
|
||||
// Collection settings
|
||||
CollectInterval time.Duration `json:"collectInterval" yaml:"collectInterval"`
|
||||
Histograms bool `json:"histograms" yaml:"histograms"`
|
||||
|
||||
// Custom labels
|
||||
Labels map[string]string `json:"labels" yaml:"labels"`
|
||||
}
|
||||
|
||||
// HealthConfig contains health check configuration
|
||||
type HealthConfig struct {
|
||||
Enabled bool `json:"enabled" yaml:"enabled"`
|
||||
Path string `json:"path" yaml:"path"`
|
||||
CheckInterval time.Duration `json:"checkInterval" yaml:"checkInterval"`
|
||||
Timeout time.Duration `json:"timeout" yaml:"timeout"`
|
||||
|
||||
// Checks to perform
|
||||
CheckProvider bool `json:"checkProvider" yaml:"checkProvider"`
|
||||
CheckRedis bool `json:"checkRedis" yaml:"checkRedis"`
|
||||
CheckCache bool `json:"checkCache" yaml:"checkCache"`
|
||||
|
||||
// Thresholds
|
||||
MaxLatency time.Duration `json:"maxLatency" yaml:"maxLatency"`
|
||||
MinMemory int64 `json:"minMemory" yaml:"minMemory"`
|
||||
}
|
||||
|
||||
// TransportConfig contains HTTP transport configuration
|
||||
type TransportConfig struct {
|
||||
MaxIdleConns int `json:"maxIdleConns" yaml:"maxIdleConns"`
|
||||
MaxIdleConnsPerHost int `json:"maxIdleConnsPerHost" yaml:"maxIdleConnsPerHost"`
|
||||
MaxConnsPerHost int `json:"maxConnsPerHost" yaml:"maxConnsPerHost"`
|
||||
IdleConnTimeout time.Duration `json:"idleConnTimeout" yaml:"idleConnTimeout"`
|
||||
TLSHandshakeTimeout time.Duration `json:"tlsHandshakeTimeout" yaml:"tlsHandshakeTimeout"`
|
||||
ExpectContinueTimeout time.Duration `json:"expectContinueTimeout" yaml:"expectContinueTimeout"`
|
||||
ResponseHeaderTimeout time.Duration `json:"responseHeaderTimeout" yaml:"responseHeaderTimeout"`
|
||||
DisableKeepAlives bool `json:"disableKeepAlives" yaml:"disableKeepAlives"`
|
||||
DisableCompression bool `json:"disableCompression" yaml:"disableCompression"`
|
||||
|
||||
// TLS configuration
|
||||
TLSInsecureSkipVerify bool `json:"tlsInsecureSkipVerify" yaml:"tlsInsecureSkipVerify"`
|
||||
TLSMinVersion string `json:"tlsMinVersion" yaml:"tlsMinVersion"`
|
||||
TLSCipherSuites []string `json:"tlsCipherSuites" yaml:"tlsCipherSuites"`
|
||||
|
||||
// Proxy settings
|
||||
ProxyURL string `json:"proxyURL" yaml:"proxyURL"`
|
||||
NoProxy []string `json:"noProxy" yaml:"noProxy"`
|
||||
}
|
||||
|
||||
// PoolConfig contains connection pool configuration
|
||||
type PoolConfig struct {
|
||||
Enabled bool `json:"enabled" yaml:"enabled"`
|
||||
Size int `json:"size" yaml:"size"`
|
||||
MinSize int `json:"minSize" yaml:"minSize"`
|
||||
MaxSize int `json:"maxSize" yaml:"maxSize"`
|
||||
MaxAge time.Duration `json:"maxAge" yaml:"maxAge"`
|
||||
IdleTimeout time.Duration `json:"idleTimeout" yaml:"idleTimeout"`
|
||||
WaitTimeout time.Duration `json:"waitTimeout" yaml:"waitTimeout"`
|
||||
|
||||
// Health checking
|
||||
HealthCheckInterval time.Duration `json:"healthCheckInterval" yaml:"healthCheckInterval"`
|
||||
MaxRetries int `json:"maxRetries" yaml:"maxRetries"`
|
||||
}
|
||||
|
||||
// CircuitConfig contains circuit breaker configuration
|
||||
type CircuitConfig struct {
|
||||
Enabled bool `json:"enabled" yaml:"enabled"`
|
||||
MaxRequests uint32 `json:"maxRequests" yaml:"maxRequests"`
|
||||
Interval time.Duration `json:"interval" yaml:"interval"`
|
||||
Timeout time.Duration `json:"timeout" yaml:"timeout"`
|
||||
ConsecutiveFailures uint32 `json:"consecutiveFailures" yaml:"consecutiveFailures"`
|
||||
FailureRatio float64 `json:"failureRatio" yaml:"failureRatio"`
|
||||
|
||||
// Circuit states
|
||||
OnOpen string `json:"onOpen" yaml:"onOpen"` // "reject", "fallback", "passthrough"
|
||||
OnHalfOpen string `json:"onHalfOpen" yaml:"onHalfOpen"`
|
||||
|
||||
// Monitoring
|
||||
MetricsEnabled bool `json:"metricsEnabled" yaml:"metricsEnabled"`
|
||||
LogStateChanges bool `json:"logStateChanges" yaml:"logStateChanges"`
|
||||
}
|
||||
@@ -1,263 +0,0 @@
|
||||
//go:build !yaegi
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// TestUnifiedConfigJSONMarshalling tests JSON marshalling with secret redaction
|
||||
func TestUnifiedConfigJSONMarshalling(t *testing.T) {
|
||||
config := &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://auth.example.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "super-secret-value",
|
||||
},
|
||||
Session: SessionConfig{
|
||||
Secret: "session-secret",
|
||||
EncryptionKey: "32-character-encryption-key-here",
|
||||
SigningKey: "signing-key-secret",
|
||||
},
|
||||
Redis: RedisConfig{
|
||||
Password: "redis-password",
|
||||
SentinelPassword: "sentinel-password",
|
||||
},
|
||||
}
|
||||
|
||||
// Marshal to JSON
|
||||
jsonBytes, err := json.Marshal(config)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal config to JSON: %v", err)
|
||||
}
|
||||
|
||||
jsonStr := string(jsonBytes)
|
||||
|
||||
// Verify secrets are redacted
|
||||
if !contains(jsonStr, `"clientSecret":"[REDACTED]"`) {
|
||||
t.Error("ClientSecret should be redacted in JSON output")
|
||||
}
|
||||
if !contains(jsonStr, `"secret":"[REDACTED]"`) {
|
||||
t.Error("Session.Secret should be redacted in JSON output")
|
||||
}
|
||||
if !contains(jsonStr, `"encryptionKey":"[REDACTED]"`) {
|
||||
t.Error("Session.EncryptionKey should be redacted in JSON output")
|
||||
}
|
||||
if !contains(jsonStr, `"signingKey":"[REDACTED]"`) {
|
||||
t.Error("Session.SigningKey should be redacted in JSON output")
|
||||
}
|
||||
if !contains(jsonStr, `"password":"[REDACTED]"`) {
|
||||
t.Error("Redis.Password should be redacted in JSON output")
|
||||
}
|
||||
if !contains(jsonStr, `"sentinelPassword":"[REDACTED]"`) {
|
||||
t.Error("Redis.SentinelPassword should be redacted in JSON output")
|
||||
}
|
||||
|
||||
// Verify non-secret fields are preserved
|
||||
if !contains(jsonStr, `"issuerURL":"https://auth.example.com"`) {
|
||||
t.Error("IssuerURL should be preserved in JSON output")
|
||||
}
|
||||
if !contains(jsonStr, `"clientID":"test-client"`) {
|
||||
t.Error("ClientID should be preserved in JSON output")
|
||||
}
|
||||
}
|
||||
|
||||
// TestUnifiedConfigYAMLMarshalling tests YAML marshalling with secret redaction
|
||||
func TestUnifiedConfigYAMLMarshalling(t *testing.T) {
|
||||
config := &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://auth.example.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "super-secret-value",
|
||||
},
|
||||
Session: SessionConfig{
|
||||
Secret: "session-secret",
|
||||
EncryptionKey: "32-character-encryption-key-here",
|
||||
SigningKey: "signing-key-secret",
|
||||
},
|
||||
Redis: RedisConfig{
|
||||
Password: "redis-password",
|
||||
SentinelPassword: "sentinel-password",
|
||||
},
|
||||
}
|
||||
|
||||
// Marshal to YAML
|
||||
yamlBytes, err := yaml.Marshal(config)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal config to YAML: %v", err)
|
||||
}
|
||||
|
||||
yamlStr := string(yamlBytes)
|
||||
|
||||
// Verify secrets are redacted
|
||||
if !contains(yamlStr, "clientSecret: '[REDACTED]'") {
|
||||
t.Error("ClientSecret should be redacted in YAML output")
|
||||
}
|
||||
if !contains(yamlStr, "secret: '[REDACTED]'") {
|
||||
t.Error("Session.Secret should be redacted in YAML output")
|
||||
}
|
||||
if !contains(yamlStr, "encryptionKey: '[REDACTED]'") {
|
||||
t.Error("Session.EncryptionKey should be redacted in YAML output")
|
||||
}
|
||||
if !contains(yamlStr, "signingKey: '[REDACTED]'") {
|
||||
t.Error("Session.SigningKey should be redacted in YAML output")
|
||||
}
|
||||
if !contains(yamlStr, "password: '[REDACTED]'") {
|
||||
t.Error("Redis.Password should be redacted in YAML output")
|
||||
}
|
||||
if !contains(yamlStr, "sentinelPassword: '[REDACTED]'") {
|
||||
t.Error("Redis.SentinelPassword should be redacted in YAML output")
|
||||
}
|
||||
|
||||
// Verify non-secret fields are preserved
|
||||
if !contains(yamlStr, "issuerURL: https://auth.example.com") {
|
||||
t.Error("IssuerURL should be preserved in YAML output")
|
||||
}
|
||||
if !contains(yamlStr, "clientID: test-client") {
|
||||
t.Error("ClientID should be preserved in YAML output")
|
||||
}
|
||||
}
|
||||
|
||||
// TestProviderConfigMarshalling tests individual struct marshalling
|
||||
func TestProviderConfigMarshalling(t *testing.T) {
|
||||
provider := ProviderConfig{
|
||||
IssuerURL: "https://auth.example.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "super-secret-value",
|
||||
}
|
||||
|
||||
// Test JSON marshalling
|
||||
jsonBytes, err := json.Marshal(provider)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal ProviderConfig to JSON: %v", err)
|
||||
}
|
||||
|
||||
jsonStr := string(jsonBytes)
|
||||
if !contains(jsonStr, `"clientSecret":"[REDACTED]"`) {
|
||||
t.Error("ClientSecret should be redacted in JSON output")
|
||||
}
|
||||
if !contains(jsonStr, `"clientID":"test-client"`) {
|
||||
t.Error("ClientID should be preserved in JSON output")
|
||||
}
|
||||
|
||||
// Test YAML marshalling
|
||||
yamlBytes, err := yaml.Marshal(provider)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal ProviderConfig to YAML: %v", err)
|
||||
}
|
||||
|
||||
yamlStr := string(yamlBytes)
|
||||
if !contains(yamlStr, "clientSecret: '[REDACTED]'") {
|
||||
t.Error("ClientSecret should be redacted in YAML output")
|
||||
}
|
||||
if !contains(yamlStr, "clientID: test-client") {
|
||||
t.Error("ClientID should be preserved in YAML output")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSessionConfigMarshalling tests session config marshalling
|
||||
func TestSessionConfigMarshalling(t *testing.T) {
|
||||
session := SessionConfig{
|
||||
Name: "session-cookie",
|
||||
Secret: "session-secret",
|
||||
EncryptionKey: "32-character-encryption-key-here",
|
||||
SigningKey: "signing-key-secret",
|
||||
Domain: "example.com",
|
||||
Secure: true,
|
||||
}
|
||||
|
||||
// Test JSON marshalling
|
||||
jsonBytes, err := json.Marshal(session)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal SessionConfig to JSON: %v", err)
|
||||
}
|
||||
|
||||
jsonStr := string(jsonBytes)
|
||||
if !contains(jsonStr, `"secret":"[REDACTED]"`) {
|
||||
t.Error("Secret should be redacted in JSON output")
|
||||
}
|
||||
if !contains(jsonStr, `"encryptionKey":"[REDACTED]"`) {
|
||||
t.Error("EncryptionKey should be redacted in JSON output")
|
||||
}
|
||||
if !contains(jsonStr, `"signingKey":"[REDACTED]"`) {
|
||||
t.Error("SigningKey should be redacted in JSON output")
|
||||
}
|
||||
if !contains(jsonStr, `"name":"session-cookie"`) {
|
||||
t.Error("Name should be preserved in JSON output")
|
||||
}
|
||||
if !contains(jsonStr, `"domain":"example.com"`) {
|
||||
t.Error("Domain should be preserved in JSON output")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRedisConfigMarshalling tests Redis config marshalling
|
||||
func TestRedisConfigMarshalling(t *testing.T) {
|
||||
redis := RedisConfig{
|
||||
Enabled: true,
|
||||
Mode: RedisModeCluster,
|
||||
Password: "redis-password",
|
||||
SentinelPassword: "sentinel-password",
|
||||
Addr: "localhost:6379",
|
||||
DB: 1,
|
||||
}
|
||||
|
||||
// Test JSON marshalling
|
||||
jsonBytes, err := json.Marshal(redis)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal RedisConfig to JSON: %v", err)
|
||||
}
|
||||
|
||||
jsonStr := string(jsonBytes)
|
||||
if !contains(jsonStr, `"password":"[REDACTED]"`) {
|
||||
t.Error("Password should be redacted in JSON output")
|
||||
}
|
||||
if !contains(jsonStr, `"sentinelPassword":"[REDACTED]"`) {
|
||||
t.Error("SentinelPassword should be redacted in JSON output")
|
||||
}
|
||||
if !contains(jsonStr, `"addr":"localhost:6379"`) {
|
||||
t.Error("Addr should be preserved in JSON output")
|
||||
}
|
||||
if !contains(jsonStr, `"db":1`) {
|
||||
t.Error("DB should be preserved in JSON output")
|
||||
}
|
||||
}
|
||||
|
||||
// TestEmptySecretsNotRedacted tests that empty secrets are not shown as redacted
|
||||
func TestEmptySecretsNotRedacted(t *testing.T) {
|
||||
config := &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://auth.example.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "", // Empty secret
|
||||
},
|
||||
Session: SessionConfig{
|
||||
Secret: "", // Empty secret
|
||||
EncryptionKey: "", // Empty secret
|
||||
},
|
||||
Redis: RedisConfig{
|
||||
Password: "", // Empty secret
|
||||
},
|
||||
}
|
||||
|
||||
// Marshal to JSON
|
||||
jsonBytes, err := json.Marshal(config)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal config to JSON: %v", err)
|
||||
}
|
||||
|
||||
jsonStr := string(jsonBytes)
|
||||
|
||||
// Verify empty secrets are not shown as redacted
|
||||
if contains(jsonStr, "[REDACTED]") {
|
||||
t.Error("Empty secrets should not be shown as [REDACTED]")
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to check if string contains substring
|
||||
func contains(s, substr string) bool {
|
||||
return strings.Contains(s, substr)
|
||||
}
|
||||
@@ -1,652 +0,0 @@
|
||||
// Package config provides validation for unified configuration
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ValidationError represents a configuration validation error
|
||||
type ValidationError struct {
|
||||
Field string
|
||||
Message string
|
||||
Value interface{}
|
||||
}
|
||||
|
||||
// Error implements the error interface
|
||||
func (e *ValidationError) Error() string {
|
||||
if e.Value != nil {
|
||||
return fmt.Sprintf("config validation error: %s: %s (value: %v)", e.Field, e.Message, e.Value)
|
||||
}
|
||||
return fmt.Sprintf("config validation error: %s: %s", e.Field, e.Message)
|
||||
}
|
||||
|
||||
// ValidationErrors represents multiple validation errors
|
||||
type ValidationErrors []ValidationError
|
||||
|
||||
// Error implements the error interface
|
||||
func (e ValidationErrors) Error() string {
|
||||
if len(e) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var messages []string
|
||||
for _, err := range e {
|
||||
messages = append(messages, err.Error())
|
||||
}
|
||||
return strings.Join(messages, "; ")
|
||||
}
|
||||
|
||||
// Validate performs comprehensive validation on the unified configuration
|
||||
func (c *UnifiedConfig) Validate() error {
|
||||
var errors ValidationErrors
|
||||
|
||||
// Validate Provider configuration
|
||||
if err := c.validateProvider(); err != nil {
|
||||
errors = append(errors, err...)
|
||||
}
|
||||
|
||||
// Validate Session configuration
|
||||
if err := c.validateSession(); err != nil {
|
||||
errors = append(errors, err...)
|
||||
}
|
||||
|
||||
// Validate Token configuration
|
||||
if err := c.validateToken(); err != nil {
|
||||
errors = append(errors, err...)
|
||||
}
|
||||
|
||||
// Validate Redis configuration (uses existing validation)
|
||||
if err := c.Redis.Validate(); err != nil {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Redis",
|
||||
Message: err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
// Validate Security configuration
|
||||
if err := c.validateSecurity(); err != nil {
|
||||
errors = append(errors, err...)
|
||||
}
|
||||
|
||||
// Validate Middleware configuration
|
||||
if err := c.validateMiddleware(); err != nil {
|
||||
errors = append(errors, err...)
|
||||
}
|
||||
|
||||
// Validate Cache configuration
|
||||
if err := c.validateCache(); err != nil {
|
||||
errors = append(errors, err...)
|
||||
}
|
||||
|
||||
// Validate RateLimit configuration
|
||||
if err := c.validateRateLimit(); err != nil {
|
||||
errors = append(errors, err...)
|
||||
}
|
||||
|
||||
// Validate Logging configuration
|
||||
if err := c.validateLogging(); err != nil {
|
||||
errors = append(errors, err...)
|
||||
}
|
||||
|
||||
// Validate Metrics configuration
|
||||
if err := c.validateMetrics(); err != nil {
|
||||
errors = append(errors, err...)
|
||||
}
|
||||
|
||||
// Validate Transport configuration
|
||||
if err := c.validateTransport(); err != nil {
|
||||
errors = append(errors, err...)
|
||||
}
|
||||
|
||||
// Validate Circuit configuration
|
||||
if err := c.validateCircuit(); err != nil {
|
||||
errors = append(errors, err...)
|
||||
}
|
||||
|
||||
if len(errors) > 0 {
|
||||
return errors
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateProvider validates provider configuration
|
||||
func (c *UnifiedConfig) validateProvider() ValidationErrors {
|
||||
var errors ValidationErrors
|
||||
|
||||
// IssuerURL is required and must be a valid URL
|
||||
if c.Provider.IssuerURL == "" {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Provider.IssuerURL",
|
||||
Message: "issuer URL is required",
|
||||
})
|
||||
} else if _, err := url.Parse(c.Provider.IssuerURL); err != nil {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Provider.IssuerURL",
|
||||
Message: "invalid issuer URL",
|
||||
Value: c.Provider.IssuerURL,
|
||||
})
|
||||
}
|
||||
|
||||
// ClientID is required
|
||||
if c.Provider.ClientID == "" {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Provider.ClientID",
|
||||
Message: "client ID is required",
|
||||
})
|
||||
}
|
||||
|
||||
// ClientSecret is required (except for public clients with PKCE)
|
||||
if c.Provider.ClientSecret == "" && !c.Security.EnablePKCE {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Provider.ClientSecret",
|
||||
Message: "client secret is required (or enable PKCE for public clients)",
|
||||
})
|
||||
}
|
||||
|
||||
// RedirectURL must be valid if provided
|
||||
if c.Provider.RedirectURL != "" {
|
||||
if _, err := url.Parse(c.Provider.RedirectURL); err != nil {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Provider.RedirectURL",
|
||||
Message: "invalid redirect URL",
|
||||
Value: c.Provider.RedirectURL,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Scopes must include 'openid' for OIDC
|
||||
hasOpenID := false
|
||||
for _, scope := range c.Provider.Scopes {
|
||||
if scope == "openid" {
|
||||
hasOpenID = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasOpenID && !c.Provider.OverrideScopes {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Provider.Scopes",
|
||||
Message: "scopes must include 'openid' for OIDC",
|
||||
Value: c.Provider.Scopes,
|
||||
})
|
||||
}
|
||||
|
||||
// JWK cache period must be positive
|
||||
if c.Provider.JWKCachePeriod < 0 {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Provider.JWKCachePeriod",
|
||||
Message: "JWK cache period must be positive",
|
||||
Value: c.Provider.JWKCachePeriod,
|
||||
})
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
|
||||
// validateSession validates session configuration
|
||||
func (c *UnifiedConfig) validateSession() ValidationErrors {
|
||||
var errors ValidationErrors
|
||||
|
||||
// Session name must not be empty
|
||||
if c.Session.Name == "" {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Session.Name",
|
||||
Message: "session name is required",
|
||||
})
|
||||
}
|
||||
|
||||
// Session secret or encryption key is required
|
||||
if c.Session.Secret == "" && c.Session.EncryptionKey == "" {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Session",
|
||||
Message: "either session secret or encryption key is required",
|
||||
})
|
||||
}
|
||||
|
||||
// Encryption key must be at least 32 bytes for security
|
||||
if c.Session.EncryptionKey != "" && len(c.Session.EncryptionKey) < 32 {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Session.EncryptionKey",
|
||||
Message: "encryption key must be at least 32 characters for proper security",
|
||||
Value: len(c.Session.EncryptionKey),
|
||||
})
|
||||
}
|
||||
|
||||
// ChunkSize must be reasonable (between 1KB and 10KB)
|
||||
if c.Session.ChunkSize < 1000 || c.Session.ChunkSize > 10000 {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Session.ChunkSize",
|
||||
Message: "chunk size must be between 1000 and 10000 bytes",
|
||||
Value: c.Session.ChunkSize,
|
||||
})
|
||||
}
|
||||
|
||||
// MaxChunks must be reasonable (between 1 and 100)
|
||||
if c.Session.MaxChunks < 1 || c.Session.MaxChunks > 100 {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Session.MaxChunks",
|
||||
Message: "max chunks must be between 1 and 100",
|
||||
Value: c.Session.MaxChunks,
|
||||
})
|
||||
}
|
||||
|
||||
// SameSite must be valid
|
||||
validSameSite := map[string]bool{
|
||||
"": true,
|
||||
"Lax": true,
|
||||
"Strict": true,
|
||||
"None": true,
|
||||
}
|
||||
if !validSameSite[c.Session.SameSite] {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Session.SameSite",
|
||||
Message: "invalid SameSite value (must be Lax, Strict, or None)",
|
||||
Value: c.Session.SameSite,
|
||||
})
|
||||
}
|
||||
|
||||
// StorageType must be valid
|
||||
validStorage := map[string]bool{
|
||||
"memory": true,
|
||||
"redis": true,
|
||||
"cookie": true,
|
||||
}
|
||||
if !validStorage[c.Session.StorageType] {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Session.StorageType",
|
||||
Message: "invalid storage type (must be memory, redis, or cookie)",
|
||||
Value: c.Session.StorageType,
|
||||
})
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
|
||||
// validateToken validates token configuration
|
||||
func (c *UnifiedConfig) validateToken() ValidationErrors {
|
||||
var errors ValidationErrors
|
||||
|
||||
// Token TTLs must be positive
|
||||
if c.Token.AccessTokenTTL <= 0 {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Token.AccessTokenTTL",
|
||||
Message: "access token TTL must be positive",
|
||||
Value: c.Token.AccessTokenTTL,
|
||||
})
|
||||
}
|
||||
|
||||
if c.Token.RefreshTokenTTL <= 0 {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Token.RefreshTokenTTL",
|
||||
Message: "refresh token TTL must be positive",
|
||||
Value: c.Token.RefreshTokenTTL,
|
||||
})
|
||||
}
|
||||
|
||||
// Validation mode must be valid
|
||||
validModes := map[string]bool{
|
||||
"jwt": true,
|
||||
"introspect": true,
|
||||
"hybrid": true,
|
||||
}
|
||||
if !validModes[c.Token.ValidationMode] {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Token.ValidationMode",
|
||||
Message: "invalid validation mode (must be jwt, introspect, or hybrid)",
|
||||
Value: c.Token.ValidationMode,
|
||||
})
|
||||
}
|
||||
|
||||
// Introspect URL required for introspect or hybrid mode
|
||||
if (c.Token.ValidationMode == "introspect" || c.Token.ValidationMode == "hybrid") && c.Token.IntrospectURL == "" {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Token.IntrospectURL",
|
||||
Message: "introspect URL is required for introspect or hybrid validation mode",
|
||||
})
|
||||
}
|
||||
|
||||
// Clock skew must be reasonable (0 to 10 minutes)
|
||||
if c.Token.ClockSkew < 0 || c.Token.ClockSkew > 10*time.Minute {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Token.ClockSkew",
|
||||
Message: "clock skew must be between 0 and 10 minutes",
|
||||
Value: c.Token.ClockSkew,
|
||||
})
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
|
||||
// validateSecurity validates security configuration
|
||||
func (c *UnifiedConfig) validateSecurity() ValidationErrors {
|
||||
var errors ValidationErrors
|
||||
|
||||
// Validate allowed user domains are valid domains
|
||||
domainRegex := regexp.MustCompile(`^([a-zA-Z0-9-]+\.)*[a-zA-Z0-9-]+\.[a-zA-Z]{2,}$`)
|
||||
for _, domain := range c.Security.AllowedUserDomains {
|
||||
if !domainRegex.MatchString(domain) {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Security.AllowedUserDomains",
|
||||
Message: "invalid domain format",
|
||||
Value: domain,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Max login attempts must be reasonable
|
||||
if c.Security.MaxLoginAttempts < 0 || c.Security.MaxLoginAttempts > 100 {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Security.MaxLoginAttempts",
|
||||
Message: "max login attempts must be between 0 and 100",
|
||||
Value: c.Security.MaxLoginAttempts,
|
||||
})
|
||||
}
|
||||
|
||||
// Lockout duration must be reasonable
|
||||
if c.Security.LockoutDuration < 0 || c.Security.LockoutDuration > 24*time.Hour {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Security.LockoutDuration",
|
||||
Message: "lockout duration must be between 0 and 24 hours",
|
||||
Value: c.Security.LockoutDuration,
|
||||
})
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
|
||||
// validateMiddleware validates middleware configuration
|
||||
func (c *UnifiedConfig) validateMiddleware() ValidationErrors {
|
||||
var errors ValidationErrors
|
||||
|
||||
// Max request size must be reasonable (1KB to 100MB)
|
||||
if c.Middleware.MaxRequestSize < 1024 || c.Middleware.MaxRequestSize > 100*1024*1024 {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Middleware.MaxRequestSize",
|
||||
Message: "max request size must be between 1KB and 100MB",
|
||||
Value: c.Middleware.MaxRequestSize,
|
||||
})
|
||||
}
|
||||
|
||||
// Request timeout must be reasonable
|
||||
if c.Middleware.RequestTimeout < time.Second || c.Middleware.RequestTimeout > 5*time.Minute {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Middleware.RequestTimeout",
|
||||
Message: "request timeout must be between 1 second and 5 minutes",
|
||||
Value: c.Middleware.RequestTimeout,
|
||||
})
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
|
||||
// validateCache validates cache configuration
|
||||
func (c *UnifiedConfig) validateCache() ValidationErrors {
|
||||
var errors ValidationErrors
|
||||
|
||||
if !c.Cache.Enabled {
|
||||
return errors
|
||||
}
|
||||
|
||||
// Cache type must be valid
|
||||
validTypes := map[string]bool{
|
||||
"memory": true,
|
||||
"redis": true,
|
||||
"hybrid": true,
|
||||
}
|
||||
if !validTypes[c.Cache.Type] {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Cache.Type",
|
||||
Message: "invalid cache type (must be memory, redis, or hybrid)",
|
||||
Value: c.Cache.Type,
|
||||
})
|
||||
}
|
||||
|
||||
// Max entries must be reasonable
|
||||
if c.Cache.MaxEntries < 10 || c.Cache.MaxEntries > 1000000 {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Cache.MaxEntries",
|
||||
Message: "max entries must be between 10 and 1000000",
|
||||
Value: c.Cache.MaxEntries,
|
||||
})
|
||||
}
|
||||
|
||||
// Eviction policy must be valid
|
||||
validEviction := map[string]bool{
|
||||
"lru": true,
|
||||
"lfu": true,
|
||||
"fifo": true,
|
||||
}
|
||||
if !validEviction[c.Cache.EvictionPolicy] {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Cache.EvictionPolicy",
|
||||
Message: "invalid eviction policy (must be lru, lfu, or fifo)",
|
||||
Value: c.Cache.EvictionPolicy,
|
||||
})
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
|
||||
// validateRateLimit validates rate limiting configuration
|
||||
func (c *UnifiedConfig) validateRateLimit() ValidationErrors {
|
||||
var errors ValidationErrors
|
||||
|
||||
if !c.RateLimit.Enabled {
|
||||
return errors
|
||||
}
|
||||
|
||||
// Requests per second must be reasonable
|
||||
if c.RateLimit.RequestsPerSecond < 1 || c.RateLimit.RequestsPerSecond > 10000 {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "RateLimit.RequestsPerSecond",
|
||||
Message: "requests per second must be between 1 and 10000",
|
||||
Value: c.RateLimit.RequestsPerSecond,
|
||||
})
|
||||
}
|
||||
|
||||
// Burst must be at least as large as requests per second
|
||||
if c.RateLimit.Burst < c.RateLimit.RequestsPerSecond {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "RateLimit.Burst",
|
||||
Message: "burst must be at least as large as requests per second",
|
||||
Value: c.RateLimit.Burst,
|
||||
})
|
||||
}
|
||||
|
||||
// Key type must be valid
|
||||
validKeyTypes := map[string]bool{
|
||||
"ip": true,
|
||||
"user": true,
|
||||
"token": true,
|
||||
"custom": true,
|
||||
}
|
||||
if !validKeyTypes[c.RateLimit.KeyType] {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "RateLimit.KeyType",
|
||||
Message: "invalid key type (must be ip, user, token, or custom)",
|
||||
Value: c.RateLimit.KeyType,
|
||||
})
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
|
||||
// validateLogging validates logging configuration
|
||||
func (c *UnifiedConfig) validateLogging() ValidationErrors {
|
||||
var errors ValidationErrors
|
||||
|
||||
// Log level must be valid
|
||||
validLevels := map[string]bool{
|
||||
"debug": true,
|
||||
"info": true,
|
||||
"warn": true,
|
||||
"error": true,
|
||||
}
|
||||
if !validLevels[c.Logging.Level] {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Logging.Level",
|
||||
Message: "invalid log level (must be debug, info, warn, or error)",
|
||||
Value: c.Logging.Level,
|
||||
})
|
||||
}
|
||||
|
||||
// Format must be valid
|
||||
validFormats := map[string]bool{
|
||||
"json": true,
|
||||
"text": true,
|
||||
"structured": true,
|
||||
}
|
||||
if !validFormats[c.Logging.Format] {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Logging.Format",
|
||||
Message: "invalid log format (must be json, text, or structured)",
|
||||
Value: c.Logging.Format,
|
||||
})
|
||||
}
|
||||
|
||||
// Output must be valid
|
||||
validOutputs := map[string]bool{
|
||||
"stdout": true,
|
||||
"stderr": true,
|
||||
"file": true,
|
||||
}
|
||||
if !validOutputs[c.Logging.Output] {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Logging.Output",
|
||||
Message: "invalid log output (must be stdout, stderr, or file)",
|
||||
Value: c.Logging.Output,
|
||||
})
|
||||
}
|
||||
|
||||
// File path required if output is file
|
||||
if c.Logging.Output == "file" && c.Logging.FilePath == "" {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Logging.FilePath",
|
||||
Message: "file path is required when output is 'file'",
|
||||
})
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
|
||||
// validateMetrics validates metrics configuration
|
||||
func (c *UnifiedConfig) validateMetrics() ValidationErrors {
|
||||
var errors ValidationErrors
|
||||
|
||||
if !c.Metrics.Enabled {
|
||||
return errors
|
||||
}
|
||||
|
||||
// Provider must be valid
|
||||
validProviders := map[string]bool{
|
||||
"prometheus": true,
|
||||
"statsd": true,
|
||||
"otlp": true,
|
||||
}
|
||||
if !validProviders[c.Metrics.Provider] {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Metrics.Provider",
|
||||
Message: "invalid metrics provider (must be prometheus, statsd, or otlp)",
|
||||
Value: c.Metrics.Provider,
|
||||
})
|
||||
}
|
||||
|
||||
// Endpoint required for some providers
|
||||
if (c.Metrics.Provider == "statsd" || c.Metrics.Provider == "otlp") && c.Metrics.Endpoint == "" {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Metrics.Endpoint",
|
||||
Message: fmt.Sprintf("endpoint is required for %s provider", c.Metrics.Provider),
|
||||
})
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
|
||||
// validateTransport validates transport configuration
|
||||
func (c *UnifiedConfig) validateTransport() ValidationErrors {
|
||||
var errors ValidationErrors
|
||||
|
||||
// Max connections must be reasonable
|
||||
if c.Transport.MaxIdleConns < 0 || c.Transport.MaxIdleConns > 10000 {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Transport.MaxIdleConns",
|
||||
Message: "max idle connections must be between 0 and 10000",
|
||||
Value: c.Transport.MaxIdleConns,
|
||||
})
|
||||
}
|
||||
|
||||
// TLS min version must be valid
|
||||
validTLSVersions := map[string]bool{
|
||||
"TLS1.0": true,
|
||||
"TLS1.1": true,
|
||||
"TLS1.2": true,
|
||||
"TLS1.3": true,
|
||||
}
|
||||
if c.Transport.TLSMinVersion != "" && !validTLSVersions[c.Transport.TLSMinVersion] {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Transport.TLSMinVersion",
|
||||
Message: "invalid TLS min version (must be TLS1.0, TLS1.1, TLS1.2, or TLS1.3)",
|
||||
Value: c.Transport.TLSMinVersion,
|
||||
})
|
||||
}
|
||||
|
||||
// Proxy URL must be valid if provided
|
||||
if c.Transport.ProxyURL != "" {
|
||||
if _, err := url.Parse(c.Transport.ProxyURL); err != nil {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Transport.ProxyURL",
|
||||
Message: "invalid proxy URL",
|
||||
Value: c.Transport.ProxyURL,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
|
||||
// validateCircuit validates circuit breaker configuration
|
||||
func (c *UnifiedConfig) validateCircuit() ValidationErrors {
|
||||
var errors ValidationErrors
|
||||
|
||||
if !c.Circuit.Enabled {
|
||||
return errors
|
||||
}
|
||||
|
||||
// Consecutive failures must be reasonable
|
||||
if c.Circuit.ConsecutiveFailures < 1 || c.Circuit.ConsecutiveFailures > 100 {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Circuit.ConsecutiveFailures",
|
||||
Message: "consecutive failures must be between 1 and 100",
|
||||
Value: c.Circuit.ConsecutiveFailures,
|
||||
})
|
||||
}
|
||||
|
||||
// Failure ratio must be between 0 and 1
|
||||
if c.Circuit.FailureRatio < 0 || c.Circuit.FailureRatio > 1 {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Circuit.FailureRatio",
|
||||
Message: "failure ratio must be between 0 and 1",
|
||||
Value: c.Circuit.FailureRatio,
|
||||
})
|
||||
}
|
||||
|
||||
// OnOpen action must be valid
|
||||
validActions := map[string]bool{
|
||||
"reject": true,
|
||||
"fallback": true,
|
||||
"passthrough": true,
|
||||
}
|
||||
if !validActions[c.Circuit.OnOpen] {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Circuit.OnOpen",
|
||||
Message: "invalid OnOpen action (must be reject, fallback, or passthrough)",
|
||||
Value: c.Circuit.OnOpen,
|
||||
})
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
@@ -1,588 +0,0 @@
|
||||
//go:build !yaegi
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestValidateUnifiedConfig tests the validation of UnifiedConfig
|
||||
func TestValidateUnifiedConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config *UnifiedConfig
|
||||
expectError bool
|
||||
errorField string
|
||||
}{
|
||||
{
|
||||
name: "valid config with minimum requirements",
|
||||
config: &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://auth.example.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "secret",
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
},
|
||||
Session: SessionConfig{
|
||||
Name: "oidc_session",
|
||||
EncryptionKey: "this-is-a-32-character-key-12345",
|
||||
ChunkSize: 4000,
|
||||
MaxChunks: 5,
|
||||
StorageType: "cookie",
|
||||
},
|
||||
Token: TokenConfig{
|
||||
AccessTokenTTL: time.Hour,
|
||||
RefreshTokenTTL: 24 * time.Hour,
|
||||
ValidationMode: "jwt",
|
||||
},
|
||||
Middleware: MiddlewareConfig{
|
||||
MaxRequestSize: 10 * 1024 * 1024,
|
||||
RequestTimeout: 30 * time.Second,
|
||||
},
|
||||
Logging: LoggingConfig{
|
||||
Level: "info",
|
||||
Format: "json",
|
||||
Output: "stdout",
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "missing provider URL",
|
||||
config: &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "secret",
|
||||
},
|
||||
Session: SessionConfig{
|
||||
EncryptionKey: "this-is-a-32-character-key-12345",
|
||||
},
|
||||
},
|
||||
expectError: true,
|
||||
errorField: "Provider.IssuerURL",
|
||||
},
|
||||
{
|
||||
name: "missing client ID",
|
||||
config: &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://auth.example.com",
|
||||
ClientSecret: "secret",
|
||||
},
|
||||
Session: SessionConfig{
|
||||
EncryptionKey: "this-is-a-32-character-key-12345",
|
||||
},
|
||||
},
|
||||
expectError: true,
|
||||
errorField: "Provider.ClientID",
|
||||
},
|
||||
{
|
||||
name: "encryption key too short",
|
||||
config: &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://auth.example.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "secret",
|
||||
},
|
||||
Session: SessionConfig{
|
||||
EncryptionKey: "too-short",
|
||||
},
|
||||
},
|
||||
expectError: true,
|
||||
errorField: "Session.EncryptionKey",
|
||||
},
|
||||
{
|
||||
name: "invalid chunk size",
|
||||
config: &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://auth.example.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "secret",
|
||||
},
|
||||
Session: SessionConfig{
|
||||
EncryptionKey: "this-is-a-32-character-key-12345",
|
||||
ChunkSize: 500, // Too small
|
||||
},
|
||||
},
|
||||
expectError: true,
|
||||
errorField: "Session.ChunkSize",
|
||||
},
|
||||
{
|
||||
name: "invalid max chunks",
|
||||
config: &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://auth.example.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "secret",
|
||||
},
|
||||
Session: SessionConfig{
|
||||
EncryptionKey: "this-is-a-32-character-key-12345",
|
||||
ChunkSize: 4000,
|
||||
MaxChunks: 0, // Too small
|
||||
},
|
||||
},
|
||||
expectError: true,
|
||||
errorField: "Session.MaxChunks",
|
||||
},
|
||||
{
|
||||
name: "invalid TLS min version",
|
||||
config: &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://auth.example.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "secret",
|
||||
},
|
||||
Session: SessionConfig{
|
||||
EncryptionKey: "this-is-a-32-character-key-12345",
|
||||
},
|
||||
Transport: TransportConfig{
|
||||
TLSMinVersion: "1.0", // Too old
|
||||
},
|
||||
},
|
||||
expectError: true,
|
||||
errorField: "Transport.TLSMinVersion",
|
||||
},
|
||||
{
|
||||
name: "invalid circuit breaker failure ratio",
|
||||
config: &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://auth.example.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "secret",
|
||||
},
|
||||
Session: SessionConfig{
|
||||
EncryptionKey: "this-is-a-32-character-key-12345",
|
||||
},
|
||||
Circuit: CircuitConfig{
|
||||
Enabled: true,
|
||||
FailureRatio: 1.5, // Too high
|
||||
},
|
||||
},
|
||||
expectError: true,
|
||||
errorField: "Circuit.FailureRatio",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.config.Validate()
|
||||
|
||||
if tt.expectError {
|
||||
if err == nil {
|
||||
t.Errorf("Expected validation error for field %s, but got none", tt.errorField)
|
||||
} else if validationErrs, ok := err.(ValidationErrors); ok {
|
||||
found := false
|
||||
for _, e := range validationErrs {
|
||||
if e.Field == tt.errorField {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("Expected validation error for field %s, but got errors for: %v",
|
||||
tt.errorField, validationErrs)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("Expected no validation error, but got: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidationErrorMessage tests validation error formatting
|
||||
func TestValidationErrorMessage(t *testing.T) {
|
||||
errs := ValidationErrors{
|
||||
{
|
||||
Field: "Provider.IssuerURL",
|
||||
Message: "is required",
|
||||
Value: nil,
|
||||
},
|
||||
{
|
||||
Field: "Session.EncryptionKey",
|
||||
Message: "must be at least 32 characters",
|
||||
Value: 16,
|
||||
},
|
||||
}
|
||||
|
||||
errMsg := errs.Error()
|
||||
|
||||
if !strings.Contains(errMsg, "Provider.IssuerURL") {
|
||||
t.Error("Error message should contain field name Provider.IssuerURL")
|
||||
}
|
||||
if !strings.Contains(errMsg, "is required") {
|
||||
t.Error("Error message should contain 'is required'")
|
||||
}
|
||||
if !strings.Contains(errMsg, "Session.EncryptionKey") {
|
||||
t.Error("Error message should contain field name Session.EncryptionKey")
|
||||
}
|
||||
if !strings.Contains(errMsg, "must be at least 32 characters") {
|
||||
t.Error("Error message should contain 'must be at least 32 characters'")
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateRedisConfig tests Redis configuration validation
|
||||
func TestValidateRedisConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config *RedisConfig
|
||||
expectError bool
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "valid standalone config",
|
||||
config: &RedisConfig{
|
||||
Enabled: true,
|
||||
Mode: RedisModeStandalone,
|
||||
Addr: "localhost:6379",
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "missing address for standalone",
|
||||
config: &RedisConfig{
|
||||
Enabled: true,
|
||||
Mode: RedisModeStandalone,
|
||||
Addr: "",
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "Redis address is required",
|
||||
},
|
||||
{
|
||||
name: "valid cluster config",
|
||||
config: &RedisConfig{
|
||||
Enabled: true,
|
||||
Mode: RedisModeCluster,
|
||||
ClusterAddrs: []string{"localhost:7000", "localhost:7001"},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "missing cluster addresses",
|
||||
config: &RedisConfig{
|
||||
Enabled: true,
|
||||
Mode: RedisModeCluster,
|
||||
ClusterAddrs: []string{},
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "cluster address is required",
|
||||
},
|
||||
{
|
||||
name: "valid sentinel config",
|
||||
config: &RedisConfig{
|
||||
Enabled: true,
|
||||
Mode: RedisModeSentinel,
|
||||
MasterName: "mymaster",
|
||||
SentinelAddrs: []string{"localhost:26379"},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "missing master name for sentinel",
|
||||
config: &RedisConfig{
|
||||
Enabled: true,
|
||||
Mode: RedisModeSentinel,
|
||||
MasterName: "",
|
||||
SentinelAddrs: []string{"localhost:26379"},
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "Master name is required",
|
||||
},
|
||||
{
|
||||
name: "missing sentinel addresses",
|
||||
config: &RedisConfig{
|
||||
Enabled: true,
|
||||
Mode: RedisModeSentinel,
|
||||
MasterName: "mymaster",
|
||||
SentinelAddrs: []string{},
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "sentinel address is required",
|
||||
},
|
||||
{
|
||||
name: "disabled redis needs no validation",
|
||||
config: &RedisConfig{
|
||||
Enabled: false,
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid redis mode",
|
||||
config: &RedisConfig{
|
||||
Enabled: true,
|
||||
Mode: "invalid-mode",
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "Invalid Redis mode",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.config.Validate()
|
||||
|
||||
if tt.expectError {
|
||||
if err == nil {
|
||||
t.Errorf("Expected validation error containing '%s', but got none", tt.errorMsg)
|
||||
} else if !strings.Contains(err.Error(), tt.errorMsg) {
|
||||
t.Errorf("Expected error message to contain '%s', but got: %v", tt.errorMsg, err)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("Expected no validation error, but got: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// validateRateLimit Tests
|
||||
// ============================================================================
|
||||
|
||||
func TestValidateRateLimit_Disabled(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.RateLimit.Enabled = false
|
||||
|
||||
errors := config.validateRateLimit()
|
||||
|
||||
assert.Empty(t, errors, "Should have no errors when rate limiting is disabled")
|
||||
}
|
||||
|
||||
func TestValidateRateLimit_ValidConfig(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.RateLimit.Enabled = true
|
||||
config.RateLimit.RequestsPerSecond = 100
|
||||
config.RateLimit.Burst = 200
|
||||
config.RateLimit.KeyType = "ip"
|
||||
|
||||
errors := config.validateRateLimit()
|
||||
|
||||
assert.Empty(t, errors, "Should have no errors for valid rate limit config")
|
||||
}
|
||||
|
||||
func TestValidateRateLimit_RequestsPerSecondTooLow(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.RateLimit.Enabled = true
|
||||
config.RateLimit.RequestsPerSecond = 0
|
||||
config.RateLimit.Burst = 100
|
||||
config.RateLimit.KeyType = "ip"
|
||||
|
||||
errors := config.validateRateLimit()
|
||||
|
||||
require.Len(t, errors, 1)
|
||||
assert.Equal(t, "RateLimit.RequestsPerSecond", errors[0].Field)
|
||||
assert.Contains(t, errors[0].Message, "between 1 and 10000")
|
||||
}
|
||||
|
||||
func TestValidateRateLimit_RequestsPerSecondTooHigh(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.RateLimit.Enabled = true
|
||||
config.RateLimit.RequestsPerSecond = 15000
|
||||
config.RateLimit.Burst = 20000
|
||||
config.RateLimit.KeyType = "ip"
|
||||
|
||||
errors := config.validateRateLimit()
|
||||
|
||||
require.Len(t, errors, 1)
|
||||
assert.Equal(t, "RateLimit.RequestsPerSecond", errors[0].Field)
|
||||
assert.Contains(t, errors[0].Message, "between 1 and 10000")
|
||||
}
|
||||
|
||||
func TestValidateRateLimit_BurstTooSmall(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.RateLimit.Enabled = true
|
||||
config.RateLimit.RequestsPerSecond = 100
|
||||
config.RateLimit.Burst = 50 // Less than RequestsPerSecond
|
||||
config.RateLimit.KeyType = "ip"
|
||||
|
||||
errors := config.validateRateLimit()
|
||||
|
||||
require.Len(t, errors, 1)
|
||||
assert.Equal(t, "RateLimit.Burst", errors[0].Field)
|
||||
assert.Contains(t, errors[0].Message, "at least as large as requests per second")
|
||||
}
|
||||
|
||||
func TestValidateRateLimit_InvalidKeyType(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
keyType string
|
||||
}{
|
||||
{"empty key type", ""},
|
||||
{"invalid key type", "invalid"},
|
||||
{"random string", "foobar"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.RateLimit.Enabled = true
|
||||
config.RateLimit.RequestsPerSecond = 100
|
||||
config.RateLimit.Burst = 200
|
||||
config.RateLimit.KeyType = tt.keyType
|
||||
|
||||
errors := config.validateRateLimit()
|
||||
|
||||
require.Len(t, errors, 1)
|
||||
assert.Equal(t, "RateLimit.KeyType", errors[0].Field)
|
||||
assert.Contains(t, errors[0].Message, "invalid key type")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateRateLimit_ValidKeyTypes(t *testing.T) {
|
||||
validKeyTypes := []string{"ip", "user", "token", "custom"}
|
||||
|
||||
for _, keyType := range validKeyTypes {
|
||||
t.Run(keyType, func(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.RateLimit.Enabled = true
|
||||
config.RateLimit.RequestsPerSecond = 100
|
||||
config.RateLimit.Burst = 200
|
||||
config.RateLimit.KeyType = keyType
|
||||
|
||||
errors := config.validateRateLimit()
|
||||
|
||||
assert.Empty(t, errors, "Should have no errors for valid key type: %s", keyType)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateRateLimit_MultipleErrors(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.RateLimit.Enabled = true
|
||||
config.RateLimit.RequestsPerSecond = 0 // Too low
|
||||
config.RateLimit.Burst = 50 // Will pass (0 < 50)
|
||||
config.RateLimit.KeyType = "invalid" // Invalid
|
||||
|
||||
errors := config.validateRateLimit()
|
||||
|
||||
// Should have 2 errors (rps and keyType)
|
||||
assert.Len(t, errors, 2)
|
||||
|
||||
// Check each error is present
|
||||
fields := make(map[string]bool)
|
||||
for _, err := range errors {
|
||||
fields[err.Field] = true
|
||||
}
|
||||
assert.True(t, fields["RateLimit.RequestsPerSecond"])
|
||||
assert.True(t, fields["RateLimit.KeyType"])
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// validateMetrics Tests
|
||||
// ============================================================================
|
||||
|
||||
func TestValidateMetrics_Disabled(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.Metrics.Enabled = false
|
||||
|
||||
errors := config.validateMetrics()
|
||||
|
||||
assert.Empty(t, errors, "Should have no errors when metrics are disabled")
|
||||
}
|
||||
|
||||
func TestValidateMetrics_ValidPrometheus(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.Metrics.Enabled = true
|
||||
config.Metrics.Provider = "prometheus"
|
||||
config.Metrics.Endpoint = "" // Prometheus doesn't require endpoint
|
||||
|
||||
errors := config.validateMetrics()
|
||||
|
||||
assert.Empty(t, errors, "Should have no errors for valid prometheus config")
|
||||
}
|
||||
|
||||
func TestValidateMetrics_ValidStatsd(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.Metrics.Enabled = true
|
||||
config.Metrics.Provider = "statsd"
|
||||
config.Metrics.Endpoint = "localhost:8125"
|
||||
|
||||
errors := config.validateMetrics()
|
||||
|
||||
assert.Empty(t, errors, "Should have no errors for valid statsd config")
|
||||
}
|
||||
|
||||
func TestValidateMetrics_ValidOTLP(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.Metrics.Enabled = true
|
||||
config.Metrics.Provider = "otlp"
|
||||
config.Metrics.Endpoint = "localhost:4317"
|
||||
|
||||
errors := config.validateMetrics()
|
||||
|
||||
assert.Empty(t, errors, "Should have no errors for valid otlp config")
|
||||
}
|
||||
|
||||
func TestValidateMetrics_InvalidProvider(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
provider string
|
||||
}{
|
||||
{"empty provider", ""},
|
||||
{"invalid provider", "invalid"},
|
||||
{"datadog", "datadog"},
|
||||
{"influx", "influx"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.Metrics.Enabled = true
|
||||
config.Metrics.Provider = tt.provider
|
||||
config.Metrics.Endpoint = "localhost:8080"
|
||||
|
||||
errors := config.validateMetrics()
|
||||
|
||||
require.Len(t, errors, 1)
|
||||
assert.Equal(t, "Metrics.Provider", errors[0].Field)
|
||||
assert.Contains(t, errors[0].Message, "invalid metrics provider")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateMetrics_StatsdMissingEndpoint(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.Metrics.Enabled = true
|
||||
config.Metrics.Provider = "statsd"
|
||||
config.Metrics.Endpoint = "" // Missing required endpoint
|
||||
|
||||
errors := config.validateMetrics()
|
||||
|
||||
require.Len(t, errors, 1)
|
||||
assert.Equal(t, "Metrics.Endpoint", errors[0].Field)
|
||||
assert.Contains(t, errors[0].Message, "endpoint is required for statsd provider")
|
||||
}
|
||||
|
||||
func TestValidateMetrics_OTLPMissingEndpoint(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.Metrics.Enabled = true
|
||||
config.Metrics.Provider = "otlp"
|
||||
config.Metrics.Endpoint = "" // Missing required endpoint
|
||||
|
||||
errors := config.validateMetrics()
|
||||
|
||||
require.Len(t, errors, 1)
|
||||
assert.Equal(t, "Metrics.Endpoint", errors[0].Field)
|
||||
assert.Contains(t, errors[0].Message, "endpoint is required for otlp provider")
|
||||
}
|
||||
|
||||
func TestValidateMetrics_MultipleErrors(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.Metrics.Enabled = true
|
||||
config.Metrics.Provider = "invalid" // Invalid provider
|
||||
config.Metrics.Endpoint = "" // Would be missing if provider was statsd/otlp
|
||||
|
||||
errors := config.validateMetrics()
|
||||
|
||||
// Should have at least 1 error for invalid provider
|
||||
assert.NotEmpty(t, errors)
|
||||
assert.Equal(t, "Metrics.Provider", errors[0].Field)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,456 @@
|
||||
# Configuration Reference
|
||||
|
||||
Complete reference for all Traefik OIDC middleware configuration options.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Required Parameters](#required-parameters)
|
||||
- [Optional Parameters](#optional-parameters)
|
||||
- [Security Options](#security-options)
|
||||
- [Session Management](#session-management)
|
||||
- [Access Control](#access-control)
|
||||
- [Headers Configuration](#headers-configuration)
|
||||
- [Security Headers](#security-headers)
|
||||
- [Scope Configuration](#scope-configuration)
|
||||
- [Advanced Options](#advanced-options)
|
||||
|
||||
---
|
||||
|
||||
## Required Parameters
|
||||
|
||||
| Parameter | Type | Description | Example |
|
||||
|-----------|------|-------------|---------|
|
||||
| `providerURL` | string | Base URL of the OIDC provider | `https://accounts.google.com` |
|
||||
| `clientID` | string | OAuth 2.0 client identifier | `1234567890.apps.googleusercontent.com` |
|
||||
| `clientSecret` | string | OAuth 2.0 client secret | `your-client-secret` |
|
||||
| `sessionEncryptionKey` | string | Key for encrypting session data (min 32 bytes) | `your-32-byte-encryption-key-here` |
|
||||
| `callbackURL` | string | Path where provider redirects after authentication | `/oauth2/callback` |
|
||||
|
||||
### Basic Configuration Example
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-auth
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://accounts.google.com
|
||||
clientID: your-client-id.apps.googleusercontent.com
|
||||
clientSecret: your-client-secret
|
||||
sessionEncryptionKey: your-32-byte-encryption-key-here
|
||||
callbackURL: /oauth2/callback
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Optional Parameters
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
|-----------|------|---------|-------------|
|
||||
| `logoutURL` | string | `callbackURL + "/logout"` | Path for logout requests |
|
||||
| `postLogoutRedirectURI` | string | `/` | Redirect URL after logout |
|
||||
| `logLevel` | string | `info` | Logging verbosity (`debug`, `info`, `error`) |
|
||||
| `forceHTTPS` | bool | `false` | Force HTTPS for redirect URIs |
|
||||
| `rateLimit` | int | `100` | Maximum requests per second |
|
||||
| `excludedURLs` | []string | none | Paths that bypass authentication |
|
||||
| `revocationURL` | string | auto-discovered | Token revocation endpoint |
|
||||
| `oidcEndSessionURL` | string | auto-discovered | Provider's end session endpoint |
|
||||
| `enablePKCE` | bool | `false` | Enable PKCE for authorization code flow |
|
||||
| `minimalHeaders` | bool | `false` | Reduce forwarded headers |
|
||||
|
||||
### TLS Termination at Load Balancer
|
||||
|
||||
If running Traefik behind a load balancer (AWS ALB, Google Cloud LB, Azure App Gateway) that terminates TLS:
|
||||
|
||||
```yaml
|
||||
forceHTTPS: true # Required for correct redirect URIs
|
||||
```
|
||||
|
||||
Without this setting, redirect URIs will use `http://` instead of `https://`, causing OAuth callback failures.
|
||||
|
||||
---
|
||||
|
||||
## Security Options
|
||||
|
||||
### Audience Validation
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
|-----------|------|---------|-------------|
|
||||
| `audience` | string | `clientID` | Expected audience for access token validation |
|
||||
| `strictAudienceValidation` | bool | `false` | Reject sessions with audience mismatch |
|
||||
| `allowOpaqueTokens` | bool | `false` | Enable opaque token support via RFC 7662 |
|
||||
| `requireTokenIntrospection` | bool | `false` | Require introspection for opaque tokens |
|
||||
|
||||
#### Production Security Configuration
|
||||
|
||||
```yaml
|
||||
audience: "https://my-api.example.com"
|
||||
strictAudienceValidation: true
|
||||
```
|
||||
|
||||
#### Opaque Token Support
|
||||
|
||||
```yaml
|
||||
allowOpaqueTokens: true
|
||||
requireTokenIntrospection: true
|
||||
strictAudienceValidation: true
|
||||
```
|
||||
|
||||
### Other Security Options
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
|-----------|------|---------|-------------|
|
||||
| `disableReplayDetection` | bool | `false` | Disable JTI-based replay attack detection |
|
||||
| `allowPrivateIPAddresses` | bool | `false` | Allow private IPs in provider URLs |
|
||||
|
||||
---
|
||||
|
||||
## Session Management
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
|-----------|------|---------|-------------|
|
||||
| `sessionMaxAge` | int | `86400` (24h) | Maximum session age in seconds |
|
||||
| `refreshGracePeriodSeconds` | int | `60` | Seconds before expiry to attempt refresh |
|
||||
| `cookieDomain` | string | auto-detected | Domain for session cookies |
|
||||
| `cookiePrefix` | string | `_oidc_raczylo_` | Prefix for cookie names |
|
||||
|
||||
### Multi-Subdomain Setup
|
||||
|
||||
```yaml
|
||||
cookieDomain: .example.com # Share cookies across subdomains
|
||||
```
|
||||
|
||||
### Multiple Middleware Instances
|
||||
|
||||
When running multiple middleware instances with different authorization requirements, use unique prefixes:
|
||||
|
||||
```yaml
|
||||
# User authentication middleware
|
||||
---
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-userauth
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
cookiePrefix: "_oidc_userauth_"
|
||||
sessionEncryptionKey: user-encryption-key-min-32-bytes
|
||||
# ... other config
|
||||
---
|
||||
# Admin authentication middleware
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-adminauth
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
cookiePrefix: "_oidc_adminauth_"
|
||||
sessionEncryptionKey: admin-encryption-key-min-32-bytes
|
||||
allowedUsers:
|
||||
- admin@example.com
|
||||
# ... other config
|
||||
```
|
||||
|
||||
### Extended Session Duration
|
||||
|
||||
```yaml
|
||||
sessionMaxAge: 604800 # 7 days
|
||||
# Common values:
|
||||
# 3600 - 1 hour (high security)
|
||||
# 86400 - 1 day (default)
|
||||
# 259200 - 3 days
|
||||
# 604800 - 7 days
|
||||
# 2592000 - 30 days
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Access Control
|
||||
|
||||
### User Restrictions
|
||||
|
||||
| Parameter | Type | Description |
|
||||
|-----------|------|-------------|
|
||||
| `allowedUserDomains` | []string | Restrict to specific email domains |
|
||||
| `allowedUsers` | []string | Specific email addresses allowed |
|
||||
| `allowedRolesAndGroups` | []string | Required roles or groups |
|
||||
| `roleClaimName` | string | JWT claim for roles (default: `roles`) |
|
||||
| `groupClaimName` | string | JWT claim for groups (default: `groups`) |
|
||||
| `userIdentifierClaim` | string | Claim for user ID (default: `email`) |
|
||||
|
||||
### Domain Restriction
|
||||
|
||||
```yaml
|
||||
allowedUserDomains:
|
||||
- company.com
|
||||
- subsidiary.com
|
||||
```
|
||||
|
||||
### Specific User Access
|
||||
|
||||
```yaml
|
||||
allowedUsers:
|
||||
- user@example.com
|
||||
- contractor@external.org
|
||||
```
|
||||
|
||||
### Role-Based Access Control
|
||||
|
||||
```yaml
|
||||
allowedRolesAndGroups:
|
||||
- admin
|
||||
- developer
|
||||
roleClaimName: "https://myapp.com/roles" # For namespaced claims (Auth0)
|
||||
```
|
||||
|
||||
### Access Control Logic
|
||||
|
||||
- If only `allowedUsers` is set: Only specified emails can access
|
||||
- If only `allowedUserDomains` is set: Only specified domains can access
|
||||
- If both are set: Access granted if email is in `allowedUsers` OR domain is in `allowedUserDomains`
|
||||
- If neither is set: Any authenticated user can access
|
||||
|
||||
### Users Without Email (Azure AD)
|
||||
|
||||
For Azure AD service accounts or users without email:
|
||||
|
||||
```yaml
|
||||
userIdentifierClaim: sub # Options: sub, oid, upn, preferred_username
|
||||
allowedUsers:
|
||||
- "abc12345-6789-0abc-def0-123456789abc" # User object ID
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Headers Configuration
|
||||
|
||||
### Default Headers
|
||||
|
||||
The middleware sets these headers for downstream services:
|
||||
|
||||
| Header | Description |
|
||||
|--------|-------------|
|
||||
| `X-Forwarded-User` | User's email address |
|
||||
| `X-User-Groups` | Comma-separated user groups |
|
||||
| `X-User-Roles` | Comma-separated user roles |
|
||||
| `X-Auth-Request-Redirect` | Original request URI |
|
||||
| `X-Auth-Request-User` | User's email address |
|
||||
| `X-Auth-Request-Token` | User's ID token |
|
||||
|
||||
### Minimal Headers Mode
|
||||
|
||||
For "431 Request Header Fields Too Large" errors:
|
||||
|
||||
```yaml
|
||||
minimalHeaders: true # Only forwards X-Forwarded-User
|
||||
```
|
||||
|
||||
### Custom Templated Headers
|
||||
|
||||
```yaml
|
||||
headers:
|
||||
- name: "X-User-Email"
|
||||
value: "{{{{.Claims.email}}}}"
|
||||
- name: "X-User-ID"
|
||||
value: "{{{{.Claims.sub}}}}"
|
||||
- name: "Authorization"
|
||||
value: "Bearer {{{{.AccessToken}}}}"
|
||||
- name: "X-User-Roles"
|
||||
value: "{{{{range $i, $e := .Claims.roles}}}}{{{{if $i}}}},{{{{end}}}}{{{{$e}}}}{{{{end}}}}"
|
||||
```
|
||||
|
||||
**Template Variables:**
|
||||
- `{{.Claims.field}}` - ID token claims
|
||||
- `{{.AccessToken}}` - Raw access token
|
||||
- `{{.IdToken}}` - Raw ID token
|
||||
- `{{.RefreshToken}}` - Raw refresh token
|
||||
|
||||
**Important:** Use double curly braces (`{{{{` and `}}}}`) to escape templates in YAML.
|
||||
|
||||
---
|
||||
|
||||
## Security Headers
|
||||
|
||||
### Security Profiles
|
||||
|
||||
| Profile | Use Case | Security Level |
|
||||
|---------|----------|----------------|
|
||||
| `default` | Standard web apps | High |
|
||||
| `strict` | Maximum security | Very High |
|
||||
| `development` | Local development | Medium |
|
||||
| `api` | API endpoints | High |
|
||||
| `custom` | Custom requirements | Configurable |
|
||||
|
||||
### Basic Configuration
|
||||
|
||||
```yaml
|
||||
securityHeaders:
|
||||
enabled: true
|
||||
profile: "default"
|
||||
```
|
||||
|
||||
### API with CORS
|
||||
|
||||
```yaml
|
||||
securityHeaders:
|
||||
enabled: true
|
||||
profile: "api"
|
||||
corsEnabled: true
|
||||
corsAllowedOrigins:
|
||||
- "https://your-frontend.com"
|
||||
- "https://*.example.com"
|
||||
corsAllowCredentials: true
|
||||
```
|
||||
|
||||
### Custom Security Configuration
|
||||
|
||||
```yaml
|
||||
securityHeaders:
|
||||
enabled: true
|
||||
profile: "custom"
|
||||
|
||||
# Content Security Policy
|
||||
contentSecurityPolicy: "default-src 'self'; script-src 'self'"
|
||||
|
||||
# HSTS
|
||||
strictTransportSecurity: true
|
||||
strictTransportSecurityMaxAge: 31536000
|
||||
strictTransportSecuritySubdomains: true
|
||||
strictTransportSecurityPreload: true
|
||||
|
||||
# Frame and Content Protection
|
||||
frameOptions: "DENY"
|
||||
contentTypeOptions: "nosniff"
|
||||
xssProtection: "1; mode=block"
|
||||
referrerPolicy: "strict-origin-when-cross-origin"
|
||||
|
||||
# CORS
|
||||
corsEnabled: true
|
||||
corsAllowedOrigins: ["https://app.example.com"]
|
||||
corsAllowedMethods: ["GET", "POST", "PUT", "DELETE", "OPTIONS"]
|
||||
corsAllowedHeaders: ["Authorization", "Content-Type"]
|
||||
corsAllowCredentials: true
|
||||
corsMaxAge: 86400
|
||||
|
||||
# Custom Headers
|
||||
customHeaders:
|
||||
X-Custom-Header: "value"
|
||||
|
||||
# Server Identification
|
||||
disableServerHeader: true
|
||||
disablePoweredByHeader: true
|
||||
```
|
||||
|
||||
### CORS Origin Patterns
|
||||
|
||||
```yaml
|
||||
corsAllowedOrigins:
|
||||
- "https://example.com" # Exact match
|
||||
- "https://*.example.com" # Subdomain wildcard
|
||||
- "http://localhost:*" # Port wildcard (development)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Scope Configuration
|
||||
|
||||
### Default Behavior (Append Mode)
|
||||
|
||||
```yaml
|
||||
scopes:
|
||||
- roles
|
||||
- custom_scope
|
||||
# Result: ["openid", "profile", "email", "roles", "custom_scope"]
|
||||
```
|
||||
|
||||
### Override Mode
|
||||
|
||||
```yaml
|
||||
overrideScopes: true
|
||||
scopes:
|
||||
- openid
|
||||
- profile
|
||||
- custom_scope
|
||||
# Result: ["openid", "profile", "custom_scope"]
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Advanced Options
|
||||
|
||||
### Dynamic Client Registration (RFC 7591)
|
||||
|
||||
```yaml
|
||||
dynamicClientRegistration:
|
||||
enabled: true
|
||||
initialAccessToken: "your-token" # Optional
|
||||
persistCredentials: true
|
||||
credentialsFile: "/tmp/oidc-credentials.json"
|
||||
clientMetadata:
|
||||
redirect_uris:
|
||||
- "https://your-app.com/oauth2/callback"
|
||||
client_name: "My Application"
|
||||
application_type: "web"
|
||||
grant_types:
|
||||
- "authorization_code"
|
||||
- "refresh_token"
|
||||
```
|
||||
|
||||
### Multi-Replica Deployment
|
||||
|
||||
Without Redis, disable replay detection:
|
||||
|
||||
```yaml
|
||||
disableReplayDetection: true
|
||||
```
|
||||
|
||||
With Redis (recommended):
|
||||
|
||||
```yaml
|
||||
redis:
|
||||
enabled: true
|
||||
address: "redis:6379"
|
||||
cacheMode: "hybrid"
|
||||
```
|
||||
|
||||
See [REDIS.md](REDIS.md) for complete Redis configuration.
|
||||
|
||||
---
|
||||
|
||||
## Kubernetes Secrets
|
||||
|
||||
Reference secrets instead of hardcoding sensitive values:
|
||||
|
||||
```yaml
|
||||
providerURL: urn:k8s:secret:oidc-secret:ISSUER
|
||||
clientID: urn:k8s:secret:oidc-secret:CLIENT_ID
|
||||
clientSecret: urn:k8s:secret:oidc-secret:SECRET
|
||||
```
|
||||
|
||||
Create the secret:
|
||||
|
||||
```bash
|
||||
kubectl create secret generic oidc-secret \
|
||||
--from-literal=ISSUER=https://accounts.google.com \
|
||||
--from-literal=CLIENT_ID=your-client-id \
|
||||
--from-literal=SECRET=your-client-secret \
|
||||
-n traefik
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Environment Variable Naming
|
||||
|
||||
**Important:** Avoid using "API" as a substring in environment variable names when using `${VAR}` syntax in Traefik configuration. Traefik reserves `TRAEFIK_API_*` variables and the substring may cause conflicts.
|
||||
|
||||
```yaml
|
||||
# Bad - may cause issues
|
||||
sessionEncryptionKey: ${OIDC_SECRET_API}
|
||||
|
||||
# Good
|
||||
sessionEncryptionKey: ${OIDC_SECRET_SVC}
|
||||
```
|
||||
@@ -0,0 +1,455 @@
|
||||
# Development Guide
|
||||
|
||||
Guide for local development, testing, and contributing to the Traefik OIDC middleware.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Prerequisites](#prerequisites)
|
||||
- [Local Development Setup](#local-development-setup)
|
||||
- [Running Tests](#running-tests)
|
||||
- [Test Categories](#test-categories)
|
||||
- [CI/CD Pipeline](#cicd-pipeline)
|
||||
- [Code Quality](#code-quality)
|
||||
- [Contributing](#contributing)
|
||||
|
||||
---
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- **Go 1.23+** for plugin compilation
|
||||
- **Docker & Docker Compose** for local testing
|
||||
- **OIDC Provider** credentials (Google, Azure, etc.)
|
||||
|
||||
### Required Development Tools
|
||||
|
||||
```bash
|
||||
# golangci-lint (comprehensive linting)
|
||||
go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
|
||||
|
||||
# staticcheck (static analysis)
|
||||
go install honnef.co/go/tools/cmd/staticcheck@latest
|
||||
|
||||
# gosec (security scanning)
|
||||
go install github.com/securego/gosec/v2/cmd/gosec@latest
|
||||
|
||||
# govulncheck (vulnerability scanning)
|
||||
go install golang.org/x/vuln/cmd/govulncheck@latest
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Local Development Setup
|
||||
|
||||
### Docker Compose Environment
|
||||
|
||||
The repository includes a Docker Compose setup for testing the plugin locally.
|
||||
|
||||
#### 1. Host Configuration
|
||||
|
||||
Add to `/etc/hosts`:
|
||||
|
||||
```bash
|
||||
127.0.0.1 hello.localhost
|
||||
127.0.0.1 traefik.localhost
|
||||
```
|
||||
|
||||
#### 2. Plugin Configuration
|
||||
|
||||
The plugin is loaded using Traefik's **local plugins mode**:
|
||||
|
||||
- Plugin source: Parent directory (`../`)
|
||||
- Mount path: `/plugins-local/src/github.com/lukaszraczylo/traefikoidc`
|
||||
- Configuration: `experimental.localPlugins` in `traefik.yml`
|
||||
|
||||
#### 3. OIDC Provider Setup
|
||||
|
||||
Edit `docker/dynamic.yml` with your provider details:
|
||||
|
||||
**Google:**
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
oidc-auth:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: "https://accounts.google.com"
|
||||
clientID: "your-client-id.apps.googleusercontent.com"
|
||||
clientSecret: "your-google-client-secret"
|
||||
sessionEncryptionKey: "your-32-character-encryption-key"
|
||||
callbackURL: "/oauth2/callback"
|
||||
logoutURL: "/oauth2/logout"
|
||||
scopes:
|
||||
- "openid"
|
||||
- "email"
|
||||
- "profile"
|
||||
```
|
||||
|
||||
**Azure AD:**
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
oidc-auth:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: "https://login.microsoftonline.com/your-tenant-id/v2.0"
|
||||
clientID: "your-azure-client-id"
|
||||
clientSecret: "your-azure-client-secret"
|
||||
sessionEncryptionKey: "your-32-character-encryption-key"
|
||||
callbackURL: "/oauth2/callback"
|
||||
scopes:
|
||||
- "openid"
|
||||
- "email"
|
||||
- "profile"
|
||||
```
|
||||
|
||||
#### 4. Start Environment
|
||||
|
||||
```bash
|
||||
cd docker
|
||||
docker-compose up -d
|
||||
```
|
||||
|
||||
#### 5. Test Plugin
|
||||
|
||||
- **Protected App**: http://hello.localhost (redirects to OIDC)
|
||||
- **Traefik Dashboard**: http://traefik.localhost:8080
|
||||
|
||||
### Development Workflow
|
||||
|
||||
1. **Edit plugin code** in the project root
|
||||
2. **Build and test** (optional syntax check):
|
||||
```bash
|
||||
go mod tidy
|
||||
go build .
|
||||
go test ./...
|
||||
```
|
||||
3. **Restart Traefik** to reload plugin:
|
||||
```bash
|
||||
docker-compose restart traefik
|
||||
```
|
||||
4. **Test changes** at http://hello.localhost
|
||||
|
||||
### Debugging
|
||||
|
||||
**View plugin logs:**
|
||||
```bash
|
||||
docker-compose logs -f traefik | grep traefikoidc
|
||||
```
|
||||
|
||||
**Check plugin loading:**
|
||||
```bash
|
||||
docker-compose logs traefik | grep -i plugin
|
||||
```
|
||||
|
||||
**Verify plugin directory:**
|
||||
```bash
|
||||
docker-compose exec traefik ls -la /plugins-local/src/github.com/lukaszraczylo/traefikoidc/
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Running Tests
|
||||
|
||||
### Quick Start
|
||||
|
||||
```bash
|
||||
# Fast development testing (< 30 seconds)
|
||||
go test ./... -short
|
||||
|
||||
# Standard tests with race detector
|
||||
go test -race -timeout=15m ./...
|
||||
|
||||
# With coverage report
|
||||
go test -coverprofile=coverage.out ./...
|
||||
go tool cover -func=coverage.out
|
||||
```
|
||||
|
||||
### Test Modes
|
||||
|
||||
| Mode | Command | Duration | Use Case |
|
||||
|------|---------|----------|----------|
|
||||
| Quick | `go test ./... -short` | < 30s | During development |
|
||||
| Extended | `RUN_EXTENDED_TESTS=1 go test ./...` | 2-5 min | Before commits |
|
||||
| Long | `RUN_LONG_TESTS=1 go test ./...` | 5-15 min | Release validation |
|
||||
| Stress | `RUN_STRESS_TESTS=1 go test ./...` | 10-30 min | Performance testing |
|
||||
|
||||
### Environment Variables
|
||||
|
||||
```bash
|
||||
# Enable specific test types
|
||||
export RUN_EXTENDED_TESTS=1
|
||||
export RUN_LONG_TESTS=1
|
||||
export RUN_STRESS_TESTS=1
|
||||
|
||||
# Disable specific features
|
||||
export DISABLE_LEAK_DETECTION=1
|
||||
|
||||
# Customize test parameters
|
||||
export TEST_MAX_CONCURRENCY=10
|
||||
export TEST_MAX_ITERATIONS=50
|
||||
export TEST_MEMORY_THRESHOLD_MB=25.5
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Test Categories
|
||||
|
||||
### Quick Tests (Default)
|
||||
|
||||
- Basic functionality verification
|
||||
- Limited iterations (1-3)
|
||||
- Small data sets
|
||||
- Essential memory leak checks
|
||||
|
||||
**Configuration:**
|
||||
- Max Iterations: 3
|
||||
- Max Concurrency: 5
|
||||
- Memory Threshold: 2.0 MB
|
||||
- Timeout: 10 seconds
|
||||
|
||||
### Extended Tests
|
||||
|
||||
- Comprehensive testing before commits
|
||||
- More iterations (5-10)
|
||||
- Enhanced memory leak detection
|
||||
|
||||
**Configuration:**
|
||||
- Max Iterations: 10
|
||||
- Max Concurrency: 20
|
||||
- Memory Threshold: 10.0 MB
|
||||
- Timeout: 30 seconds
|
||||
|
||||
### Long Tests
|
||||
|
||||
- Performance validation
|
||||
- High iteration counts (50-100)
|
||||
- Large data sets
|
||||
|
||||
**Configuration:**
|
||||
- Max Iterations: 100
|
||||
- Max Concurrency: 50
|
||||
- Memory Threshold: 50.0 MB
|
||||
- Timeout: 60 seconds
|
||||
|
||||
### Stress Tests
|
||||
|
||||
- Maximum load testing
|
||||
- Edge case validation
|
||||
- Extreme parameters
|
||||
|
||||
**Configuration:**
|
||||
- Max Iterations: 500
|
||||
- Max Concurrency: 100
|
||||
- Memory Threshold: 100.0 MB
|
||||
- Timeout: 120 seconds
|
||||
|
||||
### Running Specific Test Suites
|
||||
|
||||
```bash
|
||||
# Memory leak tests
|
||||
go test -v -run='.*Leak.*' ./...
|
||||
|
||||
# Integration tests
|
||||
go test -v -run='.*Integration.*' ./...
|
||||
|
||||
# Regression tests
|
||||
go test -v -run='.*Regression.*' ./...
|
||||
|
||||
# Provider-specific tests
|
||||
go test -v -run='.*Azure.*' ./...
|
||||
go test -v -run='.*Google.*' ./...
|
||||
```
|
||||
|
||||
### Benchmarks
|
||||
|
||||
```bash
|
||||
# Quick benchmarks
|
||||
go test -bench=. -short
|
||||
|
||||
# Extended benchmarks
|
||||
RUN_EXTENDED_TESTS=1 go test -bench=.
|
||||
|
||||
# Memory profiling
|
||||
go test -bench=. -memprofile=mem.prof
|
||||
go tool pprof mem.prof
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## CI/CD Pipeline
|
||||
|
||||
The repository uses GitHub Actions for comprehensive validation with 20+ parallel checks.
|
||||
|
||||
### Triggered On
|
||||
|
||||
- Pull requests to `main` branch
|
||||
- Pushes to `main` branch
|
||||
|
||||
### Parallel Jobs
|
||||
|
||||
#### Code Quality (3 checks)
|
||||
- **Format & Basic Checks** - gofmt, go vet, go mod
|
||||
- **golangci-lint** - 30+ linters
|
||||
- **Staticcheck** - Advanced static analysis
|
||||
|
||||
#### Security (3 checks)
|
||||
- **Gosec** - Security vulnerability scanning
|
||||
- **Govulncheck** - Go vulnerability database
|
||||
- **CodeQL** - GitHub's semantic code analysis
|
||||
|
||||
#### Testing (9 suites)
|
||||
- Race Detector
|
||||
- Coverage (75% threshold)
|
||||
- Memory Leaks
|
||||
- Integration Tests
|
||||
- Regression Tests
|
||||
- Security Edge Cases
|
||||
- Session Tests
|
||||
- Token Tests
|
||||
- CSRF Tests
|
||||
|
||||
#### Provider Testing (9 providers)
|
||||
Tests run in parallel for:
|
||||
- Google
|
||||
- Azure AD
|
||||
- Auth0
|
||||
- Okta
|
||||
- Keycloak
|
||||
- AWS Cognito
|
||||
- GitLab
|
||||
- GitHub
|
||||
- Generic OIDC
|
||||
|
||||
#### Performance & Build (3 checks)
|
||||
- Benchmarks
|
||||
- Multi-platform Build (linux/darwin x amd64/arm64)
|
||||
- Go Version Compatibility (Go 1.23 & 1.24)
|
||||
|
||||
### Quality Gates
|
||||
|
||||
All PRs must pass:
|
||||
- All parallel checks
|
||||
- 75% test coverage minimum
|
||||
- Zero security vulnerabilities
|
||||
- No race conditions
|
||||
- No memory leaks
|
||||
- All providers tested
|
||||
- Builds on all platforms
|
||||
|
||||
---
|
||||
|
||||
## Code Quality
|
||||
|
||||
### Pre-Commit Checklist
|
||||
|
||||
```bash
|
||||
# Run before every commit
|
||||
gofmt -s -w . && \
|
||||
go mod tidy && \
|
||||
golangci-lint run && \
|
||||
go test -race -short ./... && \
|
||||
echo "Ready to commit!"
|
||||
```
|
||||
|
||||
### Local Validation
|
||||
|
||||
```bash
|
||||
# Format code
|
||||
gofmt -s -w .
|
||||
|
||||
# Run linter
|
||||
golangci-lint run
|
||||
|
||||
# Static analysis
|
||||
staticcheck ./...
|
||||
|
||||
# Security scan
|
||||
gosec ./...
|
||||
|
||||
# Vulnerability check
|
||||
govulncheck ./...
|
||||
|
||||
# Tests with race detector
|
||||
go test -race -timeout=15m -count=1 ./...
|
||||
|
||||
# Coverage report
|
||||
go test -coverprofile=coverage.out ./...
|
||||
go tool cover -func=coverage.out
|
||||
|
||||
# View coverage in browser
|
||||
go tool cover -html=coverage.out
|
||||
```
|
||||
|
||||
### Troubleshooting
|
||||
|
||||
**Coverage Below Threshold:**
|
||||
```bash
|
||||
go test -coverprofile=coverage.out ./...
|
||||
go tool cover -html=coverage.out # See uncovered lines
|
||||
```
|
||||
|
||||
**Race Condition Found:**
|
||||
```bash
|
||||
go test -race -v -run=TestName ./...
|
||||
```
|
||||
|
||||
**Linter Errors:**
|
||||
```bash
|
||||
golangci-lint run -v
|
||||
golangci-lint run --fix # Auto-fix some issues
|
||||
```
|
||||
|
||||
**Provider Test Fails:**
|
||||
```bash
|
||||
go test -v -run='.*Azure.*' ./...
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Contributing
|
||||
|
||||
### Development Guidelines
|
||||
|
||||
1. **Memory Management**: Ensure all goroutines can be cancelled and resources are bounded
|
||||
2. **Testing**: Add tests for new features, including memory leak tests where appropriate
|
||||
3. **Race Conditions**: Run tests with `-race` flag to detect race conditions
|
||||
4. **Documentation**: Update README and configuration files for new options
|
||||
|
||||
### Pull Request Template
|
||||
|
||||
PRs should include:
|
||||
- Description of changes
|
||||
- Type of change (bug fix, feature, breaking change, etc.)
|
||||
- Related issues
|
||||
- Provider impact (which providers are affected)
|
||||
- Testing performed
|
||||
- Security considerations
|
||||
- Performance impact
|
||||
- Breaking changes (if any)
|
||||
|
||||
### Checklist
|
||||
|
||||
Before submitting:
|
||||
- [ ] Code follows project style
|
||||
- [ ] Self-review completed
|
||||
- [ ] Tests added for new functionality
|
||||
- [ ] All tests pass locally
|
||||
- [ ] Documentation updated
|
||||
- [ ] No new warnings generated
|
||||
|
||||
### Code Owners
|
||||
|
||||
The repository uses CODEOWNERS for automatic PR reviewer assignment based on file paths.
|
||||
|
||||
### Dependabot
|
||||
|
||||
Automated dependency updates run weekly (Mondays 9 AM) with security updates prioritized.
|
||||
|
||||
---
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- [golangci-lint Rules](.golangci.yml)
|
||||
- [PR Template](.github/PULL_REQUEST_TEMPLATE.md)
|
||||
- [Workflow Documentation](.github/workflows/README.md)
|
||||
- [GitHub Actions Documentation](https://docs.github.com/en/actions)
|
||||
@@ -0,0 +1,580 @@
|
||||
# OIDC Provider Configuration Guide
|
||||
|
||||
Configuration reference for each supported OIDC provider.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Provider Support Matrix](#provider-support-matrix)
|
||||
- [Google](#google)
|
||||
- [Microsoft Azure AD](#microsoft-azure-ad)
|
||||
- [Auth0](#auth0)
|
||||
- [Okta](#okta)
|
||||
- [Keycloak](#keycloak)
|
||||
- [AWS Cognito](#aws-cognito)
|
||||
- [GitLab](#gitlab)
|
||||
- [GitHub](#github)
|
||||
- [Generic OIDC](#generic-oidc)
|
||||
- [Automatic Scope Filtering](#automatic-scope-filtering)
|
||||
|
||||
---
|
||||
|
||||
## Provider Support Matrix
|
||||
|
||||
| Provider | OIDC Support | Refresh Tokens | Auto-Detection | ID Tokens |
|
||||
|----------|-------------|----------------|----------------|-----------|
|
||||
| Google | Full | Yes | `accounts.google.com` | Yes |
|
||||
| Azure AD | Full | Yes | `login.microsoftonline.com` | Yes |
|
||||
| Auth0 | Full | Yes | `*.auth0.com` | Yes |
|
||||
| Okta | Full | Yes | `*.okta.com` | Yes |
|
||||
| Keycloak | Full | Yes | `/auth/realms/` path | Yes |
|
||||
| AWS Cognito | Full | Yes | `cognito-idp.*.amazonaws.com` | Yes |
|
||||
| GitLab | Full | Yes | `gitlab.com` | Yes |
|
||||
| GitHub | OAuth 2.0 Only | No | `github.com` | No |
|
||||
| Generic | Full | Yes | Any OIDC endpoint | Yes |
|
||||
|
||||
---
|
||||
|
||||
## Google
|
||||
|
||||
### Provider URL
|
||||
|
||||
```yaml
|
||||
providerURL: "https://accounts.google.com"
|
||||
```
|
||||
|
||||
### Configuration
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-google
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: "https://accounts.google.com"
|
||||
clientID: "your-id.apps.googleusercontent.com"
|
||||
clientSecret: "your-client-secret"
|
||||
callbackURL: "/oauth2/callback"
|
||||
sessionEncryptionKey: "your-32-char-encryption-key-here"
|
||||
scopes:
|
||||
- openid
|
||||
- email
|
||||
- profile
|
||||
allowedUserDomains:
|
||||
- "your-gsuite-domain.com" # Optional: Workspace restriction
|
||||
forceHttps: true
|
||||
enablePkce: true
|
||||
```
|
||||
|
||||
### Google-Specific Features
|
||||
|
||||
- **Automatic offline access**: Middleware adds `access_type=offline` and `prompt=consent`
|
||||
- **Scope filtering**: Automatically removes unsupported `offline_access` scope
|
||||
- **Workspace domains**: Restrict to specific Google Workspace domains via `hd` claim
|
||||
|
||||
### Google Cloud Console Setup
|
||||
|
||||
1. Go to [Google Cloud Console](https://console.cloud.google.com/)
|
||||
2. Create or select a project
|
||||
3. Navigate to APIs & Services > Credentials
|
||||
4. Create OAuth 2.0 Client ID (Web application)
|
||||
5. Add authorized redirect URI: `https://your-domain.com/oauth2/callback`
|
||||
6. Configure OAuth consent screen (must be "Published" for production)
|
||||
|
||||
---
|
||||
|
||||
## Microsoft Azure AD
|
||||
|
||||
### Provider URL
|
||||
|
||||
```yaml
|
||||
# Single tenant
|
||||
providerURL: "https://login.microsoftonline.com/{tenant-id}/v2.0"
|
||||
|
||||
# Multi-tenant
|
||||
providerURL: "https://login.microsoftonline.com/common/v2.0"
|
||||
```
|
||||
|
||||
### Basic Configuration
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-azure
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: "https://login.microsoftonline.com/common/v2.0"
|
||||
clientID: "your-azure-client-id"
|
||||
clientSecret: "your-azure-client-secret"
|
||||
callbackURL: "/oauth2/callback"
|
||||
sessionEncryptionKey: "your-32-char-encryption-key-here"
|
||||
scopes:
|
||||
- openid
|
||||
- profile
|
||||
- email
|
||||
- offline_access
|
||||
allowedRolesAndGroups:
|
||||
- "App.Users"
|
||||
- "Admin.Group"
|
||||
forceHttps: true
|
||||
```
|
||||
|
||||
### With Application ID URI (API Access)
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-azure-api
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: "https://login.microsoftonline.com/common/v2.0"
|
||||
clientID: "your-azure-client-id"
|
||||
clientSecret: "your-azure-client-secret"
|
||||
audience: "api://your-azure-client-id" # Application ID URI
|
||||
callbackURL: "/oauth2/callback"
|
||||
sessionEncryptionKey: "your-32-char-encryption-key-here"
|
||||
forceHttps: true
|
||||
```
|
||||
|
||||
### Users Without Email
|
||||
|
||||
```yaml
|
||||
userIdentifierClaim: sub # Options: sub, oid, upn, preferred_username
|
||||
allowedUsers:
|
||||
- "user-object-id-1"
|
||||
- "user-object-id-2"
|
||||
```
|
||||
|
||||
### Azure AD Setup
|
||||
|
||||
1. Go to [Azure Portal](https://portal.azure.com/)
|
||||
2. Navigate to Azure Active Directory > App registrations
|
||||
3. Create new registration
|
||||
4. Add redirect URI: `https://your-domain.com/oauth2/callback`
|
||||
5. Create client secret in Certificates & secrets
|
||||
6. Configure Token Configuration for group claims
|
||||
|
||||
---
|
||||
|
||||
## Auth0
|
||||
|
||||
### Provider URL
|
||||
|
||||
```yaml
|
||||
providerURL: "https://your-domain.auth0.com"
|
||||
```
|
||||
|
||||
### Basic Configuration
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-auth0
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: "https://your-domain.auth0.com"
|
||||
clientID: "your-auth0-client-id"
|
||||
clientSecret: "your-auth0-client-secret"
|
||||
callbackURL: "/oauth2/callback"
|
||||
sessionEncryptionKey: "your-32-char-encryption-key-here"
|
||||
scopes:
|
||||
- openid
|
||||
- profile
|
||||
- email
|
||||
- offline_access
|
||||
postLogoutRedirectUri: "https://your-app.com"
|
||||
forceHttps: true
|
||||
enablePkce: true
|
||||
```
|
||||
|
||||
### With Custom API Audience
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-auth0-api
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: "https://your-domain.auth0.com"
|
||||
clientID: "your-auth0-client-id"
|
||||
clientSecret: "your-auth0-client-secret"
|
||||
audience: "https://api.your-domain.com" # API identifier
|
||||
strictAudienceValidation: true
|
||||
callbackURL: "/oauth2/callback"
|
||||
sessionEncryptionKey: "your-32-char-encryption-key-here"
|
||||
roleClaimName: "https://your-app.com/roles" # Namespaced claim
|
||||
groupClaimName: "https://your-app.com/groups"
|
||||
allowedRolesAndGroups:
|
||||
- admin
|
||||
- editor
|
||||
```
|
||||
|
||||
### Auth0 Action for Custom Claims
|
||||
|
||||
```javascript
|
||||
exports.onExecutePostLogin = async (event, api) => {
|
||||
const namespace = 'https://your-app.com/';
|
||||
if (event.authorization) {
|
||||
api.idToken.setCustomClaim(namespace + 'roles', event.authorization.roles);
|
||||
api.idToken.setCustomClaim('email', event.user.email);
|
||||
}
|
||||
};
|
||||
```
|
||||
|
||||
### Auth0 Setup
|
||||
|
||||
1. Go to [Auth0 Dashboard](https://manage.auth0.com/)
|
||||
2. Create Regular Web Application
|
||||
3. Configure Allowed Callback URLs: `https://your-domain.com/oauth2/callback`
|
||||
4. Configure Allowed Logout URLs: `https://your-domain.com/oauth2/logout`
|
||||
5. Enable OIDC Conformant in Advanced Settings
|
||||
6. Create API in APIs section for custom audiences
|
||||
|
||||
See [AUTH0_AUDIENCE_GUIDE.md](AUTH0_AUDIENCE_GUIDE.md) for detailed audience configuration.
|
||||
|
||||
---
|
||||
|
||||
## Okta
|
||||
|
||||
### Provider URL
|
||||
|
||||
```yaml
|
||||
providerURL: "https://your-domain.okta.com"
|
||||
# Or with custom authorization server:
|
||||
providerURL: "https://your-domain.okta.com/oauth2/default"
|
||||
```
|
||||
|
||||
### Configuration
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-okta
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: "https://your-domain.okta.com"
|
||||
clientID: "your-okta-client-id"
|
||||
clientSecret: "your-okta-client-secret"
|
||||
callbackURL: "/oauth2/callback"
|
||||
sessionEncryptionKey: "your-32-char-encryption-key-here"
|
||||
scopes:
|
||||
- openid
|
||||
- profile
|
||||
- email
|
||||
- groups
|
||||
- offline_access
|
||||
allowedRolesAndGroups:
|
||||
- admin
|
||||
- "Everyone"
|
||||
forceHttps: true
|
||||
enablePkce: true
|
||||
```
|
||||
|
||||
### Okta Setup
|
||||
|
||||
1. Access Okta Admin Console
|
||||
2. Go to Applications > Create App Integration
|
||||
3. Select OIDC - OpenID Connect > Web Application
|
||||
4. Configure Sign-in redirect URIs: `https://your-domain.com/oauth2/callback`
|
||||
5. Configure Sign-out redirect URIs: `https://your-domain.com/oauth2/logout`
|
||||
6. Enable Authorization Code and Refresh Token grant types
|
||||
7. Configure Groups claim in authorization server
|
||||
|
||||
---
|
||||
|
||||
## Keycloak
|
||||
|
||||
### Provider URL
|
||||
|
||||
```yaml
|
||||
providerURL: "https://keycloak.your-domain.com/realms/{realm-name}"
|
||||
```
|
||||
|
||||
### Configuration
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-keycloak
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: "https://keycloak.company.com/realms/your-realm"
|
||||
clientID: "your-keycloak-client-id"
|
||||
clientSecret: "your-keycloak-client-secret"
|
||||
callbackURL: "/oauth2/callback"
|
||||
sessionEncryptionKey: "your-32-char-encryption-key-here"
|
||||
scopes:
|
||||
- openid
|
||||
- profile
|
||||
- email
|
||||
- roles
|
||||
- groups
|
||||
- offline_access
|
||||
allowedRolesAndGroups:
|
||||
- admin
|
||||
- editor
|
||||
forceHttps: true
|
||||
enablePkce: true
|
||||
```
|
||||
|
||||
### Internal Network Deployment
|
||||
|
||||
For private IP addresses (Docker networks, Kubernetes):
|
||||
|
||||
```yaml
|
||||
providerURL: "https://192.168.1.100:8443/realms/your-realm"
|
||||
allowPrivateIPAddresses: true # Required for private IPs
|
||||
```
|
||||
|
||||
### Keycloak Client Setup
|
||||
|
||||
1. Access Keycloak Admin Console
|
||||
2. Select your realm
|
||||
3. Go to Clients > Create client
|
||||
4. Set Client Protocol: openid-connect
|
||||
5. Set Access Type: confidential
|
||||
6. Add Valid Redirect URIs: `https://your-domain.com/oauth2/callback`
|
||||
7. Generate client secret in Credentials tab
|
||||
8. Configure mappers to add claims to ID Token:
|
||||
- Email: User Property mapper with "Add to ID token" enabled
|
||||
- Roles: User Client Role mapper with "Add to ID token" enabled
|
||||
- Groups: Group Membership mapper with "Add to ID token" enabled
|
||||
|
||||
---
|
||||
|
||||
## AWS Cognito
|
||||
|
||||
### Provider URL
|
||||
|
||||
```yaml
|
||||
providerURL: "https://cognito-idp.{region}.amazonaws.com/{user-pool-id}"
|
||||
```
|
||||
|
||||
### Configuration
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-cognito
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: "https://cognito-idp.us-east-1.amazonaws.com/us-east-1_ABCDEF123"
|
||||
clientID: "your-cognito-client-id"
|
||||
clientSecret: "your-cognito-client-secret"
|
||||
callbackURL: "/oauth2/callback"
|
||||
sessionEncryptionKey: "your-32-char-encryption-key-here"
|
||||
scopes:
|
||||
- openid
|
||||
- profile
|
||||
- email
|
||||
- aws.cognito.signin.user.admin
|
||||
allowedRolesAndGroups:
|
||||
- admin
|
||||
- users
|
||||
forceHttps: true
|
||||
```
|
||||
|
||||
### AWS Cognito Setup
|
||||
|
||||
1. Create Cognito User Pool
|
||||
2. Create App Client with OIDC scopes
|
||||
3. Configure App Client settings:
|
||||
- Callback URLs: `https://your-domain.com/oauth2/callback`
|
||||
- Sign out URLs: `https://your-domain.com/oauth2/logout`
|
||||
- OAuth flows: Authorization code grant
|
||||
4. Configure hosted UI domain (optional)
|
||||
5. Set up groups for role-based access
|
||||
|
||||
---
|
||||
|
||||
## GitLab
|
||||
|
||||
### Provider URL
|
||||
|
||||
```yaml
|
||||
# GitLab.com
|
||||
providerURL: "https://gitlab.com"
|
||||
|
||||
# Self-hosted
|
||||
providerURL: "https://gitlab.your-company.com"
|
||||
```
|
||||
|
||||
### Configuration
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-gitlab
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: "https://gitlab.com"
|
||||
clientID: "your-gitlab-application-id"
|
||||
clientSecret: "your-gitlab-application-secret"
|
||||
callbackURL: "/oauth2/callback"
|
||||
sessionEncryptionKey: "your-32-char-encryption-key-here"
|
||||
scopes:
|
||||
- openid
|
||||
- profile
|
||||
- email
|
||||
# Note: GitLab doesn't require offline_access scope
|
||||
# Refresh tokens are issued automatically with openid
|
||||
allowedRolesAndGroups:
|
||||
- developers
|
||||
- maintainers
|
||||
forceHttps: true
|
||||
enablePkce: true
|
||||
```
|
||||
|
||||
### GitLab Setup
|
||||
|
||||
1. Go to GitLab Settings > Applications
|
||||
2. Create new application
|
||||
3. Add scopes: `openid`, `profile`, `email`
|
||||
4. Set redirect URI: `https://your-domain.com/oauth2/callback`
|
||||
5. Save and note Application ID and Secret
|
||||
|
||||
---
|
||||
|
||||
## GitHub
|
||||
|
||||
### Provider URL
|
||||
|
||||
```yaml
|
||||
providerURL: "https://github.com"
|
||||
```
|
||||
|
||||
### Configuration
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oauth-github
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: "https://github.com/login/oauth"
|
||||
clientID: "your-github-client-id"
|
||||
clientSecret: "your-github-client-secret"
|
||||
callbackURL: "/oauth2/callback"
|
||||
sessionEncryptionKey: "your-32-char-encryption-key-here"
|
||||
scopes:
|
||||
- user:email
|
||||
- read:user
|
||||
allowedUsers:
|
||||
- "github-username"
|
||||
forceHttps: true
|
||||
```
|
||||
|
||||
### Limitations
|
||||
|
||||
- **OAuth 2.0 only** - Not OpenID Connect
|
||||
- **No ID tokens** - Only access tokens for API calls
|
||||
- **No refresh tokens** - Users must re-authenticate on expiry
|
||||
- **No standard claims** - User info requires API calls
|
||||
|
||||
Use GitHub only for API access, not for user authentication with claims.
|
||||
|
||||
### GitHub Setup
|
||||
|
||||
1. Go to GitHub Settings > Developer settings > OAuth Apps
|
||||
2. Create new OAuth App
|
||||
3. Set Authorization callback URL: `https://your-domain.com/oauth2/callback`
|
||||
4. Note Client ID and generate Client Secret
|
||||
|
||||
---
|
||||
|
||||
## Generic OIDC
|
||||
|
||||
For any OIDC-compliant provider not listed above.
|
||||
|
||||
### Configuration
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-generic
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: "https://oidc.your-provider.com"
|
||||
clientID: "your-client-id"
|
||||
clientSecret: "your-client-secret"
|
||||
callbackURL: "/oauth2/callback"
|
||||
sessionEncryptionKey: "your-32-char-encryption-key-here"
|
||||
scopes:
|
||||
- openid
|
||||
- profile
|
||||
- email
|
||||
forceHttps: true
|
||||
enablePkce: true
|
||||
```
|
||||
|
||||
### Requirements
|
||||
|
||||
- Provider must expose `.well-known/openid-configuration` endpoint
|
||||
- Must support authorization code flow
|
||||
- ID tokens must contain required claims (email, sub, etc.)
|
||||
|
||||
---
|
||||
|
||||
## Automatic Scope Filtering
|
||||
|
||||
The middleware automatically filters OAuth scopes based on the provider's declared capabilities.
|
||||
|
||||
### How It Works
|
||||
|
||||
1. Fetches provider's `.well-known/openid-configuration`
|
||||
2. Extracts `scopes_supported` field
|
||||
3. Filters requested scopes to only include supported ones
|
||||
4. Falls back to all requested scopes if provider doesn't declare supported scopes
|
||||
|
||||
### Example: Self-Hosted GitLab
|
||||
|
||||
Self-hosted GitLab may reject `offline_access` scope:
|
||||
|
||||
```yaml
|
||||
scopes:
|
||||
- openid
|
||||
- profile
|
||||
- email
|
||||
- offline_access # Will be automatically filtered out if unsupported
|
||||
```
|
||||
|
||||
The middleware will:
|
||||
1. Read GitLab's discovery document
|
||||
2. Detect `offline_access` is NOT in `scopes_supported`
|
||||
3. Filter it out automatically
|
||||
4. Authentication succeeds
|
||||
|
||||
### Logging
|
||||
|
||||
```
|
||||
INFO: ScopeFilter: Filtered unsupported scopes: [offline_access]
|
||||
DEBUG: ScopeFilter: Final filtered scopes: [openid profile email]
|
||||
```
|
||||
|
||||
### Troubleshooting
|
||||
|
||||
If a provider rejects scopes even after filtering:
|
||||
1. Check the provider's discovery document: `curl https://provider/.well-known/openid-configuration`
|
||||
2. Use `overrideScopes: true` with only supported scopes
|
||||
3. Review middleware debug logs for filtering decisions
|
||||
@@ -1,955 +0,0 @@
|
||||
# Provider-Specific Configuration Guide
|
||||
|
||||
This guide covers the configuration requirements and best practices for each supported OIDC provider.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Google](#google)
|
||||
- [Microsoft Azure AD](#microsoft-azure-ad)
|
||||
- [Auth0](#auth0)
|
||||
- [GitHub](#github)
|
||||
- [GitLab](#gitlab)
|
||||
- [AWS Cognito](#aws-cognito)
|
||||
- [Keycloak](#keycloak)
|
||||
- [Okta](#okta)
|
||||
- [Generic OIDC](#generic-oidc)
|
||||
|
||||
---
|
||||
|
||||
## Google
|
||||
|
||||
### Provider URL
|
||||
```yaml
|
||||
providerUrl: "https://accounts.google.com"
|
||||
```
|
||||
|
||||
### Required Configuration
|
||||
```yaml
|
||||
clientId: "your-google-client-id.apps.googleusercontent.com"
|
||||
clientSecret: "your-google-client-secret"
|
||||
callbackUrl: "https://your-domain.com/auth/callback"
|
||||
scopes: ["openid", "profile", "email"]
|
||||
```
|
||||
|
||||
### Google-Specific Features
|
||||
- **Automatic offline access**: Google provider automatically adds `access_type=offline` and `prompt=consent`
|
||||
- **Scope filtering**: Automatically removes `offline_access` scope (not used by Google)
|
||||
- **Refresh token support**: Fully supported
|
||||
- **Domain restrictions**: Can restrict by Google Workspace domains
|
||||
|
||||
### Example Configuration
|
||||
```yaml
|
||||
# Traefik dynamic configuration
|
||||
http:
|
||||
middlewares:
|
||||
google-oidc:
|
||||
plugin:
|
||||
traefik-oidc:
|
||||
providerUrl: "https://accounts.google.com"
|
||||
clientId: "123456789-abcdef.apps.googleusercontent.com"
|
||||
clientSecret: "GOCSPX-your-client-secret"
|
||||
callbackUrl: "https://app.example.com/auth/callback"
|
||||
logoutUrl: "https://app.example.com/auth/logout"
|
||||
scopes: ["openid", "profile", "email"]
|
||||
allowedUserDomains: ["example.com", "company.org"]
|
||||
forceHttps: true
|
||||
enablePkce: true
|
||||
```
|
||||
|
||||
### Google OAuth Console Setup
|
||||
1. Go to [Google Cloud Console](https://console.cloud.google.com/)
|
||||
2. Create or select a project
|
||||
3. Enable Google+ API
|
||||
4. Create OAuth 2.0 credentials
|
||||
5. Add authorized redirect URIs: `https://your-domain.com/auth/callback`
|
||||
|
||||
---
|
||||
|
||||
## Microsoft Azure AD
|
||||
|
||||
### Provider URL
|
||||
```yaml
|
||||
# For Azure AD (single tenant)
|
||||
providerUrl: "https://login.microsoftonline.com/{tenant-id}/v2.0"
|
||||
|
||||
# For Azure AD (multi-tenant)
|
||||
providerUrl: "https://login.microsoftonline.com/common/v2.0"
|
||||
```
|
||||
|
||||
### Required Configuration
|
||||
```yaml
|
||||
clientId: "your-azure-application-id"
|
||||
clientSecret: "your-azure-client-secret"
|
||||
callbackUrl: "https://your-domain.com/auth/callback"
|
||||
scopes: ["openid", "profile", "email", "offline_access"]
|
||||
```
|
||||
|
||||
### Azure-Specific Features
|
||||
- **Response mode**: Automatically adds `response_mode=query`
|
||||
- **Offline access**: Requires `offline_access` scope for refresh tokens
|
||||
- **Access token validation**: Supports both JWT and opaque access tokens
|
||||
- **Tenant isolation**: Can restrict to specific Azure AD tenants
|
||||
- **Application ID URI**: Supports custom audience for protected APIs
|
||||
|
||||
### Example Configuration (Basic)
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
azure-oidc:
|
||||
plugin:
|
||||
traefik-oidc:
|
||||
providerUrl: "https://login.microsoftonline.com/common/v2.0"
|
||||
clientId: "12345678-1234-1234-1234-123456789abc"
|
||||
clientSecret: "your-azure-client-secret"
|
||||
callbackUrl: "https://app.example.com/auth/callback"
|
||||
logoutUrl: "https://app.example.com/auth/logout"
|
||||
postLogoutRedirectUri: "https://app.example.com"
|
||||
scopes: ["openid", "profile", "email", "offline_access"]
|
||||
allowedRolesAndGroups: ["App.Users", "Admin.Group"]
|
||||
forceHttps: true
|
||||
```
|
||||
|
||||
### Azure AD API Configuration (Application ID URI)
|
||||
|
||||
When exposing your application as an API with a custom Application ID URI, you need to specify the `audience` parameter. Azure AD includes the Application ID URI in the JWT `aud` claim.
|
||||
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
azure-api-oidc:
|
||||
plugin:
|
||||
traefik-oidc:
|
||||
providerUrl: "https://login.microsoftonline.com/common/v2.0"
|
||||
clientId: "12345678-1234-1234-1234-123456789abc"
|
||||
clientSecret: "your-azure-client-secret"
|
||||
# Specify the Application ID URI as audience
|
||||
audience: "api://12345678-1234-1234-1234-123456789abc"
|
||||
callbackUrl: "https://app.example.com/auth/callback"
|
||||
logoutUrl: "https://app.example.com/auth/logout"
|
||||
scopes: ["openid", "profile", "email", "offline_access"]
|
||||
forceHttps: true
|
||||
```
|
||||
|
||||
**Important**:
|
||||
- The `audience` parameter should match your Application ID URI (typically `api://{app-id}`)
|
||||
- Find your Application ID URI in Azure Portal → App Registration → Expose an API → Application ID URI
|
||||
- Without the `audience` parameter, access tokens with custom audiences will be rejected
|
||||
- For ID token validation only (no API access), you can omit the `audience` parameter
|
||||
|
||||
### Azure App Registration Setup
|
||||
1. Go to [Azure Portal](https://portal.azure.com/)
|
||||
2. Navigate to "Azure Active Directory" > "App registrations"
|
||||
3. Create new registration
|
||||
4. Add redirect URI: `https://your-domain.com/auth/callback`
|
||||
5. Create client secret in "Certificates & secrets"
|
||||
6. Configure API permissions for required scopes
|
||||
|
||||
### Azure AD API Exposure Setup (for custom audiences)
|
||||
1. In your App Registration, go to "Expose an API"
|
||||
2. Set the Application ID URI (e.g., `api://12345678-1234-1234-1234-123456789abc`)
|
||||
3. Add any custom scopes your API exposes
|
||||
4. Update the middleware configuration to include the `audience` parameter with this URI
|
||||
|
||||
---
|
||||
|
||||
## Auth0
|
||||
|
||||
### Provider URL
|
||||
```yaml
|
||||
providerUrl: "https://your-domain.auth0.com"
|
||||
```
|
||||
|
||||
### Required Configuration
|
||||
```yaml
|
||||
clientId: "your-auth0-client-id"
|
||||
clientSecret: "your-auth0-client-secret"
|
||||
callbackUrl: "https://your-domain.com/auth/callback"
|
||||
scopes: ["openid", "profile", "email", "offline_access"]
|
||||
```
|
||||
|
||||
### Auth0-Specific Features
|
||||
- **Custom domains**: Supports Auth0 custom domains
|
||||
- **Rules and hooks**: Leverages Auth0's extensibility
|
||||
- **Social connections**: Works with Auth0's social identity providers
|
||||
- **Offline access**: Requires `offline_access` scope
|
||||
- **API audiences**: Supports custom audience for API access tokens
|
||||
|
||||
### Example Configuration (Basic)
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
auth0-oidc:
|
||||
plugin:
|
||||
traefik-oidc:
|
||||
providerUrl: "https://company.auth0.com"
|
||||
clientId: "abcdef123456789"
|
||||
clientSecret: "your-auth0-client-secret"
|
||||
callbackUrl: "https://app.example.com/auth/callback"
|
||||
logoutUrl: "https://app.example.com/auth/logout"
|
||||
postLogoutRedirectUri: "https://app.example.com"
|
||||
scopes: ["openid", "profile", "email", "offline_access"]
|
||||
allowedUsers: ["user@example.com", "admin@company.com"]
|
||||
forceHttps: true
|
||||
enablePkce: true
|
||||
```
|
||||
|
||||
### Auth0 API Configuration (Custom Audience)
|
||||
|
||||
When using Auth0 APIs with custom audience parameters, you need to specify the `audience` field. Auth0 includes the API identifier in the JWT `aud` claim instead of the `clientId`.
|
||||
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
auth0-api-oidc:
|
||||
plugin:
|
||||
traefik-oidc:
|
||||
providerUrl: "https://company.auth0.com"
|
||||
clientId: "abcdef123456789"
|
||||
clientSecret: "your-auth0-client-secret"
|
||||
# Specify the Auth0 API identifier as audience
|
||||
audience: "https://api.company.com"
|
||||
callbackUrl: "https://app.example.com/auth/callback"
|
||||
logoutUrl: "https://app.example.com/auth/logout"
|
||||
scopes: ["openid", "profile", "email", "offline_access"]
|
||||
forceHttps: true
|
||||
enablePkce: true
|
||||
```
|
||||
|
||||
**Important**:
|
||||
- The `audience` parameter should match your Auth0 API identifier (not the client ID)
|
||||
- Find your API identifier in Auth0 Dashboard → APIs → Your API → Settings → Identifier
|
||||
- Without the `audience` parameter, access tokens with custom audiences will be rejected with "invalid audience" error
|
||||
- For ID token validation only (no APIs), you can omit the `audience` parameter
|
||||
|
||||
### Auth0 Application Setup
|
||||
1. Go to [Auth0 Dashboard](https://manage.auth0.com/)
|
||||
2. Create new application (Regular Web Application)
|
||||
3. Configure allowed callback URLs: `https://your-domain.com/auth/callback`
|
||||
4. Configure allowed logout URLs: `https://your-domain.com/auth/logout`
|
||||
5. Enable OIDC Conformant in Advanced Settings
|
||||
|
||||
### Auth0 API Setup (for custom audiences)
|
||||
1. Go to Auth0 Dashboard → APIs
|
||||
2. Create a new API or select existing API
|
||||
3. Note the "Identifier" field (e.g., `https://api.company.com`) - this is your `audience` value
|
||||
4. In API Settings → Machine to Machine Applications, authorize your application
|
||||
5. Configure API permissions/scopes as needed
|
||||
6. Use the API identifier as the `audience` parameter in your configuration
|
||||
|
||||
---
|
||||
|
||||
## GitHub
|
||||
|
||||
### Provider URL
|
||||
```yaml
|
||||
providerUrl: "https://github.com"
|
||||
```
|
||||
|
||||
### Required Configuration
|
||||
```yaml
|
||||
clientId: "your-github-client-id"
|
||||
clientSecret: "your-github-client-secret"
|
||||
callbackUrl: "https://your-domain.com/auth/callback"
|
||||
scopes: ["read:user", "user:email"]
|
||||
```
|
||||
|
||||
### GitHub-Specific Features
|
||||
- **Organization membership**: Can restrict by GitHub organization
|
||||
- **Team membership**: Can restrict by specific teams
|
||||
- **Limited OIDC**: GitHub has limited OIDC support
|
||||
- **Email verification**: Requires verified email addresses
|
||||
|
||||
### Example Configuration
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
github-oidc:
|
||||
plugin:
|
||||
traefik-oidc:
|
||||
providerUrl: "https://github.com"
|
||||
clientId: "Iv1.abcdef123456"
|
||||
clientSecret: "your-github-client-secret"
|
||||
callbackUrl: "https://app.example.com/auth/callback"
|
||||
logoutUrl: "https://app.example.com/auth/logout"
|
||||
scopes: ["read:user", "user:email"]
|
||||
allowedUsers: ["octocat", "github-user"]
|
||||
forceHttps: true
|
||||
```
|
||||
|
||||
### GitHub OAuth App Setup
|
||||
1. Go to GitHub Settings > Developer settings > OAuth Apps
|
||||
2. Create new OAuth App
|
||||
3. Set Authorization callback URL: `https://your-domain.com/auth/callback`
|
||||
4. Note the Client ID and generate Client Secret
|
||||
|
||||
---
|
||||
|
||||
## GitLab
|
||||
|
||||
### Provider URL
|
||||
```yaml
|
||||
# GitLab.com
|
||||
providerUrl: "https://gitlab.com"
|
||||
|
||||
# Self-hosted GitLab
|
||||
providerUrl: "https://gitlab.your-company.com"
|
||||
```
|
||||
|
||||
### Required Configuration
|
||||
```yaml
|
||||
clientId: "your-gitlab-application-id"
|
||||
clientSecret: "your-gitlab-application-secret"
|
||||
callbackUrl: "https://your-domain.com/auth/callback"
|
||||
scopes: ["openid", "profile", "email"]
|
||||
```
|
||||
|
||||
### GitLab-Specific Features
|
||||
- **Self-hosted support**: Works with self-hosted GitLab instances
|
||||
- **Group membership**: Can restrict by GitLab groups
|
||||
- **Project access**: Can validate project permissions
|
||||
- **Offline access**: Supports refresh tokens without requiring `offline_access` scope
|
||||
|
||||
### Example Configuration
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
gitlab-oidc:
|
||||
plugin:
|
||||
traefik-oidc:
|
||||
providerUrl: "https://gitlab.com"
|
||||
clientId: "abcdef123456789"
|
||||
clientSecret: "your-gitlab-application-secret"
|
||||
callbackUrl: "https://app.example.com/auth/callback"
|
||||
logoutUrl: "https://app.example.com/auth/logout"
|
||||
scopes: ["openid", "profile", "email"]
|
||||
# Note: GitLab doesn't support the offline_access scope.
|
||||
# Refresh tokens are issued automatically for the openid scope.
|
||||
allowedRolesAndGroups: ["developers", "maintainers"]
|
||||
forceHttps: true
|
||||
enablePkce: true
|
||||
```
|
||||
|
||||
### GitLab Application Setup
|
||||
1. Go to GitLab Settings > Applications
|
||||
2. Create new application
|
||||
3. Add scopes: `openid`, `profile`, `email`
|
||||
4. Set redirect URI: `https://your-domain.com/auth/callback`
|
||||
5. Save and note the Application ID and Secret
|
||||
|
||||
---
|
||||
|
||||
## AWS Cognito
|
||||
|
||||
### Provider URL
|
||||
```yaml
|
||||
providerUrl: "https://cognito-idp.{region}.amazonaws.com/{user-pool-id}"
|
||||
```
|
||||
|
||||
### Required Configuration
|
||||
```yaml
|
||||
clientId: "your-cognito-app-client-id"
|
||||
clientSecret: "your-cognito-app-client-secret" # If app client has secret
|
||||
callbackUrl: "https://your-domain.com/auth/callback"
|
||||
scopes: ["openid", "profile", "email"]
|
||||
```
|
||||
|
||||
### Cognito-Specific Features
|
||||
- **User pools**: Integrates with Cognito User Pools
|
||||
- **Custom attributes**: Supports custom user attributes
|
||||
- **Groups**: Can validate Cognito user group membership
|
||||
- **Regional endpoints**: Requires region-specific URLs
|
||||
|
||||
### Example Configuration
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
cognito-oidc:
|
||||
plugin:
|
||||
traefik-oidc:
|
||||
providerUrl: "https://cognito-idp.us-east-1.amazonaws.com/us-east-1_ABCDEF123"
|
||||
clientId: "1234567890abcdefghij"
|
||||
clientSecret: "your-cognito-client-secret"
|
||||
callbackUrl: "https://app.example.com/auth/callback"
|
||||
logoutUrl: "https://app.example.com/auth/logout"
|
||||
scopes: ["openid", "profile", "email"]
|
||||
allowedRolesAndGroups: ["admin", "users"]
|
||||
forceHttps: true
|
||||
```
|
||||
|
||||
### AWS Cognito Setup
|
||||
1. Create Cognito User Pool
|
||||
2. Create App Client with OIDC scopes
|
||||
3. Configure App Client settings:
|
||||
- Callback URLs: `https://your-domain.com/auth/callback`
|
||||
- Sign out URLs: `https://your-domain.com/auth/logout`
|
||||
- OAuth flows: Authorization code grant
|
||||
4. Configure hosted UI domain (optional)
|
||||
|
||||
---
|
||||
|
||||
## Keycloak
|
||||
|
||||
### Provider URL
|
||||
```yaml
|
||||
providerUrl: "https://keycloak.your-company.com/realms/{realm-name}"
|
||||
```
|
||||
|
||||
### Required Configuration
|
||||
```yaml
|
||||
clientId: "your-keycloak-client-id"
|
||||
clientSecret: "your-keycloak-client-secret"
|
||||
callbackUrl: "https://your-domain.com/auth/callback"
|
||||
scopes: ["openid", "profile", "email"]
|
||||
```
|
||||
|
||||
### Keycloak-Specific Features
|
||||
- **Realm support**: Multi-realm deployments
|
||||
- **Custom mappers**: Rich claim mapping capabilities
|
||||
- **Role-based access**: Fine-grained role management
|
||||
- **Offline access**: Full refresh token support
|
||||
|
||||
### Example Configuration
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
keycloak-oidc:
|
||||
plugin:
|
||||
traefik-oidc:
|
||||
providerUrl: "https://keycloak.company.com/realms/employees"
|
||||
clientId: "traefik-app"
|
||||
clientSecret: "your-keycloak-client-secret"
|
||||
callbackUrl: "https://app.example.com/auth/callback"
|
||||
logoutUrl: "https://app.example.com/auth/logout"
|
||||
postLogoutRedirectUri: "https://app.example.com"
|
||||
scopes: ["openid", "profile", "email", "offline_access"]
|
||||
allowedRolesAndGroups: ["app-users", "administrators"]
|
||||
forceHttps: true
|
||||
enablePkce: true
|
||||
```
|
||||
|
||||
### Keycloak Client Setup
|
||||
1. Access Keycloak Admin Console
|
||||
2. Select appropriate realm
|
||||
3. Create new client:
|
||||
- Client Protocol: openid-connect
|
||||
- Access Type: confidential
|
||||
- Valid Redirect URIs: `https://your-domain.com/auth/callback`
|
||||
4. Configure client scopes and mappers
|
||||
5. Generate client secret in Credentials tab
|
||||
|
||||
---
|
||||
|
||||
## Okta
|
||||
|
||||
### Provider URL
|
||||
```yaml
|
||||
providerUrl: "https://your-domain.okta.com"
|
||||
```
|
||||
|
||||
### Required Configuration
|
||||
```yaml
|
||||
clientId: "your-okta-client-id"
|
||||
clientSecret: "your-okta-client-secret"
|
||||
callbackUrl: "https://your-domain.com/auth/callback"
|
||||
scopes: ["openid", "profile", "email", "offline_access"]
|
||||
```
|
||||
|
||||
### Okta-Specific Features
|
||||
- **Custom authorization servers**: Supports custom auth servers
|
||||
- **Group claims**: Rich group membership information
|
||||
- **Universal Directory**: Integrates with Okta's user store
|
||||
- **Offline access**: Requires `offline_access` scope
|
||||
|
||||
### Example Configuration
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
okta-oidc:
|
||||
plugin:
|
||||
traefik-oidc:
|
||||
providerUrl: "https://company.okta.com"
|
||||
clientId: "0oa123456789abcdef"
|
||||
clientSecret: "your-okta-client-secret"
|
||||
callbackUrl: "https://app.example.com/auth/callback"
|
||||
logoutUrl: "https://app.example.com/auth/logout"
|
||||
postLogoutRedirectUri: "https://app.example.com"
|
||||
scopes: ["openid", "profile", "email", "offline_access"]
|
||||
allowedRolesAndGroups: ["Everyone", "Administrators"]
|
||||
forceHttps: true
|
||||
enablePkce: true
|
||||
```
|
||||
|
||||
### Okta Application Setup
|
||||
1. Access Okta Admin Console
|
||||
2. Go to Applications > Create App Integration
|
||||
3. Select OIDC - OpenID Connect
|
||||
4. Choose Web Application
|
||||
5. Configure:
|
||||
- Sign-in redirect URIs: `https://your-domain.com/auth/callback`
|
||||
- Sign-out redirect URIs: `https://your-domain.com/auth/logout`
|
||||
- Grant types: Authorization Code, Refresh Token
|
||||
6. Assign users or groups
|
||||
|
||||
---
|
||||
|
||||
## Generic OIDC
|
||||
|
||||
### Provider URL
|
||||
```yaml
|
||||
providerUrl: "https://your-oidc-provider.com"
|
||||
```
|
||||
|
||||
### Required Configuration
|
||||
```yaml
|
||||
clientId: "your-client-id"
|
||||
clientSecret: "your-client-secret"
|
||||
callbackUrl: "https://your-domain.com/auth/callback"
|
||||
scopes: ["openid", "profile", "email"]
|
||||
```
|
||||
|
||||
### Generic Features
|
||||
- **Standards compliance**: Works with any OIDC-compliant provider
|
||||
- **Auto-discovery**: Uses `.well-known/openid-configuration` endpoint
|
||||
- **Flexible scopes**: Supports custom scope requirements
|
||||
- **Custom claims**: Works with provider-specific claims
|
||||
|
||||
### Example Configuration
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
generic-oidc:
|
||||
plugin:
|
||||
traefik-oidc:
|
||||
providerUrl: "https://oidc.your-provider.com"
|
||||
clientId: "your-client-id"
|
||||
clientSecret: "your-client-secret"
|
||||
callbackUrl: "https://app.example.com/auth/callback"
|
||||
logoutUrl: "https://app.example.com/auth/logout"
|
||||
scopes: ["openid", "profile", "email"]
|
||||
forceHttps: true
|
||||
enablePkce: true
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Automatic Scope Filtering
|
||||
|
||||
### Overview
|
||||
|
||||
The middleware automatically filters OAuth scopes based on the provider's capabilities declared in their OIDC discovery document (`.well-known/openid-configuration`). This prevents authentication failures when providers reject unsupported scopes.
|
||||
|
||||
### How It Works
|
||||
|
||||
1. **Discovery Document Parsing**: The middleware fetches the provider's discovery document and extracts the `scopes_supported` field
|
||||
2. **Intelligent Filtering**: Requested scopes are filtered to only include those the provider supports
|
||||
3. **Fallback Behavior**: If the provider doesn't declare `scopes_supported`, all requested scopes are used (backward compatible)
|
||||
4. **Provider-Specific Handling**: Special logic for Google and Azure is preserved and applied after filtering
|
||||
|
||||
### Example Scenarios
|
||||
|
||||
#### Self-Hosted GitLab
|
||||
|
||||
**Problem**: Self-hosted GitLab instances reject the `offline_access` scope with error:
|
||||
```
|
||||
The requested scope is invalid, unknown, or malformed.
|
||||
```
|
||||
|
||||
**Solution**: The middleware automatically detects this by:
|
||||
1. Reading GitLab's discovery document at `https://gitlab.example.com/.well-known/openid-configuration`
|
||||
2. Observing that `offline_access` is NOT in the `scopes_supported` list
|
||||
3. Filtering out `offline_access` from the request
|
||||
4. Authentication succeeds
|
||||
|
||||
**Configuration**:
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
gitlab-oidc:
|
||||
plugin:
|
||||
traefik-oidc:
|
||||
providerUrl: "https://gitlab.example.com"
|
||||
clientId: "your-gitlab-application-id"
|
||||
clientSecret: "your-gitlab-application-secret"
|
||||
callbackUrl: "https://app.example.com/auth/callback"
|
||||
scopes: ["openid", "profile", "email", "offline_access"]
|
||||
# Even though offline_access is listed, it will be automatically
|
||||
# filtered out if GitLab doesn't support it
|
||||
```
|
||||
|
||||
#### Auth0 or Keycloak
|
||||
|
||||
These providers typically support `offline_access` and it will be included:
|
||||
|
||||
```yaml
|
||||
# Auth0 scopes_supported: ["openid", "profile", "email", "offline_access", ...]
|
||||
# Result: All requested scopes are sent
|
||||
```
|
||||
|
||||
### Benefits
|
||||
|
||||
1. **Self-Hosted Support**: Works seamlessly with self-hosted provider instances
|
||||
2. **No Manual Configuration**: No need to know which scopes each provider supports
|
||||
3. **Error Prevention**: Eliminates "invalid scope" authentication failures
|
||||
4. **Standards Compliant**: Uses official OIDC discovery specification (RFC 8414)
|
||||
5. **Backward Compatible**: Existing configurations continue to work
|
||||
|
||||
### Logging
|
||||
|
||||
The middleware provides detailed logging for scope filtering:
|
||||
|
||||
```
|
||||
INFO: ScopeFilter: Filtered unsupported scopes for https://gitlab.example.com: [offline_access]
|
||||
DEBUG: ScopeFilter: Provider https://gitlab.example.com supported scopes: [openid profile email read_user read_api]
|
||||
DEBUG: ScopeFilter: Final filtered scopes: [openid profile email]
|
||||
```
|
||||
|
||||
### Troubleshooting
|
||||
|
||||
**Issue**: Provider rejects scope even after filtering
|
||||
|
||||
**Possible Causes**:
|
||||
1. Provider's discovery document is outdated
|
||||
2. Provider doesn't properly implement `scopes_supported`
|
||||
3. Custom authorization server with non-standard behavior
|
||||
|
||||
**Solutions**:
|
||||
1. Use `overrideScopes: true` and explicitly list only supported scopes
|
||||
2. Check the provider's discovery document manually: `curl https://your-provider/.well-known/openid-configuration`
|
||||
3. Review middleware debug logs for filtering decisions
|
||||
|
||||
---
|
||||
|
||||
## Common Configuration Options
|
||||
|
||||
### Audience Configuration
|
||||
|
||||
The `audience` parameter specifies the expected JWT audience claim value. This is particularly important when using Auth0 APIs, Azure AD Application ID URIs, or other providers with custom audience requirements.
|
||||
|
||||
```yaml
|
||||
# Optional: Custom audience for JWT validation
|
||||
# If not set, defaults to clientID for backward compatibility
|
||||
audience: "https://api.example.com" # Auth0 API identifier
|
||||
# OR
|
||||
audience: "api://12345-guid" # Azure AD Application ID URI
|
||||
```
|
||||
|
||||
**When to use**:
|
||||
- **Auth0**: When using Auth0 APIs with custom audience parameters
|
||||
- **Azure AD**: When exposing your app as an API with Application ID URI
|
||||
- **Keycloak**: When using audience-restricted tokens
|
||||
- **Okta**: When using custom authorization servers with API audiences
|
||||
|
||||
**When to omit**:
|
||||
- For standard ID token validation (default behavior)
|
||||
- When the provider sets `aud` claim to your `clientID`
|
||||
- For backward compatibility with existing configurations
|
||||
|
||||
**Security Note**: The `audience` parameter prevents token confusion attacks by ensuring tokens issued for one service cannot be used at another service.
|
||||
|
||||
### Security Settings
|
||||
```yaml
|
||||
# Force HTTPS (recommended for production)
|
||||
forceHttps: true
|
||||
|
||||
# Enable PKCE (recommended for security)
|
||||
enablePkce: true
|
||||
|
||||
# Session encryption key (32+ characters)
|
||||
sessionEncryptionKey: "your-very-long-encryption-key-here"
|
||||
```
|
||||
|
||||
### Access Control
|
||||
```yaml
|
||||
# Restrict by email addresses
|
||||
allowedUsers: ["user1@example.com", "user2@example.com"]
|
||||
|
||||
# Restrict by email domains
|
||||
allowedUserDomains: ["company.com", "partner.org"]
|
||||
|
||||
# Restrict by roles/groups (provider-specific)
|
||||
allowedRolesAndGroups: ["admin", "users", "developers"]
|
||||
```
|
||||
|
||||
### URLs and Endpoints
|
||||
```yaml
|
||||
# OAuth callback URL (must match provider config)
|
||||
callbackUrl: "https://your-domain.com/auth/callback"
|
||||
|
||||
# Logout endpoint
|
||||
logoutUrl: "https://your-domain.com/auth/logout"
|
||||
|
||||
# Post-logout redirect (optional)
|
||||
postLogoutRedirectUri: "https://your-domain.com"
|
||||
|
||||
# URLs to exclude from authentication
|
||||
excludedUrls: ["/health", "/metrics", "/public"]
|
||||
```
|
||||
|
||||
### Advanced Settings
|
||||
```yaml
|
||||
# Override default scopes
|
||||
overrideScopes: true
|
||||
scopes: ["openid", "custom_scope"]
|
||||
|
||||
# Rate limiting (requests per second)
|
||||
rateLimit: 10
|
||||
|
||||
# Token refresh grace period (seconds)
|
||||
refreshGracePeriodSeconds: 60
|
||||
|
||||
# Cookie domain (for subdomain sharing)
|
||||
cookieDomain: ".example.com"
|
||||
|
||||
# Custom headers to inject
|
||||
headers:
|
||||
- name: "X-User-Email"
|
||||
value: "{{.Claims.email}}"
|
||||
- name: "X-User-Name"
|
||||
value: "{{.Claims.name}}"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
1. **Invalid redirect URI**
|
||||
- Ensure callback URL exactly matches provider configuration
|
||||
- Check for HTTP vs HTTPS mismatches
|
||||
|
||||
2. **Scope errors**
|
||||
- Verify required scopes are configured in provider
|
||||
- Some providers require specific scopes for refresh tokens
|
||||
|
||||
3. **Token validation failures**
|
||||
- Check provider URL format and accessibility
|
||||
- Verify `.well-known/openid-configuration` endpoint is reachable
|
||||
|
||||
4. **Session issues**
|
||||
- Ensure session encryption key is properly configured
|
||||
- Check cookie domain settings for subdomain scenarios
|
||||
|
||||
### Debug Mode
|
||||
Enable debug logging to troubleshoot configuration issues:
|
||||
```yaml
|
||||
logLevel: "debug"
|
||||
```
|
||||
|
||||
This will provide detailed logs of the authentication flow and help identify configuration problems.
|
||||
|
||||
---
|
||||
|
||||
## Security Headers Configuration
|
||||
|
||||
The plugin includes comprehensive security headers support to protect your applications against common web vulnerabilities.
|
||||
|
||||
### Default Security Headers
|
||||
|
||||
By default, the plugin applies these security headers:
|
||||
|
||||
- `X-Frame-Options: DENY` - Prevents clickjacking
|
||||
- `X-Content-Type-Options: nosniff` - Prevents MIME sniffing
|
||||
- `X-XSS-Protection: 1; mode=block` - Enables XSS protection
|
||||
- `Referrer-Policy: strict-origin-when-cross-origin` - Controls referrer information
|
||||
- `Strict-Transport-Security` - Forces HTTPS (when HTTPS is detected)
|
||||
|
||||
### Security Profiles
|
||||
|
||||
Choose from predefined security profiles or create custom configurations:
|
||||
|
||||
#### Default Profile (Recommended)
|
||||
```yaml
|
||||
securityHeaders:
|
||||
enabled: true
|
||||
profile: "default"
|
||||
```
|
||||
|
||||
#### Strict Profile (Maximum Security)
|
||||
```yaml
|
||||
securityHeaders:
|
||||
enabled: true
|
||||
profile: "strict"
|
||||
# Additional strict CSP and cross-origin policies
|
||||
```
|
||||
|
||||
#### Development Profile (Local Development)
|
||||
```yaml
|
||||
securityHeaders:
|
||||
enabled: true
|
||||
profile: "development"
|
||||
# Relaxed policies for local development
|
||||
```
|
||||
|
||||
#### API Profile (API Endpoints)
|
||||
```yaml
|
||||
securityHeaders:
|
||||
enabled: true
|
||||
profile: "api"
|
||||
corsEnabled: true
|
||||
corsAllowedOrigins: ["https://your-frontend.com"]
|
||||
```
|
||||
|
||||
### Custom Security Configuration
|
||||
|
||||
For complete control, use the custom profile:
|
||||
|
||||
```yaml
|
||||
securityHeaders:
|
||||
enabled: true
|
||||
profile: "custom"
|
||||
|
||||
# Content Security Policy
|
||||
contentSecurityPolicy: "default-src 'self'; script-src 'self' 'unsafe-inline'"
|
||||
|
||||
# HSTS Configuration
|
||||
strictTransportSecurity: true
|
||||
strictTransportSecurityMaxAge: 31536000 # 1 year
|
||||
strictTransportSecuritySubdomains: true
|
||||
strictTransportSecurityPreload: true
|
||||
|
||||
# Frame and content protection
|
||||
frameOptions: "DENY" # or "SAMEORIGIN", "ALLOW-FROM uri"
|
||||
contentTypeOptions: "nosniff"
|
||||
xssProtection: "1; mode=block"
|
||||
referrerPolicy: "strict-origin-when-cross-origin"
|
||||
|
||||
# Permissions policy (feature policy)
|
||||
permissionsPolicy: "geolocation=(), microphone=(), camera=()"
|
||||
|
||||
# Cross-origin policies
|
||||
crossOriginEmbedderPolicy: "require-corp"
|
||||
crossOriginOpenerPolicy: "same-origin"
|
||||
crossOriginResourcePolicy: "same-origin"
|
||||
|
||||
# CORS configuration
|
||||
corsEnabled: true
|
||||
corsAllowedOrigins:
|
||||
- "https://app.example.com"
|
||||
- "https://*.api.example.com"
|
||||
corsAllowedMethods: ["GET", "POST", "PUT", "DELETE", "OPTIONS"]
|
||||
corsAllowedHeaders: ["Authorization", "Content-Type", "X-Requested-With"]
|
||||
corsAllowCredentials: true
|
||||
corsMaxAge: 86400 # 24 hours
|
||||
|
||||
# Custom headers
|
||||
customHeaders:
|
||||
X-Custom-Header: "custom-value"
|
||||
X-API-Version: "v1"
|
||||
|
||||
# Server identification
|
||||
disableServerHeader: true
|
||||
disablePoweredByHeader: true
|
||||
```
|
||||
|
||||
### Complete Example with Security Headers
|
||||
|
||||
Here's a complete configuration example for Google OIDC with custom security headers:
|
||||
|
||||
```yaml
|
||||
# Traefik dynamic configuration
|
||||
http:
|
||||
middlewares:
|
||||
secure-google-oidc:
|
||||
plugin:
|
||||
traefik-oidc:
|
||||
# OIDC Configuration
|
||||
providerUrl: "https://accounts.google.com"
|
||||
clientId: "123456789-abcdef.apps.googleusercontent.com"
|
||||
clientSecret: "GOCSPX-your-client-secret"
|
||||
callbackUrl: "https://your-domain.com/auth/callback"
|
||||
sessionEncryptionKey: "your-32-character-encryption-key-here"
|
||||
|
||||
# Domain restrictions
|
||||
allowedUserDomains: ["your-company.com"]
|
||||
|
||||
# Security Headers
|
||||
securityHeaders:
|
||||
enabled: true
|
||||
profile: "strict"
|
||||
corsEnabled: true
|
||||
corsAllowedOrigins:
|
||||
- "https://your-frontend.com"
|
||||
- "https://*.your-domain.com"
|
||||
corsAllowCredentials: true
|
||||
customHeaders:
|
||||
X-Company: "YourCompany"
|
||||
X-Environment: "production"
|
||||
|
||||
routers:
|
||||
secure-app:
|
||||
rule: "Host(`your-domain.com`)"
|
||||
middlewares:
|
||||
- secure-google-oidc
|
||||
service: your-app-service
|
||||
tls:
|
||||
certResolver: letsencrypt
|
||||
```
|
||||
|
||||
### CORS Configuration Details
|
||||
|
||||
For applications with frontend-backend separation, configure CORS properly:
|
||||
|
||||
#### Simple CORS (Single Origin)
|
||||
```yaml
|
||||
securityHeaders:
|
||||
corsEnabled: true
|
||||
corsAllowedOrigins: ["https://app.example.com"]
|
||||
corsAllowCredentials: true
|
||||
```
|
||||
|
||||
#### Wildcard Subdomains
|
||||
```yaml
|
||||
securityHeaders:
|
||||
corsEnabled: true
|
||||
corsAllowedOrigins: ["https://*.example.com"]
|
||||
corsAllowCredentials: true
|
||||
```
|
||||
|
||||
#### Development with Multiple Ports
|
||||
```yaml
|
||||
securityHeaders:
|
||||
profile: "development"
|
||||
corsEnabled: true
|
||||
corsAllowedOrigins:
|
||||
- "http://localhost:*"
|
||||
- "http://127.0.0.1:*"
|
||||
```
|
||||
|
||||
### Security Best Practices
|
||||
|
||||
1. **Always use HTTPS in production**
|
||||
- Set `forceHttps: true`
|
||||
- Configure proper TLS certificates
|
||||
|
||||
2. **Implement proper CSP**
|
||||
- Start with strict policy
|
||||
- Add exceptions only when necessary
|
||||
- Test thoroughly
|
||||
|
||||
3. **Configure CORS restrictively**
|
||||
- Only allow necessary origins
|
||||
- Use specific domains instead of wildcards when possible
|
||||
|
||||
4. **Enable HSTS**
|
||||
- Use long max-age values (1 year minimum)
|
||||
- Include subdomains when appropriate
|
||||
|
||||
5. **Monitor security headers**
|
||||
- Use browser developer tools to verify headers
|
||||
- Test with security scanning tools
|
||||
- Regularly review and update policies
|
||||
|
||||
### Testing Security Headers
|
||||
|
||||
Use browser developer tools or online tools to verify your security headers:
|
||||
|
||||
1. **Browser DevTools**: Check Network tab → Response Headers
|
||||
2. **Online scanners**: Use securityheaders.com or observatory.mozilla.org
|
||||
3. **Command line**: Use `curl -I https://your-domain.com`
|
||||
|
||||
Example verification:
|
||||
```bash
|
||||
curl -I https://your-domain.com
|
||||
# Should show security headers in response
|
||||
```
|
||||
+546
@@ -0,0 +1,546 @@
|
||||
# Redis Cache for Distributed Deployments
|
||||
|
||||
Redis cache support for multi-replica Traefik deployments with shared state.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Overview](#overview)
|
||||
- [Why Use Redis Cache?](#why-use-redis-cache)
|
||||
- [Configuration](#configuration)
|
||||
- [Cache Modes](#cache-modes)
|
||||
- [Deployment Examples](#deployment-examples)
|
||||
- [Performance Tuning](#performance-tuning)
|
||||
- [Monitoring](#monitoring)
|
||||
- [Troubleshooting](#troubleshooting)
|
||||
- [Migration Guide](#migration-guide)
|
||||
|
||||
---
|
||||
|
||||
## Overview
|
||||
|
||||
The Redis cache feature provides distributed caching for the Traefik OIDC plugin, enabling seamless operation across multiple Traefik instances.
|
||||
|
||||
### Key Features
|
||||
|
||||
- **Distributed JTI Replay Detection**: Prevents token replay attacks across all instances
|
||||
- **Shared Session Management**: Consistent user sessions across replicas
|
||||
- **Circuit Breaker**: Automatic fallback to memory cache during Redis outages
|
||||
- **Health Checking**: Continuous monitoring of Redis connectivity
|
||||
- **Flexible Cache Modes**: Memory, Redis, or hybrid caching strategies
|
||||
- **Pure-Go Implementation**: Yaegi-compatible, works with dynamic plugin loading
|
||||
|
||||
### Architecture
|
||||
|
||||
```
|
||||
┌──────────────┐ ┌──────────────┐ ┌──────────────┐
|
||||
│ Traefik #1 │ │ Traefik #2 │ │ Traefik #3 │
|
||||
│ (Plugin) │ │ (Plugin) │ │ (Plugin) │
|
||||
└──────┬───────┘ └──────┬───────┘ └──────┬───────┘
|
||||
│ │ │
|
||||
└────────────────────┼────────────────────┘
|
||||
│
|
||||
┌──────▼──────┐
|
||||
│ Redis │
|
||||
│ (Shared │
|
||||
│ Cache) │
|
||||
└─────────────┘
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Why Use Redis Cache?
|
||||
|
||||
### The Problem
|
||||
|
||||
When running multiple Traefik instances without shared cache:
|
||||
|
||||
1. **False Positive Replay Detection**
|
||||
- User authenticates → Token stored in Instance A's JTI cache
|
||||
- Next request → Load balancer routes to Instance B
|
||||
- Instance B doesn't have the JTI → Falsely detects replay attack
|
||||
|
||||
2. **Session Inconsistency**
|
||||
- User session created on Instance A
|
||||
- Subsequent request routed to Instance B
|
||||
- Instance B has no knowledge of the session
|
||||
|
||||
3. **Token Metadata Fragmentation**
|
||||
- Token refresh happens on Instance A
|
||||
- Other instances continue using old tokens
|
||||
|
||||
### The Solution
|
||||
|
||||
Redis provides centralized cache that all instances share, ensuring:
|
||||
|
||||
- **Consistent Authentication**: All instances share authentication state
|
||||
- **True Replay Detection**: JTI cache shared across all instances
|
||||
- **Seamless Scaling**: Add/remove instances without affecting sessions
|
||||
- **High Availability**: Circuit breaker with automatic fallback
|
||||
|
||||
---
|
||||
|
||||
## Configuration
|
||||
|
||||
### Basic Configuration
|
||||
|
||||
```yaml
|
||||
redis:
|
||||
enabled: true
|
||||
address: "redis:6379"
|
||||
password: "your-password" # Optional
|
||||
db: 0
|
||||
keyPrefix: "traefikoidc:"
|
||||
cacheMode: "hybrid"
|
||||
```
|
||||
|
||||
### All Configuration Options
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
|-----------|------|---------|-------------|
|
||||
| `enabled` | bool | `false` | Enable Redis caching |
|
||||
| `address` | string | - | Redis server address (`host:port`) |
|
||||
| `password` | string | - | Redis password (optional) |
|
||||
| `db` | int | `0` | Redis database number (0-15) |
|
||||
| `keyPrefix` | string | `traefikoidc:` | Prefix for all Redis keys |
|
||||
| `cacheMode` | string | `redis` | Cache mode: `memory`, `redis`, `hybrid` |
|
||||
| `poolSize` | int | `10` | Connection pool size |
|
||||
| `connectTimeout` | int | `5` | Connection timeout (seconds) |
|
||||
| `readTimeout` | int | `3` | Read timeout (seconds) |
|
||||
| `writeTimeout` | int | `3` | Write timeout (seconds) |
|
||||
| `enableTLS` | bool | `false` | Enable TLS for connections |
|
||||
| `tlsSkipVerify` | bool | `false` | Skip TLS certificate verification |
|
||||
| `enableCircuitBreaker` | bool | `true` | Enable circuit breaker |
|
||||
| `circuitBreakerThreshold` | int | `5` | Failures before circuit opens |
|
||||
| `circuitBreakerTimeout` | int | `60` | Circuit reset timeout (seconds) |
|
||||
| `enableHealthCheck` | bool | `true` | Enable periodic health checks |
|
||||
| `healthCheckInterval` | int | `30` | Health check interval (seconds) |
|
||||
| `hybridL1Size` | int | `500` | Max items in L1 cache (hybrid mode) |
|
||||
| `hybridL1MemoryMB` | int64 | `10` | Max memory for L1 cache in MB |
|
||||
|
||||
### Environment Variables (Fallback)
|
||||
|
||||
If not configured through Traefik, these environment variables are used:
|
||||
|
||||
```bash
|
||||
REDIS_ENABLED=true
|
||||
REDIS_ADDRESS=redis:6379
|
||||
REDIS_PASSWORD=your-password
|
||||
REDIS_DB=0
|
||||
REDIS_KEY_PREFIX=traefikoidc:
|
||||
REDIS_CACHE_MODE=hybrid
|
||||
REDIS_POOL_SIZE=10
|
||||
REDIS_CONNECT_TIMEOUT=5
|
||||
REDIS_READ_TIMEOUT=3
|
||||
REDIS_WRITE_TIMEOUT=3
|
||||
REDIS_ENABLE_TLS=false
|
||||
REDIS_TLS_SKIP_VERIFY=false
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Cache Modes
|
||||
|
||||
### Memory Mode (Default without Redis)
|
||||
|
||||
```yaml
|
||||
redis:
|
||||
cacheMode: "memory"
|
||||
```
|
||||
|
||||
- Uses only in-memory cache
|
||||
- Suitable for single-instance deployments
|
||||
- No Redis dependency
|
||||
- Fastest performance
|
||||
|
||||
### Redis Mode
|
||||
|
||||
```yaml
|
||||
redis:
|
||||
enabled: true
|
||||
address: "redis:6379"
|
||||
cacheMode: "redis"
|
||||
```
|
||||
|
||||
- All operations go directly to Redis
|
||||
- Ensures consistency across replicas
|
||||
- Slightly higher latency
|
||||
|
||||
### Hybrid Mode (Recommended)
|
||||
|
||||
```yaml
|
||||
redis:
|
||||
enabled: true
|
||||
address: "redis:6379"
|
||||
cacheMode: "hybrid"
|
||||
```
|
||||
|
||||
Two-tier caching strategy:
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────┐
|
||||
│ Client Request │
|
||||
└────────────────┬────────────────────────┘
|
||||
▼
|
||||
┌────────────────┐
|
||||
│ Local Cache │ ← L1 Cache (Fast)
|
||||
│ (Memory) │
|
||||
└────────┬───────┘
|
||||
│ Miss
|
||||
▼
|
||||
┌────────────────┐
|
||||
│ Remote Cache │ ← L2 Cache (Shared)
|
||||
│ (Redis) │
|
||||
└────────────────┘
|
||||
```
|
||||
|
||||
**Read Path:**
|
||||
1. Check local memory cache (L1)
|
||||
2. On miss, check Redis (L2)
|
||||
3. On hit in Redis, populate L1
|
||||
4. Return value
|
||||
|
||||
**Write Path:**
|
||||
1. Write to Redis (L2) for durability
|
||||
2. Write to local cache (L1) for speed
|
||||
|
||||
### Performance Comparison
|
||||
|
||||
| Operation | Memory Mode | Redis Mode | Hybrid Mode |
|
||||
|-----------|------------|------------|-------------|
|
||||
| Read (p50) | 0.1ms | 2ms | 0.2ms |
|
||||
| Read (p99) | 0.5ms | 10ms | 5ms |
|
||||
| Write (p50) | 0.2ms | 3ms | 3ms |
|
||||
| Throughput | 100k/s | 20k/s | 80k/s |
|
||||
|
||||
---
|
||||
|
||||
## Deployment Examples
|
||||
|
||||
### Docker Compose
|
||||
|
||||
```yaml
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
redis:
|
||||
image: redis:7-alpine
|
||||
command: redis-server --requirepass ${REDIS_PASSWORD}
|
||||
volumes:
|
||||
- redis-data:/data
|
||||
healthcheck:
|
||||
test: ["CMD", "redis-cli", "--raw", "incr", "ping"]
|
||||
interval: 30s
|
||||
timeout: 3s
|
||||
retries: 3
|
||||
|
||||
traefik:
|
||||
image: traefik:v3.2
|
||||
deploy:
|
||||
replicas: 3
|
||||
labels:
|
||||
- "traefik.http.middlewares.oidc.plugin.traefikoidc.redis.enabled=true"
|
||||
- "traefik.http.middlewares.oidc.plugin.traefikoidc.redis.address=redis:6379"
|
||||
- "traefik.http.middlewares.oidc.plugin.traefikoidc.redis.password=${REDIS_PASSWORD}"
|
||||
- "traefik.http.middlewares.oidc.plugin.traefikoidc.redis.cacheMode=hybrid"
|
||||
depends_on:
|
||||
redis:
|
||||
condition: service_healthy
|
||||
|
||||
volumes:
|
||||
redis-data:
|
||||
```
|
||||
|
||||
### Kubernetes
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-with-redis
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://accounts.google.com
|
||||
clientID: your-client-id
|
||||
clientSecret: your-client-secret
|
||||
sessionEncryptionKey: your-encryption-key
|
||||
callbackURL: /oauth2/callback
|
||||
redis:
|
||||
enabled: true
|
||||
address: "redis-service.redis-namespace:6379"
|
||||
password: "urn:k8s:secret:redis-secret:password"
|
||||
db: 0
|
||||
keyPrefix: "traefikoidc:"
|
||||
cacheMode: "hybrid"
|
||||
poolSize: 20
|
||||
enableCircuitBreaker: true
|
||||
circuitBreakerThreshold: 5
|
||||
```
|
||||
|
||||
### AWS ElastiCache
|
||||
|
||||
```yaml
|
||||
redis:
|
||||
enabled: true
|
||||
address: "your-cache.abc123.cache.amazonaws.com:6379"
|
||||
cacheMode: "hybrid"
|
||||
enableTLS: true
|
||||
password: "your-elasticache-auth-token"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Performance Tuning
|
||||
|
||||
### Connection Pool Sizing
|
||||
|
||||
```yaml
|
||||
redis:
|
||||
poolSize: 20 # Formula: 2 * CPU cores * replicas
|
||||
# For 4 cores, 3 replicas: poolSize = 24
|
||||
```
|
||||
|
||||
### TTL Strategy
|
||||
|
||||
The plugin automatically sets TTLs based on token lifetimes:
|
||||
|
||||
- **JTI Cache**: Matches token lifetime (typically 1 hour)
|
||||
- **Session**: Matches `sessionMaxAge` configuration
|
||||
- **Token Metadata**: 5 minutes (short-lived)
|
||||
|
||||
### Redis Server Configuration
|
||||
|
||||
```bash
|
||||
# Recommended Redis settings for cache
|
||||
maxmemory 512mb
|
||||
maxmemory-policy allkeys-lru # Evict least recently used
|
||||
|
||||
# For cache data, disable persistence for better performance
|
||||
save ""
|
||||
appendonly no
|
||||
```
|
||||
|
||||
### Hybrid Mode Tuning
|
||||
|
||||
```yaml
|
||||
redis:
|
||||
cacheMode: "hybrid"
|
||||
hybridL1Size: 500 # Max items in local cache
|
||||
hybridL1MemoryMB: 10 # Max memory for local cache
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Monitoring
|
||||
|
||||
### Key Metrics
|
||||
|
||||
- **Cache hit rate** (target: >90% for hybrid mode)
|
||||
- **Redis latency** (target: <10ms p99)
|
||||
- **Circuit breaker state**
|
||||
- **Connection pool utilization
|
||||
|
||||
### Redis Commands for Monitoring
|
||||
|
||||
```bash
|
||||
# Monitor commands in real-time
|
||||
redis-cli MONITOR
|
||||
|
||||
# Check slow queries
|
||||
redis-cli SLOWLOG GET 10
|
||||
|
||||
# Memory usage
|
||||
redis-cli INFO memory
|
||||
|
||||
# Key statistics
|
||||
redis-cli DBSIZE
|
||||
|
||||
# List keys with prefix
|
||||
redis-cli --scan --pattern "traefikoidc:*"
|
||||
|
||||
# Check key TTL
|
||||
redis-cli TTL "traefikoidc:session:abc123"
|
||||
```
|
||||
|
||||
### Health Check Endpoint
|
||||
|
||||
The plugin provides health information including:
|
||||
|
||||
```json
|
||||
{
|
||||
"status": "healthy",
|
||||
"cache": {
|
||||
"mode": "hybrid",
|
||||
"redis": {
|
||||
"connected": true,
|
||||
"latency": "2ms"
|
||||
},
|
||||
"circuit_breaker": {
|
||||
"state": "closed",
|
||||
"failures": 0
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Connection Refused
|
||||
|
||||
**Symptoms:** `dial tcp: connection refused`
|
||||
|
||||
**Solutions:**
|
||||
1. Verify Redis is running: `redis-cli ping`
|
||||
2. Check network connectivity: `telnet redis-host 6379`
|
||||
3. Verify address configuration
|
||||
|
||||
### Authentication Failure
|
||||
|
||||
**Symptoms:** `NOAUTH Authentication required`
|
||||
|
||||
**Solutions:**
|
||||
1. Set Redis password in configuration
|
||||
2. Verify password is correct
|
||||
|
||||
### Circuit Breaker Open
|
||||
|
||||
**Symptoms:** `Circuit breaker is open`, falling back to memory
|
||||
|
||||
**Solutions:**
|
||||
1. Check Redis health: `redis-cli INFO server`
|
||||
2. Review network latency: `redis-cli --latency`
|
||||
3. Adjust circuit breaker thresholds if needed
|
||||
|
||||
### High Memory Usage
|
||||
|
||||
**Symptoms:** Redis memory constantly growing, OOM errors
|
||||
|
||||
**Solutions:**
|
||||
1. Configure eviction policy:
|
||||
```bash
|
||||
CONFIG SET maxmemory 512mb
|
||||
CONFIG SET maxmemory-policy allkeys-lru
|
||||
```
|
||||
2. Review key count: `redis-cli DBSIZE`
|
||||
3. Check for large keys: `redis-cli --bigkeys`
|
||||
|
||||
### Inconsistent Cache State
|
||||
|
||||
**Symptoms:** Different responses from different replicas
|
||||
|
||||
**Solutions:**
|
||||
1. Verify all instances use the same Redis address
|
||||
2. Check cache mode consistency across instances
|
||||
3. Verify time synchronization on all hosts
|
||||
|
||||
---
|
||||
|
||||
## Migration Guide
|
||||
|
||||
### From Memory-Only to Redis
|
||||
|
||||
#### Phase 1: Preparation
|
||||
|
||||
1. Deploy Redis infrastructure
|
||||
2. Test Redis connectivity
|
||||
3. Configure monitoring
|
||||
|
||||
#### Phase 2: Gradual Rollout
|
||||
|
||||
1. Enable Redis on one instance:
|
||||
```yaml
|
||||
redis:
|
||||
enabled: true
|
||||
address: "redis:6379"
|
||||
cacheMode: "hybrid"
|
||||
```
|
||||
2. Monitor for errors
|
||||
3. Gradually enable on more instances
|
||||
|
||||
#### Phase 3: Full Migration
|
||||
|
||||
1. Enable Redis on all instances
|
||||
2. Remove `disableReplayDetection: true` if set
|
||||
3. Monitor for issues
|
||||
|
||||
### Rollback Plan
|
||||
|
||||
If issues occur:
|
||||
1. Set `redis.enabled: false`
|
||||
2. Plugin falls back to memory cache automatically
|
||||
3. Investigate and resolve issues
|
||||
|
||||
### Migration Checklist
|
||||
|
||||
- [ ] Redis deployed and accessible
|
||||
- [ ] Redis password configured
|
||||
- [ ] Network connectivity verified
|
||||
- [ ] Monitoring configured
|
||||
- [ ] Backup plan prepared
|
||||
- [ ] Test environment validated
|
||||
- [ ] Gradual rollout planned
|
||||
|
||||
---
|
||||
|
||||
## Best Practices
|
||||
|
||||
### Security
|
||||
|
||||
- Always use Redis password authentication
|
||||
- Enable TLS for production deployments
|
||||
- Use network segmentation (private subnets)
|
||||
- Rotate Redis passwords regularly
|
||||
|
||||
### High Availability
|
||||
|
||||
- Use Redis Sentinel or Cluster for HA
|
||||
- Configure appropriate circuit breaker thresholds
|
||||
- Implement proper health checks
|
||||
- Use connection pooling
|
||||
|
||||
### Performance
|
||||
|
||||
- Use hybrid cache mode for best performance
|
||||
- Monitor cache hit rates
|
||||
- Size Redis memory appropriately
|
||||
- Disable persistence for cache-only usage
|
||||
|
||||
### Operations
|
||||
|
||||
- Implement comprehensive monitoring
|
||||
- Set up alerting for circuit breaker state
|
||||
- Document Redis configuration
|
||||
- Test failover scenarios
|
||||
|
||||
---
|
||||
|
||||
## FAQ
|
||||
|
||||
### Is Redis required?
|
||||
|
||||
No, Redis is optional. The plugin works with in-memory cache for single-instance deployments.
|
||||
|
||||
### What happens if Redis goes down?
|
||||
|
||||
The circuit breaker opens after threshold failures, and the plugin falls back to in-memory cache. It periodically attempts to reconnect.
|
||||
|
||||
### Which cache mode should I use?
|
||||
|
||||
For production multi-replica deployments, use `hybrid` mode for best performance and consistency.
|
||||
|
||||
### How much memory does Redis need?
|
||||
|
||||
Depends on active sessions and token sizes:
|
||||
- Small (1-1000 users): 128MB
|
||||
- Medium (1000-10000 users): 256-512MB
|
||||
- Large (10000+ users): 1GB+
|
||||
|
||||
### Can I use managed Redis services?
|
||||
|
||||
Yes, the plugin works with AWS ElastiCache, Azure Cache for Redis, Google Cloud Memorystore, and Redis Enterprise Cloud.
|
||||
|
||||
### Is data encrypted in Redis?
|
||||
|
||||
Session data is encrypted before storing using `sessionEncryptionKey`. Additionally, you can enable TLS for Redis connections.
|
||||
-1125
File diff suppressed because it is too large
Load Diff
@@ -1,413 +0,0 @@
|
||||
# Redis Cache Backend Test Suite
|
||||
|
||||
## Overview
|
||||
|
||||
This document describes the comprehensive test suite created for the Redis cache backend feature in the Traefik OIDC plugin. The test suite ensures reliability, performance, and correctness of the caching infrastructure.
|
||||
|
||||
## Test Structure
|
||||
|
||||
### Directory Organization
|
||||
|
||||
```
|
||||
internal/cache/
|
||||
├── backend/
|
||||
│ ├── interface.go # CacheBackend interface definition
|
||||
│ ├── interface_test.go # Contract tests for all backends
|
||||
│ ├── memory.go # In-memory backend implementation
|
||||
│ ├── memory_test.go # Memory backend unit tests
|
||||
│ ├── redis.go # Redis backend implementation
|
||||
│ ├── redis_test.go # Redis backend unit tests
|
||||
│ ├── errors.go # Error definitions
|
||||
│ └── test_helpers_test.go # Test infrastructure and helpers
|
||||
│
|
||||
└── resilience/
|
||||
├── circuit_breaker.go # Circuit breaker implementation
|
||||
├── circuit_breaker_test.go # Circuit breaker tests
|
||||
├── health_check.go # Health checker implementation
|
||||
└── health_check_test.go # Health check tests
|
||||
|
||||
redis_integration_test.go # End-to-end integration tests
|
||||
```
|
||||
|
||||
## Test Categories
|
||||
|
||||
### 1. Interface Contract Tests (`interface_test.go`)
|
||||
|
||||
**Purpose:** Ensure all backend implementations (Memory, Redis, Hybrid) comply with the CacheBackend interface contract.
|
||||
|
||||
**Test Cases:**
|
||||
- `TestCacheBackendContract` - Runs all contract tests against each backend type
|
||||
- `testBasicSetGet` - Verifies basic set/get operations
|
||||
- `testGetNonExistent` - Tests behavior for non-existent keys
|
||||
- `testUpdateExisting` - Validates updating existing keys
|
||||
- `testDelete` - Tests delete operations
|
||||
- `testDeleteNonExistent` - Delete non-existent keys
|
||||
- `testExists` - Key existence checking
|
||||
- `testTTLExpiration` - TTL and expiration behavior
|
||||
- `testClear` - Clear all keys operation
|
||||
- `testPing` - Health check functionality
|
||||
- `testStats` - Statistics tracking
|
||||
- `testConcurrentAccess` - Thread safety with 10+ goroutines
|
||||
- `testLargeValues` - Handling of 1MB+ values
|
||||
- `testEmptyValues` - Empty byte array handling
|
||||
- `testSpecialCharactersInKeys` - Special characters in key names
|
||||
|
||||
**Coverage:** ~95% of interface methods
|
||||
|
||||
### 2. Memory Backend Tests (`memory_test.go`)
|
||||
|
||||
**Purpose:** Test the in-memory LRU cache backend with comprehensive edge cases.
|
||||
|
||||
**Test Cases:**
|
||||
|
||||
#### Basic Operations (6 tests)
|
||||
- `TestMemoryBackend_BasicOperations` - CRUD operations
|
||||
- SetAndGet
|
||||
- GetNonExistent
|
||||
- Delete
|
||||
- DeleteNonExistent
|
||||
- Exists
|
||||
- Clear
|
||||
|
||||
#### TTL and Expiration (3 tests)
|
||||
- `TestMemoryBackend_TTLExpiration`
|
||||
- ShortTTL (100ms)
|
||||
- TTLDecrement over time
|
||||
- CleanupExpiredItems
|
||||
|
||||
#### LRU Eviction (2 tests)
|
||||
- `TestMemoryBackend_LRUEviction` - Verifies LRU algorithm
|
||||
- `TestMemoryBackend_MemoryLimit` - Memory-based eviction
|
||||
|
||||
#### Concurrency (1 test)
|
||||
- `TestMemoryBackend_ConcurrentAccess` - 20 goroutines, 50 iterations each
|
||||
|
||||
#### Edge Cases (6 tests)
|
||||
- `TestMemoryBackend_UpdateExisting` - Overwriting values
|
||||
- `TestMemoryBackend_Stats` - Metrics tracking (hits, misses, hit rate)
|
||||
- `TestMemoryBackend_EmptyValues` - Zero-length byte arrays
|
||||
- `TestMemoryBackend_LargeValues` - 1MB values
|
||||
- `TestMemoryBackend_Close` - Proper cleanup
|
||||
- `TestMemoryBackend_Ping` - Health checks
|
||||
- `TestMemoryBackend_ValueIsolation` - Returns copies, not references
|
||||
|
||||
**Coverage:** ~92% of memory backend code
|
||||
|
||||
### 3. Redis Backend Tests (`redis_test.go`)
|
||||
|
||||
**Purpose:** Test Redis backend using miniredis (in-memory Redis mock).
|
||||
|
||||
**Test Cases:**
|
||||
|
||||
#### Basic Operations (4 tests)
|
||||
- `TestRedisBackend_BasicOperations`
|
||||
- SetAndGet
|
||||
- GetNonExistent
|
||||
- Delete
|
||||
- Exists
|
||||
|
||||
#### Redis-Specific Features (6 tests)
|
||||
- `TestRedisBackend_KeyPrefixing` - Namespace isolation
|
||||
- `TestRedisBackend_TTLExpiration` - Redis TTL handling
|
||||
- `TestRedisBackend_Clear` - Bulk delete with SCAN
|
||||
- `TestRedisBackend_NoPrefix` - Operation without prefix
|
||||
|
||||
#### Error Handling (2 tests)
|
||||
- `TestRedisBackend_ConnectionFailure` - Connection errors
|
||||
- `TestRedisBackend_RedisErrors` - Simulated Redis failures
|
||||
|
||||
#### Concurrency (1 test)
|
||||
- `TestRedisBackend_ConcurrentAccess` - 20 goroutines, 50 operations
|
||||
|
||||
#### Advanced Features (3 tests)
|
||||
- `TestRedisBackend_PipelineOperations`
|
||||
- SetMany (batch writes)
|
||||
- GetMany (batch reads)
|
||||
- GetManyWithNonExistent
|
||||
|
||||
#### Edge Cases (5 tests)
|
||||
- `TestRedisBackend_Stats` - Statistics tracking
|
||||
- `TestRedisBackend_Ping` - Connection health
|
||||
- `TestRedisBackend_Close` - Resource cleanup
|
||||
- `TestRedisBackend_UpdateExisting` - Overwrite handling
|
||||
- `TestRedisBackend_LargeValues` - 1MB values
|
||||
- `TestRedisBackend_EmptyValues` - Empty arrays
|
||||
|
||||
**Coverage:** ~88% of Redis backend code
|
||||
|
||||
**Key Testing Tool:** `miniredis` - In-memory Redis mock that supports:
|
||||
- All basic Redis commands
|
||||
- TTL and expiration
|
||||
- Time manipulation (FastForward)
|
||||
- Error simulation
|
||||
- No external Redis server required
|
||||
|
||||
### 4. Circuit Breaker Tests (`circuit_breaker_test.go`)
|
||||
|
||||
**Purpose:** Verify circuit breaker pattern implementation for fault tolerance.
|
||||
|
||||
**Test Cases:**
|
||||
|
||||
#### State Transitions (5 tests)
|
||||
- `TestCircuitBreaker_StateTransitions`
|
||||
- Initial state (Closed)
|
||||
- Closed → Open (after max failures)
|
||||
- Open → HalfOpen (after timeout)
|
||||
- HalfOpen → Closed (after successful requests)
|
||||
- HalfOpen → Open (on failure)
|
||||
|
||||
#### Behavior Tests (5 tests)
|
||||
- `TestCircuitBreaker_OpenCircuitBlocks` - Blocks requests when open
|
||||
- `TestCircuitBreaker_HalfOpenMaxRequests` - Limits requests in half-open
|
||||
- `TestCircuitBreaker_SuccessResetsFailures` - Failure counter reset
|
||||
- `TestCircuitBreaker_ConcurrentAccess` - Thread safety
|
||||
- `TestCircuitBreaker_Stats` - Statistics tracking
|
||||
|
||||
#### Advanced Tests (7 tests)
|
||||
- `TestCircuitBreaker_Reset` - Manual reset
|
||||
- `TestCircuitBreaker_StateChangeCallback` - Notifications
|
||||
- `TestCircuitBreaker_IsAvailable` - Availability check
|
||||
- `TestCircuitBreaker_RapidFailures` - Fast consecutive failures
|
||||
- `TestCircuitBreaker_TimeoutAccuracy` - Timeout precision
|
||||
- `TestCircuitBreaker_DefaultConfig` - Default configuration
|
||||
- `TestCircuitBreaker_StateString` - String representation
|
||||
|
||||
**Benchmarks:**
|
||||
- `BenchmarkCircuitBreaker_Execute` - Successful operations
|
||||
- `BenchmarkCircuitBreaker_ExecuteWithFailures` - Mixed success/failure
|
||||
|
||||
**Coverage:** ~95% of circuit breaker code
|
||||
|
||||
### 5. Health Check Tests (`health_check_test.go`)
|
||||
|
||||
**Purpose:** Validate periodic health checking and status management.
|
||||
|
||||
**Test Cases:**
|
||||
|
||||
#### Status Transitions (4 tests)
|
||||
- `TestHealthChecker_StatusTransitions` - Healthy → Degraded → Unhealthy → Healthy
|
||||
- `TestHealthChecker_InitialState` - Default healthy state
|
||||
- `TestHealthChecker_ForceCheck` - Manual health check trigger
|
||||
- `TestHealthChecker_StatusChangeCallback` - Change notifications
|
||||
|
||||
#### Behavior Tests (6 tests)
|
||||
- `TestHealthChecker_Stats` - Statistics tracking
|
||||
- `TestHealthChecker_Timeout` - Check timeout handling
|
||||
- `TestHealthChecker_ConcurrentAccess` - Thread safety
|
||||
- `TestHealthChecker_StopAndStart` - Lifecycle management
|
||||
- `TestHealthChecker_DegradedState` - Degraded status detection
|
||||
- `TestHealthChecker_DefaultConfig` - Default settings
|
||||
|
||||
#### Advanced Tests (2 tests)
|
||||
- `TestHealthChecker_StatusString` - String representation
|
||||
- `TestHealthChecker_RecoveryPattern` - Typical failure/recovery cycle
|
||||
|
||||
**Benchmarks:**
|
||||
- `BenchmarkHealthChecker_ForceCheck` - Check performance
|
||||
- `BenchmarkHealthChecker_Status` - Status read performance
|
||||
|
||||
**Coverage:** ~90% of health checker code
|
||||
|
||||
### 6. Integration Tests (`redis_integration_test.go`)
|
||||
|
||||
**Purpose:** End-to-end testing of real-world scenarios.
|
||||
|
||||
**Test Cases:**
|
||||
|
||||
#### Multi-Instance Tests (3 tests)
|
||||
- `TestRedisIntegration_MultipleInstances`
|
||||
- ShareTokenBlacklist - JTI sharing across Traefik replicas
|
||||
- ShareTokenCache - Token cache sharing
|
||||
- ShareMetadataCache - Provider metadata sharing
|
||||
|
||||
#### Replay Detection (2 tests)
|
||||
- `TestRedisIntegration_JTIReplayDetection`
|
||||
- PreventReplayAcrossInstances - Block used JTIs
|
||||
- ConcurrentJTIChecks - Race condition handling
|
||||
|
||||
#### Resilience (1 test)
|
||||
- `TestRedisIntegration_Failover`
|
||||
- RedisTemporaryFailure - Recovery from temporary failures
|
||||
|
||||
#### Performance (1 test)
|
||||
- `TestRedisIntegration_HighLoad`
|
||||
- HighConcurrency - 50 goroutines × 100 operations
|
||||
|
||||
#### Consistency (2 tests)
|
||||
- `TestRedisIntegration_TTLConsistency` - TTL accuracy
|
||||
- `TestRedisIntegration_MemoryUsage` - 10,000 item dataset
|
||||
- `TestRedisIntegration_Cleanup` - Bulk cleanup operations
|
||||
|
||||
**Coverage:** Integration scenarios covering 80%+ of realistic use cases
|
||||
|
||||
## Test Helpers and Infrastructure
|
||||
|
||||
### Test Helpers (`test_helpers_test.go`)
|
||||
|
||||
**Utilities:**
|
||||
- `TestLogger` - Logging for tests
|
||||
- `MiniredisServer` - Miniredis setup/teardown
|
||||
- `TestConfig` - Default test configurations
|
||||
- `GenerateTestData` - Test data generation
|
||||
- `GenerateLargeValue` - Large value creation
|
||||
- `AssertCacheStats` - Statistics validation
|
||||
- `WaitForCondition` - Async condition waiting
|
||||
- `AssertEventuallyExpires` - TTL expiration verification
|
||||
|
||||
## Running the Tests
|
||||
|
||||
### Run All Tests
|
||||
```bash
|
||||
go test ./internal/cache/backend/... -v
|
||||
go test ./internal/cache/resilience/... -v
|
||||
go test -run TestRedisIntegration -v
|
||||
```
|
||||
|
||||
### Run Specific Test Suites
|
||||
```bash
|
||||
# Memory backend only
|
||||
go test ./internal/cache/backend -run TestMemoryBackend -v
|
||||
|
||||
# Redis backend only
|
||||
go test ./internal/cache/backend -run TestRedisBackend -v
|
||||
|
||||
# Circuit breaker only
|
||||
go test ./internal/cache/resilience -run TestCircuitBreaker -v
|
||||
|
||||
# Integration tests only
|
||||
go test -run TestRedisIntegration -v
|
||||
```
|
||||
|
||||
### Run with Coverage
|
||||
```bash
|
||||
go test ./internal/cache/backend/... -coverprofile=coverage.out
|
||||
go test ./internal/cache/resilience/... -coverprofile=coverage_resilience.out
|
||||
go tool cover -html=coverage.out
|
||||
```
|
||||
|
||||
### Run Benchmarks
|
||||
```bash
|
||||
go test ./internal/cache/backend -bench=. -benchmem
|
||||
go test ./internal/cache/resilience -bench=. -benchmem
|
||||
```
|
||||
|
||||
### Run with Race Detector
|
||||
```bash
|
||||
go test ./internal/cache/... -race -v
|
||||
```
|
||||
|
||||
## Test Patterns Used
|
||||
|
||||
### 1. Table-Driven Tests
|
||||
Used for testing multiple scenarios with similar structure.
|
||||
|
||||
### 2. Subtests (t.Run)
|
||||
Organized test cases into logical groups with clear names.
|
||||
|
||||
### 3. Parallel Tests
|
||||
Tests marked with `t.Parallel()` for faster execution.
|
||||
|
||||
### 4. Test Fixtures
|
||||
Reusable setup functions for common test data.
|
||||
|
||||
### 5. Mocking
|
||||
- `miniredis` for Redis operations
|
||||
- Mock functions for callbacks and health checks
|
||||
|
||||
### 6. Assertion Helpers
|
||||
Using `testify/assert` and `testify/require` for clear assertions.
|
||||
|
||||
## Test Coverage Summary
|
||||
|
||||
| Component | Coverage | Tests | Lines of Code |
|
||||
|-----------|----------|-------|---------------|
|
||||
| Interface Contract | 95% | 14 | ~200 |
|
||||
| Memory Backend | 92% | 18 | ~350 |
|
||||
| Redis Backend | 88% | 21 | ~400 |
|
||||
| Circuit Breaker | 95% | 17 | ~250 |
|
||||
| Health Checker | 90% | 12 | ~200 |
|
||||
| Integration Tests | 80% | 9 | ~300 |
|
||||
| **Total** | **90%** | **91** | **~1,700** |
|
||||
|
||||
## Edge Cases Tested
|
||||
|
||||
1. **Empty values** - Zero-length byte arrays
|
||||
2. **Large values** - 1MB+ data
|
||||
3. **Special characters** - Keys with :, /, -, _, ., |
|
||||
4. **Concurrent access** - 10-50 goroutines
|
||||
5. **TTL edge cases** - Very short (<100ms) and long (24h+) TTLs
|
||||
6. **Connection failures** - Network errors, timeouts
|
||||
7. **Redis errors** - Simulated Redis failures
|
||||
8. **Memory limits** - Eviction under memory pressure
|
||||
9. **Race conditions** - Concurrent JTI checks
|
||||
10. **State transitions** - All circuit breaker and health check states
|
||||
|
||||
## Performance Benchmarks
|
||||
|
||||
Benchmarks included for:
|
||||
- Cache operations (Set, Get, Delete)
|
||||
- Circuit breaker execution
|
||||
- Health check operations
|
||||
- Concurrent access patterns
|
||||
- Large datasets (10,000+ items)
|
||||
|
||||
## Dependencies
|
||||
|
||||
### Testing Libraries
|
||||
- `github.com/stretchr/testify` - Assertions and test utilities
|
||||
- `github.com/alicebob/miniredis/v2` - In-memory Redis mock
|
||||
- `github.com/redis/go-redis/v9` - Redis client
|
||||
|
||||
### Why Miniredis?
|
||||
- **No external dependencies** - No Redis server required
|
||||
- **Fast** - In-memory, perfect for unit tests
|
||||
- **Full Redis API** - Supports all operations we need
|
||||
- **Time manipulation** - FastForward for TTL testing
|
||||
- **Error simulation** - Test failure scenarios
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
### Planned Tests
|
||||
1. Hybrid backend tests (L1/L2 cache)
|
||||
2. Network partition scenarios
|
||||
3. Redis cluster support
|
||||
4. Persistence and recovery tests
|
||||
5. Metrics and monitoring integration
|
||||
|
||||
### Test Infrastructure Improvements
|
||||
1. Test containers for real Redis integration
|
||||
2. Performance regression tracking
|
||||
3. Chaos engineering tests
|
||||
4. Load testing framework
|
||||
|
||||
## Continuous Integration
|
||||
|
||||
### Recommended CI Configuration
|
||||
|
||||
```yaml
|
||||
test:
|
||||
script:
|
||||
- go test ./internal/cache/... -race -cover -v
|
||||
- go test -run TestRedisIntegration -v
|
||||
- go test ./internal/cache/... -bench=. -benchmem
|
||||
```
|
||||
|
||||
## Maintenance Guidelines
|
||||
|
||||
1. **Add tests for new features** - Maintain >85% coverage
|
||||
2. **Update contract tests** - When interface changes
|
||||
3. **Test edge cases** - Always test error paths
|
||||
4. **Document test purpose** - Clear comments explaining what each test validates
|
||||
5. **Keep tests fast** - Use t.Parallel() where possible
|
||||
6. **Mock external dependencies** - Use miniredis, not real Redis
|
||||
|
||||
## Conclusion
|
||||
|
||||
This comprehensive test suite provides:
|
||||
- **High confidence** in cache backend correctness
|
||||
- **Fast feedback** - Tests run in seconds
|
||||
- **Good coverage** - 90% overall
|
||||
- **Clear documentation** - Each test is well-documented
|
||||
- **Maintainability** - Clear structure and patterns
|
||||
|
||||
The test suite ensures that the Redis cache backend feature is production-ready and reliable for multi-replica Traefik deployments with shared caching requirements.
|
||||
+390
@@ -0,0 +1,390 @@
|
||||
# Testing Guide
|
||||
|
||||
Comprehensive testing infrastructure for traefikoidc.
|
||||
|
||||
## Overview
|
||||
|
||||
| Metric | Value |
|
||||
|--------|-------|
|
||||
| Test files | 99 |
|
||||
| Lines of test code | ~65,500 |
|
||||
| Code coverage | 71.0% |
|
||||
| Race conditions | None (all pass with `-race`) |
|
||||
|
||||
## Running Tests
|
||||
|
||||
```bash
|
||||
# Run all tests
|
||||
go test ./...
|
||||
|
||||
# Run with race detection
|
||||
go test -race ./...
|
||||
|
||||
# Run with coverage
|
||||
go test -cover ./...
|
||||
|
||||
# Run specific test suite
|
||||
go test -v -run "TokenValidationSuite" .
|
||||
|
||||
# Run edge case tests
|
||||
go test -v -run "ClockSkewEdgeCasesSuite|UnicodeClaimsSuite" .
|
||||
```
|
||||
|
||||
## Test Infrastructure
|
||||
|
||||
### Directory Structure
|
||||
|
||||
```
|
||||
internal/testutil/
|
||||
├── compat.go # Re-exports for main package access
|
||||
├── mocks/
|
||||
│ ├── interfaces.go # JWKCache, TokenExchanger, TokenVerifier, etc.
|
||||
│ ├── session.go # SessionManager, SessionData
|
||||
│ ├── cache.go # Cache, TokenCache, Blacklist
|
||||
│ └── interfaces_test.go # Mock verification tests
|
||||
├── fixtures/
|
||||
│ └── tokens.go # JWT token generation fixtures
|
||||
└── servers/
|
||||
├── oidc.go # Mock OIDC server factory
|
||||
└── oidc_test.go # Server tests
|
||||
```
|
||||
|
||||
### Test Suites
|
||||
|
||||
| Suite | File | Description |
|
||||
|-------|------|-------------|
|
||||
| TokenValidationSuite | `token_validation_suite_test.go` | Token validation happy path and error cases |
|
||||
| JWKCacheTestSuite | `token_validation_suite_test.go` | JWK cache behavior tests |
|
||||
| TokenExchangerTestSuite | `token_validation_suite_test.go` | Token exchange scenarios |
|
||||
| ClockSkewEdgeCasesSuite | `edge_cases_suite_test.go` | Expiry boundary testing |
|
||||
| UnicodeClaimsSuite | `edge_cases_suite_test.go` | Unicode/emoji handling in claims |
|
||||
| LargeClaimsSuite | `edge_cases_suite_test.go` | Large data handling (100s of claims) |
|
||||
| URLPathEdgeCasesSuite | `edge_cases_suite_test.go` | URL parsing edge cases |
|
||||
| ConcurrencyEdgeCasesSuite | `edge_cases_suite_test.go` | Concurrent token validation |
|
||||
| ExampleTestSuite | `testutil_example_test.go` | Example demonstrating patterns |
|
||||
| AuthFlowBehaviourSuite | `auth_flow_behaviour_test.go` | Authentication flow behavior tests |
|
||||
| SessionBehaviourSuite | `session_behaviour_test.go` | Session management behavior tests |
|
||||
| EnhancedMocksSuite | `enhanced_mocks_suite_test.go` | Enhanced mock usage demonstration |
|
||||
|
||||
## Mock Types
|
||||
|
||||
The project provides two mocking patterns:
|
||||
|
||||
### State-Based Mocks (Basic)
|
||||
|
||||
Located in `main_test.go`, `mocks_test.go`. Simple mocks that store data in struct fields.
|
||||
|
||||
| Mock | Interface | Description |
|
||||
|------|-----------|-------------|
|
||||
| `MockJWKCache` | `JWKCacheInterface` | Simple state-based mock with JWKS/Err fields |
|
||||
| `MockTokenVerifier` | `TokenVerifier` | Function-based mock for token verification |
|
||||
| `MockTokenExchanger` | `TokenExchanger` | Function-based mock for token exchange |
|
||||
| `MockOAuthProvider` | `http.Handler` | Full HTTP handler mock for OAuth provider simulation |
|
||||
| `MockSessionManager` | `SessionManager` | State-based mock for session management |
|
||||
| `MockHTTPClient` | N/A | Mock HTTP client with customizable responses |
|
||||
|
||||
**Usage:**
|
||||
```go
|
||||
mock := &MockJWKCache{
|
||||
JWKS: &JWKSet{Keys: []JWK{jwk}},
|
||||
Err: nil,
|
||||
}
|
||||
tOidc := &TraefikOidc{
|
||||
jwkCache: mock,
|
||||
// ...
|
||||
}
|
||||
```
|
||||
|
||||
### Enhanced State-Based Mocks (with Call Tracking)
|
||||
|
||||
Located in `enhanced_mocks_test.go`. State-based mocks with built-in call tracking and assertion helpers.
|
||||
|
||||
| Mock | Interface | Description |
|
||||
|------|-----------|-------------|
|
||||
| `EnhancedMockJWKCache` | `JWKCacheInterface` | State-based with call tracking |
|
||||
| `EnhancedMockTokenVerifier` | `TokenVerifier` | State-based with call tracking |
|
||||
| `EnhancedMockTokenExchanger` | `TokenExchanger` | State-based with call tracking |
|
||||
| `EnhancedMockCacheInterface` | `CacheInterface` | Functional cache with call tracking |
|
||||
|
||||
**Usage:**
|
||||
```go
|
||||
mock := &EnhancedMockJWKCache{
|
||||
JWKS: &JWKSet{Keys: []JWK{jwk}},
|
||||
}
|
||||
|
||||
// Make calls
|
||||
result, err := mock.GetJWKS(ctx, "https://example.com/jwks", nil)
|
||||
|
||||
// Verify calls were made
|
||||
mock.AssertGetJWKSCalled(t)
|
||||
mock.AssertGetJWKSCalledWith(t, "https://example.com/jwks")
|
||||
mock.AssertGetJWKSCallCount(t, 1)
|
||||
|
||||
// Access call details
|
||||
s.Equal(1, mock.GetJWKSCallCount())
|
||||
```
|
||||
|
||||
**Features:**
|
||||
- Track all calls with parameters and timestamps
|
||||
- Built-in assertion helpers using testify
|
||||
- Thread-safe for concurrent tests
|
||||
- `Reset()` method to clear state between tests
|
||||
- `LastCall()` to inspect most recent call
|
||||
|
||||
### Testify-Based Mocks
|
||||
|
||||
Located in `testify_mocks_test.go`. Mocks using testify's `.On()/.Return()` pattern for behavior verification.
|
||||
|
||||
| Mock | Interface | Description |
|
||||
|------|-----------|-------------|
|
||||
| `TestifyJWKCache` | `JWKCacheInterface` | Testify mock with `.On()/.Return()` |
|
||||
| `TestifyTokenVerifier` | `TokenVerifier` | Testify mock for token verification |
|
||||
| `TestifyTokenExchanger` | `TokenExchanger` | Testify mock for token exchange |
|
||||
| `TestifyCacheInterface` | `CacheInterface` | Testify mock for cache operations |
|
||||
| `TestifyHTTPClient` | N/A | Testify mock for HTTP client |
|
||||
| `TestifyRoundTripper` | `http.RoundTripper` | Testify mock for HTTP transport |
|
||||
|
||||
**Usage:**
|
||||
```go
|
||||
mock := &TestifyJWKCache{}
|
||||
mock.On("GetJWKS", mock.Anything, "https://example.com/jwks", mock.Anything).
|
||||
Return(&JWKSet{Keys: []JWK{jwk}}, nil)
|
||||
|
||||
// After test
|
||||
mock.AssertExpectations(t)
|
||||
```
|
||||
|
||||
### Testutil Package Mocks
|
||||
|
||||
Located in `internal/testutil/mocks/`. Generic mocks for testing the test infrastructure itself.
|
||||
|
||||
```go
|
||||
import "github.com/lukaszraczylo/traefikoidc/internal/testutil"
|
||||
|
||||
mock := testutil.NewJWKCacheMock()
|
||||
mock.On("GetJWKS", mock.Anything, mock.Anything, mock.Anything).
|
||||
Return(&mocks.JWKSet{Keys: []mocks.JWK{{Kty: "RSA"}}}, nil)
|
||||
```
|
||||
|
||||
### Choosing the Right Mock
|
||||
|
||||
| Use Case | Recommended Mock |
|
||||
|----------|-----------------|
|
||||
| Simple return values only | Basic state-based (`MockJWKCache`) |
|
||||
| Return values + verify calls made | Enhanced state-based (`EnhancedMockJWKCache`) |
|
||||
| Complex call expectations | Testify-based (`TestifyJWKCache`) |
|
||||
| Verify call order/sequence | Testify-based |
|
||||
| HTTP endpoint simulation | `MockOAuthProvider` |
|
||||
| New testify suite tests | Enhanced or Testify-based |
|
||||
|
||||
**Decision Guide:**
|
||||
|
||||
1. **Basic State-Based**: Use when you only need to control return values and don't care about verifying interactions.
|
||||
|
||||
2. **Enhanced State-Based**: Use when you want to verify calls were made with specific parameters, but prefer simpler setup than testify's `.On()/.Return()` pattern.
|
||||
|
||||
3. **Testify-Based**: Use when you need complex behavior like different returns per call, strict call ordering, or detailed expectation matching.
|
||||
|
||||
## Token Fixtures
|
||||
|
||||
The `testutil.TokenFixture` generates JWT tokens for testing:
|
||||
|
||||
```go
|
||||
fixture, err := testutil.NewTokenFixture()
|
||||
|
||||
// Valid token with default claims
|
||||
token, _ := fixture.ValidToken(nil)
|
||||
|
||||
// Token with custom claims
|
||||
token, _ := fixture.ValidToken(map[string]interface{}{
|
||||
"email": "test@example.com",
|
||||
"roles": []string{"admin"},
|
||||
})
|
||||
|
||||
// Expired token
|
||||
token, _ := fixture.ExpiredToken()
|
||||
|
||||
// Token with specific roles/groups
|
||||
token, _ := fixture.TokenWithRoles([]string{"admin", "user"})
|
||||
token, _ := fixture.TokenWithGroups([]string{"developers"})
|
||||
|
||||
// Token with clock skew
|
||||
token, _ := fixture.TokenWithSkew(-2 * time.Minute) // expired 2 min ago
|
||||
token, _ := fixture.TokenWithSkew(5 * time.Minute) // expires in 5 min
|
||||
|
||||
// Token missing specific claims
|
||||
token, _ := fixture.TokenMissingClaim("email", "sub")
|
||||
|
||||
// Malformed token
|
||||
token := fixture.MalformedToken() // "not.a.valid.jwt"
|
||||
|
||||
// Get JWKS for verification
|
||||
jwks := fixture.GetJWKS()
|
||||
```
|
||||
|
||||
## Mock OIDC Server
|
||||
|
||||
The `testutil.OIDCServer` provides a fully functional mock OIDC provider:
|
||||
|
||||
```go
|
||||
// Default configuration
|
||||
server := testutil.NewOIDCServer(nil)
|
||||
defer server.Close()
|
||||
|
||||
// Custom configuration
|
||||
config := testutil.DefaultServerConfig()
|
||||
config.Issuer = "https://custom-issuer.com"
|
||||
config.TokenError = &testutil.OIDCError{
|
||||
Error: "invalid_grant",
|
||||
Description: "Authorization code expired",
|
||||
}
|
||||
server := testutil.NewOIDCServer(config)
|
||||
|
||||
// Provider-specific configurations
|
||||
googleConfig := testutil.GoogleServerConfig()
|
||||
azureConfig := testutil.AzureServerConfig()
|
||||
auth0Config := testutil.Auth0ServerConfig()
|
||||
keycloakConfig := testutil.KeycloakServerConfig()
|
||||
|
||||
// Behavior configurations
|
||||
slowConfig := testutil.SlowServerConfig(100 * time.Millisecond)
|
||||
rateLimitedConfig := testutil.RateLimitedServerConfig(5) // Limit after 5 requests
|
||||
```
|
||||
|
||||
### Server Endpoints
|
||||
|
||||
| Endpoint | Description |
|
||||
|----------|-------------|
|
||||
| `/.well-known/openid-configuration` | OIDC discovery document |
|
||||
| `/authorize` | Authorization endpoint |
|
||||
| `/token` | Token exchange endpoint |
|
||||
| `/jwks` | JSON Web Key Set |
|
||||
| `/userinfo` | User information endpoint |
|
||||
| `/introspect` | Token introspection |
|
||||
| `/revoke` | Token revocation |
|
||||
| `/logout` | End session endpoint |
|
||||
|
||||
### Request Tracking
|
||||
|
||||
```go
|
||||
server := testutil.NewOIDCServer(nil)
|
||||
|
||||
// Make requests...
|
||||
|
||||
count := server.GetRequestCount()
|
||||
requests := server.GetRequests()
|
||||
server.Reset() // Clear tracking
|
||||
```
|
||||
|
||||
## Writing Test Suites
|
||||
|
||||
### Basic Suite Structure
|
||||
|
||||
```go
|
||||
type MyTestSuite struct {
|
||||
suite.Suite
|
||||
|
||||
fixture *testutil.TokenFixture
|
||||
tOidc *TraefikOidc
|
||||
}
|
||||
|
||||
func (s *MyTestSuite) SetupSuite() {
|
||||
var err error
|
||||
s.fixture, err = testutil.NewTokenFixture()
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
func (s *MyTestSuite) SetupTest() {
|
||||
// Per-test setup
|
||||
s.tOidc = &TraefikOidc{
|
||||
issuerURL: s.fixture.Issuer,
|
||||
// ...
|
||||
}
|
||||
}
|
||||
|
||||
func (s *MyTestSuite) TearDownTest() {
|
||||
// Per-test cleanup
|
||||
}
|
||||
|
||||
func (s *MyTestSuite) TestSomething() {
|
||||
token, err := s.fixture.ValidToken(nil)
|
||||
s.Require().NoError(err)
|
||||
|
||||
err = s.tOidc.VerifyToken(token)
|
||||
s.NoError(err)
|
||||
}
|
||||
|
||||
func TestMyTestSuite(t *testing.T) {
|
||||
suite.Run(t, new(MyTestSuite))
|
||||
}
|
||||
```
|
||||
|
||||
### Table-Driven Tests
|
||||
|
||||
```go
|
||||
func (s *MyTestSuite) TestClockSkewEdgeCases() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
skew time.Duration
|
||||
shouldPass bool
|
||||
}{
|
||||
{"valid_token", 5 * time.Minute, true},
|
||||
{"expired_within_tolerance", -1 * time.Minute, true},
|
||||
{"expired_beyond_tolerance", -10 * time.Minute, false},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
token, err := s.fixture.TokenWithSkew(tc.skew)
|
||||
s.Require().NoError(err)
|
||||
|
||||
err = s.tOidc.VerifyToken(token)
|
||||
if tc.shouldPass {
|
||||
s.NoError(err)
|
||||
} else {
|
||||
s.Error(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Test Categories
|
||||
|
||||
### Happy Path Tests
|
||||
|
||||
Test the expected successful scenarios:
|
||||
|
||||
- Valid token verification
|
||||
- Successful token exchange
|
||||
- Session creation and retrieval
|
||||
- Cache operations
|
||||
|
||||
### Error Case Tests
|
||||
|
||||
Test failure scenarios:
|
||||
|
||||
- Expired tokens
|
||||
- Invalid signatures
|
||||
- Wrong issuer/audience
|
||||
- Network failures
|
||||
- Rate limiting
|
||||
|
||||
### Edge Case Tests
|
||||
|
||||
Test boundary conditions:
|
||||
|
||||
- Clock skew tolerance boundaries
|
||||
- Unicode/emoji in claims
|
||||
- Very large claim values
|
||||
- Concurrent access
|
||||
- Special characters in URLs
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Use fixtures for token generation** - Don't manually construct JWTs
|
||||
2. **Use mock servers for integration tests** - Test against realistic OIDC behavior
|
||||
3. **Always run with `-race`** - Catch concurrency issues early
|
||||
4. **Use testify assertions** - Better error messages and cleaner code
|
||||
5. **Clean up resources** - Use `t.Cleanup()` or `TearDownTest()`
|
||||
6. **Test edge cases systematically** - Use table-driven tests
|
||||
@@ -1,308 +0,0 @@
|
||||
# Test Execution Guide
|
||||
|
||||
This guide explains how to run tests efficiently with the new test categorization and optimization system.
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Fast Development Testing (Default - Target: < 30 seconds)
|
||||
```bash
|
||||
# Run quick smoke tests only
|
||||
go test ./...
|
||||
|
||||
# Or explicitly run in short mode
|
||||
go test ./... -short
|
||||
```
|
||||
|
||||
### Extended Testing (Target: 2-5 minutes)
|
||||
```bash
|
||||
# Enable extended tests with more iterations and concurrency
|
||||
RUN_EXTENDED_TESTS=1 go test ./...
|
||||
|
||||
# Or use the flag equivalent (if using test runner that supports it)
|
||||
go test ./... -extended
|
||||
```
|
||||
|
||||
### Long-Running Performance Tests (Target: 5-15 minutes)
|
||||
```bash
|
||||
# Enable comprehensive performance and stress tests
|
||||
RUN_LONG_TESTS=1 go test ./...
|
||||
```
|
||||
|
||||
### Full Stress Testing (Target: 10-30 minutes)
|
||||
```bash
|
||||
# Enable all stress tests with maximum parameters
|
||||
RUN_STRESS_TESTS=1 go test ./...
|
||||
```
|
||||
|
||||
## Test Categories
|
||||
|
||||
### 1. Quick Tests (Default)
|
||||
- **Purpose**: Fast feedback during development
|
||||
- **Duration**: < 30 seconds total
|
||||
- **Features**:
|
||||
- Basic functionality verification
|
||||
- Limited iterations (1-3)
|
||||
- Small data sets
|
||||
- Minimal concurrency
|
||||
- Essential memory leak checks
|
||||
|
||||
**Configuration**:
|
||||
- Max Iterations: 3
|
||||
- Max Concurrency: 5
|
||||
- Memory Threshold: 2.0 MB
|
||||
- Cache Size: 50
|
||||
- Timeout: 10 seconds
|
||||
|
||||
### 2. Extended Tests
|
||||
- **Purpose**: Comprehensive testing before commits
|
||||
- **Duration**: 2-5 minutes
|
||||
- **Features**:
|
||||
- Increased test coverage
|
||||
- More iterations (5-10)
|
||||
- Medium concurrency tests
|
||||
- Enhanced memory leak detection
|
||||
|
||||
**Configuration**:
|
||||
- Max Iterations: 10
|
||||
- Max Concurrency: 20
|
||||
- Memory Threshold: 10.0 MB
|
||||
- Cache Size: 200
|
||||
- Timeout: 30 seconds
|
||||
|
||||
### 3. Long Tests
|
||||
- **Purpose**: Performance validation and stress testing
|
||||
- **Duration**: 5-15 minutes
|
||||
- **Features**:
|
||||
- High iteration counts (50-100)
|
||||
- High concurrency scenarios
|
||||
- Large data sets
|
||||
- Comprehensive memory testing
|
||||
|
||||
**Configuration**:
|
||||
- Max Iterations: 100
|
||||
- Max Concurrency: 50
|
||||
- Memory Threshold: 50.0 MB
|
||||
- Cache Size: 1000
|
||||
- Timeout: 60 seconds
|
||||
|
||||
### 4. Stress Tests
|
||||
- **Purpose**: Maximum load testing and edge case validation
|
||||
- **Duration**: 10-30 minutes
|
||||
- **Features**:
|
||||
- Extreme iteration counts (100-500)
|
||||
- Maximum concurrency (100+)
|
||||
- Large memory allocations
|
||||
- Edge case combinations
|
||||
|
||||
**Configuration**:
|
||||
- Max Iterations: 500
|
||||
- Max Concurrency: 100
|
||||
- Memory Threshold: 100.0 MB
|
||||
- Cache Size: 2000
|
||||
- Timeout: 120 seconds
|
||||
|
||||
## Environment Variables
|
||||
|
||||
### Test Execution Control
|
||||
```bash
|
||||
# Enable specific test types
|
||||
export RUN_EXTENDED_TESTS=1 # Enable extended tests
|
||||
export RUN_LONG_TESTS=1 # Enable long-running tests
|
||||
export RUN_STRESS_TESTS=1 # Enable stress tests
|
||||
|
||||
# Disable specific features
|
||||
export DISABLE_LEAK_DETECTION=1 # Skip memory leak detection
|
||||
```
|
||||
|
||||
### Parameter Customization
|
||||
```bash
|
||||
# Customize concurrency limits
|
||||
export TEST_MAX_CONCURRENCY=10 # Override max concurrent operations
|
||||
|
||||
# Customize iteration limits
|
||||
export TEST_MAX_ITERATIONS=50 # Override max test iterations
|
||||
|
||||
# Customize memory thresholds
|
||||
export TEST_MEMORY_THRESHOLD_MB=25.5 # Override memory growth limit (in MB)
|
||||
```
|
||||
|
||||
## Test-Specific Behavior
|
||||
|
||||
### Memory Leak Tests
|
||||
- **Quick Mode**: 1-3 iterations, small data sets, strict memory limits
|
||||
- **Extended Mode**: 5-10 iterations, medium data sets, relaxed limits
|
||||
- **Long Mode**: 50-100 iterations, large data sets, performance focus
|
||||
- **Stress Mode**: 100-500 iterations, maximum data sets, stress focus
|
||||
|
||||
### Concurrency Tests
|
||||
- **Quick Mode**: 2-5 concurrent operations, basic race detection
|
||||
- **Extended Mode**: 10-20 concurrent operations, moderate stress
|
||||
- **Long Mode**: 20-50 concurrent operations, high contention
|
||||
- **Stress Mode**: 50-100+ concurrent operations, maximum stress
|
||||
|
||||
### Cache Tests
|
||||
- **Quick Mode**: Small caches (50 items), basic operations
|
||||
- **Extended Mode**: Medium caches (200 items), varied operations
|
||||
- **Long Mode**: Large caches (1000 items), performance testing
|
||||
- **Stress Mode**: Very large caches (2000+ items), stress testing
|
||||
|
||||
## Integration with CI/CD
|
||||
|
||||
### GitHub Actions Example
|
||||
```yaml
|
||||
# Quick tests for every push/PR
|
||||
- name: Quick Tests
|
||||
run: go test ./... -short
|
||||
|
||||
# Extended tests for main branch
|
||||
- name: Extended Tests
|
||||
if: github.ref == 'refs/heads/main'
|
||||
run: RUN_EXTENDED_TESTS=1 go test ./...
|
||||
|
||||
# Nightly comprehensive testing
|
||||
- name: Nightly Stress Tests
|
||||
if: github.event_name == 'schedule'
|
||||
run: RUN_STRESS_TESTS=1 go test ./...
|
||||
```
|
||||
|
||||
### Local Development Workflow
|
||||
```bash
|
||||
# During active development
|
||||
go test ./... -short
|
||||
|
||||
# Before committing
|
||||
RUN_EXTENDED_TESTS=1 go test ./...
|
||||
|
||||
# Before major releases
|
||||
RUN_LONG_TESTS=1 go test ./...
|
||||
|
||||
# Performance validation
|
||||
RUN_STRESS_TESTS=1 go test ./...
|
||||
```
|
||||
|
||||
## Performance Optimization Features
|
||||
|
||||
### Dynamic Test Scaling
|
||||
The test system automatically adjusts parameters based on:
|
||||
- Test mode (quick/extended/long/stress)
|
||||
- Available resources
|
||||
- Environment variables
|
||||
- Previous test performance
|
||||
|
||||
### Memory Management
|
||||
- **Garbage Collection**: Forced GC between test iterations
|
||||
- **Memory Monitoring**: Real-time memory growth tracking
|
||||
- **Leak Detection**: Goroutine and memory leak prevention
|
||||
- **Resource Cleanup**: Automatic cleanup of test resources
|
||||
|
||||
### Timeout Management
|
||||
- **Adaptive Timeouts**: Timeouts scale with test complexity
|
||||
- **Graceful Degradation**: Tests adapt to slower environments
|
||||
- **Early Termination**: Failed tests terminate quickly
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Tests Taking Too Long
|
||||
```bash
|
||||
# Check if running in extended mode accidentally
|
||||
echo $RUN_EXTENDED_TESTS $RUN_LONG_TESTS
|
||||
|
||||
# Force quick mode
|
||||
unset RUN_EXTENDED_TESTS RUN_LONG_TESTS RUN_STRESS_TESTS
|
||||
go test ./... -short
|
||||
```
|
||||
|
||||
### Memory Issues
|
||||
```bash
|
||||
# Reduce memory limits for constrained environments
|
||||
export TEST_MEMORY_THRESHOLD_MB=5.0
|
||||
export TEST_MAX_CONCURRENCY=2
|
||||
go test ./...
|
||||
```
|
||||
|
||||
### Concurrency Issues
|
||||
```bash
|
||||
# Reduce concurrency for slower systems
|
||||
export TEST_MAX_CONCURRENCY=5
|
||||
export TEST_MAX_ITERATIONS=10
|
||||
go test ./...
|
||||
```
|
||||
|
||||
### Skip Specific Test Types
|
||||
```bash
|
||||
# Skip memory leak detection if problematic
|
||||
export DISABLE_LEAK_DETECTION=1
|
||||
go test ./...
|
||||
```
|
||||
|
||||
## Benchmarking
|
||||
|
||||
### Running Benchmarks
|
||||
```bash
|
||||
# Quick benchmarks
|
||||
go test -bench=. -short
|
||||
|
||||
# Extended benchmarks
|
||||
RUN_EXTENDED_TESTS=1 go test -bench=.
|
||||
|
||||
# Memory profiling
|
||||
go test -bench=. -memprofile=mem.prof
|
||||
go tool pprof mem.prof
|
||||
```
|
||||
|
||||
### Benchmark Categories
|
||||
- **Basic Operations**: Set/Get performance
|
||||
- **Concurrency**: Multi-threaded performance
|
||||
- **Memory**: Allocation and cleanup performance
|
||||
- **Cache**: Eviction and cleanup performance
|
||||
|
||||
## Best Practices
|
||||
|
||||
### For Developers
|
||||
1. Always run quick tests during development (`go test ./... -short`)
|
||||
2. Run extended tests before committing (`RUN_EXTENDED_TESTS=1 go test ./...`)
|
||||
3. Use appropriate test categories for your use case
|
||||
4. Monitor test execution time and adjust if needed
|
||||
|
||||
### For CI/CD
|
||||
1. Use quick tests for fast feedback on PRs
|
||||
2. Use extended tests for main branch validation
|
||||
3. Use long tests for release validation
|
||||
4. Use stress tests for nightly/weekly validation
|
||||
|
||||
### For Performance Testing
|
||||
1. Use consistent environment variables
|
||||
2. Run tests multiple times for statistical significance
|
||||
3. Monitor both execution time and resource usage
|
||||
4. Use profiling tools for detailed analysis
|
||||
|
||||
## Examples
|
||||
|
||||
### Daily Development
|
||||
```bash
|
||||
# Fast tests while coding
|
||||
go test ./... -short
|
||||
|
||||
# Before git commit
|
||||
RUN_EXTENDED_TESTS=1 go test ./...
|
||||
```
|
||||
|
||||
### Release Testing
|
||||
```bash
|
||||
# Comprehensive validation
|
||||
RUN_LONG_TESTS=1 go test ./...
|
||||
|
||||
# Stress testing
|
||||
RUN_STRESS_TESTS=1 go test ./...
|
||||
```
|
||||
|
||||
### Custom Configuration
|
||||
```bash
|
||||
# Custom limits for specific environment
|
||||
export TEST_MAX_CONCURRENCY=8
|
||||
export TEST_MAX_ITERATIONS=25
|
||||
export TEST_MEMORY_THRESHOLD_MB=15.0
|
||||
RUN_EXTENDED_TESTS=1 go test ./...
|
||||
```
|
||||
|
||||
This test system provides flexible, scalable test execution that adapts to your development workflow and infrastructure constraints while maintaining comprehensive test coverage.
|
||||
@@ -1,163 +0,0 @@
|
||||
# Google OAuth Integration Fix
|
||||
|
||||
## Problem Overview
|
||||
|
||||
The Traefik OIDC plugin encountered an authentication issue when using Google as an OAuth provider. Authentication would fail with the following error:
|
||||
|
||||
```
|
||||
Some requested scopes were invalid. {valid=[openid, https://www.googleapis.com/auth/userinfo.email, https://www.googleapis.com/auth/userinfo.profile], invalid=[offline_access]}
|
||||
```
|
||||
|
||||
This occurred because Google's OAuth implementation differs from the standard OIDC specification in how it handles refresh tokens and offline access.
|
||||
|
||||
## Technical Details of the Issue
|
||||
|
||||
### Standard OIDC Provider Behavior
|
||||
|
||||
Most OpenID Connect (OIDC) providers follow the standard specification, where:
|
||||
- To obtain a refresh token, clients include the `offline_access` scope in their authorization request
|
||||
- This allows authenticated sessions to persist beyond the initial access token expiration
|
||||
|
||||
### Google's Non-Standard Approach
|
||||
|
||||
Google's OAuth implementation deviates from the standard by:
|
||||
1. Not supporting the `offline_access` scope, instead rejecting it as an invalid scope
|
||||
2. Requiring the `access_type=offline` query parameter for requesting refresh tokens
|
||||
3. Needing the `prompt=consent` parameter to consistently issue refresh tokens (especially for repeat authentications)
|
||||
|
||||
This difference caused the plugin to fail when configured for Google OAuth, as it was using a standard approach that didn't work with Google's implementation.
|
||||
|
||||
## Solution Implementation
|
||||
|
||||
The fix involved modifying the authentication flow to specifically handle Google providers:
|
||||
|
||||
1. **Google Provider Detection**: Added code to detect if the OIDC provider is Google based on the issuer URL:
|
||||
|
||||
```go
|
||||
// Check if we're dealing with a Google OIDC provider
|
||||
isGoogleProvider := strings.Contains(t.issuerURL, "google") ||
|
||||
strings.Contains(t.issuerURL, "accounts.google.com")
|
||||
```
|
||||
|
||||
2. **Provider-Specific Auth URL Building**: Modified the `buildAuthURL` function to handle Google and non-Google providers differently:
|
||||
|
||||
```go
|
||||
// Handle offline access differently for Google vs other providers
|
||||
if isGoogleProvider {
|
||||
// For Google, use access_type=offline parameter instead of offline_access scope
|
||||
params.Set("access_type", "offline")
|
||||
t.logger.Debug("Google OIDC provider detected, added access_type=offline for refresh tokens")
|
||||
|
||||
// Add prompt=consent for Google to ensure refresh token is issued
|
||||
params.Set("prompt", "consent")
|
||||
t.logger.Debug("Google OIDC provider detected, added prompt=consent to ensure refresh tokens")
|
||||
} else {
|
||||
// For non-Google providers, use the offline_access scope
|
||||
hasOfflineAccess := false
|
||||
for _, scope := range scopes {
|
||||
if scope == "offline_access" {
|
||||
hasOfflineAccess = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !hasOfflineAccess {
|
||||
scopes = append(scopes, "offline_access")
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
3. **Token Refresh Enhancement**: Improved the token refresh logic to better handle Google's behavior, particularly when refresh tokens aren't returned in refresh responses (as Google often uses the same refresh token for multiple requests).
|
||||
|
||||
## Why This Approach Works
|
||||
|
||||
This solution aligns with Google's OAuth 2.0 documentation which specifies:
|
||||
|
||||
1. **Access Type Parameter**: Google's [OAuth 2.0 documentation](https://developers.google.com/identity/protocols/oauth2/web-server#offline) states that to request a refresh token, applications must include `access_type=offline` in the authorization request.
|
||||
|
||||
2. **Prompt Parameter**: The [`prompt=consent`](https://developers.google.com/identity/protocols/oauth2/web-server#forceapprovalprompt) parameter forces the consent screen to appear, ensuring a refresh token is issued even if the user has previously granted access.
|
||||
|
||||
3. **Scope Validation**: Google strictly validates scopes and rejects non-standard ones like `offline_access`, instead relying on the `access_type` parameter to indicate whether a refresh token should be issued.
|
||||
|
||||
By adapting to these Google-specific requirements, the OIDC plugin can now seamlessly work with both standard OIDC providers and Google's OAuth implementation.
|
||||
|
||||
## Testing and Verification
|
||||
|
||||
Comprehensive tests were implemented to verify the solution:
|
||||
|
||||
1. **Provider Detection Test**: Ensures the code correctly identifies Google providers and applies the appropriate parameters.
|
||||
|
||||
2. **Auth URL Parameter Tests**: Verifies that:
|
||||
- For Google providers: `access_type=offline` and `prompt=consent` are included; `offline_access` scope is NOT included
|
||||
- For non-Google providers: `offline_access` scope IS included; `access_type` parameter is NOT added
|
||||
|
||||
3. **Token Refresh Tests**: Validates that Google's token refresh process works correctly, including the preservation of refresh tokens when Google doesn't return a new one.
|
||||
|
||||
4. **Integration Test**: Tests the complete authentication flow with a mocked Google provider to ensure all components work together seamlessly.
|
||||
|
||||
Sample test case (simplified):
|
||||
|
||||
```go
|
||||
t.Run("Google provider detection adds required parameters", func(t *testing.T) {
|
||||
// Test buildAuthURL to ensure it adds access_type=offline and prompt=consent for Google
|
||||
authURL := tOidc.buildAuthURL("https://example.com/callback", "state123", "nonce123", "")
|
||||
|
||||
// Check that access_type=offline was added (not offline_access scope for Google)
|
||||
if !strings.Contains(authURL, "access_type=offline") {
|
||||
t.Errorf("access_type=offline not added to Google auth URL: %s", authURL)
|
||||
}
|
||||
|
||||
// Verify offline_access scope is NOT included for Google providers
|
||||
if strings.Contains(authURL, "offline_access") {
|
||||
t.Errorf("offline_access scope incorrectly added to Google auth URL: %s", authURL)
|
||||
}
|
||||
|
||||
// Check that prompt=consent was added
|
||||
if !strings.Contains(authURL, "prompt=consent") {
|
||||
t.Errorf("prompt=consent not added to Google auth URL: %s", authURL)
|
||||
}
|
||||
})
|
||||
```
|
||||
|
||||
## Usage Guidance for Developers
|
||||
|
||||
When configuring the Traefik OIDC middleware for Google:
|
||||
|
||||
1. **Provider URL**: Use `https://accounts.google.com` as the `providerURL` value
|
||||
|
||||
2. **Client Configuration**: Create OAuth 2.0 credentials in the Google Cloud Console:
|
||||
- Configure the authorized redirect URI to match your `callbackURL` setting
|
||||
- Ensure your OAuth consent screen is properly configured (especially if you want long-lived refresh tokens)
|
||||
|
||||
3. **Configuration Example**:
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-google
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://accounts.google.com
|
||||
clientID: your-google-client-id.apps.googleusercontent.com
|
||||
clientSecret: your-google-client-secret
|
||||
sessionEncryptionKey: your-secure-encryption-key-min-32-chars
|
||||
callbackURL: /oauth2/callback
|
||||
scopes:
|
||||
- openid
|
||||
- email
|
||||
- profile
|
||||
# Note: DO NOT manually add offline_access scope for Google
|
||||
# The middleware handles this automatically and correctly
|
||||
```
|
||||
|
||||
4. **Troubleshooting**: If sessions still expire prematurely with Google (typically after 1 hour):
|
||||
- Ensure your Google Cloud OAuth consent screen is set to "External" and "Production" mode (not "Testing" mode, which limits refresh token validity)
|
||||
- Review your application logs with `logLevel: debug` to check for refresh token errors
|
||||
- Verify you're using a version of the middleware that includes this fix
|
||||
|
||||
## Conclusion
|
||||
|
||||
This fix ensures that the Traefik OIDC plugin works seamlessly with Google's OAuth implementation without requiring users to make provider-specific configuration changes. The middleware now intelligently adapts to the provider's requirements, making it more robust and user-friendly while maintaining compatibility with the standard OIDC specification for other providers.
|
||||
@@ -348,6 +348,7 @@ func (r *DynamicClientRegistrar) saveCredentials(resp *ClientRegistrationRespons
|
||||
func (r *DynamicClientRegistrar) loadCredentials() (*ClientRegistrationResponse, error) {
|
||||
filePath := r.credentialsFilePath()
|
||||
|
||||
// #nosec G304 -- path is constructed from trusted config values via credentialsFilePath()
|
||||
data, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
|
||||
@@ -0,0 +1,620 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/testutil"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// ClockSkewEdgeCasesSuite tests clock skew tolerance scenarios
|
||||
type ClockSkewEdgeCasesSuite struct {
|
||||
suite.Suite
|
||||
|
||||
fixture *testutil.TokenFixture
|
||||
tOidc *TraefikOidc
|
||||
}
|
||||
|
||||
func (s *ClockSkewEdgeCasesSuite) SetupSuite() {
|
||||
var err error
|
||||
s.fixture, err = testutil.NewTokenFixture()
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
func (s *ClockSkewEdgeCasesSuite) SetupTest() {
|
||||
// Create JWK for the test key
|
||||
jwk := JWK{
|
||||
Kty: "RSA",
|
||||
Kid: s.fixture.KeyID,
|
||||
Alg: "RS256",
|
||||
N: base64.RawURLEncoding.EncodeToString(s.fixture.RSAPublicKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString(bigIntToBytes(big.NewInt(int64(s.fixture.RSAPublicKey.E)))),
|
||||
}
|
||||
|
||||
jwkCache := &MockJWKCache{
|
||||
JWKS: &JWKSet{Keys: []JWK{jwk}},
|
||||
Err: nil,
|
||||
}
|
||||
|
||||
tokenBlacklist := NewCache()
|
||||
tokenCacheInternal := NewCache()
|
||||
tokenCache := &TokenCache{}
|
||||
if tokenCache.cache == nil {
|
||||
if wrapper, ok := tokenCacheInternal.(*CacheInterfaceWrapper); ok {
|
||||
tokenCache.cache = wrapper.cache
|
||||
}
|
||||
}
|
||||
|
||||
logger := NewLogger("error") // Reduce noise
|
||||
|
||||
s.tOidc = &TraefikOidc{
|
||||
issuerURL: s.fixture.Issuer,
|
||||
clientID: s.fixture.Audience,
|
||||
audience: s.fixture.Audience,
|
||||
clientSecret: "test-client-secret",
|
||||
roleClaimName: "roles",
|
||||
groupClaimName: "groups",
|
||||
userIdentifierClaim: "email",
|
||||
jwkCache: jwkCache,
|
||||
jwksURL: "https://test-jwks-url.com",
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
||||
tokenBlacklist: tokenBlacklist,
|
||||
tokenCache: tokenCache,
|
||||
logger: logger,
|
||||
httpClient: &http.Client{Timeout: 10 * time.Second},
|
||||
extractClaimsFunc: extractClaims,
|
||||
initComplete: make(chan struct{}),
|
||||
goroutineWG: &sync.WaitGroup{},
|
||||
ctx: context.Background(),
|
||||
}
|
||||
close(s.tOidc.initComplete)
|
||||
s.tOidc.tokenVerifier = s.tOidc
|
||||
s.tOidc.jwtVerifier = s.tOidc
|
||||
|
||||
s.T().Cleanup(func() {
|
||||
if s.tOidc.tokenBlacklist != nil {
|
||||
s.tOidc.tokenBlacklist.Close()
|
||||
}
|
||||
if s.tOidc.tokenCache != nil && s.tOidc.tokenCache.cache != nil {
|
||||
s.tOidc.tokenCache.cache.Close()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (s *ClockSkewEdgeCasesSuite) TestExactlyAtExpiry() {
|
||||
token, err := s.fixture.TokenWithSkew(0)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Token at exact expiry - behavior is implementation-defined
|
||||
err = s.tOidc.VerifyToken(token)
|
||||
s.T().Logf("Exact expiry result: %v", err)
|
||||
}
|
||||
|
||||
func (s *ClockSkewEdgeCasesSuite) TestOneSecondBeforeExpiry() {
|
||||
token, err := s.fixture.TokenWithSkew(1 * time.Second)
|
||||
s.Require().NoError(err)
|
||||
|
||||
err = s.tOidc.VerifyToken(token)
|
||||
s.NoError(err, "Token should be valid 1 second before expiry")
|
||||
}
|
||||
|
||||
func (s *ClockSkewEdgeCasesSuite) TestOneSecondAfterExpiry() {
|
||||
token, err := s.fixture.TokenWithSkew(-1 * time.Second)
|
||||
s.Require().NoError(err)
|
||||
|
||||
err = s.tOidc.VerifyToken(token)
|
||||
// With default 2-minute clock skew tolerance, 1 second past expiry should still be valid
|
||||
s.NoError(err, "Token 1 second past expiry should be valid within clock skew tolerance")
|
||||
}
|
||||
|
||||
func (s *ClockSkewEdgeCasesSuite) TestWithinSkewTolerance() {
|
||||
// Most implementations allow 5-minute clock skew
|
||||
token, err := s.fixture.TokenWithSkew(-4 * time.Minute)
|
||||
s.Require().NoError(err)
|
||||
|
||||
err = s.tOidc.VerifyToken(token)
|
||||
// May pass or fail depending on implementation
|
||||
s.T().Logf("4-minute expired token result: %v", err)
|
||||
}
|
||||
|
||||
func (s *ClockSkewEdgeCasesSuite) TestBeyondSkewTolerance() {
|
||||
token, err := s.fixture.TokenWithSkew(-10 * time.Minute)
|
||||
s.Require().NoError(err)
|
||||
|
||||
err = s.tOidc.VerifyToken(token)
|
||||
s.Error(err, "Token should be invalid 10 minutes after expiry")
|
||||
}
|
||||
|
||||
func TestClockSkewEdgeCasesSuite(t *testing.T) {
|
||||
suite.Run(t, new(ClockSkewEdgeCasesSuite))
|
||||
}
|
||||
|
||||
// UnicodeClaimsSuite tests Unicode handling in JWT claims
|
||||
type UnicodeClaimsSuite struct {
|
||||
suite.Suite
|
||||
|
||||
fixture *testutil.TokenFixture
|
||||
tOidc *TraefikOidc
|
||||
}
|
||||
|
||||
func (s *UnicodeClaimsSuite) SetupSuite() {
|
||||
var err error
|
||||
s.fixture, err = testutil.NewTokenFixture()
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
func (s *UnicodeClaimsSuite) SetupTest() {
|
||||
jwk := JWK{
|
||||
Kty: "RSA",
|
||||
Kid: s.fixture.KeyID,
|
||||
Alg: "RS256",
|
||||
N: base64.RawURLEncoding.EncodeToString(s.fixture.RSAPublicKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString(bigIntToBytes(big.NewInt(int64(s.fixture.RSAPublicKey.E)))),
|
||||
}
|
||||
|
||||
jwkCache := &MockJWKCache{
|
||||
JWKS: &JWKSet{Keys: []JWK{jwk}},
|
||||
Err: nil,
|
||||
}
|
||||
|
||||
tokenBlacklist := NewCache()
|
||||
tokenCacheInternal := NewCache()
|
||||
tokenCache := &TokenCache{}
|
||||
if tokenCache.cache == nil {
|
||||
if wrapper, ok := tokenCacheInternal.(*CacheInterfaceWrapper); ok {
|
||||
tokenCache.cache = wrapper.cache
|
||||
}
|
||||
}
|
||||
|
||||
logger := NewLogger("error")
|
||||
|
||||
s.tOidc = &TraefikOidc{
|
||||
issuerURL: s.fixture.Issuer,
|
||||
clientID: s.fixture.Audience,
|
||||
audience: s.fixture.Audience,
|
||||
clientSecret: "test-client-secret",
|
||||
roleClaimName: "roles",
|
||||
groupClaimName: "groups",
|
||||
userIdentifierClaim: "email",
|
||||
jwkCache: jwkCache,
|
||||
jwksURL: "https://test-jwks-url.com",
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
||||
tokenBlacklist: tokenBlacklist,
|
||||
tokenCache: tokenCache,
|
||||
logger: logger,
|
||||
httpClient: &http.Client{Timeout: 10 * time.Second},
|
||||
extractClaimsFunc: extractClaims,
|
||||
initComplete: make(chan struct{}),
|
||||
goroutineWG: &sync.WaitGroup{},
|
||||
ctx: context.Background(),
|
||||
}
|
||||
close(s.tOidc.initComplete)
|
||||
s.tOidc.tokenVerifier = s.tOidc
|
||||
s.tOidc.jwtVerifier = s.tOidc
|
||||
|
||||
s.T().Cleanup(func() {
|
||||
if s.tOidc.tokenBlacklist != nil {
|
||||
s.tOidc.tokenBlacklist.Close()
|
||||
}
|
||||
if s.tOidc.tokenCache != nil && s.tOidc.tokenCache.cache != nil {
|
||||
s.tOidc.tokenCache.cache.Close()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (s *UnicodeClaimsSuite) TestUnicodeEmail() {
|
||||
token, err := s.fixture.TokenWithEmail("用户@example.com")
|
||||
s.Require().NoError(err)
|
||||
|
||||
err = s.tOidc.VerifyToken(token)
|
||||
s.NoError(err, "Unicode email should be handled correctly")
|
||||
}
|
||||
|
||||
func (s *UnicodeClaimsSuite) TestUnicodeName() {
|
||||
token, err := s.fixture.TokenWithCustomClaims(map[string]interface{}{
|
||||
"name": "田中太郎",
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
|
||||
err = s.tOidc.VerifyToken(token)
|
||||
s.NoError(err, "Unicode name should be handled correctly")
|
||||
}
|
||||
|
||||
func (s *UnicodeClaimsSuite) TestEmojiInClaims() {
|
||||
token, err := s.fixture.TokenWithCustomClaims(map[string]interface{}{
|
||||
"name": "Test User 😀",
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
|
||||
err = s.tOidc.VerifyToken(token)
|
||||
s.NoError(err, "Emoji in claims should be handled correctly")
|
||||
}
|
||||
|
||||
func (s *UnicodeClaimsSuite) TestRTLText() {
|
||||
token, err := s.fixture.TokenWithCustomClaims(map[string]interface{}{
|
||||
"name": "مستخدم اختبار",
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
|
||||
err = s.tOidc.VerifyToken(token)
|
||||
s.NoError(err, "RTL text should be handled correctly")
|
||||
}
|
||||
|
||||
func (s *UnicodeClaimsSuite) TestMixedScripts() {
|
||||
token, err := s.fixture.TokenWithCustomClaims(map[string]interface{}{
|
||||
"name": "Test 测试 テスト",
|
||||
"roles": []string{"admin", "管理者", "管理员"},
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
|
||||
err = s.tOidc.VerifyToken(token)
|
||||
s.NoError(err, "Mixed scripts should be handled correctly")
|
||||
}
|
||||
|
||||
func TestUnicodeClaimsSuite(t *testing.T) {
|
||||
suite.Run(t, new(UnicodeClaimsSuite))
|
||||
}
|
||||
|
||||
// LargeClaimsSuite tests large claim values
|
||||
type LargeClaimsSuite struct {
|
||||
suite.Suite
|
||||
|
||||
fixture *testutil.TokenFixture
|
||||
tOidc *TraefikOidc
|
||||
}
|
||||
|
||||
func (s *LargeClaimsSuite) SetupSuite() {
|
||||
var err error
|
||||
s.fixture, err = testutil.NewTokenFixture()
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
func (s *LargeClaimsSuite) SetupTest() {
|
||||
jwk := JWK{
|
||||
Kty: "RSA",
|
||||
Kid: s.fixture.KeyID,
|
||||
Alg: "RS256",
|
||||
N: base64.RawURLEncoding.EncodeToString(s.fixture.RSAPublicKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString(bigIntToBytes(big.NewInt(int64(s.fixture.RSAPublicKey.E)))),
|
||||
}
|
||||
|
||||
jwkCache := &MockJWKCache{
|
||||
JWKS: &JWKSet{Keys: []JWK{jwk}},
|
||||
Err: nil,
|
||||
}
|
||||
|
||||
tokenBlacklist := NewCache()
|
||||
tokenCacheInternal := NewCache()
|
||||
tokenCache := &TokenCache{}
|
||||
if tokenCache.cache == nil {
|
||||
if wrapper, ok := tokenCacheInternal.(*CacheInterfaceWrapper); ok {
|
||||
tokenCache.cache = wrapper.cache
|
||||
}
|
||||
}
|
||||
|
||||
logger := NewLogger("error")
|
||||
|
||||
s.tOidc = &TraefikOidc{
|
||||
issuerURL: s.fixture.Issuer,
|
||||
clientID: s.fixture.Audience,
|
||||
audience: s.fixture.Audience,
|
||||
clientSecret: "test-client-secret",
|
||||
roleClaimName: "roles",
|
||||
groupClaimName: "groups",
|
||||
userIdentifierClaim: "email",
|
||||
jwkCache: jwkCache,
|
||||
jwksURL: "https://test-jwks-url.com",
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
||||
tokenBlacklist: tokenBlacklist,
|
||||
tokenCache: tokenCache,
|
||||
logger: logger,
|
||||
httpClient: &http.Client{Timeout: 10 * time.Second},
|
||||
extractClaimsFunc: extractClaims,
|
||||
initComplete: make(chan struct{}),
|
||||
goroutineWG: &sync.WaitGroup{},
|
||||
ctx: context.Background(),
|
||||
}
|
||||
close(s.tOidc.initComplete)
|
||||
s.tOidc.tokenVerifier = s.tOidc
|
||||
s.tOidc.jwtVerifier = s.tOidc
|
||||
|
||||
s.T().Cleanup(func() {
|
||||
if s.tOidc.tokenBlacklist != nil {
|
||||
s.tOidc.tokenBlacklist.Close()
|
||||
}
|
||||
if s.tOidc.tokenCache != nil && s.tOidc.tokenCache.cache != nil {
|
||||
s.tOidc.tokenCache.cache.Close()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (s *LargeClaimsSuite) TestManyRoles() {
|
||||
roles := make([]string, 100)
|
||||
for i := 0; i < 100; i++ {
|
||||
roles[i] = strings.Repeat("role", 10) + string(rune('A'+i%26))
|
||||
}
|
||||
|
||||
token, err := s.fixture.TokenWithRoles(roles)
|
||||
s.Require().NoError(err)
|
||||
|
||||
err = s.tOidc.VerifyToken(token)
|
||||
s.NoError(err, "Token with 100 roles should be handled")
|
||||
}
|
||||
|
||||
func (s *LargeClaimsSuite) TestManyGroups() {
|
||||
groups := make([]string, 50)
|
||||
for i := 0; i < 50; i++ {
|
||||
groups[i] = strings.Repeat("group", 5) + string(rune('A'+i%26))
|
||||
}
|
||||
|
||||
token, err := s.fixture.TokenWithGroups(groups)
|
||||
s.Require().NoError(err)
|
||||
|
||||
err = s.tOidc.VerifyToken(token)
|
||||
s.NoError(err, "Token with 50 groups should be handled")
|
||||
}
|
||||
|
||||
func (s *LargeClaimsSuite) TestLongEmail() {
|
||||
// RFC 5321 allows up to 254 characters
|
||||
localPart := strings.Repeat("a", 64)
|
||||
domain := strings.Repeat("b", 63) + ".com"
|
||||
email := localPart + "@" + domain
|
||||
|
||||
token, err := s.fixture.TokenWithEmail(email)
|
||||
s.Require().NoError(err)
|
||||
|
||||
err = s.tOidc.VerifyToken(token)
|
||||
s.NoError(err, "Token with long email should be handled")
|
||||
}
|
||||
|
||||
func (s *LargeClaimsSuite) TestLongSubject() {
|
||||
longSub := strings.Repeat("subject", 100)
|
||||
|
||||
token, err := s.fixture.TokenWithCustomClaims(map[string]interface{}{
|
||||
"sub": longSub,
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
|
||||
err = s.tOidc.VerifyToken(token)
|
||||
s.NoError(err, "Token with long subject should be handled")
|
||||
}
|
||||
|
||||
func TestLargeClaimsSuite(t *testing.T) {
|
||||
suite.Run(t, new(LargeClaimsSuite))
|
||||
}
|
||||
|
||||
// URLPathEdgeCasesSuite tests URL handling edge cases
|
||||
type URLPathEdgeCasesSuite struct {
|
||||
suite.Suite
|
||||
}
|
||||
|
||||
func (s *URLPathEdgeCasesSuite) TestVeryLongPath() {
|
||||
longPath := "/" + strings.Repeat("segment/", 100)
|
||||
req := httptest.NewRequest("GET", longPath, nil)
|
||||
|
||||
s.NotNil(req)
|
||||
s.Contains(req.URL.Path, "segment")
|
||||
}
|
||||
|
||||
func (s *URLPathEdgeCasesSuite) TestSpecialCharactersInPath() {
|
||||
paths := []string{
|
||||
"/path%20with%20spaces",
|
||||
"/path/with/日本語",
|
||||
"/path?query=value&another=test",
|
||||
"/path#fragment",
|
||||
"/path/../traversal",
|
||||
"/path/./current",
|
||||
}
|
||||
|
||||
for _, path := range paths {
|
||||
s.Run(path, func() {
|
||||
req := httptest.NewRequest("GET", path, nil)
|
||||
s.NotNil(req)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *URLPathEdgeCasesSuite) TestEmptyPath() {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
s.Equal("/", req.URL.Path)
|
||||
}
|
||||
|
||||
func (s *URLPathEdgeCasesSuite) TestDoubleSlashes() {
|
||||
req := httptest.NewRequest("GET", "//double//slashes//", nil)
|
||||
s.NotNil(req)
|
||||
}
|
||||
|
||||
func TestURLPathEdgeCasesSuite(t *testing.T) {
|
||||
suite.Run(t, new(URLPathEdgeCasesSuite))
|
||||
}
|
||||
|
||||
// ConcurrencyEdgeCasesSuite tests concurrency scenarios
|
||||
type ConcurrencyEdgeCasesSuite struct {
|
||||
suite.Suite
|
||||
|
||||
fixture *testutil.TokenFixture
|
||||
tOidc *TraefikOidc
|
||||
}
|
||||
|
||||
func (s *ConcurrencyEdgeCasesSuite) SetupSuite() {
|
||||
var err error
|
||||
s.fixture, err = testutil.NewTokenFixture()
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyEdgeCasesSuite) SetupTest() {
|
||||
jwk := JWK{
|
||||
Kty: "RSA",
|
||||
Kid: s.fixture.KeyID,
|
||||
Alg: "RS256",
|
||||
N: base64.RawURLEncoding.EncodeToString(s.fixture.RSAPublicKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString(bigIntToBytes(big.NewInt(int64(s.fixture.RSAPublicKey.E)))),
|
||||
}
|
||||
|
||||
jwkCache := &MockJWKCache{
|
||||
JWKS: &JWKSet{Keys: []JWK{jwk}},
|
||||
Err: nil,
|
||||
}
|
||||
|
||||
tokenBlacklist := NewCache()
|
||||
tokenCacheInternal := NewCache()
|
||||
tokenCache := &TokenCache{}
|
||||
if tokenCache.cache == nil {
|
||||
if wrapper, ok := tokenCacheInternal.(*CacheInterfaceWrapper); ok {
|
||||
tokenCache.cache = wrapper.cache
|
||||
}
|
||||
}
|
||||
|
||||
logger := NewLogger("error")
|
||||
|
||||
s.tOidc = &TraefikOidc{
|
||||
issuerURL: s.fixture.Issuer,
|
||||
clientID: s.fixture.Audience,
|
||||
audience: s.fixture.Audience,
|
||||
clientSecret: "test-client-secret",
|
||||
roleClaimName: "roles",
|
||||
groupClaimName: "groups",
|
||||
userIdentifierClaim: "email",
|
||||
jwkCache: jwkCache,
|
||||
jwksURL: "https://test-jwks-url.com",
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 100), // Higher limit for concurrency tests
|
||||
tokenBlacklist: tokenBlacklist,
|
||||
tokenCache: tokenCache,
|
||||
logger: logger,
|
||||
httpClient: &http.Client{Timeout: 10 * time.Second},
|
||||
extractClaimsFunc: extractClaims,
|
||||
initComplete: make(chan struct{}),
|
||||
goroutineWG: &sync.WaitGroup{},
|
||||
ctx: context.Background(),
|
||||
}
|
||||
close(s.tOidc.initComplete)
|
||||
s.tOidc.tokenVerifier = s.tOidc
|
||||
s.tOidc.jwtVerifier = s.tOidc
|
||||
|
||||
s.T().Cleanup(func() {
|
||||
if s.tOidc.tokenBlacklist != nil {
|
||||
s.tOidc.tokenBlacklist.Close()
|
||||
}
|
||||
if s.tOidc.tokenCache != nil && s.tOidc.tokenCache.cache != nil {
|
||||
s.tOidc.tokenCache.cache.Close()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (s *ConcurrencyEdgeCasesSuite) TestConcurrentTokenValidation() {
|
||||
token, err := s.fixture.ValidToken(nil)
|
||||
s.Require().NoError(err)
|
||||
|
||||
const goroutines = 50
|
||||
var wg sync.WaitGroup
|
||||
errors := make(chan error, goroutines)
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := s.tOidc.VerifyToken(token); err != nil {
|
||||
errors <- err
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errors)
|
||||
|
||||
var errCount int
|
||||
for err := range errors {
|
||||
s.T().Logf("Concurrent error: %v", err)
|
||||
errCount++
|
||||
}
|
||||
|
||||
s.Equal(0, errCount, "All concurrent validations should succeed")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyEdgeCasesSuite) TestConcurrentDifferentTokens() {
|
||||
const goroutines = 20
|
||||
var wg sync.WaitGroup
|
||||
errors := make(chan error, goroutines)
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
token, err := s.fixture.TokenWithCustomClaims(map[string]interface{}{
|
||||
"custom": idx,
|
||||
})
|
||||
if err != nil {
|
||||
errors <- err
|
||||
return
|
||||
}
|
||||
if err := s.tOidc.VerifyToken(token); err != nil {
|
||||
errors <- err
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errors)
|
||||
|
||||
var errCount int
|
||||
for err := range errors {
|
||||
s.T().Logf("Concurrent different token error: %v", err)
|
||||
errCount++
|
||||
}
|
||||
|
||||
s.Equal(0, errCount, "All concurrent different token validations should succeed")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyEdgeCasesSuite) TestConcurrentMixedValidInvalid() {
|
||||
validToken, err := s.fixture.ValidToken(nil)
|
||||
s.Require().NoError(err)
|
||||
expiredToken, err := s.fixture.ExpiredToken()
|
||||
s.Require().NoError(err)
|
||||
|
||||
const goroutines = 40
|
||||
var wg sync.WaitGroup
|
||||
validCount := int32(0)
|
||||
expiredCount := int32(0)
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
var token string
|
||||
if idx%2 == 0 {
|
||||
token = validToken
|
||||
} else {
|
||||
token = expiredToken
|
||||
}
|
||||
|
||||
err := s.tOidc.VerifyToken(token)
|
||||
if idx%2 == 0 {
|
||||
if err == nil {
|
||||
atomic.AddInt32(&validCount, 1)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
atomic.AddInt32(&expiredCount, 1)
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
s.T().Logf("Valid passed: %d, Expired rejected: %d", validCount, expiredCount)
|
||||
}
|
||||
|
||||
func TestConcurrencyEdgeCasesSuite(t *testing.T) {
|
||||
suite.Run(t, new(ConcurrencyEdgeCasesSuite))
|
||||
}
|
||||
@@ -0,0 +1,258 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
// EnhancedMocksSuite demonstrates improved state-based mocks with call tracking
|
||||
type EnhancedMocksSuite struct {
|
||||
suite.Suite
|
||||
}
|
||||
|
||||
func (s *EnhancedMocksSuite) TestEnhancedJWKCacheCallTracking() {
|
||||
mock := &EnhancedMockJWKCache{
|
||||
JWKS: &JWKSet{Keys: []JWK{{Kid: "test-key"}}},
|
||||
}
|
||||
|
||||
// Make some calls
|
||||
result, err := mock.GetJWKS(context.Background(), "https://example.com/jwks", nil)
|
||||
s.NoError(err)
|
||||
s.NotNil(result)
|
||||
|
||||
// Another call with different URL
|
||||
_, _ = mock.GetJWKS(context.Background(), "https://other.com/jwks", nil)
|
||||
|
||||
// Verify calls were tracked
|
||||
s.Equal(2, mock.GetJWKSCallCount())
|
||||
mock.AssertGetJWKSCalled(s.T())
|
||||
mock.AssertGetJWKSCalledWith(s.T(), "https://example.com/jwks")
|
||||
mock.AssertGetJWKSCallCount(s.T(), 2)
|
||||
}
|
||||
|
||||
func (s *EnhancedMocksSuite) TestEnhancedJWKCacheWithError() {
|
||||
expectedErr := errors.New("network error")
|
||||
mock := &EnhancedMockJWKCache{
|
||||
Err: expectedErr,
|
||||
}
|
||||
|
||||
result, err := mock.GetJWKS(context.Background(), "https://example.com/jwks", nil)
|
||||
|
||||
s.Nil(result)
|
||||
s.Equal(expectedErr, err)
|
||||
mock.AssertGetJWKSCalled(s.T())
|
||||
}
|
||||
|
||||
func (s *EnhancedMocksSuite) TestEnhancedJWKCacheReset() {
|
||||
mock := &EnhancedMockJWKCache{
|
||||
JWKS: &JWKSet{Keys: []JWK{{Kid: "test-key"}}},
|
||||
}
|
||||
|
||||
_, _ = mock.GetJWKS(context.Background(), "https://example.com/jwks", nil)
|
||||
s.Equal(1, mock.GetJWKSCallCount())
|
||||
|
||||
mock.Reset()
|
||||
|
||||
s.Equal(0, mock.GetJWKSCallCount())
|
||||
s.Nil(mock.JWKS)
|
||||
}
|
||||
|
||||
func (s *EnhancedMocksSuite) TestEnhancedTokenVerifierCallTracking() {
|
||||
mock := &EnhancedMockTokenVerifier{
|
||||
Err: nil, // Valid tokens
|
||||
}
|
||||
|
||||
// Verify a token
|
||||
err := mock.VerifyToken("test-token-1")
|
||||
s.NoError(err)
|
||||
|
||||
// Verify another token
|
||||
err = mock.VerifyToken("test-token-2")
|
||||
s.NoError(err)
|
||||
|
||||
// Check tracking
|
||||
s.Equal(2, mock.GetVerifyTokenCallCount())
|
||||
mock.AssertVerifyTokenCalled(s.T())
|
||||
mock.AssertVerifyTokenCalledWith(s.T(), "test-token-1")
|
||||
|
||||
// Check last call
|
||||
lastCall := mock.LastCall()
|
||||
s.NotNil(lastCall)
|
||||
s.Equal("test-token-2", lastCall.Token)
|
||||
}
|
||||
|
||||
func (s *EnhancedMocksSuite) TestEnhancedTokenVerifierWithDynamicFunc() {
|
||||
callCount := 0
|
||||
mock := &EnhancedMockTokenVerifier{
|
||||
VerifyFunc: func(token string) error {
|
||||
callCount++
|
||||
if token == "invalid" {
|
||||
return errors.New("invalid token")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
// Valid token
|
||||
err := mock.VerifyToken("valid-token")
|
||||
s.NoError(err)
|
||||
|
||||
// Invalid token
|
||||
err = mock.VerifyToken("invalid")
|
||||
s.Error(err)
|
||||
|
||||
s.Equal(2, callCount)
|
||||
s.Equal(2, mock.GetVerifyTokenCallCount())
|
||||
}
|
||||
|
||||
func (s *EnhancedMocksSuite) TestEnhancedTokenExchangerCallTracking() {
|
||||
mock := &EnhancedMockTokenExchanger{
|
||||
ExchangeResponse: &TokenResponse{
|
||||
AccessToken: "access-token",
|
||||
RefreshToken: "refresh-token",
|
||||
ExpiresIn: 3600,
|
||||
},
|
||||
RefreshResponse: &TokenResponse{
|
||||
AccessToken: "new-access-token",
|
||||
ExpiresIn: 3600,
|
||||
},
|
||||
}
|
||||
|
||||
// Exchange code
|
||||
resp, err := mock.ExchangeCodeForToken(context.Background(), "authorization_code", "auth-code", "https://redirect.com", "verifier")
|
||||
s.NoError(err)
|
||||
s.Equal("access-token", resp.AccessToken)
|
||||
|
||||
// Refresh token
|
||||
resp, err = mock.GetNewTokenWithRefreshToken("refresh-token")
|
||||
s.NoError(err)
|
||||
s.Equal("new-access-token", resp.AccessToken)
|
||||
|
||||
// Revoke token
|
||||
err = mock.RevokeTokenWithProvider("access-token", "access_token")
|
||||
s.NoError(err)
|
||||
|
||||
// Check tracking
|
||||
mock.AssertExchangeCalled(s.T())
|
||||
mock.AssertExchangeCalledWith(s.T(), "authorization_code")
|
||||
mock.AssertRefreshCalled(s.T())
|
||||
mock.AssertRevokeCalled(s.T())
|
||||
|
||||
s.Equal(1, mock.GetExchangeCallCount())
|
||||
s.Equal(1, mock.GetRefreshCallCount())
|
||||
s.Equal(1, mock.GetRevokeCallCount())
|
||||
|
||||
// Check last exchange call details
|
||||
lastExchange := mock.LastExchangeCall()
|
||||
s.NotNil(lastExchange)
|
||||
s.Equal("authorization_code", lastExchange.GrantType)
|
||||
s.Equal("auth-code", lastExchange.CodeOrToken)
|
||||
s.Equal("https://redirect.com", lastExchange.RedirectURL)
|
||||
}
|
||||
|
||||
func (s *EnhancedMocksSuite) TestEnhancedTokenExchangerWithErrors() {
|
||||
mock := &EnhancedMockTokenExchanger{
|
||||
ExchangeErr: errors.New("invalid_grant"),
|
||||
RefreshErr: errors.New("refresh_expired"),
|
||||
RevokeErr: errors.New("revoke_failed"),
|
||||
}
|
||||
|
||||
_, err := mock.ExchangeCodeForToken(context.Background(), "authorization_code", "code", "", "")
|
||||
s.Error(err)
|
||||
s.Contains(err.Error(), "invalid_grant")
|
||||
|
||||
_, err = mock.GetNewTokenWithRefreshToken("token")
|
||||
s.Error(err)
|
||||
s.Contains(err.Error(), "refresh_expired")
|
||||
|
||||
err = mock.RevokeTokenWithProvider("token", "access_token")
|
||||
s.Error(err)
|
||||
s.Contains(err.Error(), "revoke_failed")
|
||||
}
|
||||
|
||||
func (s *EnhancedMocksSuite) TestEnhancedCacheCallTracking() {
|
||||
mock := NewEnhancedMockCache()
|
||||
|
||||
// Set some values
|
||||
mock.Set("key1", "value1", 5*time.Minute)
|
||||
mock.Set("key2", "value2", 10*time.Minute)
|
||||
|
||||
// Get values
|
||||
val, found := mock.Get("key1")
|
||||
s.True(found)
|
||||
s.Equal("value1", val)
|
||||
|
||||
_, found = mock.Get("nonexistent")
|
||||
s.False(found)
|
||||
|
||||
// Delete
|
||||
mock.Delete("key1")
|
||||
|
||||
// Verify tracking
|
||||
mock.AssertSetCalled(s.T(), "key1")
|
||||
mock.AssertSetCalled(s.T(), "key2")
|
||||
mock.AssertGetCalled(s.T(), "key1")
|
||||
mock.AssertGetCalled(s.T(), "nonexistent")
|
||||
mock.AssertDeleteCalled(s.T(), "key1")
|
||||
|
||||
s.Equal(2, mock.SetCallCount())
|
||||
s.Equal(2, mock.GetCallCount())
|
||||
}
|
||||
|
||||
func (s *EnhancedMocksSuite) TestEnhancedCacheActualStorage() {
|
||||
mock := NewEnhancedMockCache()
|
||||
|
||||
// The enhanced mock actually stores data
|
||||
mock.Set("key", "value", time.Hour)
|
||||
s.Equal(1, mock.Size())
|
||||
|
||||
val, found := mock.Get("key")
|
||||
s.True(found)
|
||||
s.Equal("value", val)
|
||||
|
||||
mock.Delete("key")
|
||||
s.Equal(0, mock.Size())
|
||||
|
||||
_, found = mock.Get("key")
|
||||
s.False(found)
|
||||
}
|
||||
|
||||
func (s *EnhancedMocksSuite) TestEnhancedCacheClear() {
|
||||
mock := NewEnhancedMockCache()
|
||||
|
||||
mock.Set("key1", "value1", time.Hour)
|
||||
mock.Set("key2", "value2", time.Hour)
|
||||
s.Equal(2, mock.Size())
|
||||
|
||||
mock.Clear()
|
||||
s.Equal(0, mock.Size())
|
||||
}
|
||||
|
||||
func (s *EnhancedMocksSuite) TestConcurrentAccess() {
|
||||
mock := &EnhancedMockJWKCache{
|
||||
JWKS: &JWKSet{Keys: []JWK{{Kid: "test-key"}}},
|
||||
}
|
||||
|
||||
// Concurrent calls should be safe
|
||||
done := make(chan bool)
|
||||
for i := 0; i < 10; i++ {
|
||||
go func() {
|
||||
_, _ = mock.GetJWKS(context.Background(), "https://example.com/jwks", nil)
|
||||
done <- true
|
||||
}()
|
||||
}
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
<-done
|
||||
}
|
||||
|
||||
s.Equal(10, mock.GetJWKSCallCount())
|
||||
}
|
||||
|
||||
func TestEnhancedMocksSuite(t *testing.T) {
|
||||
suite.Run(t, new(EnhancedMocksSuite))
|
||||
}
|
||||
@@ -0,0 +1,595 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// EnhancedMockJWKCache is an improved state-based mock with call tracking
|
||||
type EnhancedMockJWKCache struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
// State (what to return)
|
||||
JWKS *JWKSet
|
||||
Err error
|
||||
|
||||
// Call tracking
|
||||
GetJWKSCalls []JWKSCall
|
||||
CleanupCalls int32
|
||||
CloseCalls int32
|
||||
getJWKSCallsMu sync.Mutex
|
||||
}
|
||||
|
||||
// JWKSCall records parameters from a GetJWKS call
|
||||
type JWKSCall struct {
|
||||
URL string
|
||||
Timestamp time.Time
|
||||
}
|
||||
|
||||
func (m *EnhancedMockJWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error) {
|
||||
m.getJWKSCallsMu.Lock()
|
||||
m.GetJWKSCalls = append(m.GetJWKSCalls, JWKSCall{
|
||||
URL: jwksURL,
|
||||
Timestamp: time.Now(),
|
||||
})
|
||||
m.getJWKSCallsMu.Unlock()
|
||||
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.JWKS, m.Err
|
||||
}
|
||||
|
||||
func (m *EnhancedMockJWKCache) Cleanup() {
|
||||
atomic.AddInt32(&m.CleanupCalls, 1)
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.JWKS = nil
|
||||
m.Err = nil
|
||||
}
|
||||
|
||||
func (m *EnhancedMockJWKCache) Close() {
|
||||
atomic.AddInt32(&m.CloseCalls, 1)
|
||||
}
|
||||
|
||||
// Assertion helpers
|
||||
|
||||
// AssertGetJWKSCalled verifies GetJWKS was called
|
||||
func (m *EnhancedMockJWKCache) AssertGetJWKSCalled(t assert.TestingT) bool {
|
||||
m.getJWKSCallsMu.Lock()
|
||||
defer m.getJWKSCallsMu.Unlock()
|
||||
return assert.NotEmpty(t, m.GetJWKSCalls, "GetJWKS should have been called")
|
||||
}
|
||||
|
||||
// AssertGetJWKSCalledWith verifies GetJWKS was called with specific URL
|
||||
func (m *EnhancedMockJWKCache) AssertGetJWKSCalledWith(t assert.TestingT, expectedURL string) bool {
|
||||
m.getJWKSCallsMu.Lock()
|
||||
defer m.getJWKSCallsMu.Unlock()
|
||||
for _, call := range m.GetJWKSCalls {
|
||||
if call.URL == expectedURL {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return assert.Fail(t, "GetJWKS was not called with URL: "+expectedURL)
|
||||
}
|
||||
|
||||
// AssertGetJWKSCallCount verifies the number of GetJWKS calls
|
||||
func (m *EnhancedMockJWKCache) AssertGetJWKSCallCount(t assert.TestingT, expected int) bool {
|
||||
m.getJWKSCallsMu.Lock()
|
||||
defer m.getJWKSCallsMu.Unlock()
|
||||
return assert.Equal(t, expected, len(m.GetJWKSCalls), "GetJWKS call count mismatch")
|
||||
}
|
||||
|
||||
// GetJWKSCallCount returns the number of GetJWKS calls
|
||||
func (m *EnhancedMockJWKCache) GetJWKSCallCount() int {
|
||||
m.getJWKSCallsMu.Lock()
|
||||
defer m.getJWKSCallsMu.Unlock()
|
||||
return len(m.GetJWKSCalls)
|
||||
}
|
||||
|
||||
// Reset clears all state and call tracking
|
||||
func (m *EnhancedMockJWKCache) Reset() {
|
||||
m.mu.Lock()
|
||||
m.JWKS = nil
|
||||
m.Err = nil
|
||||
m.mu.Unlock()
|
||||
|
||||
m.getJWKSCallsMu.Lock()
|
||||
m.GetJWKSCalls = nil
|
||||
m.getJWKSCallsMu.Unlock()
|
||||
|
||||
atomic.StoreInt32(&m.CleanupCalls, 0)
|
||||
atomic.StoreInt32(&m.CloseCalls, 0)
|
||||
}
|
||||
|
||||
// EnhancedMockTokenVerifier is an improved state-based mock with call tracking
|
||||
type EnhancedMockTokenVerifier struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
// State (what to return) - can be a fixed error or a function
|
||||
Err error
|
||||
VerifyFunc func(token string) error
|
||||
|
||||
// Call tracking
|
||||
VerifyCalls []TokenVerifyCall
|
||||
verifyCallsMu sync.Mutex
|
||||
}
|
||||
|
||||
// TokenVerifyCall records parameters from a VerifyToken call
|
||||
type TokenVerifyCall struct {
|
||||
Token string
|
||||
Timestamp time.Time
|
||||
Result error
|
||||
}
|
||||
|
||||
func (m *EnhancedMockTokenVerifier) VerifyToken(token string) error {
|
||||
var result error
|
||||
|
||||
m.mu.RLock()
|
||||
if m.VerifyFunc != nil {
|
||||
result = m.VerifyFunc(token)
|
||||
} else {
|
||||
result = m.Err
|
||||
}
|
||||
m.mu.RUnlock()
|
||||
|
||||
m.verifyCallsMu.Lock()
|
||||
m.VerifyCalls = append(m.VerifyCalls, TokenVerifyCall{
|
||||
Token: token,
|
||||
Timestamp: time.Now(),
|
||||
Result: result,
|
||||
})
|
||||
m.verifyCallsMu.Unlock()
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// Assertion helpers
|
||||
|
||||
// AssertVerifyTokenCalled verifies VerifyToken was called
|
||||
func (m *EnhancedMockTokenVerifier) AssertVerifyTokenCalled(t assert.TestingT) bool {
|
||||
m.verifyCallsMu.Lock()
|
||||
defer m.verifyCallsMu.Unlock()
|
||||
return assert.NotEmpty(t, m.VerifyCalls, "VerifyToken should have been called")
|
||||
}
|
||||
|
||||
// AssertVerifyTokenCalledWith verifies VerifyToken was called with specific token
|
||||
func (m *EnhancedMockTokenVerifier) AssertVerifyTokenCalledWith(t assert.TestingT, expectedToken string) bool {
|
||||
m.verifyCallsMu.Lock()
|
||||
defer m.verifyCallsMu.Unlock()
|
||||
for _, call := range m.VerifyCalls {
|
||||
if call.Token == expectedToken {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return assert.Fail(t, "VerifyToken was not called with expected token")
|
||||
}
|
||||
|
||||
// AssertVerifyTokenCallCount verifies the number of VerifyToken calls
|
||||
func (m *EnhancedMockTokenVerifier) AssertVerifyTokenCallCount(t assert.TestingT, expected int) bool {
|
||||
m.verifyCallsMu.Lock()
|
||||
defer m.verifyCallsMu.Unlock()
|
||||
return assert.Equal(t, expected, len(m.VerifyCalls), "VerifyToken call count mismatch")
|
||||
}
|
||||
|
||||
// GetVerifyTokenCallCount returns the number of VerifyToken calls
|
||||
func (m *EnhancedMockTokenVerifier) GetVerifyTokenCallCount() int {
|
||||
m.verifyCallsMu.Lock()
|
||||
defer m.verifyCallsMu.Unlock()
|
||||
return len(m.VerifyCalls)
|
||||
}
|
||||
|
||||
// LastCall returns the most recent VerifyToken call
|
||||
func (m *EnhancedMockTokenVerifier) LastCall() *TokenVerifyCall {
|
||||
m.verifyCallsMu.Lock()
|
||||
defer m.verifyCallsMu.Unlock()
|
||||
if len(m.VerifyCalls) == 0 {
|
||||
return nil
|
||||
}
|
||||
return &m.VerifyCalls[len(m.VerifyCalls)-1]
|
||||
}
|
||||
|
||||
// Reset clears all state and call tracking
|
||||
func (m *EnhancedMockTokenVerifier) Reset() {
|
||||
m.mu.Lock()
|
||||
m.Err = nil
|
||||
m.VerifyFunc = nil
|
||||
m.mu.Unlock()
|
||||
|
||||
m.verifyCallsMu.Lock()
|
||||
m.VerifyCalls = nil
|
||||
m.verifyCallsMu.Unlock()
|
||||
}
|
||||
|
||||
// EnhancedMockTokenExchanger is an improved state-based mock with call tracking
|
||||
type EnhancedMockTokenExchanger struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
// State (what to return)
|
||||
ExchangeResponse *TokenResponse
|
||||
ExchangeErr error
|
||||
RefreshResponse *TokenResponse
|
||||
RefreshErr error
|
||||
RevokeErr error
|
||||
|
||||
// Optional functions for dynamic behavior
|
||||
ExchangeCodeFunc func(ctx context.Context, grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error)
|
||||
RefreshTokenFunc func(refreshToken string) (*TokenResponse, error)
|
||||
RevokeTokenFunc func(token, tokenType string) error
|
||||
|
||||
// Call tracking
|
||||
ExchangeCalls []ExchangeCall
|
||||
RefreshCalls []RefreshCall
|
||||
RevokeCalls []RevokeCall
|
||||
exchangeCallsMu sync.Mutex
|
||||
refreshCallsMu sync.Mutex
|
||||
revokeCallsMu sync.Mutex
|
||||
}
|
||||
|
||||
// ExchangeCall records parameters from an ExchangeCodeForToken call
|
||||
type ExchangeCall struct {
|
||||
GrantType string
|
||||
CodeOrToken string
|
||||
RedirectURL string
|
||||
CodeVerifier string
|
||||
Timestamp time.Time
|
||||
}
|
||||
|
||||
// RefreshCall records parameters from a GetNewTokenWithRefreshToken call
|
||||
type RefreshCall struct {
|
||||
RefreshToken string
|
||||
Timestamp time.Time
|
||||
}
|
||||
|
||||
// RevokeCall records parameters from a RevokeTokenWithProvider call
|
||||
type RevokeCall struct {
|
||||
Token string
|
||||
TokenType string
|
||||
Timestamp time.Time
|
||||
}
|
||||
|
||||
func (m *EnhancedMockTokenExchanger) ExchangeCodeForToken(ctx context.Context, grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error) {
|
||||
m.exchangeCallsMu.Lock()
|
||||
m.ExchangeCalls = append(m.ExchangeCalls, ExchangeCall{
|
||||
GrantType: grantType,
|
||||
CodeOrToken: codeOrToken,
|
||||
RedirectURL: redirectURL,
|
||||
CodeVerifier: codeVerifier,
|
||||
Timestamp: time.Now(),
|
||||
})
|
||||
m.exchangeCallsMu.Unlock()
|
||||
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
if m.ExchangeCodeFunc != nil {
|
||||
return m.ExchangeCodeFunc(ctx, grantType, codeOrToken, redirectURL, codeVerifier)
|
||||
}
|
||||
return m.ExchangeResponse, m.ExchangeErr
|
||||
}
|
||||
|
||||
func (m *EnhancedMockTokenExchanger) GetNewTokenWithRefreshToken(refreshToken string) (*TokenResponse, error) {
|
||||
m.refreshCallsMu.Lock()
|
||||
m.RefreshCalls = append(m.RefreshCalls, RefreshCall{
|
||||
RefreshToken: refreshToken,
|
||||
Timestamp: time.Now(),
|
||||
})
|
||||
m.refreshCallsMu.Unlock()
|
||||
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
if m.RefreshTokenFunc != nil {
|
||||
return m.RefreshTokenFunc(refreshToken)
|
||||
}
|
||||
return m.RefreshResponse, m.RefreshErr
|
||||
}
|
||||
|
||||
func (m *EnhancedMockTokenExchanger) RevokeTokenWithProvider(token, tokenType string) error {
|
||||
m.revokeCallsMu.Lock()
|
||||
m.RevokeCalls = append(m.RevokeCalls, RevokeCall{
|
||||
Token: token,
|
||||
TokenType: tokenType,
|
||||
Timestamp: time.Now(),
|
||||
})
|
||||
m.revokeCallsMu.Unlock()
|
||||
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
if m.RevokeTokenFunc != nil {
|
||||
return m.RevokeTokenFunc(token, tokenType)
|
||||
}
|
||||
return m.RevokeErr
|
||||
}
|
||||
|
||||
// Assertion helpers
|
||||
|
||||
// AssertExchangeCalled verifies ExchangeCodeForToken was called
|
||||
func (m *EnhancedMockTokenExchanger) AssertExchangeCalled(t assert.TestingT) bool {
|
||||
m.exchangeCallsMu.Lock()
|
||||
defer m.exchangeCallsMu.Unlock()
|
||||
return assert.NotEmpty(t, m.ExchangeCalls, "ExchangeCodeForToken should have been called")
|
||||
}
|
||||
|
||||
// AssertExchangeCalledWith verifies ExchangeCodeForToken was called with specific grant type
|
||||
func (m *EnhancedMockTokenExchanger) AssertExchangeCalledWith(t assert.TestingT, grantType string) bool {
|
||||
m.exchangeCallsMu.Lock()
|
||||
defer m.exchangeCallsMu.Unlock()
|
||||
for _, call := range m.ExchangeCalls {
|
||||
if call.GrantType == grantType {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return assert.Fail(t, "ExchangeCodeForToken was not called with grant type: "+grantType)
|
||||
}
|
||||
|
||||
// AssertRefreshCalled verifies GetNewTokenWithRefreshToken was called
|
||||
func (m *EnhancedMockTokenExchanger) AssertRefreshCalled(t assert.TestingT) bool {
|
||||
m.refreshCallsMu.Lock()
|
||||
defer m.refreshCallsMu.Unlock()
|
||||
return assert.NotEmpty(t, m.RefreshCalls, "GetNewTokenWithRefreshToken should have been called")
|
||||
}
|
||||
|
||||
// AssertRevokeCalled verifies RevokeTokenWithProvider was called
|
||||
func (m *EnhancedMockTokenExchanger) AssertRevokeCalled(t assert.TestingT) bool {
|
||||
m.revokeCallsMu.Lock()
|
||||
defer m.revokeCallsMu.Unlock()
|
||||
return assert.NotEmpty(t, m.RevokeCalls, "RevokeTokenWithProvider should have been called")
|
||||
}
|
||||
|
||||
// GetExchangeCallCount returns the number of ExchangeCodeForToken calls
|
||||
func (m *EnhancedMockTokenExchanger) GetExchangeCallCount() int {
|
||||
m.exchangeCallsMu.Lock()
|
||||
defer m.exchangeCallsMu.Unlock()
|
||||
return len(m.ExchangeCalls)
|
||||
}
|
||||
|
||||
// GetRefreshCallCount returns the number of GetNewTokenWithRefreshToken calls
|
||||
func (m *EnhancedMockTokenExchanger) GetRefreshCallCount() int {
|
||||
m.refreshCallsMu.Lock()
|
||||
defer m.refreshCallsMu.Unlock()
|
||||
return len(m.RefreshCalls)
|
||||
}
|
||||
|
||||
// GetRevokeCallCount returns the number of RevokeTokenWithProvider calls
|
||||
func (m *EnhancedMockTokenExchanger) GetRevokeCallCount() int {
|
||||
m.revokeCallsMu.Lock()
|
||||
defer m.revokeCallsMu.Unlock()
|
||||
return len(m.RevokeCalls)
|
||||
}
|
||||
|
||||
// LastExchangeCall returns the most recent ExchangeCodeForToken call
|
||||
func (m *EnhancedMockTokenExchanger) LastExchangeCall() *ExchangeCall {
|
||||
m.exchangeCallsMu.Lock()
|
||||
defer m.exchangeCallsMu.Unlock()
|
||||
if len(m.ExchangeCalls) == 0 {
|
||||
return nil
|
||||
}
|
||||
return &m.ExchangeCalls[len(m.ExchangeCalls)-1]
|
||||
}
|
||||
|
||||
// Reset clears all state and call tracking
|
||||
func (m *EnhancedMockTokenExchanger) Reset() {
|
||||
m.mu.Lock()
|
||||
m.ExchangeResponse = nil
|
||||
m.ExchangeErr = nil
|
||||
m.RefreshResponse = nil
|
||||
m.RefreshErr = nil
|
||||
m.RevokeErr = nil
|
||||
m.ExchangeCodeFunc = nil
|
||||
m.RefreshTokenFunc = nil
|
||||
m.RevokeTokenFunc = nil
|
||||
m.mu.Unlock()
|
||||
|
||||
m.exchangeCallsMu.Lock()
|
||||
m.ExchangeCalls = nil
|
||||
m.exchangeCallsMu.Unlock()
|
||||
|
||||
m.refreshCallsMu.Lock()
|
||||
m.RefreshCalls = nil
|
||||
m.refreshCallsMu.Unlock()
|
||||
|
||||
m.revokeCallsMu.Lock()
|
||||
m.RevokeCalls = nil
|
||||
m.revokeCallsMu.Unlock()
|
||||
}
|
||||
|
||||
// EnhancedMockCacheInterface is an improved state-based mock for CacheInterface
|
||||
type EnhancedMockCacheInterface struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
// Internal storage
|
||||
data map[string]cacheEntry
|
||||
maxSize int
|
||||
|
||||
// Call tracking
|
||||
GetCalls []CacheGetCall
|
||||
SetCalls []CacheSetCall
|
||||
DeleteCalls []string
|
||||
getCalls sync.Mutex
|
||||
setCalls sync.Mutex
|
||||
deleteCalls sync.Mutex
|
||||
}
|
||||
|
||||
type cacheEntry struct {
|
||||
value any
|
||||
ttl time.Duration
|
||||
}
|
||||
|
||||
// CacheGetCall records parameters from a Get call
|
||||
type CacheGetCall struct {
|
||||
Key string
|
||||
Found bool
|
||||
Timestamp time.Time
|
||||
}
|
||||
|
||||
// CacheSetCall records parameters from a Set call
|
||||
type CacheSetCall struct {
|
||||
Key string
|
||||
Value any
|
||||
TTL time.Duration
|
||||
Timestamp time.Time
|
||||
}
|
||||
|
||||
// NewEnhancedMockCache creates a new enhanced cache mock
|
||||
func NewEnhancedMockCache() *EnhancedMockCacheInterface {
|
||||
return &EnhancedMockCacheInterface{
|
||||
data: make(map[string]cacheEntry),
|
||||
maxSize: 1000,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *EnhancedMockCacheInterface) Set(key string, value any, ttl time.Duration) {
|
||||
m.setCalls.Lock()
|
||||
m.SetCalls = append(m.SetCalls, CacheSetCall{
|
||||
Key: key,
|
||||
Value: value,
|
||||
TTL: ttl,
|
||||
Timestamp: time.Now(),
|
||||
})
|
||||
m.setCalls.Unlock()
|
||||
|
||||
m.mu.Lock()
|
||||
m.data[key] = cacheEntry{value: value, ttl: ttl}
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
func (m *EnhancedMockCacheInterface) Get(key string) (any, bool) {
|
||||
m.mu.RLock()
|
||||
entry, found := m.data[key]
|
||||
m.mu.RUnlock()
|
||||
|
||||
m.getCalls.Lock()
|
||||
m.GetCalls = append(m.GetCalls, CacheGetCall{
|
||||
Key: key,
|
||||
Found: found,
|
||||
Timestamp: time.Now(),
|
||||
})
|
||||
m.getCalls.Unlock()
|
||||
|
||||
if found {
|
||||
return entry.value, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (m *EnhancedMockCacheInterface) Delete(key string) {
|
||||
m.deleteCalls.Lock()
|
||||
m.DeleteCalls = append(m.DeleteCalls, key)
|
||||
m.deleteCalls.Unlock()
|
||||
|
||||
m.mu.Lock()
|
||||
delete(m.data, key)
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
func (m *EnhancedMockCacheInterface) SetMaxSize(size int) {
|
||||
m.mu.Lock()
|
||||
m.maxSize = size
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
func (m *EnhancedMockCacheInterface) Size() int {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return len(m.data)
|
||||
}
|
||||
|
||||
func (m *EnhancedMockCacheInterface) Clear() {
|
||||
m.mu.Lock()
|
||||
m.data = make(map[string]cacheEntry)
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
func (m *EnhancedMockCacheInterface) Cleanup() {
|
||||
// No-op for mock
|
||||
}
|
||||
|
||||
func (m *EnhancedMockCacheInterface) Close() {
|
||||
// No-op for mock
|
||||
}
|
||||
|
||||
func (m *EnhancedMockCacheInterface) GetStats() map[string]any {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return map[string]any{
|
||||
"size": len(m.data),
|
||||
"max_size": m.maxSize,
|
||||
}
|
||||
}
|
||||
|
||||
// Assertion helpers
|
||||
|
||||
// AssertGetCalled verifies Get was called with specific key
|
||||
func (m *EnhancedMockCacheInterface) AssertGetCalled(t assert.TestingT, key string) bool {
|
||||
m.getCalls.Lock()
|
||||
defer m.getCalls.Unlock()
|
||||
for _, call := range m.GetCalls {
|
||||
if call.Key == key {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return assert.Fail(t, "Get was not called with key: "+key)
|
||||
}
|
||||
|
||||
// AssertSetCalled verifies Set was called with specific key
|
||||
func (m *EnhancedMockCacheInterface) AssertSetCalled(t assert.TestingT, key string) bool {
|
||||
m.setCalls.Lock()
|
||||
defer m.setCalls.Unlock()
|
||||
for _, call := range m.SetCalls {
|
||||
if call.Key == key {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return assert.Fail(t, "Set was not called with key: "+key)
|
||||
}
|
||||
|
||||
// AssertDeleteCalled verifies Delete was called with specific key
|
||||
func (m *EnhancedMockCacheInterface) AssertDeleteCalled(t assert.TestingT, key string) bool {
|
||||
m.deleteCalls.Lock()
|
||||
defer m.deleteCalls.Unlock()
|
||||
for _, k := range m.DeleteCalls {
|
||||
if k == key {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return assert.Fail(t, "Delete was not called with key: "+key)
|
||||
}
|
||||
|
||||
// GetCallCount returns the number of Get calls
|
||||
func (m *EnhancedMockCacheInterface) GetCallCount() int {
|
||||
m.getCalls.Lock()
|
||||
defer m.getCalls.Unlock()
|
||||
return len(m.GetCalls)
|
||||
}
|
||||
|
||||
// SetCallCount returns the number of Set calls
|
||||
func (m *EnhancedMockCacheInterface) SetCallCount() int {
|
||||
m.setCalls.Lock()
|
||||
defer m.setCalls.Unlock()
|
||||
return len(m.SetCalls)
|
||||
}
|
||||
|
||||
// Reset clears all state and call tracking
|
||||
func (m *EnhancedMockCacheInterface) Reset() {
|
||||
m.mu.Lock()
|
||||
m.data = make(map[string]cacheEntry)
|
||||
m.mu.Unlock()
|
||||
|
||||
m.getCalls.Lock()
|
||||
m.GetCalls = nil
|
||||
m.getCalls.Unlock()
|
||||
|
||||
m.setCalls.Lock()
|
||||
m.SetCalls = nil
|
||||
m.setCalls.Unlock()
|
||||
|
||||
m.deleteCalls.Lock()
|
||||
m.DeleteCalls = nil
|
||||
m.deleteCalls.Unlock()
|
||||
}
|
||||
@@ -2,10 +2,14 @@ package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"math/rand/v2"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -411,6 +415,31 @@ func DefaultRetryConfig() RetryConfig {
|
||||
}
|
||||
}
|
||||
|
||||
// MetadataFetchRetryConfig returns retry configuration optimized for OIDC metadata
|
||||
// fetching during startup. Uses more aggressive retry settings to handle the race
|
||||
// condition where Traefik initializes the plugin before routes are fully established,
|
||||
// or before TLS certificates are properly loaded.
|
||||
// See: https://github.com/lukaszraczylo/traefikoidc/issues/90
|
||||
func MetadataFetchRetryConfig() RetryConfig {
|
||||
return RetryConfig{
|
||||
MaxAttempts: 10, // More attempts for startup scenarios
|
||||
InitialDelay: 1 * time.Second, // 1 second between attempts as suggested
|
||||
MaxDelay: 10 * time.Second, // Cap at 10 seconds
|
||||
BackoffFactor: 1.5, // Gentler backoff for startup
|
||||
EnableJitter: true, // Prevent thundering herd
|
||||
RetryableErrors: []string{
|
||||
"connection refused",
|
||||
"timeout",
|
||||
"temporary failure",
|
||||
"network unreachable",
|
||||
"EOF",
|
||||
"certificate",
|
||||
"x509",
|
||||
"tls",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// RetryExecutor implements retry logic with exponential backoff and jitter.
|
||||
// It automatically retries failed operations based on configurable error patterns
|
||||
// and uses exponential backoff to avoid overwhelming failing services.
|
||||
@@ -487,11 +516,29 @@ func (re *RetryExecutor) Execute(ctx context.Context, fn func() error) error {
|
||||
// isRetryableError checks if an error should trigger a retry
|
||||
// isRetryableError determines if an error should trigger a retry attempt.
|
||||
// Checks error message against configured retryable error patterns.
|
||||
// Also handles startup-specific errors like Traefik default certificate errors
|
||||
// and EOF errors that occur during service initialization.
|
||||
func (re *RetryExecutor) isRetryableError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for Traefik default certificate error (startup race condition)
|
||||
// See: https://github.com/lukaszraczylo/traefikoidc/issues/90
|
||||
if isTraefikDefaultCertError(err) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for EOF errors (common during startup when services aren't ready)
|
||||
if isEOFError(err) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for certificate errors (transient during startup)
|
||||
if isCertificateError(err) {
|
||||
return true
|
||||
}
|
||||
|
||||
errStr := err.Error()
|
||||
|
||||
for _, retryableErr := range re.config.RetryableErrors {
|
||||
@@ -538,6 +585,7 @@ func (re *RetryExecutor) calculateDelay(attempt int) time.Duration {
|
||||
delay = float64(re.config.MaxDelay)
|
||||
}
|
||||
|
||||
// #nosec G404 -- math/rand is acceptable for jitter timing, not security-sensitive
|
||||
if re.config.EnableJitter {
|
||||
jitter := delay * 0.1 * (2.0*rand.Float64() - 1.0)
|
||||
delay += jitter
|
||||
@@ -1087,3 +1135,86 @@ func containsSubstring(s, substr string) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isTraefikDefaultCertError detects when Traefik is serving its default self-signed
|
||||
// certificate during cold-start, before the real certificates are loaded.
|
||||
// This manifests as an x509.HostnameError where one of the certificate's DNS names
|
||||
// ends with "traefik.default" (the default Traefik certificate pattern).
|
||||
// See: https://github.com/lukaszraczylo/traefikoidc/issues/90
|
||||
func isTraefikDefaultCertError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
var hostnameErr x509.HostnameError
|
||||
if errors.As(err, &hostnameErr) {
|
||||
if hostnameErr.Certificate != nil {
|
||||
for _, name := range hostnameErr.Certificate.DNSNames {
|
||||
if strings.HasSuffix(name, "traefik.default") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// isEOFError checks if an error is an EOF error, which can occur during
|
||||
// connection establishment when the remote end closes unexpectedly.
|
||||
// This is common during service startup when endpoints aren't fully ready.
|
||||
func isEOFError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for direct EOF
|
||||
if errors.Is(err, io.EOF) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for unexpected EOF
|
||||
if errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check error message for EOF patterns (wrapped errors)
|
||||
errStr := err.Error()
|
||||
return strings.Contains(errStr, "EOF") || strings.Contains(errStr, "unexpected EOF")
|
||||
}
|
||||
|
||||
// isCertificateError checks if an error is related to TLS certificate validation.
|
||||
// These errors are often transient during startup when services are still initializing.
|
||||
func isCertificateError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for x509 certificate errors
|
||||
var certInvalidErr x509.CertificateInvalidError
|
||||
var hostnameErr x509.HostnameError
|
||||
var unknownAuthErr x509.UnknownAuthorityError
|
||||
|
||||
if errors.As(err, &certInvalidErr) ||
|
||||
errors.As(err, &hostnameErr) ||
|
||||
errors.As(err, &unknownAuthErr) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check error message for certificate patterns
|
||||
errStr := strings.ToLower(err.Error())
|
||||
certPatterns := []string{
|
||||
"certificate",
|
||||
"x509",
|
||||
"tls",
|
||||
"ssl",
|
||||
}
|
||||
|
||||
for _, pattern := range certPatterns {
|
||||
if strings.Contains(errStr, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -1,242 +0,0 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestDefaultCircuitBreakerConfig tests the default configuration function
|
||||
func TestDefaultCircuitBreakerConfig(t *testing.T) {
|
||||
config := DefaultCircuitBreakerConfig()
|
||||
|
||||
// Test default values
|
||||
if config.MaxFailures != 2 {
|
||||
t.Errorf("Expected MaxFailures 2, got %d", config.MaxFailures)
|
||||
}
|
||||
|
||||
if config.Timeout != 60*time.Second {
|
||||
t.Errorf("Expected Timeout 60s, got %v", config.Timeout)
|
||||
}
|
||||
|
||||
if config.ResetTimeout != 30*time.Second {
|
||||
t.Errorf("Expected ResetTimeout 30s, got %v", config.ResetTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBaseRecoveryMechanism_GetBaseMetrics tests getting base metrics
|
||||
func TestBaseRecoveryMechanism_GetBaseMetrics(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
base := NewBaseRecoveryMechanism("test-mechanism", logger)
|
||||
|
||||
metrics := base.GetBaseMetrics()
|
||||
|
||||
if metrics == nil {
|
||||
t.Fatal("Expected non-nil metrics")
|
||||
}
|
||||
|
||||
// Check expected metric fields
|
||||
expectedFields := []string{
|
||||
"total_requests",
|
||||
"total_failures",
|
||||
"total_successes",
|
||||
"uptime_seconds",
|
||||
"name",
|
||||
}
|
||||
|
||||
for _, field := range expectedFields {
|
||||
if _, exists := metrics[field]; !exists {
|
||||
t.Errorf("Expected metric field %s to exist", field)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestBaseRecoveryMechanism_RecordRequest tests request recording
|
||||
func TestBaseRecoveryMechanism_RecordRequest(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
base := NewBaseRecoveryMechanism("test-mechanism", logger)
|
||||
|
||||
// Record some requests
|
||||
base.RecordRequest()
|
||||
base.RecordRequest()
|
||||
base.RecordRequest()
|
||||
|
||||
// Get metrics to verify
|
||||
metrics := base.GetBaseMetrics()
|
||||
totalRequests := metrics["total_requests"].(int64)
|
||||
|
||||
if totalRequests != 3 {
|
||||
t.Errorf("Expected 3 total requests, got %d", totalRequests)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBaseRecoveryMechanism_RecordSuccess tests success recording
|
||||
func TestBaseRecoveryMechanism_RecordSuccess(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
base := NewBaseRecoveryMechanism("test-mechanism", logger)
|
||||
|
||||
// Record some successes
|
||||
base.RecordSuccess()
|
||||
base.RecordSuccess()
|
||||
|
||||
// Get metrics to verify
|
||||
metrics := base.GetBaseMetrics()
|
||||
totalSuccesses := metrics["total_successes"].(int64)
|
||||
|
||||
if totalSuccesses != 2 {
|
||||
t.Errorf("Expected 2 successful requests, got %d", totalSuccesses)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBaseRecoveryMechanism_RecordFailure tests failure recording
|
||||
func TestBaseRecoveryMechanism_RecordFailure(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
base := NewBaseRecoveryMechanism("test-mechanism", logger)
|
||||
|
||||
// Record some failures
|
||||
base.RecordFailure()
|
||||
base.RecordFailure()
|
||||
base.RecordFailure()
|
||||
|
||||
// Get metrics to verify
|
||||
metrics := base.GetBaseMetrics()
|
||||
totalFailures := metrics["total_failures"].(int64)
|
||||
|
||||
if totalFailures != 3 {
|
||||
t.Errorf("Expected 3 failed requests, got %d", totalFailures)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBaseRecoveryMechanism_LogInfo tests info logging
|
||||
func TestBaseRecoveryMechanism_LogInfo(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
base := NewBaseRecoveryMechanism("test-mechanism", logger)
|
||||
|
||||
// Test logging doesn't panic
|
||||
base.LogInfo("test message")
|
||||
base.LogInfo("test message with args: %s %d", "arg1", 42)
|
||||
|
||||
// Test with nil logger
|
||||
baseNoLogger := NewBaseRecoveryMechanism("test", nil)
|
||||
baseNoLogger.LogInfo("test message") // Should not panic
|
||||
}
|
||||
|
||||
// TestBaseRecoveryMechanism_LogError tests error logging
|
||||
func TestBaseRecoveryMechanism_LogError(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
base := NewBaseRecoveryMechanism("test-mechanism", logger)
|
||||
|
||||
// Test logging doesn't panic
|
||||
base.LogError("error message")
|
||||
base.LogError("error message with args: %s %d", "error", 500)
|
||||
|
||||
// Test with nil logger
|
||||
baseNoLogger := NewBaseRecoveryMechanism("test", nil)
|
||||
baseNoLogger.LogError("error message") // Should not panic
|
||||
}
|
||||
|
||||
// TestBaseRecoveryMechanism_LogDebug tests debug logging
|
||||
func TestBaseRecoveryMechanism_LogDebug(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
base := NewBaseRecoveryMechanism("test-mechanism", logger)
|
||||
|
||||
// Test logging doesn't panic
|
||||
base.LogDebug("debug message")
|
||||
base.LogDebug("debug message with args: %s %d", "debug", 123)
|
||||
|
||||
// Test with nil logger
|
||||
baseNoLogger := NewBaseRecoveryMechanism("test", nil)
|
||||
baseNoLogger.LogDebug("debug message") // Should not panic
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_GetState tests getting circuit breaker state
|
||||
func TestCircuitBreaker_GetState(t *testing.T) {
|
||||
config := DefaultCircuitBreakerConfig()
|
||||
logger := GetSingletonNoOpLogger()
|
||||
cb := NewCircuitBreaker(config, logger)
|
||||
|
||||
// Initial state should be closed
|
||||
state := cb.GetState()
|
||||
if state != CircuitBreakerClosed {
|
||||
t.Errorf("Expected initial state to be closed, got %d", state)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_Reset tests resetting circuit breaker
|
||||
func TestCircuitBreaker_Reset(t *testing.T) {
|
||||
config := DefaultCircuitBreakerConfig()
|
||||
logger := GetSingletonNoOpLogger()
|
||||
cb := NewCircuitBreaker(config, logger)
|
||||
|
||||
// Reset should not panic
|
||||
cb.Reset()
|
||||
|
||||
// State should be closed after reset
|
||||
state := cb.GetState()
|
||||
if state != CircuitBreakerClosed {
|
||||
t.Errorf("Expected state to be closed after reset, got %d", state)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_IsAvailable tests availability check
|
||||
func TestCircuitBreaker_IsAvailable(t *testing.T) {
|
||||
config := DefaultCircuitBreakerConfig()
|
||||
logger := GetSingletonNoOpLogger()
|
||||
cb := NewCircuitBreaker(config, logger)
|
||||
|
||||
// Initially should be available
|
||||
available := cb.IsAvailable()
|
||||
if !available {
|
||||
t.Error("Expected circuit breaker to be available initially")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_GetMetrics tests getting circuit breaker metrics
|
||||
func TestCircuitBreaker_GetMetrics(t *testing.T) {
|
||||
config := DefaultCircuitBreakerConfig()
|
||||
logger := GetSingletonNoOpLogger()
|
||||
cb := NewCircuitBreaker(config, logger)
|
||||
|
||||
metrics := cb.GetMetrics()
|
||||
if metrics == nil {
|
||||
t.Fatal("Expected non-nil metrics")
|
||||
}
|
||||
|
||||
// Should include base metrics
|
||||
if _, exists := metrics["total_requests"]; !exists {
|
||||
t.Error("Expected total_requests in metrics")
|
||||
}
|
||||
|
||||
// Should include circuit breaker specific metrics
|
||||
if _, exists := metrics["state"]; !exists {
|
||||
t.Error("Expected state in metrics")
|
||||
}
|
||||
}
|
||||
|
||||
// Retry mechanism tests removed due to complex dependencies
|
||||
|
||||
// Benchmark tests
|
||||
func BenchmarkDefaultCircuitBreakerConfig(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
DefaultCircuitBreakerConfig()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkBaseRecoveryMechanism_GetBaseMetrics(b *testing.B) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
base := NewBaseRecoveryMechanism("test-mechanism", logger)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
base.GetBaseMetrics()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkBaseRecoveryMechanism_RecordRequest(b *testing.B) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
base := NewBaseRecoveryMechanism("test-mechanism", logger)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
base.RecordRequest()
|
||||
}
|
||||
}
|
||||
@@ -1,560 +0,0 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestRetryExecutorReset tests the Reset method
|
||||
func TestRetryExecutorReset(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
executor := NewRetryExecutor(DefaultRetryConfig(), logger)
|
||||
|
||||
require.NotNil(t, executor)
|
||||
|
||||
// Should not panic
|
||||
assert.NotPanics(t, func() {
|
||||
executor.Reset()
|
||||
})
|
||||
|
||||
// Multiple resets should be safe
|
||||
executor.Reset()
|
||||
executor.Reset()
|
||||
}
|
||||
|
||||
// TestRetryExecutorIsAvailable tests the IsAvailable method
|
||||
func TestRetryExecutorIsAvailable(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
executor := NewRetryExecutor(DefaultRetryConfig(), logger)
|
||||
|
||||
// Retry executor should always be available
|
||||
assert.True(t, executor.IsAvailable())
|
||||
|
||||
// Should remain available after operations
|
||||
ctx := context.Background()
|
||||
executor.ExecuteWithContext(ctx, func() error {
|
||||
return nil
|
||||
})
|
||||
|
||||
assert.True(t, executor.IsAvailable())
|
||||
}
|
||||
|
||||
// TestSessionErrorUnwrap tests SessionError.Unwrap
|
||||
func TestSessionErrorUnwrap(t *testing.T) {
|
||||
t.Run("unwrap with cause", func(t *testing.T) {
|
||||
rootErr := errors.New("root cause")
|
||||
sessionErr := NewSessionError("save", "failed to save session", rootErr)
|
||||
|
||||
unwrapped := sessionErr.Unwrap()
|
||||
assert.Equal(t, rootErr, unwrapped)
|
||||
})
|
||||
|
||||
t.Run("unwrap without cause", func(t *testing.T) {
|
||||
sessionErr := NewSessionError("load", "failed to load session", nil)
|
||||
|
||||
unwrapped := sessionErr.Unwrap()
|
||||
assert.Nil(t, unwrapped)
|
||||
})
|
||||
|
||||
t.Run("error chain", func(t *testing.T) {
|
||||
rootErr := errors.New("database error")
|
||||
sessionErr := NewSessionError("delete", "failed to delete session", rootErr)
|
||||
|
||||
// Verify error chain works
|
||||
assert.True(t, errors.Is(sessionErr, rootErr))
|
||||
})
|
||||
}
|
||||
|
||||
// TestTokenErrorUnwrap tests TokenError.Unwrap
|
||||
func TestTokenErrorUnwrap(t *testing.T) {
|
||||
t.Run("unwrap with cause", func(t *testing.T) {
|
||||
rootErr := errors.New("signature verification failed")
|
||||
tokenErr := NewTokenError("id_token", "invalid", "token is invalid", rootErr)
|
||||
|
||||
unwrapped := tokenErr.Unwrap()
|
||||
assert.Equal(t, rootErr, unwrapped)
|
||||
})
|
||||
|
||||
t.Run("unwrap without cause", func(t *testing.T) {
|
||||
tokenErr := NewTokenError("access_token", "expired", "token has expired", nil)
|
||||
|
||||
unwrapped := tokenErr.Unwrap()
|
||||
assert.Nil(t, unwrapped)
|
||||
})
|
||||
|
||||
t.Run("error chain", func(t *testing.T) {
|
||||
rootErr := errors.New("crypto error")
|
||||
tokenErr := NewTokenError("refresh_token", "malformed", "token is malformed", rootErr)
|
||||
|
||||
// Verify error chain works
|
||||
assert.True(t, errors.Is(tokenErr, rootErr))
|
||||
})
|
||||
}
|
||||
|
||||
// TestGracefulDegradationRegisterFallback tests fallback registration
|
||||
func TestGracefulDegradationRegisterFallback(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultGracefulDegradationConfig()
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
t.Run("register single fallback", func(t *testing.T) {
|
||||
fallback := func() (interface{}, error) {
|
||||
return "fallback result", nil
|
||||
}
|
||||
|
||||
gd.RegisterFallback("service1", fallback)
|
||||
|
||||
// Verify fallback was registered (indirectly)
|
||||
result, err := gd.ExecuteWithFallback("service1", func() (interface{}, error) {
|
||||
return nil, errors.New("service failed")
|
||||
})
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "fallback result", result)
|
||||
})
|
||||
|
||||
t.Run("register multiple fallbacks", func(t *testing.T) {
|
||||
gd.RegisterFallback("service2", func() (interface{}, error) {
|
||||
return "fallback2", nil
|
||||
})
|
||||
gd.RegisterFallback("service3", func() (interface{}, error) {
|
||||
return "fallback3", nil
|
||||
})
|
||||
|
||||
result2, _ := gd.ExecuteWithFallback("service2", func() (interface{}, error) {
|
||||
return nil, errors.New("fail")
|
||||
})
|
||||
result3, _ := gd.ExecuteWithFallback("service3", func() (interface{}, error) {
|
||||
return nil, errors.New("fail")
|
||||
})
|
||||
|
||||
assert.Equal(t, "fallback2", result2)
|
||||
assert.Equal(t, "fallback3", result3)
|
||||
})
|
||||
|
||||
t.Run("override existing fallback", func(t *testing.T) {
|
||||
gd.RegisterFallback("service4", func() (interface{}, error) {
|
||||
return "old fallback", nil
|
||||
})
|
||||
gd.RegisterFallback("service4", func() (interface{}, error) {
|
||||
return "new fallback", nil
|
||||
})
|
||||
|
||||
result, _ := gd.ExecuteWithFallback("service4", func() (interface{}, error) {
|
||||
return nil, errors.New("fail")
|
||||
})
|
||||
|
||||
assert.Equal(t, "new fallback", result)
|
||||
})
|
||||
}
|
||||
|
||||
// TestGracefulDegradationRegisterHealthCheck tests health check registration
|
||||
func TestGracefulDegradationRegisterHealthCheck(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultGracefulDegradationConfig()
|
||||
config.HealthCheckInterval = 50 * time.Millisecond
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
t.Run("register health check", func(t *testing.T) {
|
||||
healthy := true
|
||||
healthCheck := func() bool {
|
||||
return healthy
|
||||
}
|
||||
|
||||
gd.RegisterHealthCheck("service1", healthCheck)
|
||||
|
||||
// Mark service as degraded
|
||||
gd.markServiceDegraded("service1")
|
||||
assert.True(t, gd.isServiceDegraded("service1"))
|
||||
|
||||
// Set healthy and wait for health check to run
|
||||
healthy = true
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Service should be recovered
|
||||
// (may still be degraded due to timing, but health check was registered)
|
||||
})
|
||||
|
||||
t.Run("multiple health checks", func(t *testing.T) {
|
||||
gd.RegisterHealthCheck("service2", func() bool { return true })
|
||||
gd.RegisterHealthCheck("service3", func() bool { return false })
|
||||
|
||||
// Health checks are registered and will be called periodically
|
||||
})
|
||||
}
|
||||
|
||||
// TestGracefulDegradationExecuteWithContext tests ExecuteWithContext
|
||||
func TestGracefulDegradationExecuteWithContext(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultGracefulDegradationConfig()
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
t.Run("successful execution", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
err := gd.ExecuteWithContext(ctx, func() error {
|
||||
return nil
|
||||
})
|
||||
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("failed execution", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
testErr := errors.New("operation failed")
|
||||
|
||||
err := gd.ExecuteWithContext(ctx, func() error {
|
||||
return testErr
|
||||
})
|
||||
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("uses fallback on failure", func(t *testing.T) {
|
||||
gd.RegisterFallback("default", func() (interface{}, error) {
|
||||
return nil, nil // Success fallback
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
err := gd.ExecuteWithContext(ctx, func() error {
|
||||
return errors.New("primary failed")
|
||||
})
|
||||
|
||||
// With fallback succeeding, overall operation succeeds
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// TestGracefulDegradationExecuteWithFallback tests ExecuteWithFallback
|
||||
func TestGracefulDegradationExecuteWithFallback(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultGracefulDegradationConfig()
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
t.Run("primary succeeds", func(t *testing.T) {
|
||||
result, err := gd.ExecuteWithFallback("service1", func() (interface{}, error) {
|
||||
return "primary result", nil
|
||||
})
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "primary result", result)
|
||||
})
|
||||
|
||||
t.Run("fallback succeeds when primary fails", func(t *testing.T) {
|
||||
gd.RegisterFallback("service2", func() (interface{}, error) {
|
||||
return "fallback result", nil
|
||||
})
|
||||
|
||||
result, err := gd.ExecuteWithFallback("service2", func() (interface{}, error) {
|
||||
return nil, errors.New("primary failed")
|
||||
})
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "fallback result", result)
|
||||
})
|
||||
|
||||
t.Run("error when no fallback available", func(t *testing.T) {
|
||||
config.EnableFallbacks = false
|
||||
gdNoFallback := NewGracefulDegradation(config, logger)
|
||||
defer gdNoFallback.Close()
|
||||
|
||||
result, err := gdNoFallback.ExecuteWithFallback("service3", func() (interface{}, error) {
|
||||
return nil, errors.New("primary failed")
|
||||
})
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
})
|
||||
|
||||
t.Run("fallback also fails", func(t *testing.T) {
|
||||
gd.RegisterFallback("service4", func() (interface{}, error) {
|
||||
return nil, errors.New("fallback also failed")
|
||||
})
|
||||
|
||||
result, err := gd.ExecuteWithFallback("service4", func() (interface{}, error) {
|
||||
return nil, errors.New("primary failed")
|
||||
})
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
assert.Contains(t, err.Error(), "fallback also failed")
|
||||
})
|
||||
}
|
||||
|
||||
// TestGracefulDegradationIsServiceDegraded tests service degradation status
|
||||
func TestGracefulDegradationIsServiceDegraded(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultGracefulDegradationConfig()
|
||||
config.RecoveryTimeout = 100 * time.Millisecond
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
t.Run("service not degraded initially", func(t *testing.T) {
|
||||
assert.False(t, gd.isServiceDegraded("new-service"))
|
||||
})
|
||||
|
||||
t.Run("service degraded after marking", func(t *testing.T) {
|
||||
gd.markServiceDegraded("service1")
|
||||
assert.True(t, gd.isServiceDegraded("service1"))
|
||||
})
|
||||
|
||||
t.Run("service recovers after timeout", func(t *testing.T) {
|
||||
gd.markServiceDegraded("service2")
|
||||
assert.True(t, gd.isServiceDegraded("service2"))
|
||||
|
||||
// Wait for recovery timeout
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// Should be recovered
|
||||
assert.False(t, gd.isServiceDegraded("service2"))
|
||||
})
|
||||
}
|
||||
|
||||
// TestGracefulDegradationMarkServiceDegraded tests marking services as degraded
|
||||
func TestGracefulDegradationMarkServiceDegraded(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultGracefulDegradationConfig()
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
t.Run("mark single service", func(t *testing.T) {
|
||||
gd.markServiceDegraded("service1")
|
||||
|
||||
degraded := gd.GetDegradedServices()
|
||||
assert.Contains(t, degraded, "service1")
|
||||
})
|
||||
|
||||
t.Run("mark multiple services", func(t *testing.T) {
|
||||
gd.markServiceDegraded("service2")
|
||||
gd.markServiceDegraded("service3")
|
||||
|
||||
degraded := gd.GetDegradedServices()
|
||||
assert.Contains(t, degraded, "service2")
|
||||
assert.Contains(t, degraded, "service3")
|
||||
})
|
||||
|
||||
t.Run("marking same service multiple times updates timestamp", func(t *testing.T) {
|
||||
gd.markServiceDegraded("service4")
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
gd.markServiceDegraded("service4")
|
||||
|
||||
// Service should still be marked as degraded
|
||||
assert.True(t, gd.isServiceDegraded("service4"))
|
||||
})
|
||||
}
|
||||
|
||||
// TestGracefulDegradationExecuteFallback tests fallback execution
|
||||
func TestGracefulDegradationExecuteFallback(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultGracefulDegradationConfig()
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
t.Run("execute registered fallback", func(t *testing.T) {
|
||||
gd.RegisterFallback("service1", func() (interface{}, error) {
|
||||
return "fallback value", nil
|
||||
})
|
||||
|
||||
result, err := gd.executeFallback("service1")
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "fallback value", result)
|
||||
})
|
||||
|
||||
t.Run("error when fallback not registered", func(t *testing.T) {
|
||||
result, err := gd.executeFallback("non-existent-service")
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
assert.Contains(t, err.Error(), "no fallback available")
|
||||
})
|
||||
|
||||
t.Run("propagate fallback errors", func(t *testing.T) {
|
||||
gd.RegisterFallback("service2", func() (interface{}, error) {
|
||||
return nil, errors.New("fallback error")
|
||||
})
|
||||
|
||||
result, err := gd.executeFallback("service2")
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
assert.Contains(t, err.Error(), "fallback error")
|
||||
})
|
||||
}
|
||||
|
||||
// TestGracefulDegradationReset tests Reset method
|
||||
func TestGracefulDegradationReset(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultGracefulDegradationConfig()
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
t.Run("reset clears degraded services", func(t *testing.T) {
|
||||
// Mark several services as degraded
|
||||
gd.markServiceDegraded("service1")
|
||||
gd.markServiceDegraded("service2")
|
||||
gd.markServiceDegraded("service3")
|
||||
|
||||
assert.Len(t, gd.GetDegradedServices(), 3)
|
||||
|
||||
// Reset
|
||||
gd.Reset()
|
||||
|
||||
// All should be cleared
|
||||
assert.Len(t, gd.GetDegradedServices(), 0)
|
||||
})
|
||||
|
||||
t.Run("can mark services degraded after reset", func(t *testing.T) {
|
||||
gd.Reset()
|
||||
gd.markServiceDegraded("service4")
|
||||
|
||||
assert.Len(t, gd.GetDegradedServices(), 1)
|
||||
assert.Contains(t, gd.GetDegradedServices(), "service4")
|
||||
})
|
||||
|
||||
t.Run("multiple resets are safe", func(t *testing.T) {
|
||||
assert.NotPanics(t, func() {
|
||||
gd.Reset()
|
||||
gd.Reset()
|
||||
gd.Reset()
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// TestGracefulDegradationIsAvailable tests IsAvailable method
|
||||
func TestGracefulDegradationIsAvailable(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultGracefulDegradationConfig()
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
// Should always return true
|
||||
assert.True(t, gd.IsAvailable())
|
||||
|
||||
// Even with degraded services
|
||||
gd.markServiceDegraded("service1")
|
||||
assert.True(t, gd.IsAvailable())
|
||||
|
||||
// Even after reset
|
||||
gd.Reset()
|
||||
assert.True(t, gd.IsAvailable())
|
||||
}
|
||||
|
||||
// TestGracefulDegradationGetMetrics tests GetMetrics method
|
||||
func TestGracefulDegradationGetMetrics(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultGracefulDegradationConfig()
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
t.Run("basic metrics", func(t *testing.T) {
|
||||
metrics := gd.GetMetrics()
|
||||
|
||||
require.NotNil(t, metrics)
|
||||
assert.Contains(t, metrics, "degraded_services_count")
|
||||
assert.Contains(t, metrics, "degraded_services")
|
||||
assert.Contains(t, metrics, "registered_fallbacks_count")
|
||||
assert.Contains(t, metrics, "registered_health_checks_count")
|
||||
assert.Contains(t, metrics, "health_check_interval_seconds")
|
||||
assert.Contains(t, metrics, "recovery_timeout_seconds")
|
||||
assert.Contains(t, metrics, "fallbacks_enabled")
|
||||
})
|
||||
|
||||
t.Run("metrics reflect degraded services", func(t *testing.T) {
|
||||
gd.Reset()
|
||||
gd.markServiceDegraded("service1")
|
||||
gd.markServiceDegraded("service2")
|
||||
|
||||
metrics := gd.GetMetrics()
|
||||
|
||||
assert.Equal(t, 2, metrics["degraded_services_count"])
|
||||
degradedList := metrics["degraded_services"].([]string)
|
||||
assert.Len(t, degradedList, 2)
|
||||
})
|
||||
|
||||
t.Run("metrics reflect registered fallbacks", func(t *testing.T) {
|
||||
gd.RegisterFallback("service1", func() (interface{}, error) { return nil, nil })
|
||||
gd.RegisterFallback("service2", func() (interface{}, error) { return nil, nil })
|
||||
|
||||
metrics := gd.GetMetrics()
|
||||
|
||||
assert.GreaterOrEqual(t, metrics["registered_fallbacks_count"], 2)
|
||||
})
|
||||
|
||||
t.Run("metrics include base metrics", func(t *testing.T) {
|
||||
metrics := gd.GetMetrics()
|
||||
|
||||
// Should include base recovery mechanism metrics
|
||||
assert.Contains(t, metrics, "name")
|
||||
assert.Contains(t, metrics, "uptime_seconds")
|
||||
assert.Contains(t, metrics, "total_requests")
|
||||
})
|
||||
}
|
||||
|
||||
// TestGracefulDegradationFullScenario tests a complete degradation scenario
|
||||
func TestGracefulDegradationFullScenario(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping full scenario test in short mode")
|
||||
}
|
||||
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultGracefulDegradationConfig()
|
||||
config.RecoveryTimeout = 200 * time.Millisecond
|
||||
config.HealthCheckInterval = 50 * time.Millisecond
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
// Register fallback
|
||||
gd.RegisterFallback("critical-service", func() (interface{}, error) {
|
||||
return "fallback data", nil
|
||||
})
|
||||
|
||||
// Register health check
|
||||
serviceHealthy := false
|
||||
gd.RegisterHealthCheck("critical-service", func() bool {
|
||||
return serviceHealthy
|
||||
})
|
||||
|
||||
// First call - primary succeeds
|
||||
result1, err1 := gd.ExecuteWithFallback("critical-service", func() (interface{}, error) {
|
||||
return "primary data", nil
|
||||
})
|
||||
assert.NoError(t, err1)
|
||||
assert.Equal(t, "primary data", result1)
|
||||
|
||||
// Second call - primary fails, fallback succeeds
|
||||
result2, err2 := gd.ExecuteWithFallback("critical-service", func() (interface{}, error) {
|
||||
return nil, errors.New("service down")
|
||||
})
|
||||
assert.NoError(t, err2)
|
||||
assert.Equal(t, "fallback data", result2)
|
||||
|
||||
// Service is now degraded
|
||||
assert.True(t, gd.isServiceDegraded("critical-service"))
|
||||
|
||||
// Third call - should use fallback immediately
|
||||
result3, err3 := gd.ExecuteWithFallback("critical-service", func() (interface{}, error) {
|
||||
return "should not be called", nil
|
||||
})
|
||||
assert.NoError(t, err3)
|
||||
assert.Equal(t, "fallback data", result3)
|
||||
|
||||
// Mark service as healthy and wait for health check
|
||||
serviceHealthy = true
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
|
||||
// Service should be recovered
|
||||
// (timing-dependent, so we don't assert)
|
||||
|
||||
// Get metrics
|
||||
metrics := gd.GetMetrics()
|
||||
assert.NotNil(t, metrics)
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
package traefikoidc
|
||||
|
||||
import "testing"
|
||||
|
||||
func BenchmarkDefaultCircuitBreakerConfig(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
DefaultCircuitBreakerConfig()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkBaseRecoveryMechanism_GetBaseMetrics(b *testing.B) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
base := NewBaseRecoveryMechanism("test-mechanism", logger)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
base.GetBaseMetrics()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkBaseRecoveryMechanism_RecordRequest(b *testing.B) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
base := NewBaseRecoveryMechanism("test-mechanism", logger)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
base.RecordRequest()
|
||||
}
|
||||
}
|
||||
@@ -1,663 +0,0 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// TestCircuitBreakerAllowRequestEdgeCases tests edge cases in circuit breaker request allowing
|
||||
func TestCircuitBreakerAllowRequestEdgeCases(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
|
||||
t.Run("invalid state returns false", func(t *testing.T) {
|
||||
config := DefaultCircuitBreakerConfig()
|
||||
cb := NewCircuitBreaker(config, logger)
|
||||
|
||||
// Force invalid state
|
||||
cb.mutex.Lock()
|
||||
cb.state = CircuitBreakerState(999) // Invalid state
|
||||
cb.mutex.Unlock()
|
||||
|
||||
// Should return false for invalid state
|
||||
allowed := cb.allowRequest()
|
||||
assert.False(t, allowed, "invalid state should not allow requests")
|
||||
})
|
||||
|
||||
t.Run("open to half-open transition on timeout", func(t *testing.T) {
|
||||
baseTimeout := GetTestDuration(50 * time.Millisecond)
|
||||
config := CircuitBreakerConfig{
|
||||
MaxFailures: 1,
|
||||
Timeout: baseTimeout,
|
||||
ResetTimeout: 30 * time.Second,
|
||||
}
|
||||
cb := NewCircuitBreaker(config, logger)
|
||||
|
||||
// Trip the circuit
|
||||
cb.Execute(func() error { return errors.New("fail") })
|
||||
|
||||
// Verify circuit is open
|
||||
assert.Equal(t, CircuitBreakerOpen, cb.GetState())
|
||||
assert.False(t, cb.allowRequest())
|
||||
|
||||
// Wait for timeout (longer than timeout to ensure transition)
|
||||
time.Sleep(baseTimeout + GetTestDuration(20*time.Millisecond))
|
||||
|
||||
// Should transition to half-open
|
||||
allowed := cb.allowRequest()
|
||||
assert.True(t, allowed, "should allow request after timeout")
|
||||
assert.Equal(t, CircuitBreakerHalfOpen, cb.GetState())
|
||||
})
|
||||
|
||||
t.Run("half-open allows requests", func(t *testing.T) {
|
||||
config := DefaultCircuitBreakerConfig()
|
||||
cb := NewCircuitBreaker(config, logger)
|
||||
|
||||
// Manually set to half-open
|
||||
cb.mutex.Lock()
|
||||
cb.state = CircuitBreakerHalfOpen
|
||||
cb.mutex.Unlock()
|
||||
|
||||
allowed := cb.allowRequest()
|
||||
assert.True(t, allowed, "half-open should allow requests")
|
||||
})
|
||||
|
||||
t.Run("open blocks requests before timeout", func(t *testing.T) {
|
||||
config := CircuitBreakerConfig{
|
||||
MaxFailures: 1,
|
||||
Timeout: 1 * time.Hour, // Long timeout
|
||||
ResetTimeout: 30 * time.Second,
|
||||
}
|
||||
cb := NewCircuitBreaker(config, logger)
|
||||
|
||||
// Trip the circuit
|
||||
cb.Execute(func() error { return errors.New("fail") })
|
||||
|
||||
// Should be blocked
|
||||
allowed := cb.allowRequest()
|
||||
assert.False(t, allowed, "open circuit should block requests")
|
||||
})
|
||||
}
|
||||
|
||||
// TestRetryExecutorIsRetryableErrorEdgeCases tests edge cases for retry decision
|
||||
func TestRetryExecutorIsRetryableErrorEdgeCases(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultRetryConfig()
|
||||
re := NewRetryExecutor(config, logger)
|
||||
|
||||
t.Run("nil error is not retryable", func(t *testing.T) {
|
||||
retryable := re.isRetryableError(nil)
|
||||
assert.False(t, retryable)
|
||||
})
|
||||
|
||||
t.Run("HTTPError with 429 is retryable", func(t *testing.T) {
|
||||
httpErr := &HTTPError{
|
||||
StatusCode: 429,
|
||||
Message: "Too Many Requests",
|
||||
}
|
||||
|
||||
retryable := re.isRetryableError(httpErr)
|
||||
assert.True(t, retryable, "429 Too Many Requests should be retryable")
|
||||
})
|
||||
|
||||
t.Run("HTTPError with 500 is retryable", func(t *testing.T) {
|
||||
httpErr := &HTTPError{
|
||||
StatusCode: 500,
|
||||
Message: "Internal Server Error",
|
||||
}
|
||||
|
||||
retryable := re.isRetryableError(httpErr)
|
||||
assert.True(t, retryable, "500 errors should be retryable")
|
||||
})
|
||||
|
||||
t.Run("HTTPError with 503 is retryable", func(t *testing.T) {
|
||||
httpErr := &HTTPError{
|
||||
StatusCode: 503,
|
||||
Message: "Service Unavailable",
|
||||
}
|
||||
|
||||
retryable := re.isRetryableError(httpErr)
|
||||
assert.True(t, retryable, "503 errors should be retryable")
|
||||
})
|
||||
|
||||
t.Run("HTTPError with 400 is not retryable", func(t *testing.T) {
|
||||
httpErr := &HTTPError{
|
||||
StatusCode: 400,
|
||||
Message: "Bad Request",
|
||||
}
|
||||
|
||||
retryable := re.isRetryableError(httpErr)
|
||||
assert.False(t, retryable, "400 errors should not be retryable")
|
||||
})
|
||||
|
||||
t.Run("net.Error with timeout is retryable", func(t *testing.T) {
|
||||
netErr := &mockNetError{
|
||||
timeout: true,
|
||||
temporary: false,
|
||||
msg: "timeout error",
|
||||
}
|
||||
|
||||
retryable := re.isRetryableError(netErr)
|
||||
assert.True(t, retryable, "timeout errors should be retryable")
|
||||
})
|
||||
|
||||
t.Run("net.Error with connection refused is retryable", func(t *testing.T) {
|
||||
netErr := &mockNetError{
|
||||
timeout: false,
|
||||
temporary: false,
|
||||
msg: "connection refused",
|
||||
}
|
||||
|
||||
retryable := re.isRetryableError(netErr)
|
||||
assert.True(t, retryable, "connection refused should be retryable")
|
||||
})
|
||||
|
||||
t.Run("net.Error with connection reset is retryable", func(t *testing.T) {
|
||||
netErr := &mockNetError{
|
||||
timeout: false,
|
||||
temporary: false,
|
||||
msg: "connection reset by peer",
|
||||
}
|
||||
|
||||
retryable := re.isRetryableError(netErr)
|
||||
assert.True(t, retryable, "connection reset should be retryable")
|
||||
})
|
||||
|
||||
t.Run("net.Error with network unreachable is retryable", func(t *testing.T) {
|
||||
netErr := &mockNetError{
|
||||
timeout: false,
|
||||
temporary: false,
|
||||
msg: "network is unreachable",
|
||||
}
|
||||
|
||||
retryable := re.isRetryableError(netErr)
|
||||
assert.True(t, retryable, "network unreachable should be retryable")
|
||||
})
|
||||
|
||||
t.Run("net.Error with no route to host is retryable", func(t *testing.T) {
|
||||
netErr := &mockNetError{
|
||||
timeout: false,
|
||||
temporary: false,
|
||||
msg: "no route to host",
|
||||
}
|
||||
|
||||
retryable := re.isRetryableError(netErr)
|
||||
assert.True(t, retryable, "no route to host should be retryable")
|
||||
})
|
||||
|
||||
t.Run("net.Error with temporary failure is retryable", func(t *testing.T) {
|
||||
netErr := &mockNetError{
|
||||
timeout: false,
|
||||
temporary: false,
|
||||
msg: "temporary failure in name resolution",
|
||||
}
|
||||
|
||||
retryable := re.isRetryableError(netErr)
|
||||
assert.True(t, retryable, "temporary failure should be retryable")
|
||||
})
|
||||
|
||||
t.Run("net.Error with try again is retryable", func(t *testing.T) {
|
||||
netErr := &mockNetError{
|
||||
timeout: false,
|
||||
temporary: false,
|
||||
msg: "try again later",
|
||||
}
|
||||
|
||||
retryable := re.isRetryableError(netErr)
|
||||
assert.True(t, retryable, "try again should be retryable")
|
||||
})
|
||||
|
||||
t.Run("net.Error with resource temporarily unavailable is retryable", func(t *testing.T) {
|
||||
netErr := &mockNetError{
|
||||
timeout: false,
|
||||
temporary: false,
|
||||
msg: "resource temporarily unavailable",
|
||||
}
|
||||
|
||||
retryable := re.isRetryableError(netErr)
|
||||
assert.True(t, retryable, "resource temporarily unavailable should be retryable")
|
||||
})
|
||||
|
||||
t.Run("configured retryable error patterns", func(t *testing.T) {
|
||||
err := errors.New("connection refused by server")
|
||||
|
||||
retryable := re.isRetryableError(err)
|
||||
assert.True(t, retryable, "configured pattern should be retryable")
|
||||
})
|
||||
|
||||
t.Run("non-retryable error", func(t *testing.T) {
|
||||
err := errors.New("invalid input data")
|
||||
|
||||
retryable := re.isRetryableError(err)
|
||||
assert.False(t, retryable, "non-configured error should not be retryable")
|
||||
})
|
||||
}
|
||||
|
||||
// TestRetryExecutorCalculateDelayEdgeCases tests delay calculation edge cases
|
||||
func TestRetryExecutorCalculateDelayEdgeCases(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
|
||||
t.Run("delay calculation without jitter", func(t *testing.T) {
|
||||
config := RetryConfig{
|
||||
MaxAttempts: 3,
|
||||
InitialDelay: 100 * time.Millisecond,
|
||||
MaxDelay: 5 * time.Second,
|
||||
BackoffFactor: 2.0,
|
||||
EnableJitter: false, // Jitter disabled
|
||||
}
|
||||
re := NewRetryExecutor(config, logger)
|
||||
|
||||
// Attempt 1: 100ms * 2^0 = 100ms
|
||||
delay1 := re.calculateDelay(1)
|
||||
assert.Equal(t, 100*time.Millisecond, delay1)
|
||||
|
||||
// Attempt 2: 100ms * 2^1 = 200ms
|
||||
delay2 := re.calculateDelay(2)
|
||||
assert.Equal(t, 200*time.Millisecond, delay2)
|
||||
|
||||
// Attempt 3: 100ms * 2^2 = 400ms
|
||||
delay3 := re.calculateDelay(3)
|
||||
assert.Equal(t, 400*time.Millisecond, delay3)
|
||||
})
|
||||
|
||||
t.Run("delay calculation with jitter", func(t *testing.T) {
|
||||
config := RetryConfig{
|
||||
MaxAttempts: 3,
|
||||
InitialDelay: 100 * time.Millisecond,
|
||||
MaxDelay: 5 * time.Second,
|
||||
BackoffFactor: 2.0,
|
||||
EnableJitter: true, // Jitter enabled
|
||||
}
|
||||
re := NewRetryExecutor(config, logger)
|
||||
|
||||
// With jitter, delay should be within 10% of expected
|
||||
delay := re.calculateDelay(2)
|
||||
expectedBase := 200 * time.Millisecond
|
||||
minDelay := time.Duration(float64(expectedBase) * 0.9)
|
||||
maxDelay := time.Duration(float64(expectedBase) * 1.1)
|
||||
|
||||
assert.GreaterOrEqual(t, delay, minDelay, "delay should be >= 90% of base")
|
||||
assert.LessOrEqual(t, delay, maxDelay, "delay should be <= 110% of base")
|
||||
})
|
||||
|
||||
t.Run("delay capped at max delay", func(t *testing.T) {
|
||||
config := RetryConfig{
|
||||
MaxAttempts: 10,
|
||||
InitialDelay: 100 * time.Millisecond,
|
||||
MaxDelay: 500 * time.Millisecond, // Low max delay
|
||||
BackoffFactor: 2.0,
|
||||
EnableJitter: false,
|
||||
}
|
||||
re := NewRetryExecutor(config, logger)
|
||||
|
||||
// Attempt 10: would be 100ms * 2^9 = 51200ms, but capped at 500ms
|
||||
delay := re.calculateDelay(10)
|
||||
assert.Equal(t, 500*time.Millisecond, delay, "delay should be capped at max")
|
||||
})
|
||||
|
||||
t.Run("delay with large backoff factor", func(t *testing.T) {
|
||||
config := RetryConfig{
|
||||
MaxAttempts: 5,
|
||||
InitialDelay: 50 * time.Millisecond,
|
||||
MaxDelay: 10 * time.Second,
|
||||
BackoffFactor: 3.0, // Larger backoff
|
||||
EnableJitter: false,
|
||||
}
|
||||
re := NewRetryExecutor(config, logger)
|
||||
|
||||
// Attempt 3: 50ms * 3^2 = 450ms
|
||||
delay := re.calculateDelay(3)
|
||||
assert.Equal(t, 450*time.Millisecond, delay)
|
||||
})
|
||||
}
|
||||
|
||||
// TestErrorTypesErrorMethodsWithoutCause tests error type Error() methods without cause
|
||||
func TestErrorTypesErrorMethodsWithoutCause(t *testing.T) {
|
||||
t.Run("HTTPError.Error without cause", func(t *testing.T) {
|
||||
httpErr := &HTTPError{
|
||||
StatusCode: 404,
|
||||
Message: "Not Found",
|
||||
}
|
||||
|
||||
errStr := httpErr.Error()
|
||||
assert.Equal(t, "HTTP 404: Not Found", errStr)
|
||||
})
|
||||
|
||||
t.Run("HTTPError.Error with different status codes", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
code int
|
||||
message string
|
||||
expected string
|
||||
}{
|
||||
{200, "OK", "HTTP 200: OK"},
|
||||
{301, "Moved", "HTTP 301: Moved"},
|
||||
{401, "Unauthorized", "HTTP 401: Unauthorized"},
|
||||
{500, "Server Error", "HTTP 500: Server Error"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
httpErr := &HTTPError{
|
||||
StatusCode: tc.code,
|
||||
Message: tc.message,
|
||||
}
|
||||
assert.Equal(t, tc.expected, httpErr.Error())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("OIDCError.Error without cause", func(t *testing.T) {
|
||||
oidcErr := &OIDCError{
|
||||
Code: "invalid_token",
|
||||
Message: "Token validation failed",
|
||||
Context: make(map[string]interface{}),
|
||||
}
|
||||
|
||||
errStr := oidcErr.Error()
|
||||
assert.Equal(t, "OIDC error [invalid_token]: Token validation failed", errStr)
|
||||
})
|
||||
|
||||
t.Run("OIDCError.Error with cause", func(t *testing.T) {
|
||||
rootErr := errors.New("signature mismatch")
|
||||
oidcErr := &OIDCError{
|
||||
Code: "invalid_signature",
|
||||
Message: "JWT signature invalid",
|
||||
Context: make(map[string]interface{}),
|
||||
Cause: rootErr,
|
||||
}
|
||||
|
||||
errStr := oidcErr.Error()
|
||||
assert.Contains(t, errStr, "OIDC error [invalid_signature]: JWT signature invalid")
|
||||
assert.Contains(t, errStr, "caused by: signature mismatch")
|
||||
})
|
||||
|
||||
t.Run("SessionError.Error without cause", func(t *testing.T) {
|
||||
sessErr := &SessionError{
|
||||
Operation: "load",
|
||||
Message: "Session not found",
|
||||
SessionID: "sess123",
|
||||
}
|
||||
|
||||
errStr := sessErr.Error()
|
||||
assert.Equal(t, "Session error in load: Session not found", errStr)
|
||||
})
|
||||
|
||||
t.Run("SessionError.Error with cause", func(t *testing.T) {
|
||||
rootErr := errors.New("database connection failed")
|
||||
sessErr := &SessionError{
|
||||
Operation: "save",
|
||||
Message: "Failed to persist session",
|
||||
SessionID: "sess456",
|
||||
Cause: rootErr,
|
||||
}
|
||||
|
||||
errStr := sessErr.Error()
|
||||
assert.Contains(t, errStr, "Session error in save: Failed to persist session")
|
||||
assert.Contains(t, errStr, "caused by: database connection failed")
|
||||
})
|
||||
|
||||
t.Run("TokenError.Error without cause", func(t *testing.T) {
|
||||
tokenErr := &TokenError{
|
||||
TokenType: "access_token",
|
||||
Reason: "expired",
|
||||
Message: "Token has expired",
|
||||
}
|
||||
|
||||
errStr := tokenErr.Error()
|
||||
assert.Equal(t, "Token error (access_token) - expired: Token has expired", errStr)
|
||||
})
|
||||
|
||||
t.Run("TokenError.Error with cause", func(t *testing.T) {
|
||||
rootErr := errors.New("time check failed")
|
||||
tokenErr := &TokenError{
|
||||
TokenType: "id_token",
|
||||
Reason: "expired",
|
||||
Message: "Token validity period exceeded",
|
||||
Cause: rootErr,
|
||||
}
|
||||
|
||||
errStr := tokenErr.Error()
|
||||
assert.Contains(t, errStr, "Token error (id_token) - expired: Token validity period exceeded")
|
||||
assert.Contains(t, errStr, "caused by: time check failed")
|
||||
})
|
||||
}
|
||||
|
||||
// TestGracefulDegradationHealthChecks tests health check functionality
|
||||
func TestGracefulDegradationHealthChecks(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
|
||||
t.Run("performHealthChecks recovers degraded service", func(t *testing.T) {
|
||||
config := DefaultGracefulDegradationConfig()
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
// Register health check that returns true
|
||||
healthCheckCalled := false
|
||||
gd.RegisterHealthCheck("test-service", func() bool {
|
||||
healthCheckCalled = true
|
||||
return true // Service is healthy
|
||||
})
|
||||
|
||||
// Mark service as degraded
|
||||
gd.markServiceDegraded("test-service")
|
||||
|
||||
// Verify service is degraded
|
||||
assert.True(t, gd.isServiceDegraded("test-service"))
|
||||
|
||||
// Manually trigger health check
|
||||
gd.performHealthChecks()
|
||||
|
||||
// Health check should have been called
|
||||
assert.True(t, healthCheckCalled, "health check should be called")
|
||||
|
||||
// Service should be recovered
|
||||
assert.False(t, gd.isServiceDegraded("test-service"), "service should be recovered")
|
||||
})
|
||||
|
||||
t.Run("performHealthChecks marks service degraded on failure", func(t *testing.T) {
|
||||
config := DefaultGracefulDegradationConfig()
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
// Register health check that returns false
|
||||
gd.RegisterHealthCheck("failing-service", func() bool {
|
||||
return false // Service is unhealthy
|
||||
})
|
||||
|
||||
// Initially not degraded
|
||||
assert.False(t, gd.isServiceDegraded("failing-service"))
|
||||
|
||||
// Manually trigger health check
|
||||
gd.performHealthChecks()
|
||||
|
||||
// Service should be marked degraded
|
||||
assert.True(t, gd.isServiceDegraded("failing-service"), "service should be degraded")
|
||||
})
|
||||
|
||||
t.Run("performHealthChecks runs multiple health checks independently", func(t *testing.T) {
|
||||
config := DefaultGracefulDegradationConfig()
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
service1Checked := false
|
||||
service2Checked := false
|
||||
|
||||
gd.RegisterHealthCheck("service1", func() bool {
|
||||
service1Checked = true
|
||||
return true
|
||||
})
|
||||
|
||||
gd.RegisterHealthCheck("service2", func() bool {
|
||||
service2Checked = true
|
||||
return true
|
||||
})
|
||||
|
||||
// Manually trigger health checks
|
||||
gd.performHealthChecks()
|
||||
|
||||
assert.True(t, service1Checked, "service1 health check should run")
|
||||
assert.True(t, service2Checked, "service2 health check should run")
|
||||
})
|
||||
|
||||
t.Run("performHealthChecks handles empty health checks", func(t *testing.T) {
|
||||
config := DefaultGracefulDegradationConfig()
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
// Call performHealthChecks with no registered health checks
|
||||
// Should not panic
|
||||
assert.NotPanics(t, func() {
|
||||
gd.performHealthChecks()
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// TestGracefulDegradationServiceRecoveryTimeout tests recovery timeout behavior
|
||||
func TestGracefulDegradationServiceRecoveryTimeout(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
|
||||
t.Run("service auto-recovers after timeout", func(t *testing.T) {
|
||||
baseTimeout := GetTestDuration(50 * time.Millisecond)
|
||||
config := GracefulDegradationConfig{
|
||||
HealthCheckInterval: 1 * time.Hour, // Long interval, won't run during test
|
||||
RecoveryTimeout: baseTimeout,
|
||||
EnableFallbacks: true,
|
||||
}
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
// Mark service degraded
|
||||
gd.markServiceDegraded("auto-recover-service")
|
||||
|
||||
// Verify degraded
|
||||
assert.True(t, gd.isServiceDegraded("auto-recover-service"))
|
||||
|
||||
// Wait for recovery timeout (longer than timeout to ensure recovery)
|
||||
time.Sleep(baseTimeout + GetTestDuration(20*time.Millisecond))
|
||||
|
||||
// Should auto-recover
|
||||
assert.False(t, gd.isServiceDegraded("auto-recover-service"), "service should auto-recover after timeout")
|
||||
})
|
||||
|
||||
t.Run("service remains degraded before timeout", func(t *testing.T) {
|
||||
config := GracefulDegradationConfig{
|
||||
HealthCheckInterval: 1 * time.Hour,
|
||||
RecoveryTimeout: 1 * time.Hour, // Very long timeout
|
||||
EnableFallbacks: true,
|
||||
}
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
// Mark service degraded
|
||||
gd.markServiceDegraded("long-timeout-service")
|
||||
|
||||
// Verify degraded
|
||||
assert.True(t, gd.isServiceDegraded("long-timeout-service"))
|
||||
|
||||
// Wait a bit
|
||||
time.Sleep(GetTestDuration(10 * time.Millisecond))
|
||||
|
||||
// Should still be degraded
|
||||
assert.True(t, gd.isServiceDegraded("long-timeout-service"), "service should remain degraded before timeout")
|
||||
})
|
||||
}
|
||||
|
||||
// TestErrorRecoveryManagerIntegration tests full integration of error recovery mechanisms
|
||||
func TestErrorRecoveryManagerIntegration(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
erm := NewErrorRecoveryManager(logger)
|
||||
|
||||
t.Run("circuit breaker and retry integration", func(t *testing.T) {
|
||||
// Create a circuit breaker with higher max failures to allow retries
|
||||
cb := NewCircuitBreaker(CircuitBreakerConfig{
|
||||
MaxFailures: 10, // High threshold
|
||||
Timeout: 60 * time.Second,
|
||||
ResetTimeout: 30 * time.Second,
|
||||
}, logger)
|
||||
|
||||
erm.mutex.Lock()
|
||||
erm.circuitBreakers["test-service-integration"] = cb
|
||||
erm.mutex.Unlock()
|
||||
|
||||
attempts := 0
|
||||
fn := func() error {
|
||||
attempts++
|
||||
if attempts < 3 {
|
||||
return errors.New("temporary failure")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
err := erm.ExecuteWithRecovery(context.Background(), "test-service-integration", fn)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.GreaterOrEqual(t, attempts, 3, "should retry until success")
|
||||
})
|
||||
|
||||
t.Run("circuit breaker opens on repeated failures", func(t *testing.T) {
|
||||
fn := func() error {
|
||||
return errors.New("persistent failure")
|
||||
}
|
||||
|
||||
// First call - should fail after retries
|
||||
err1 := erm.ExecuteWithRecovery(context.Background(), "failing-service", fn)
|
||||
assert.Error(t, err1)
|
||||
|
||||
// Second call - should fail after retries
|
||||
err2 := erm.ExecuteWithRecovery(context.Background(), "failing-service", fn)
|
||||
assert.Error(t, err2)
|
||||
|
||||
// Check circuit breaker state
|
||||
cb := erm.GetCircuitBreaker("failing-service")
|
||||
state := cb.GetState()
|
||||
assert.Equal(t, CircuitBreakerOpen, state, "circuit should be open after repeated failures")
|
||||
})
|
||||
|
||||
t.Run("recovery metrics include all mechanisms", func(t *testing.T) {
|
||||
metrics := erm.GetRecoveryMetrics()
|
||||
|
||||
assert.NotNil(t, metrics)
|
||||
assert.Contains(t, metrics, "circuit_breakers")
|
||||
assert.Contains(t, metrics, "degraded_services")
|
||||
})
|
||||
}
|
||||
|
||||
// TestContainsHelperFunction tests the contains helper function edge cases
|
||||
func TestContainsHelperFunction(t *testing.T) {
|
||||
t.Run("exact match", func(t *testing.T) {
|
||||
assert.True(t, contains("timeout", "timeout"))
|
||||
})
|
||||
|
||||
t.Run("prefix match", func(t *testing.T) {
|
||||
assert.True(t, contains("timeout error occurred", "timeout"))
|
||||
})
|
||||
|
||||
t.Run("suffix match", func(t *testing.T) {
|
||||
assert.True(t, contains("connection timeout", "timeout"))
|
||||
})
|
||||
|
||||
t.Run("middle match", func(t *testing.T) {
|
||||
assert.True(t, contains("a connection timeout error", "timeout"))
|
||||
})
|
||||
|
||||
t.Run("no match", func(t *testing.T) {
|
||||
assert.False(t, contains("connection refused", "timeout"))
|
||||
})
|
||||
|
||||
t.Run("substring longer than string", func(t *testing.T) {
|
||||
assert.False(t, contains("abc", "abcdef"))
|
||||
})
|
||||
|
||||
t.Run("empty substring", func(t *testing.T) {
|
||||
assert.True(t, contains("test", ""))
|
||||
})
|
||||
|
||||
t.Run("empty string", func(t *testing.T) {
|
||||
assert.False(t, contains("", "test"))
|
||||
})
|
||||
|
||||
t.Run("both empty", func(t *testing.T) {
|
||||
assert.True(t, contains("", ""))
|
||||
})
|
||||
}
|
||||
+1178
-37
File diff suppressed because it is too large
Load Diff
@@ -1,797 +0,0 @@
|
||||
package features
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"text/template"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// Mock types for testing
|
||||
type TemplatedHeader struct {
|
||||
Name string `json:"name"`
|
||||
Value string `json:"value"`
|
||||
}
|
||||
|
||||
type MockConfig struct {
|
||||
ProviderURL string `json:"providerURL"`
|
||||
ClientID string `json:"clientID"`
|
||||
ClientSecret string `json:"clientSecret"`
|
||||
CallbackURL string `json:"callbackURL"`
|
||||
SessionEncryptionKey string `json:"sessionEncryptionKey"`
|
||||
Headers []TemplatedHeader `json:"headers"`
|
||||
}
|
||||
|
||||
// TestTemplateHeaderFeatures consolidates all template header-related tests
|
||||
func TestTemplateHeaderFeatures(t *testing.T) {
|
||||
t.Run("Issue55_TemplateExecutionWithWrongTypes", testIssue55TemplateExecutionWithWrongTypes)
|
||||
t.Run("Template_Parsing_Validation", testTemplateParsingValidation)
|
||||
t.Run("Middleware_Header_Templating", testMiddlewareHeaderTemplating)
|
||||
t.Run("JSON_Config_Parsing", testJSONConfigParsing)
|
||||
t.Run("Template_Double_Processing", testTemplateDoubleProcessing)
|
||||
t.Run("Template_Execution_Context", testTemplateExecutionContext)
|
||||
t.Run("Template_Integration_With_Plugin", testTemplateIntegrationWithPlugin)
|
||||
t.Run("Template_Syntax_Validation", testTemplateSyntaxValidation)
|
||||
t.Run("Missing_Field_Handling", testMissingFieldHandling)
|
||||
t.Run("Complex_Template_Expressions", testComplexTemplateExpressions)
|
||||
t.Run("Traefik_Configuration_Parsing", testTraefikConfigurationParsing)
|
||||
}
|
||||
|
||||
// testIssue55TemplateExecutionWithWrongTypes tests what happens when templates
|
||||
// receive wrong data types during execution - reproduces GitHub issue #55
|
||||
func testIssue55TemplateExecutionWithWrongTypes(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
templateText string
|
||||
templateData interface{}
|
||||
errorContains string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "correct map data",
|
||||
templateText: "Bearer {{.AccessToken}}",
|
||||
templateData: map[string]interface{}{
|
||||
"AccessToken": "valid-token",
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "boolean as root context - reproduces issue #55",
|
||||
templateText: "Bearer {{.AccessToken}}",
|
||||
templateData: true,
|
||||
expectError: true,
|
||||
errorContains: "can't evaluate field AccessToken in type bool",
|
||||
},
|
||||
{
|
||||
name: "string as root context",
|
||||
templateText: "Bearer {{.AccessToken}}",
|
||||
templateData: "just a string",
|
||||
expectError: true,
|
||||
errorContains: "can't evaluate field AccessToken in type string",
|
||||
},
|
||||
{
|
||||
name: "nested claims access with correct data",
|
||||
templateText: "User: {{.Claims.email}}",
|
||||
templateData: map[string]interface{}{
|
||||
"Claims": map[string]interface{}{
|
||||
"email": "user@example.com",
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "nested claims with wrong structure",
|
||||
templateText: "User: {{.Claims.email}}",
|
||||
templateData: map[string]interface{}{
|
||||
"Claims": "not a map",
|
||||
},
|
||||
expectError: true,
|
||||
errorContains: "can't evaluate field email in type",
|
||||
},
|
||||
{
|
||||
name: "complex nested structure",
|
||||
templateText: "{{.Claims.sub}} - {{.Claims.groups}} - {{.AccessToken}}",
|
||||
templateData: map[string]interface{}{
|
||||
"AccessToken": "token123",
|
||||
"Claims": map[string]interface{}{
|
||||
"sub": "user-id",
|
||||
"groups": "admin,users",
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tmpl, err := template.New("test").Parse(tc.templateText)
|
||||
require.NoError(t, err)
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = tmpl.Execute(&buf, tc.templateData)
|
||||
|
||||
if tc.expectError {
|
||||
require.Error(t, err)
|
||||
if tc.errorContains != "" {
|
||||
assert.Contains(t, err.Error(), tc.errorContains)
|
||||
}
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// testTemplateParsingValidation ensures templates are parsed correctly
|
||||
func testTemplateParsingValidation(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
headerTemplates []TemplatedHeader
|
||||
shouldError bool
|
||||
}{
|
||||
{
|
||||
name: "valid bearer token template",
|
||||
headerTemplates: []TemplatedHeader{
|
||||
{Name: "Authorization", Value: "Bearer {{.AccessToken}}"},
|
||||
},
|
||||
shouldError: false,
|
||||
},
|
||||
{
|
||||
name: "multiple valid templates",
|
||||
headerTemplates: []TemplatedHeader{
|
||||
{Name: "Authorization", Value: "Bearer {{.AccessToken}}"},
|
||||
{Name: "X-User-Email", Value: "{{.Claims.email}}"},
|
||||
{Name: "X-User-ID", Value: "{{.Claims.sub}}"},
|
||||
},
|
||||
shouldError: false,
|
||||
},
|
||||
{
|
||||
name: "template with conditional logic",
|
||||
headerTemplates: []TemplatedHeader{
|
||||
{Name: "X-Auth-Info", Value: "{{if .AccessToken}}Bearer {{.AccessToken}}{{else}}No Token{{end}}"},
|
||||
},
|
||||
shouldError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid template syntax",
|
||||
headerTemplates: []TemplatedHeader{
|
||||
{Name: "Bad-Template", Value: "{{.AccessToken"},
|
||||
},
|
||||
shouldError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
for _, header := range tc.headerTemplates {
|
||||
_, err := template.New(header.Name).Parse(header.Value)
|
||||
|
||||
if tc.shouldError {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// testMiddlewareHeaderTemplating simulates the actual middleware flow
|
||||
func testMiddlewareHeaderTemplating(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
headers []TemplatedHeader
|
||||
accessToken string
|
||||
idToken string
|
||||
claims map[string]interface{}
|
||||
expectedValues map[string]string
|
||||
}{
|
||||
{
|
||||
name: "authorization header with access token",
|
||||
headers: []TemplatedHeader{
|
||||
{Name: "Authorization", Value: "Bearer {{.AccessToken}}"},
|
||||
},
|
||||
accessToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9",
|
||||
expectedValues: map[string]string{
|
||||
"Authorization": "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple headers with claims",
|
||||
headers: []TemplatedHeader{
|
||||
{Name: "X-User-Email", Value: "{{.Claims.email}}"},
|
||||
{Name: "X-User-Groups", Value: "{{.Claims.groups}}"},
|
||||
{Name: "X-Auth-Token", Value: "{{.AccessToken}}"},
|
||||
},
|
||||
accessToken: "token123",
|
||||
claims: map[string]interface{}{
|
||||
"email": "user@example.com",
|
||||
"groups": "admin,developers",
|
||||
},
|
||||
expectedValues: map[string]string{
|
||||
"X-User-Email": "user@example.com",
|
||||
"X-User-Groups": "admin,developers",
|
||||
"X-Auth-Token": "token123",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "complex template expressions",
|
||||
headers: []TemplatedHeader{
|
||||
{Name: "X-User-Info", Value: "{{.Claims.sub}} ({{.Claims.email}})"},
|
||||
{Name: "X-Auth-Header", Value: "Bearer {{.AccessToken}} | ID: {{.IDToken}}"},
|
||||
},
|
||||
accessToken: "access-token",
|
||||
idToken: "id-token",
|
||||
claims: map[string]interface{}{
|
||||
"sub": "user-12345",
|
||||
"email": "john@example.com",
|
||||
},
|
||||
expectedValues: map[string]string{
|
||||
"X-User-Info": "user-12345 (john@example.com)",
|
||||
"X-Auth-Header": "Bearer access-token | ID: id-token",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Parse all templates
|
||||
headerTemplates := make(map[string]*template.Template)
|
||||
for _, header := range tc.headers {
|
||||
tmpl, err := template.New(header.Name).Parse(header.Value)
|
||||
require.NoError(t, err)
|
||||
headerTemplates[header.Name] = tmpl
|
||||
}
|
||||
|
||||
// Create template data
|
||||
templateData := map[string]interface{}{
|
||||
"AccessToken": tc.accessToken,
|
||||
"IDToken": tc.idToken,
|
||||
"Claims": tc.claims,
|
||||
}
|
||||
|
||||
// Create a test request
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
|
||||
// Execute templates and set headers
|
||||
for headerName, tmpl := range headerTemplates {
|
||||
var buf bytes.Buffer
|
||||
err := tmpl.Execute(&buf, templateData)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set(headerName, buf.String())
|
||||
}
|
||||
|
||||
// Verify all expected headers are set correctly
|
||||
for headerName, expectedValue := range tc.expectedValues {
|
||||
actualValue := req.Header.Get(headerName)
|
||||
assert.Equal(t, expectedValue, actualValue)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// testJSONConfigParsing tests that JSON configuration is properly parsed
|
||||
func testJSONConfigParsing(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
jsonConfig string
|
||||
expectedError bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "valid JSON configuration",
|
||||
jsonConfig: `{
|
||||
"headers": [
|
||||
{
|
||||
"name": "Authorization",
|
||||
"value": "Bearer {{.AccessToken}}"
|
||||
}
|
||||
]
|
||||
}`,
|
||||
expectedError: false,
|
||||
description: "Properly formatted JSON with string values",
|
||||
},
|
||||
{
|
||||
name: "JSON with boolean value",
|
||||
jsonConfig: `{
|
||||
"headers": [
|
||||
{
|
||||
"name": "Authorization",
|
||||
"value": true
|
||||
}
|
||||
]
|
||||
}`,
|
||||
expectedError: true,
|
||||
description: "Boolean value instead of string template",
|
||||
},
|
||||
{
|
||||
name: "JSON with number value",
|
||||
jsonConfig: `{
|
||||
"headers": [
|
||||
{
|
||||
"name": "Authorization",
|
||||
"value": 123
|
||||
}
|
||||
]
|
||||
}`,
|
||||
expectedError: true,
|
||||
description: "Number value instead of string template",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var config struct {
|
||||
Headers []TemplatedHeader `json:"headers"`
|
||||
}
|
||||
|
||||
err := json.Unmarshal([]byte(tc.jsonConfig), &config)
|
||||
|
||||
if tc.expectedError {
|
||||
require.Error(t, err, tc.description)
|
||||
} else {
|
||||
require.NoError(t, err, tc.description)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// testTemplateDoubleProcessing tests if template strings are being double-processed
|
||||
func testTemplateDoubleProcessing(t *testing.T) {
|
||||
// Simulate how Traefik passes config to the plugin
|
||||
config := &MockConfig{
|
||||
Headers: []TemplatedHeader{
|
||||
{Name: "X-User-Email", Value: "{{.Claims.email}}"},
|
||||
{Name: "X-User-Role", Value: "{{.Claims.internal_role}}"},
|
||||
},
|
||||
}
|
||||
|
||||
// Verify that template strings are still raw (not processed)
|
||||
assert.Equal(t, "{{.Claims.email}}", config.Headers[0].Value)
|
||||
assert.Equal(t, "{{.Claims.internal_role}}", config.Headers[1].Value)
|
||||
|
||||
// Simulate template parsing during initialization
|
||||
headerTemplates := make(map[string]*template.Template)
|
||||
|
||||
funcMap := template.FuncMap{
|
||||
"default": func(defaultVal interface{}, val interface{}) interface{} {
|
||||
if val == nil || val == "" || val == "<no value>" {
|
||||
return defaultVal
|
||||
}
|
||||
return val
|
||||
},
|
||||
"get": func(m interface{}, key string) interface{} {
|
||||
if mapVal, ok := m.(map[string]interface{}); ok {
|
||||
if val, exists := mapVal[key]; exists {
|
||||
return val
|
||||
}
|
||||
}
|
||||
return ""
|
||||
},
|
||||
}
|
||||
|
||||
for _, header := range config.Headers {
|
||||
tmpl := template.New(header.Name).Funcs(funcMap).Option("missingkey=zero")
|
||||
parsedTmpl, err := tmpl.Parse(header.Value)
|
||||
require.NoError(t, err)
|
||||
headerTemplates[header.Name] = parsedTmpl
|
||||
}
|
||||
|
||||
// Test execution with actual claims
|
||||
claims := map[string]interface{}{
|
||||
"email": "user@example.com",
|
||||
// Note: internal_role is missing
|
||||
}
|
||||
|
||||
templateData := map[string]interface{}{
|
||||
"Claims": claims,
|
||||
}
|
||||
|
||||
// Execute templates
|
||||
for headerName, tmpl := range headerTemplates {
|
||||
var buf bytes.Buffer
|
||||
err := tmpl.Execute(&buf, templateData)
|
||||
require.NoError(t, err)
|
||||
|
||||
result := buf.String()
|
||||
if headerName == "X-User-Email" {
|
||||
assert.Equal(t, "user@example.com", result)
|
||||
} else if headerName == "X-User-Role" {
|
||||
// With missingkey=zero, missing fields return "<no value>"
|
||||
assert.Equal(t, "<no value>", result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// testTemplateExecutionContext tests the specific template data context
|
||||
func testTemplateExecutionContext(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
templateText string
|
||||
data map[string]interface{}
|
||||
expectedValue string
|
||||
}{
|
||||
{
|
||||
name: "Access and ID token distinction",
|
||||
templateText: "Access: {{.AccessToken}} ID: {{.IDToken}}",
|
||||
data: map[string]interface{}{
|
||||
"AccessToken": "access-token-value",
|
||||
"IDToken": "id-token-value",
|
||||
"Claims": map[string]interface{}{},
|
||||
},
|
||||
expectedValue: "Access: access-token-value ID: id-token-value",
|
||||
},
|
||||
{
|
||||
name: "Combining tokens and claims",
|
||||
templateText: "User: {{.Claims.sub}} Token: {{.AccessToken}}",
|
||||
data: map[string]interface{}{
|
||||
"AccessToken": "access-token",
|
||||
"IDToken": "id-token",
|
||||
"Claims": map[string]interface{}{
|
||||
"sub": "user123",
|
||||
},
|
||||
},
|
||||
expectedValue: "User: user123 Token: access-token",
|
||||
},
|
||||
{
|
||||
name: "Custom non-standard claims",
|
||||
templateText: "X-User-Role: {{.Claims.role}}, X-User-Permissions: {{.Claims.permissions}}",
|
||||
data: map[string]interface{}{
|
||||
"AccessToken": "access-token-value",
|
||||
"Claims": map[string]interface{}{
|
||||
"role": "admin",
|
||||
"permissions": "read:all,write:own",
|
||||
},
|
||||
},
|
||||
expectedValue: "X-User-Role: admin, X-User-Permissions: read:all,write:own",
|
||||
},
|
||||
{
|
||||
name: "Deeply nested custom claims",
|
||||
templateText: "X-Organization: {{.Claims.app_metadata.organization.name}}, X-Team: {{.Claims.app_metadata.team}}",
|
||||
data: map[string]interface{}{
|
||||
"Claims": map[string]interface{}{
|
||||
"app_metadata": map[string]interface{}{
|
||||
"organization": map[string]interface{}{
|
||||
"name": "acme-corp",
|
||||
},
|
||||
"team": "platform",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedValue: "X-Organization: acme-corp, X-Team: platform",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tmpl, err := template.New("test").Parse(tc.templateText)
|
||||
require.NoError(t, err)
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = tmpl.Execute(&buf, tc.data)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tc.expectedValue, buf.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// testTemplateIntegrationWithPlugin tests template processing in the actual plugin
|
||||
func testTemplateIntegrationWithPlugin(t *testing.T) {
|
||||
// Test template integration using mock plugin components
|
||||
|
||||
// Set up test OIDC server
|
||||
var testServerURL string
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/.well-known/openid-configuration":
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"issuer": testServerURL,
|
||||
"authorization_endpoint": testServerURL + "/auth",
|
||||
"token_endpoint": testServerURL + "/token",
|
||||
"jwks_uri": testServerURL + "/jwks",
|
||||
"userinfo_endpoint": testServerURL + "/userinfo",
|
||||
})
|
||||
case "/jwks":
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"keys": []interface{}{},
|
||||
})
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer testServer.Close()
|
||||
testServerURL = testServer.URL
|
||||
|
||||
// Create config with templates that reference potentially missing fields
|
||||
config := &MockConfig{
|
||||
ProviderURL: testServer.URL,
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "test-secret",
|
||||
CallbackURL: "/callback",
|
||||
SessionEncryptionKey: "test-encryption-key-32-characters",
|
||||
Headers: []TemplatedHeader{
|
||||
{Name: "X-User-Email", Value: "{{.Claims.email}}"},
|
||||
{Name: "X-User-Role", Value: "{{.Claims.internal_role}}"},
|
||||
},
|
||||
}
|
||||
|
||||
// Initialize plugin would be done here
|
||||
ctx := context.Background()
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Test would create plugin handler here
|
||||
_ = ctx
|
||||
_ = next
|
||||
_ = config
|
||||
}
|
||||
|
||||
// testTemplateSyntaxValidation tests that template syntax is properly validated
|
||||
func testTemplateSyntaxValidation(t *testing.T) {
|
||||
validTemplates := []string{
|
||||
"{{.Claims.email}}",
|
||||
"{{.Claims.internal_role}}",
|
||||
"{{.AccessToken}}",
|
||||
"{{.IdToken}}",
|
||||
"{{.RefreshToken}}",
|
||||
}
|
||||
|
||||
for _, tmplStr := range validTemplates {
|
||||
err := validateTemplateSecure(tmplStr)
|
||||
assert.NoError(t, err, "Template should be valid: %s", tmplStr)
|
||||
}
|
||||
|
||||
// Test invalid templates
|
||||
invalidTemplates := []struct {
|
||||
template string
|
||||
reason string
|
||||
}{
|
||||
{"{{call .SomeFunc}}", "function calls not allowed"},
|
||||
{"{{range .Items}}{{.}}{{end}}", "range not allowed"},
|
||||
{"{{with .Data}}{{.Field}}{{end}}", "with statements blocked"},
|
||||
{"{{index .Array 0}}", "index access blocked"},
|
||||
{"{{printf \"%s\" .Data}}", "printf blocked"},
|
||||
}
|
||||
|
||||
for _, tc := range invalidTemplates {
|
||||
err := validateTemplateSecure(tc.template)
|
||||
assert.Error(t, err, "Template should be invalid: %s (%s)", tc.template, tc.reason)
|
||||
assert.Contains(t, strings.ToLower(err.Error()), "dangerous")
|
||||
}
|
||||
|
||||
// Test safe custom functions
|
||||
safeTemplates := []string{
|
||||
"{{get .Claims \"internal_role\"}}",
|
||||
"{{default \"guest\" .Claims.role}}",
|
||||
}
|
||||
|
||||
for _, tmplStr := range safeTemplates {
|
||||
err := validateTemplateSecure(tmplStr)
|
||||
assert.NoError(t, err, "Safe custom functions should be allowed: %s", tmplStr)
|
||||
}
|
||||
}
|
||||
|
||||
// Mock validation function for template security
|
||||
func validateTemplateSecure(templateStr string) error {
|
||||
// List of potentially dangerous template actions
|
||||
dangerousFunctions := []string{
|
||||
"call", "range", "with", "index", "printf", "println", "print",
|
||||
"js", "html", "urlquery", "base64", "exec",
|
||||
}
|
||||
|
||||
for _, dangerous := range dangerousFunctions {
|
||||
if strings.Contains(templateStr, dangerous) {
|
||||
return fmt.Errorf("dangerous template function detected: %s", dangerous)
|
||||
}
|
||||
}
|
||||
|
||||
// Define safe custom functions
|
||||
funcMap := template.FuncMap{
|
||||
"get": func(data map[string]interface{}, key string) interface{} {
|
||||
return data[key]
|
||||
},
|
||||
"default": func(defaultVal interface{}, val interface{}) interface{} {
|
||||
if val == nil || val == "" {
|
||||
return defaultVal
|
||||
}
|
||||
return val
|
||||
},
|
||||
}
|
||||
|
||||
// Try to parse the template with custom functions to check for syntax errors
|
||||
_, err := template.New("test").Funcs(funcMap).Parse(templateStr)
|
||||
return err
|
||||
}
|
||||
|
||||
// testMissingFieldHandling tests handling of missing fields in templates
|
||||
func testMissingFieldHandling(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
templateText string
|
||||
data map[string]interface{}
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "missing claim field",
|
||||
templateText: "{{.Claims.missing}}",
|
||||
data: map[string]interface{}{
|
||||
"Claims": map[string]interface{}{},
|
||||
},
|
||||
expected: "<no value>",
|
||||
},
|
||||
{
|
||||
name: "missing nested field",
|
||||
templateText: "{{.Claims.user.missing}}",
|
||||
data: map[string]interface{}{
|
||||
"Claims": map[string]interface{}{
|
||||
"user": map[string]interface{}{},
|
||||
},
|
||||
},
|
||||
expected: "<no value>",
|
||||
},
|
||||
{
|
||||
name: "missing entire path",
|
||||
templateText: "{{.Missing.Path.Field}}",
|
||||
data: map[string]interface{}{},
|
||||
expected: "<no value>",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tmpl, err := template.New("test").Parse(tc.templateText)
|
||||
require.NoError(t, err)
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = tmpl.Execute(&buf, tc.data)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tc.expected, buf.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// testComplexTemplateExpressions tests complex template expressions
|
||||
func testComplexTemplateExpressions(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
templateText string
|
||||
data map[string]interface{}
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "conditional template",
|
||||
templateText: "{{if .Claims.admin}}Admin User{{else}}Regular User{{end}}",
|
||||
data: map[string]interface{}{
|
||||
"Claims": map[string]interface{}{
|
||||
"admin": true,
|
||||
},
|
||||
},
|
||||
expected: "Admin User",
|
||||
},
|
||||
{
|
||||
name: "multiple claims concatenation",
|
||||
templateText: "{{.Claims.firstName}} {{.Claims.lastName}} <{{.Claims.email}}>",
|
||||
data: map[string]interface{}{
|
||||
"Claims": map[string]interface{}{
|
||||
"firstName": "John",
|
||||
"lastName": "Doe",
|
||||
"email": "john.doe@example.com",
|
||||
},
|
||||
},
|
||||
expected: "John Doe <john.doe@example.com>",
|
||||
},
|
||||
{
|
||||
name: "array access",
|
||||
templateText: "{{index .Claims.roles 0}}",
|
||||
data: map[string]interface{}{
|
||||
"Claims": map[string]interface{}{
|
||||
"roles": []string{"admin", "user"},
|
||||
},
|
||||
},
|
||||
expected: "admin",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tmpl, err := template.New("test").Parse(tc.templateText)
|
||||
require.NoError(t, err)
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = tmpl.Execute(&buf, tc.data)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tc.expected, buf.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// testTraefikConfigurationParsing tests various ways Traefik might pass configuration
|
||||
func testTraefikConfigurationParsing(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
config *MockConfig
|
||||
expectError bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "valid configuration with templated headers",
|
||||
config: &MockConfig{
|
||||
ProviderURL: "https://accounts.google.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "test-secret",
|
||||
SessionEncryptionKey: "test-encryption-key-32-bytes-long",
|
||||
CallbackURL: "/oauth2/callback",
|
||||
Headers: []TemplatedHeader{
|
||||
{Name: "Authorization", Value: "Bearer {{.AccessToken}}"},
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
description: "Standard configuration should work",
|
||||
},
|
||||
{
|
||||
name: "configuration with multiple headers",
|
||||
config: &MockConfig{
|
||||
ProviderURL: "https://accounts.google.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "test-secret",
|
||||
SessionEncryptionKey: "test-encryption-key-32-bytes-long",
|
||||
CallbackURL: "/oauth2/callback",
|
||||
Headers: []TemplatedHeader{
|
||||
{Name: "Authorization", Value: "Bearer {{.AccessToken}}"},
|
||||
{Name: "X-User-Email", Value: "{{.Claims.email}}"},
|
||||
{Name: "X-User-ID", Value: "{{.Claims.sub}}"},
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
description: "Multiple headers should work",
|
||||
},
|
||||
{
|
||||
name: "empty headers configuration",
|
||||
config: &MockConfig{
|
||||
ProviderURL: "https://accounts.google.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "test-secret",
|
||||
SessionEncryptionKey: "test-encryption-key-32-bytes-long",
|
||||
CallbackURL: "/oauth2/callback",
|
||||
Headers: []TemplatedHeader{},
|
||||
},
|
||||
expectError: false,
|
||||
description: "Empty headers should not cause issues",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Create a simple next handler
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Try to create the middleware would be done here
|
||||
ctx := context.Background()
|
||||
|
||||
// Test would create middleware handler here
|
||||
_ = ctx
|
||||
_ = next
|
||||
_ = tc.config
|
||||
|
||||
// For now, we just validate the configuration is well-formed
|
||||
if !tc.expectError {
|
||||
require.NotNil(t, tc.config, tc.description)
|
||||
require.NotEmpty(t, tc.config.ClientID, tc.description)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -6,7 +6,7 @@ require (
|
||||
github.com/alicebob/miniredis/v2 v2.35.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/gorilla/sessions v1.3.0
|
||||
github.com/redis/go-redis/v9 v9.14.0
|
||||
github.com/redis/go-redis/v9 v9.17.2
|
||||
github.com/stretchr/testify v1.10.0
|
||||
golang.org/x/time v0.14.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
@@ -18,5 +18,6 @@ require (
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/gorilla/securecookie v1.1.2 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/stretchr/objx v0.5.2 // indirect
|
||||
github.com/yuin/gopher-lua v1.1.1 // indirect
|
||||
)
|
||||
|
||||
@@ -20,8 +20,10 @@ github.com/gorilla/sessions v1.3.0 h1:XYlkq7KcpOB2ZhHBPv5WpjMIxrQosiZanfoy1HLZFz
|
||||
github.com/gorilla/sessions v1.3.0/go.mod h1:ePLdVu+jbEgHH+KWw8I1z2wqd0BAdAQh/8LRvBeoNcQ=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/redis/go-redis/v9 v9.14.0 h1:u4tNCjXOyzfgeLN+vAZaW1xUooqWDqVEsZN0U01jfAE=
|
||||
github.com/redis/go-redis/v9 v9.14.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw=
|
||||
github.com/redis/go-redis/v9 v9.17.2 h1:P2EGsA4qVIM3Pp+aPocCJ7DguDHhqrXNhVcEp4ViluI=
|
||||
github.com/redis/go-redis/v9 v9.17.2/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
|
||||
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
|
||||
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M=
|
||||
|
||||
@@ -1,764 +0,0 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ============================================================================
|
||||
// OAuth Handler Tests
|
||||
// ============================================================================
|
||||
|
||||
func TestOAuthHandler(t *testing.T) {
|
||||
t.Run("HandleAuthorizationRequest", func(t *testing.T) {
|
||||
// Test authorization request handling logic
|
||||
logger := &MockLogger{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
requestURL string
|
||||
expectedStatus int
|
||||
checkLocation bool
|
||||
}{
|
||||
{
|
||||
name: "Valid authorization request",
|
||||
requestURL: "/auth/login",
|
||||
expectedStatus: http.StatusFound,
|
||||
checkLocation: true,
|
||||
},
|
||||
{
|
||||
name: "With return URL",
|
||||
requestURL: "/auth/login?return=/dashboard",
|
||||
expectedStatus: http.StatusFound,
|
||||
checkLocation: true,
|
||||
},
|
||||
}
|
||||
|
||||
// Test the test case structure
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
// Verify test case parameters
|
||||
if test.requestURL == "" {
|
||||
t.Error("Request URL should not be empty")
|
||||
}
|
||||
if test.expectedStatus == 0 {
|
||||
t.Error("Expected status should be set")
|
||||
}
|
||||
// In a real implementation, this would test the actual handler
|
||||
t.Logf("Testing %s with URL %s expecting status %d", test.name, test.requestURL, test.expectedStatus)
|
||||
})
|
||||
}
|
||||
|
||||
// Verify logger doesn't cause issues
|
||||
logger.Debugf("Authorization request test completed")
|
||||
})
|
||||
|
||||
t.Run("HandleCallbackRequest", func(t *testing.T) {
|
||||
// Test callback request handling with existing mocks
|
||||
sessionManager := NewMockSessionManager()
|
||||
logger := &MockLogger{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
queryParams string
|
||||
expectedStatus int
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "Valid callback with code",
|
||||
queryParams: "code=test-code&state=test-state",
|
||||
expectedStatus: http.StatusFound,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Callback with error",
|
||||
queryParams: "error=access_denied&error_description=User denied access",
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Missing code",
|
||||
queryParams: "state=test-state",
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Missing state",
|
||||
queryParams: "code=test-code",
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
// Test the callback scenarios
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
// Verify test case parameters
|
||||
if test.queryParams == "" && !test.expectError {
|
||||
t.Error("Query params should not be empty for successful cases")
|
||||
}
|
||||
if test.expectedStatus == 0 {
|
||||
t.Error("Expected status should be set")
|
||||
}
|
||||
|
||||
// Test session manager functionality
|
||||
if sessionManager != nil {
|
||||
t.Logf("Session manager available for test %s", test.name)
|
||||
}
|
||||
|
||||
t.Logf("Testing %s with params %s expecting status %d", test.name, test.queryParams, test.expectedStatus)
|
||||
})
|
||||
}
|
||||
|
||||
// Verify logger doesn't cause issues
|
||||
logger.Debugf("Callback request test completed")
|
||||
})
|
||||
|
||||
t.Run("HandleLogout", func(t *testing.T) {
|
||||
// Test logout functionality with mock implementations
|
||||
sessionManager := NewMockSessionManager()
|
||||
logger := &MockLogger{}
|
||||
|
||||
// Test session clearing
|
||||
mockReq := &http.Request{}
|
||||
session, err := sessionManager.GetSession(mockReq)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
// Set up authenticated session
|
||||
err = session.SetAuthenticated(true)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to set authentication: %v", err)
|
||||
}
|
||||
session.SetIDToken("test-token")
|
||||
|
||||
// Verify session is authenticated
|
||||
if !session.GetAuthenticated() {
|
||||
t.Error("Session should be authenticated before logout")
|
||||
}
|
||||
|
||||
// Test logout by clearing session
|
||||
// session.Clear() // Method not implemented in SessionData
|
||||
// Additional logout verification would go here
|
||||
|
||||
// Verify logger doesn't cause issues
|
||||
logger.Debugf("Logout test completed")
|
||||
t.Log("Logout test completed successfully")
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Auth Handler Tests
|
||||
// ============================================================================
|
||||
|
||||
func TestAuthHandler(t *testing.T) {
|
||||
t.Run("HandleAuthentication", func(t *testing.T) {
|
||||
// Test authentication handling with mock types
|
||||
// validator := &MockTokenValidator{valid: true} // Currently unused
|
||||
/*
|
||||
handler := &MockAuthHandler{
|
||||
logger: &MockLogger{},
|
||||
sessionManager: NewMockSessionManager(),
|
||||
}
|
||||
*/
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupSession func(*MockSession)
|
||||
expectedStatus int
|
||||
expectNext bool
|
||||
}{
|
||||
{
|
||||
name: "Authenticated user",
|
||||
setupSession: func(s *MockSession) {
|
||||
s.SetAuthenticated(true)
|
||||
s.SetIDToken("valid-token")
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectNext: true,
|
||||
},
|
||||
{
|
||||
name: "Unauthenticated user",
|
||||
setupSession: func(s *MockSession) {
|
||||
s.SetAuthenticated(false)
|
||||
},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectNext: false,
|
||||
},
|
||||
{
|
||||
name: "Expired token",
|
||||
setupSession: func(s *MockSession) {
|
||||
s.SetAuthenticated(true)
|
||||
s.SetIDToken("expired-token")
|
||||
},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectNext: false,
|
||||
},
|
||||
}
|
||||
|
||||
// Test the authentication test cases
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
// Test with mock session
|
||||
mockSession := &MockSession{values: make(map[string]interface{})}
|
||||
// Use mock session to avoid unused variable error
|
||||
_ = mockSession
|
||||
t.Logf("Testing %s", test.name)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("HandleRefreshToken", func(t *testing.T) {
|
||||
// Test authentication handling with mock types
|
||||
// validator := &MockTokenValidator{valid: true} // Currently unused
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
refreshToken string
|
||||
mockResponse *MockTokenResponse
|
||||
mockError error
|
||||
expectSuccess bool
|
||||
}{
|
||||
{
|
||||
name: "Successful refresh",
|
||||
refreshToken: "valid-refresh-token",
|
||||
mockResponse: &MockTokenResponse{
|
||||
AccessToken: "new-access-token",
|
||||
IDToken: "new-id-token",
|
||||
RefreshToken: "new-refresh-token",
|
||||
},
|
||||
expectSuccess: true,
|
||||
},
|
||||
{
|
||||
name: "Failed refresh",
|
||||
refreshToken: "invalid-refresh-token",
|
||||
mockError: errors.New("invalid_grant"),
|
||||
expectSuccess: false,
|
||||
},
|
||||
{
|
||||
name: "Empty refresh token",
|
||||
refreshToken: "",
|
||||
expectSuccess: false,
|
||||
},
|
||||
}
|
||||
|
||||
// Test the authentication test cases
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
// Test with mock session
|
||||
mockSession := &MockSession{values: make(map[string]interface{})}
|
||||
// Use mock session to avoid unused variable error
|
||||
_ = mockSession
|
||||
t.Logf("Testing %s", test.name)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Error Handler Tests
|
||||
// ============================================================================
|
||||
|
||||
func TestErrorHandler(t *testing.T) {
|
||||
t.Run("HandleHTTPErrors", func(t *testing.T) {
|
||||
// Test with mock implementations
|
||||
/*
|
||||
handler := &MockErrorHandler{
|
||||
logger: &MockLogger{},
|
||||
}
|
||||
*/
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
errorCode int
|
||||
errorMessage string
|
||||
isAjax bool
|
||||
expectedStatus int
|
||||
expectedBody string
|
||||
}{
|
||||
{
|
||||
name: "401 Unauthorized",
|
||||
errorCode: http.StatusUnauthorized,
|
||||
errorMessage: "Authentication required",
|
||||
isAjax: false,
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedBody: "Authentication required",
|
||||
},
|
||||
{
|
||||
name: "403 Forbidden",
|
||||
errorCode: http.StatusForbidden,
|
||||
errorMessage: "Access denied",
|
||||
isAjax: false,
|
||||
expectedStatus: http.StatusForbidden,
|
||||
expectedBody: "Access denied",
|
||||
},
|
||||
{
|
||||
name: "500 Internal Server Error",
|
||||
errorCode: http.StatusInternalServerError,
|
||||
errorMessage: "Internal server error",
|
||||
isAjax: false,
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedBody: "Internal server error",
|
||||
},
|
||||
{
|
||||
name: "Ajax 401",
|
||||
errorCode: http.StatusUnauthorized,
|
||||
errorMessage: "Token expired",
|
||||
isAjax: true,
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedBody: `{"error":"unauthorized","message":"Token expired"}`,
|
||||
},
|
||||
}
|
||||
|
||||
// Test the authentication test cases
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
// Test with mock session
|
||||
mockSession := &MockSession{values: make(map[string]interface{})}
|
||||
// Use mock session to avoid unused variable error
|
||||
_ = mockSession
|
||||
t.Logf("Testing %s", test.name)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("RecoverFromPanic", func(t *testing.T) {
|
||||
// Test with mock implementations
|
||||
/*
|
||||
handler := &MockErrorHandler{
|
||||
logger: &MockLogger{},
|
||||
}
|
||||
*/
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
panicValue interface{}
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "String panic",
|
||||
panicValue: "something went wrong",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Error panic",
|
||||
panicValue: errors.New("critical error"),
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Nil panic",
|
||||
panicValue: nil,
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
// Test the authentication test cases
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
// Test with mock session
|
||||
mockSession := &MockSession{values: make(map[string]interface{})}
|
||||
// Use mock session to avoid unused variable error
|
||||
_ = mockSession
|
||||
t.Logf("Testing %s", test.name)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Azure OAuth Callback Tests
|
||||
// ============================================================================
|
||||
|
||||
func TestAzureOAuthCallback(t *testing.T) {
|
||||
t.Run("AzureSpecificClaims", func(t *testing.T) {
|
||||
// Test with mock configuration
|
||||
/*
|
||||
handler := &OAuthHandler{
|
||||
logger: &MockLogger{},
|
||||
sessionManager: NewMockSessionManager(),
|
||||
}
|
||||
*/
|
||||
|
||||
azureClaims := map[string]interface{}{
|
||||
"oid": "object-id",
|
||||
"tid": "tenant-id",
|
||||
"preferred_username": "user@example.com",
|
||||
"name": "Test User",
|
||||
"email": "user@example.com",
|
||||
"groups": []string{"group1", "group2"},
|
||||
}
|
||||
|
||||
// Test would go here when properly implemented
|
||||
_ = azureClaims
|
||||
})
|
||||
|
||||
t.Run("AzureTokenValidation", func(t *testing.T) {
|
||||
// Test with mock validator types
|
||||
/*
|
||||
validator := &MockAzureTokenValidator{
|
||||
tenantID: "test-tenant",
|
||||
clientID: "test-client",
|
||||
}
|
||||
*/
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
claims map[string]interface{}
|
||||
expectValid bool
|
||||
}{
|
||||
{
|
||||
name: "Valid Azure token",
|
||||
token: "valid-azure-token",
|
||||
claims: map[string]interface{}{
|
||||
"aud": "test-client",
|
||||
"tid": "test-tenant",
|
||||
"exp": float64(time.Now().Add(time.Hour).Unix()),
|
||||
},
|
||||
expectValid: true,
|
||||
},
|
||||
{
|
||||
name: "Wrong tenant",
|
||||
token: "wrong-tenant-token",
|
||||
claims: map[string]interface{}{
|
||||
"aud": "test-client",
|
||||
"tid": "wrong-tenant",
|
||||
"exp": float64(time.Now().Add(time.Hour).Unix()),
|
||||
},
|
||||
expectValid: false,
|
||||
},
|
||||
{
|
||||
name: "Wrong audience",
|
||||
token: "wrong-audience-token",
|
||||
claims: map[string]interface{}{
|
||||
"aud": "wrong-client",
|
||||
"tid": "test-tenant",
|
||||
"exp": float64(time.Now().Add(time.Hour).Unix()),
|
||||
},
|
||||
expectValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
// Test the authentication test cases
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
// Test with mock session
|
||||
mockSession := &MockSession{values: make(map[string]interface{})}
|
||||
// Use mock session to avoid unused variable error
|
||||
_ = mockSession
|
||||
t.Logf("Testing %s", test.name)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Concurrent Handler Tests
|
||||
// ============================================================================
|
||||
|
||||
func TestConcurrentHandlers(t *testing.T) {
|
||||
t.Run("ConcurrentCallbacks", func(t *testing.T) {
|
||||
// Test with mock configuration
|
||||
/*
|
||||
handler := &OAuthHandler{
|
||||
logger: &MockLogger{},
|
||||
sessionManager: NewMockSessionManager(),
|
||||
}
|
||||
*/
|
||||
|
||||
var wg sync.WaitGroup
|
||||
successCount := int32(0)
|
||||
errorCount := int32(0)
|
||||
|
||||
// Test would go here when properly implemented
|
||||
wg.Wait() // Proper usage instead of assignment
|
||||
_ = successCount
|
||||
_ = errorCount
|
||||
})
|
||||
|
||||
t.Run("ConcurrentLogouts", func(t *testing.T) {
|
||||
// Test with mock configuration
|
||||
/*
|
||||
handler := &OAuthHandler{
|
||||
logger: &MockLogger{},
|
||||
sessionManager: NewMockSessionManager(),
|
||||
}
|
||||
*/
|
||||
|
||||
var wg sync.WaitGroup
|
||||
logoutCount := int32(0)
|
||||
|
||||
// Test would go here when properly implemented
|
||||
wg.Wait() // Proper usage instead of assignment
|
||||
_ = logoutCount
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Mock Implementations
|
||||
// ============================================================================
|
||||
|
||||
type MockSessionManager struct {
|
||||
sessions map[string]*MockSession
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewMockSessionManager() *MockSessionManager {
|
||||
return &MockSessionManager{
|
||||
sessions: make(map[string]*MockSession),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockSessionManager) GetSession(r *http.Request) (SessionData, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
sessionID := "test-session"
|
||||
if session, exists := m.sessions[sessionID]; exists {
|
||||
return session, nil
|
||||
}
|
||||
|
||||
session := &MockSession{
|
||||
values: make(map[string]interface{}),
|
||||
}
|
||||
m.sessions[sessionID] = session
|
||||
return session, nil
|
||||
}
|
||||
|
||||
type MockSession struct {
|
||||
values map[string]interface{}
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func (s *MockSession) SetAuthenticated(auth bool) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.values["authenticated"] = auth
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *MockSession) GetAuthenticated() bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
auth, ok := s.values["authenticated"].(bool)
|
||||
return ok && auth
|
||||
}
|
||||
|
||||
func (s *MockSession) SetIDToken(token string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.values["id_token"] = token
|
||||
}
|
||||
|
||||
func (s *MockSession) GetIDToken() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
token, _ := s.values["id_token"].(string)
|
||||
return token
|
||||
}
|
||||
|
||||
func (s *MockSession) SetAccessToken(token string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.values["access_token"] = token
|
||||
}
|
||||
|
||||
func (s *MockSession) GetAccessToken() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
token, _ := s.values["access_token"].(string)
|
||||
return token
|
||||
}
|
||||
|
||||
func (s *MockSession) SetRefreshToken(token string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.values["refresh_token"] = token
|
||||
}
|
||||
|
||||
func (s *MockSession) GetRefreshToken() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
token, _ := s.values["refresh_token"].(string)
|
||||
return token
|
||||
}
|
||||
|
||||
func (s *MockSession) SetState(state string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.values["state"] = state
|
||||
}
|
||||
|
||||
func (s *MockSession) GetState() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
state, _ := s.values["state"].(string)
|
||||
return state
|
||||
}
|
||||
|
||||
func (s *MockSession) SetClaims(claims map[string]interface{}) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.values["claims"] = claims
|
||||
}
|
||||
|
||||
func (s *MockSession) GetClaims() map[string]interface{} {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
claims, _ := s.values["claims"].(map[string]interface{})
|
||||
return claims
|
||||
}
|
||||
|
||||
// Additional SessionData interface methods to match real interface
|
||||
func (s *MockSession) GetCSRF() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
csrf, _ := s.values["csrf"].(string)
|
||||
return csrf
|
||||
}
|
||||
|
||||
func (s *MockSession) GetNonce() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
nonce, _ := s.values["nonce"].(string)
|
||||
return nonce
|
||||
}
|
||||
|
||||
func (s *MockSession) GetCodeVerifier() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
verifier, _ := s.values["code_verifier"].(string)
|
||||
return verifier
|
||||
}
|
||||
|
||||
func (s *MockSession) GetIncomingPath() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
path, _ := s.values["incoming_path"].(string)
|
||||
return path
|
||||
}
|
||||
|
||||
func (s *MockSession) SetEmail(email string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.values["email"] = email
|
||||
}
|
||||
|
||||
func (s *MockSession) GetEmail() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
email, _ := s.values["email"].(string)
|
||||
return email
|
||||
}
|
||||
|
||||
func (s *MockSession) SetCSRF(csrf string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.values["csrf"] = csrf
|
||||
}
|
||||
|
||||
func (s *MockSession) SetNonce(nonce string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.values["nonce"] = nonce
|
||||
}
|
||||
|
||||
func (s *MockSession) SetCodeVerifier(verifier string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.values["code_verifier"] = verifier
|
||||
}
|
||||
|
||||
func (s *MockSession) SetIncomingPath(path string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.values["incoming_path"] = path
|
||||
}
|
||||
|
||||
func (s *MockSession) ResetRedirectCount() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.values["redirect_count"] = 0
|
||||
}
|
||||
|
||||
func (s *MockSession) Save(r *http.Request, w http.ResponseWriter) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *MockSession) Clear() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.values = make(map[string]interface{})
|
||||
}
|
||||
|
||||
func (s *MockSession) returnToPoolSafely() {
|
||||
// No-op for mock
|
||||
}
|
||||
|
||||
type MockTokenValidator struct {
|
||||
valid bool
|
||||
}
|
||||
|
||||
func (v *MockTokenValidator) Validate(token string) bool {
|
||||
if token == "expired-token" {
|
||||
return false
|
||||
}
|
||||
return v.valid
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Mock Handler Type Definitions (for testing)
|
||||
// ============================================================================
|
||||
|
||||
// These mock handlers are simplified versions for testing purposes
|
||||
// They don't match the actual handler implementations
|
||||
|
||||
type MockAuthHandler struct{}
|
||||
|
||||
type MockErrorHandler struct{}
|
||||
|
||||
type MockAzureTokenValidator struct {
|
||||
tenantID string
|
||||
clientID string
|
||||
}
|
||||
|
||||
func (v *MockAzureTokenValidator) ValidateAzureToken(token string, claims map[string]interface{}) bool {
|
||||
// Validate tenant ID
|
||||
if tid, ok := claims["tid"].(string); !ok || tid != v.tenantID {
|
||||
return false
|
||||
}
|
||||
|
||||
// Validate audience
|
||||
if aud, ok := claims["aud"].(string); !ok || aud != v.clientID {
|
||||
return false
|
||||
}
|
||||
|
||||
// Validate expiration
|
||||
if exp, ok := claims["exp"].(float64); ok {
|
||||
if time.Now().Unix() > int64(exp) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Helper Types and Mock Logger
|
||||
// ============================================================================
|
||||
|
||||
type MockLogger struct{}
|
||||
|
||||
func (l *MockLogger) Debugf(format string, args ...interface{}) {}
|
||||
func (l *MockLogger) Errorf(format string, args ...interface{}) {}
|
||||
func (l *MockLogger) Error(msg string) {}
|
||||
|
||||
type MockTokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
IDToken string `json:"id_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
}
|
||||
@@ -1,313 +0,0 @@
|
||||
// Package handlers provides HTTP request handlers for the OIDC middleware.
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// OAuthHandler handles OAuth callback requests
|
||||
type OAuthHandler struct {
|
||||
logger Logger
|
||||
sessionManager SessionManager
|
||||
tokenExchanger TokenExchanger
|
||||
tokenVerifier TokenVerifier
|
||||
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
|
||||
isAllowedDomainFunc func(email string) bool
|
||||
redirURLPath string
|
||||
sendErrorResponseFunc func(rw http.ResponseWriter, req *http.Request, message string, code int)
|
||||
}
|
||||
|
||||
// Logger interface for dependency injection
|
||||
type Logger interface {
|
||||
Debugf(format string, args ...interface{})
|
||||
Errorf(format string, args ...interface{})
|
||||
Error(msg string)
|
||||
}
|
||||
|
||||
// SessionManager interface for session operations
|
||||
type SessionManager interface {
|
||||
GetSession(req *http.Request) (SessionData, error)
|
||||
}
|
||||
|
||||
// SessionData interface for session data operations
|
||||
type SessionData interface {
|
||||
GetCSRF() string
|
||||
GetNonce() string
|
||||
GetCodeVerifier() string
|
||||
GetIncomingPath() string
|
||||
GetAuthenticated() bool
|
||||
GetAccessToken() string
|
||||
GetRefreshToken() string
|
||||
GetIDToken() string
|
||||
GetEmail() string
|
||||
SetAuthenticated(bool) error
|
||||
SetEmail(string)
|
||||
SetIDToken(string)
|
||||
SetAccessToken(string)
|
||||
SetRefreshToken(string)
|
||||
SetCSRF(string)
|
||||
SetNonce(string)
|
||||
SetCodeVerifier(string)
|
||||
SetIncomingPath(string)
|
||||
ResetRedirectCount()
|
||||
Save(req *http.Request, rw http.ResponseWriter) error
|
||||
returnToPoolSafely()
|
||||
}
|
||||
|
||||
// TokenExchanger interface for token operations
|
||||
type TokenExchanger interface {
|
||||
ExchangeCodeForToken(ctx context.Context, grantType string, codeOrToken string, redirectURL string, codeVerifier string) (*TokenResponse, error)
|
||||
}
|
||||
|
||||
// TokenVerifier interface for token verification
|
||||
type TokenVerifier interface {
|
||||
VerifyToken(token string) error
|
||||
}
|
||||
|
||||
// TokenResponse represents the response from token exchange
|
||||
type TokenResponse struct {
|
||||
IDToken string
|
||||
AccessToken string
|
||||
RefreshToken string
|
||||
}
|
||||
|
||||
// NewOAuthHandler creates a new OAuth handler
|
||||
func NewOAuthHandler(logger Logger, sessionManager SessionManager, tokenExchanger TokenExchanger,
|
||||
tokenVerifier TokenVerifier, extractClaimsFunc func(string) (map[string]interface{}, error),
|
||||
isAllowedDomainFunc func(string) bool, redirURLPath string,
|
||||
sendErrorResponseFunc func(http.ResponseWriter, *http.Request, string, int)) *OAuthHandler {
|
||||
|
||||
return &OAuthHandler{
|
||||
logger: logger,
|
||||
sessionManager: sessionManager,
|
||||
tokenExchanger: tokenExchanger,
|
||||
tokenVerifier: tokenVerifier,
|
||||
extractClaimsFunc: extractClaimsFunc,
|
||||
isAllowedDomainFunc: isAllowedDomainFunc,
|
||||
redirURLPath: redirURLPath,
|
||||
sendErrorResponseFunc: sendErrorResponseFunc,
|
||||
}
|
||||
}
|
||||
|
||||
// HandleCallback handles OAuth callback requests
|
||||
func (h *OAuthHandler) HandleCallback(rw http.ResponseWriter, req *http.Request, redirectURL string) {
|
||||
session, err := h.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
h.logger.Errorf("Session error during callback: %v", err)
|
||||
h.sendErrorResponseFunc(rw, req, "Session error during callback", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer session.returnToPoolSafely()
|
||||
|
||||
h.logger.Debugf("Handling callback, URL: %s", req.URL.String())
|
||||
|
||||
// Debug logging for cookie configuration
|
||||
h.logger.Debugf("Callback request headers - Host: %s, X-Forwarded-Host: %s, X-Forwarded-Proto: %s",
|
||||
req.Host, req.Header.Get("X-Forwarded-Host"), req.Header.Get("X-Forwarded-Proto"))
|
||||
|
||||
// Log all cookies in the request for debugging
|
||||
cookies := req.Cookies()
|
||||
h.logger.Debugf("Total cookies in callback request: %d", len(cookies))
|
||||
for _, cookie := range cookies {
|
||||
if strings.HasPrefix(cookie.Name, "_oidc_") {
|
||||
h.logger.Debugf("Cookie found - Name: %s, Domain: %s, Path: %s, SameSite: %v, Secure: %v, HttpOnly: %v, Value length: %d",
|
||||
cookie.Name, cookie.Domain, cookie.Path, cookie.SameSite, cookie.Secure, cookie.HttpOnly, len(cookie.Value))
|
||||
}
|
||||
}
|
||||
|
||||
if req.URL.Query().Get("error") != "" {
|
||||
errorDescription := req.URL.Query().Get("error_description")
|
||||
if errorDescription == "" {
|
||||
errorDescription = req.URL.Query().Get("error")
|
||||
}
|
||||
h.logger.Errorf("Authentication error from provider during callback: %s - %s", req.URL.Query().Get("error"), errorDescription)
|
||||
h.sendErrorResponseFunc(rw, req, fmt.Sprintf("Authentication error from provider: %s", errorDescription), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
state := req.URL.Query().Get("state")
|
||||
if state == "" {
|
||||
h.logger.Error("No state in callback")
|
||||
h.sendErrorResponseFunc(rw, req, "State parameter missing in callback", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Debug log the state parameter received
|
||||
h.logger.Debugf("State parameter received in callback: %s (length: %d)", state, len(state))
|
||||
|
||||
csrfToken := session.GetCSRF()
|
||||
if csrfToken == "" {
|
||||
h.logger.Errorf("CSRF token missing in session during callback. Authenticated: %v, Request URL: %s",
|
||||
session.GetAuthenticated(), req.URL.String())
|
||||
|
||||
// Enhanced debugging for missing CSRF token
|
||||
cookie, err := req.Cookie("_oidc_raczylo_m")
|
||||
if err != nil {
|
||||
h.logger.Errorf("Main session cookie not found in request: %v", err)
|
||||
// Log cookie names only, not values (avoid logging sensitive session data)
|
||||
cookieNames := make([]string, 0, len(req.Cookies()))
|
||||
for _, c := range req.Cookies() {
|
||||
cookieNames = append(cookieNames, c.Name)
|
||||
}
|
||||
h.logger.Debugf("Available cookies (names only): %v", cookieNames)
|
||||
} else {
|
||||
h.logger.Errorf("Main session cookie exists but CSRF token is empty. Cookie value length: %d", len(cookie.Value))
|
||||
h.logger.Debugf("Cookie details - Domain: %s, Path: %s, Secure: %v, HttpOnly: %v, SameSite: %v",
|
||||
cookie.Domain, cookie.Path, cookie.Secure, cookie.HttpOnly, cookie.SameSite)
|
||||
}
|
||||
|
||||
// Log session state for debugging
|
||||
h.logger.Debugf("Session state during CSRF check - Authenticated: %v, Has AccessToken: %v",
|
||||
session.GetAuthenticated(), session.GetAccessToken() != "")
|
||||
|
||||
h.sendErrorResponseFunc(rw, req, "CSRF token missing in session", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Debug log successful CSRF token retrieval
|
||||
h.logger.Debugf("CSRF token retrieved from session: %s (length: %d)", csrfToken, len(csrfToken))
|
||||
|
||||
if state != csrfToken {
|
||||
h.logger.Error("State parameter does not match CSRF token in session during callback")
|
||||
h.sendErrorResponseFunc(rw, req, "Invalid state parameter (CSRF mismatch)", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
code := req.URL.Query().Get("code")
|
||||
if code == "" {
|
||||
h.logger.Error("No code in callback")
|
||||
h.sendErrorResponseFunc(rw, req, "No authorization code received in callback", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
codeVerifier := session.GetCodeVerifier()
|
||||
|
||||
tokenResponse, err := h.tokenExchanger.ExchangeCodeForToken(req.Context(), "authorization_code", code, redirectURL, codeVerifier)
|
||||
if err != nil {
|
||||
h.logger.Errorf("Failed to exchange code for token during callback: %v", err)
|
||||
h.sendErrorResponseFunc(rw, req, "Authentication failed: Could not exchange code for token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if err = h.tokenVerifier.VerifyToken(tokenResponse.IDToken); err != nil {
|
||||
h.logger.Errorf("Failed to verify id_token during callback: %v", err)
|
||||
h.sendErrorResponseFunc(rw, req, "Authentication failed: Could not verify ID token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
claims, err := h.extractClaimsFunc(tokenResponse.IDToken)
|
||||
if err != nil {
|
||||
h.logger.Errorf("Failed to extract claims during callback: %v", err)
|
||||
h.sendErrorResponseFunc(rw, req, "Authentication failed: Could not extract claims from token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
nonceClaim, ok := claims["nonce"].(string)
|
||||
if !ok || nonceClaim == "" {
|
||||
h.logger.Error("Nonce claim missing in id_token during callback")
|
||||
h.sendErrorResponseFunc(rw, req, "Authentication failed: Nonce missing in token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
sessionNonce := session.GetNonce()
|
||||
if sessionNonce == "" {
|
||||
h.logger.Error("Nonce not found in session during callback")
|
||||
h.sendErrorResponseFunc(rw, req, "Authentication failed: Nonce missing in session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if nonceClaim != sessionNonce {
|
||||
h.logger.Error("Nonce claim does not match session nonce during callback")
|
||||
h.sendErrorResponseFunc(rw, req, "Authentication failed: Nonce mismatch", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
email, _ := claims["email"].(string)
|
||||
if email == "" {
|
||||
h.logger.Errorf("Email claim missing or empty in token during callback")
|
||||
h.sendErrorResponseFunc(rw, req, "Authentication failed: Email missing in token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if !h.isAllowedDomainFunc(email) {
|
||||
h.logger.Errorf("Disallowed email domain during callback: %s", email)
|
||||
h.sendErrorResponseFunc(rw, req, "Authentication failed: Email domain not allowed", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
if err := session.SetAuthenticated(true); err != nil {
|
||||
h.logger.Errorf("Failed to set authenticated state and regenerate session ID: %v", err)
|
||||
h.sendErrorResponseFunc(rw, req, "Failed to update session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
session.SetEmail(email)
|
||||
session.SetIDToken(tokenResponse.IDToken)
|
||||
session.SetAccessToken(tokenResponse.AccessToken)
|
||||
session.SetRefreshToken(tokenResponse.RefreshToken)
|
||||
|
||||
session.SetCSRF("")
|
||||
session.SetNonce("")
|
||||
session.SetCodeVerifier("")
|
||||
|
||||
session.ResetRedirectCount()
|
||||
|
||||
redirectPath := "/"
|
||||
if incomingPath := session.GetIncomingPath(); incomingPath != "" && incomingPath != h.redirURLPath {
|
||||
redirectPath = incomingPath
|
||||
}
|
||||
session.SetIncomingPath("")
|
||||
|
||||
if err := session.Save(req, rw); err != nil {
|
||||
h.logger.Errorf("Failed to save session after callback: %v", err)
|
||||
h.sendErrorResponseFunc(rw, req, "Failed to save session after callback", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Debugf("Callback successful, redirecting to %s", redirectPath)
|
||||
http.Redirect(rw, req, redirectPath, http.StatusFound)
|
||||
}
|
||||
|
||||
// URLHelper provides utility methods for URL operations
|
||||
type URLHelper struct {
|
||||
logger Logger
|
||||
}
|
||||
|
||||
// NewURLHelper creates a new URL helper
|
||||
func NewURLHelper(logger Logger) *URLHelper {
|
||||
return &URLHelper{logger: logger}
|
||||
}
|
||||
|
||||
// DetermineExcludedURL checks if a URL path should bypass OIDC authentication.
|
||||
// It compares the request path against configured excluded URL prefixes.
|
||||
func (h *URLHelper) DetermineExcludedURL(currentRequest string, excludedURLs map[string]struct{}) bool {
|
||||
for excludedURL := range excludedURLs {
|
||||
if strings.HasPrefix(currentRequest, excludedURL) {
|
||||
h.logger.Debugf("URL is excluded - got %s / excluded hit: %s", currentRequest, excludedURL)
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// DetermineScheme determines the URL scheme for building redirect URLs.
|
||||
// It checks X-Forwarded-Proto header first, then TLS presence.
|
||||
func (h *URLHelper) DetermineScheme(req *http.Request) string {
|
||||
if scheme := req.Header.Get("X-Forwarded-Proto"); scheme != "" {
|
||||
return scheme
|
||||
}
|
||||
if req.TLS != nil {
|
||||
return "https"
|
||||
}
|
||||
return "http"
|
||||
}
|
||||
|
||||
// DetermineHost determines the host for building redirect URLs.
|
||||
// It checks X-Forwarded-Host header first, then falls back to req.Host.
|
||||
func (h *URLHelper) DetermineHost(req *http.Request) string {
|
||||
if host := req.Header.Get("X-Forwarded-Host"); host != "" {
|
||||
return host
|
||||
}
|
||||
return req.Host
|
||||
}
|
||||
@@ -1,899 +0,0 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// Test mocks - implementing interfaces defined in oauth_handler.go
|
||||
type mockLogger struct {
|
||||
debugMessages []string
|
||||
errorMessages []string
|
||||
}
|
||||
|
||||
func (l *mockLogger) Debugf(format string, args ...interface{}) {
|
||||
l.debugMessages = append(l.debugMessages, format)
|
||||
}
|
||||
|
||||
func (l *mockLogger) Errorf(format string, args ...interface{}) {
|
||||
l.errorMessages = append(l.errorMessages, format)
|
||||
}
|
||||
|
||||
func (l *mockLogger) Error(msg string) {
|
||||
l.errorMessages = append(l.errorMessages, msg)
|
||||
}
|
||||
|
||||
type mockSessionManager struct {
|
||||
sessionToReturn SessionData
|
||||
errorToReturn error
|
||||
}
|
||||
|
||||
func (m *mockSessionManager) GetSession(req *http.Request) (SessionData, error) {
|
||||
return m.sessionToReturn, m.errorToReturn
|
||||
}
|
||||
|
||||
type mockSessionData struct {
|
||||
authenticated bool
|
||||
email string
|
||||
csrf string
|
||||
nonce string
|
||||
codeVerifier string
|
||||
incomingPath string
|
||||
accessToken string
|
||||
refreshToken string
|
||||
idToken string
|
||||
saveError error
|
||||
setAuthError error
|
||||
}
|
||||
|
||||
func (s *mockSessionData) GetCSRF() string { return s.csrf }
|
||||
func (s *mockSessionData) GetNonce() string { return s.nonce }
|
||||
func (s *mockSessionData) GetCodeVerifier() string { return s.codeVerifier }
|
||||
func (s *mockSessionData) GetIncomingPath() string { return s.incomingPath }
|
||||
func (s *mockSessionData) GetAuthenticated() bool { return s.authenticated }
|
||||
func (s *mockSessionData) GetAccessToken() string { return s.accessToken }
|
||||
func (s *mockSessionData) GetRefreshToken() string { return s.refreshToken }
|
||||
func (s *mockSessionData) GetIDToken() string { return s.idToken }
|
||||
func (s *mockSessionData) GetEmail() string { return s.email }
|
||||
|
||||
func (s *mockSessionData) SetAuthenticated(auth bool) error {
|
||||
s.authenticated = auth
|
||||
return s.setAuthError
|
||||
}
|
||||
|
||||
func (s *mockSessionData) SetEmail(email string) { s.email = email }
|
||||
func (s *mockSessionData) SetIDToken(token string) { s.idToken = token }
|
||||
func (s *mockSessionData) SetAccessToken(token string) { s.accessToken = token }
|
||||
func (s *mockSessionData) SetRefreshToken(token string) { s.refreshToken = token }
|
||||
func (s *mockSessionData) SetCSRF(csrf string) { s.csrf = csrf }
|
||||
func (s *mockSessionData) SetNonce(nonce string) { s.nonce = nonce }
|
||||
func (s *mockSessionData) SetCodeVerifier(verif string) { s.codeVerifier = verif }
|
||||
func (s *mockSessionData) SetIncomingPath(path string) { s.incomingPath = path }
|
||||
func (s *mockSessionData) ResetRedirectCount() {}
|
||||
func (s *mockSessionData) returnToPoolSafely() {}
|
||||
|
||||
func (s *mockSessionData) Save(req *http.Request, rw http.ResponseWriter) error {
|
||||
return s.saveError
|
||||
}
|
||||
|
||||
type mockTokenExchanger struct {
|
||||
response *TokenResponse
|
||||
err error
|
||||
}
|
||||
|
||||
func (e *mockTokenExchanger) ExchangeCodeForToken(ctx context.Context, grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error) {
|
||||
return e.response, e.err
|
||||
}
|
||||
|
||||
type mockTokenVerifier struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (v *mockTokenVerifier) VerifyToken(token string) error {
|
||||
return v.err
|
||||
}
|
||||
|
||||
// TestOAuthHandler_NewOAuthHandler tests the constructor
|
||||
func TestOAuthHandler_NewOAuthHandler(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
sessionManager := &mockSessionManager{}
|
||||
tokenExchanger := &mockTokenExchanger{}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{"email": "test@example.com", "nonce": "test-nonce"}, nil
|
||||
}
|
||||
|
||||
isAllowed := func(email string) bool { return true }
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
if handler == nil {
|
||||
t.Fatal("Expected handler to be created, got nil")
|
||||
}
|
||||
|
||||
if handler.logger != logger {
|
||||
t.Error("Logger not set correctly")
|
||||
}
|
||||
|
||||
if handler.redirURLPath != "/callback" {
|
||||
t.Errorf("Expected redirURLPath '/callback', got '%s'", handler.redirURLPath)
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_SessionError tests session retrieval errors
|
||||
func TestOAuthHandler_HandleCallback_SessionError(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
sessionManager := &mockSessionManager{errorToReturn: errors.New("session error")}
|
||||
tokenExchanger := &mockTokenExchanger{}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) {
|
||||
return nil, nil
|
||||
}
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusInternalServerError {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
|
||||
}
|
||||
if !strings.Contains(msg, "Session error") {
|
||||
t.Errorf("Expected error message to contain 'Session error', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test&state=test", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
|
||||
if len(logger.errorMessages) == 0 {
|
||||
t.Error("Expected error to be logged")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_ProviderError tests OAuth provider errors
|
||||
func TestOAuthHandler_HandleCallback_ProviderError(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{csrf: "test-state", nonce: "test-nonce"}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenExchanger := &mockTokenExchanger{}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) { return nil, nil }
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusBadRequest {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusBadRequest, code)
|
||||
}
|
||||
if !strings.Contains(msg, "Authentication error from provider") {
|
||||
t.Errorf("Expected error message to contain 'Authentication error from provider', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
// Test with error parameter
|
||||
req := httptest.NewRequest("GET", "/callback?error=access_denied&error_description=User%20denied%20access", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
|
||||
if len(logger.errorMessages) == 0 {
|
||||
t.Error("Expected error to be logged")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_MissingState tests missing state parameter
|
||||
func TestOAuthHandler_HandleCallback_MissingState(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenExchanger := &mockTokenExchanger{}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) { return nil, nil }
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusBadRequest {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusBadRequest, code)
|
||||
}
|
||||
if !strings.Contains(msg, "State parameter missing") {
|
||||
t.Errorf("Expected error message to contain 'State parameter missing', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_MissingCSRF tests missing CSRF token in session
|
||||
func TestOAuthHandler_HandleCallback_MissingCSRF(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{csrf: ""} // Empty CSRF
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenExchanger := &mockTokenExchanger{}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) { return nil, nil }
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusBadRequest {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusBadRequest, code)
|
||||
}
|
||||
if !strings.Contains(msg, "CSRF token missing") {
|
||||
t.Errorf("Expected error message to contain 'CSRF token missing', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_CSRFMismatch tests CSRF token mismatch
|
||||
func TestOAuthHandler_HandleCallback_CSRFMismatch(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{csrf: "different-token"}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenExchanger := &mockTokenExchanger{}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) { return nil, nil }
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusBadRequest {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusBadRequest, code)
|
||||
}
|
||||
if !strings.Contains(msg, "CSRF mismatch") {
|
||||
t.Errorf("Expected error message to contain 'CSRF mismatch', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_MissingCode tests missing authorization code
|
||||
func TestOAuthHandler_HandleCallback_MissingCode(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{csrf: "test-state"}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenExchanger := &mockTokenExchanger{}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) { return nil, nil }
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusBadRequest {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusBadRequest, code)
|
||||
}
|
||||
if !strings.Contains(msg, "No authorization code received") {
|
||||
t.Errorf("Expected error message to contain 'No authorization code received', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_TokenExchangeError tests token exchange failure
|
||||
func TestOAuthHandler_HandleCallback_TokenExchangeError(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{csrf: "test-state", nonce: "test-nonce", codeVerifier: "test-verifier"}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenExchanger := &mockTokenExchanger{err: errors.New("token exchange failed")}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) { return nil, nil }
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusInternalServerError {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
|
||||
}
|
||||
if !strings.Contains(msg, "Could not exchange code for token") {
|
||||
t.Errorf("Expected error message to contain 'Could not exchange code for token', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_TokenVerificationError tests token verification failure
|
||||
func TestOAuthHandler_HandleCallback_TokenVerificationError(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{csrf: "test-state", nonce: "test-nonce"}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenResponse := &TokenResponse{IDToken: "invalid-token", AccessToken: "access-token"}
|
||||
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
|
||||
tokenVerifier := &mockTokenVerifier{err: errors.New("token verification failed")}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) { return nil, nil }
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusInternalServerError {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
|
||||
}
|
||||
if !strings.Contains(msg, "Could not verify ID token") {
|
||||
t.Errorf("Expected error message to contain 'Could not verify ID token', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_ClaimsExtractionError tests claims extraction failure
|
||||
func TestOAuthHandler_HandleCallback_ClaimsExtractionError(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{csrf: "test-state", nonce: "test-nonce"}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenResponse := &TokenResponse{IDToken: "valid-token", AccessToken: "access-token"}
|
||||
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) {
|
||||
return nil, errors.New("claims extraction failed")
|
||||
}
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusInternalServerError {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
|
||||
}
|
||||
if !strings.Contains(msg, "Could not extract claims") {
|
||||
t.Errorf("Expected error message to contain 'Could not extract claims', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_MissingNonceInToken tests missing nonce in token
|
||||
func TestOAuthHandler_HandleCallback_MissingNonceInToken(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{csrf: "test-state", nonce: "test-nonce"}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenResponse := &TokenResponse{IDToken: "valid-token", AccessToken: "access-token"}
|
||||
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
// Claims without nonce
|
||||
extractClaims := func(token string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{"email": "test@example.com"}, nil
|
||||
}
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusInternalServerError {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
|
||||
}
|
||||
if !strings.Contains(msg, "Nonce missing in token") {
|
||||
t.Errorf("Expected error message to contain 'Nonce missing in token', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_MissingNonceInSession tests missing nonce in session
|
||||
func TestOAuthHandler_HandleCallback_MissingNonceInSession(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{csrf: "test-state", nonce: ""} // Empty nonce
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenResponse := &TokenResponse{IDToken: "valid-token", AccessToken: "access-token"}
|
||||
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{"email": "test@example.com", "nonce": "test-nonce"}, nil
|
||||
}
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusInternalServerError {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
|
||||
}
|
||||
if !strings.Contains(msg, "Nonce missing in session") {
|
||||
t.Errorf("Expected error message to contain 'Nonce missing in session', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_NonceMismatch tests nonce mismatch
|
||||
func TestOAuthHandler_HandleCallback_NonceMismatch(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{csrf: "test-state", nonce: "session-nonce"}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenResponse := &TokenResponse{IDToken: "valid-token", AccessToken: "access-token"}
|
||||
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{"email": "test@example.com", "nonce": "token-nonce"}, nil
|
||||
}
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusInternalServerError {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
|
||||
}
|
||||
if !strings.Contains(msg, "Nonce mismatch") {
|
||||
t.Errorf("Expected error message to contain 'Nonce mismatch', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_MissingEmail tests missing email in claims
|
||||
func TestOAuthHandler_HandleCallback_MissingEmail(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{csrf: "test-state", nonce: "test-nonce"}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenResponse := &TokenResponse{IDToken: "valid-token", AccessToken: "access-token"}
|
||||
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{"nonce": "test-nonce"}, nil // No email
|
||||
}
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusInternalServerError {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
|
||||
}
|
||||
if !strings.Contains(msg, "Email missing in token") {
|
||||
t.Errorf("Expected error message to contain 'Email missing in token', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_DisallowedDomain tests disallowed email domain
|
||||
func TestOAuthHandler_HandleCallback_DisallowedDomain(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{csrf: "test-state", nonce: "test-nonce"}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenResponse := &TokenResponse{IDToken: "valid-token", AccessToken: "access-token"}
|
||||
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{"email": "test@disallowed.com", "nonce": "test-nonce"}, nil
|
||||
}
|
||||
isAllowed := func(email string) bool { return false } // Disallow all domains
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusForbidden {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusForbidden, code)
|
||||
}
|
||||
if !strings.Contains(msg, "Email domain not allowed") {
|
||||
t.Errorf("Expected error message to contain 'Email domain not allowed', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_SessionSaveError tests session save failure
|
||||
func TestOAuthHandler_HandleCallback_SessionSaveError(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{
|
||||
csrf: "test-state",
|
||||
nonce: "test-nonce",
|
||||
saveError: errors.New("save failed"),
|
||||
}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenResponse := &TokenResponse{IDToken: "valid-token", AccessToken: "access-token", RefreshToken: "refresh-token"}
|
||||
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{"email": "test@example.com", "nonce": "test-nonce"}, nil
|
||||
}
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusInternalServerError {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
|
||||
}
|
||||
if !strings.Contains(msg, "Failed to save session") {
|
||||
t.Errorf("Expected error message to contain 'Failed to save session', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_SetAuthenticatedError tests SetAuthenticated failure
|
||||
func TestOAuthHandler_HandleCallback_SetAuthenticatedError(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{
|
||||
csrf: "test-state",
|
||||
nonce: "test-nonce",
|
||||
setAuthError: errors.New("set auth failed"),
|
||||
}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenResponse := &TokenResponse{IDToken: "valid-token", AccessToken: "access-token"}
|
||||
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{"email": "test@example.com", "nonce": "test-nonce"}, nil
|
||||
}
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusInternalServerError {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
|
||||
}
|
||||
if !strings.Contains(msg, "Failed to update session") {
|
||||
t.Errorf("Expected error message to contain 'Failed to update session', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_Success tests successful callback handling
|
||||
func TestOAuthHandler_HandleCallback_Success(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{
|
||||
csrf: "test-state",
|
||||
nonce: "test-nonce",
|
||||
incomingPath: "/dashboard",
|
||||
}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenResponse := &TokenResponse{
|
||||
IDToken: "valid-id-token",
|
||||
AccessToken: "valid-access-token",
|
||||
RefreshToken: "valid-refresh-token",
|
||||
}
|
||||
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{"email": "test@example.com", "nonce": "test-nonce"}, nil
|
||||
}
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
t.Errorf("Unexpected error sent: %s (code: %d)", msg, code)
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if errorSent {
|
||||
t.Error("Unexpected error response sent")
|
||||
}
|
||||
|
||||
// Check redirect
|
||||
if rw.Code != http.StatusFound {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusFound, rw.Code)
|
||||
}
|
||||
|
||||
location := rw.Header().Get("Location")
|
||||
if location != "/dashboard" {
|
||||
t.Errorf("Expected redirect to '/dashboard', got '%s'", location)
|
||||
}
|
||||
|
||||
// Verify session data was set correctly
|
||||
if session.email != "test@example.com" {
|
||||
t.Errorf("Expected email 'test@example.com', got '%s'", session.email)
|
||||
}
|
||||
|
||||
if session.idToken != "valid-id-token" {
|
||||
t.Errorf("Expected ID token 'valid-id-token', got '%s'", session.idToken)
|
||||
}
|
||||
|
||||
if session.accessToken != "valid-access-token" {
|
||||
t.Errorf("Expected access token 'valid-access-token', got '%s'", session.accessToken)
|
||||
}
|
||||
|
||||
if session.refreshToken != "valid-refresh-token" {
|
||||
t.Errorf("Expected refresh token 'valid-refresh-token', got '%s'", session.refreshToken)
|
||||
}
|
||||
|
||||
if !session.authenticated {
|
||||
t.Error("Expected session to be authenticated")
|
||||
}
|
||||
|
||||
// Check that temporary fields are cleared
|
||||
if session.csrf != "" {
|
||||
t.Errorf("Expected CSRF to be cleared, got '%s'", session.csrf)
|
||||
}
|
||||
|
||||
if session.nonce != "" {
|
||||
t.Errorf("Expected nonce to be cleared, got '%s'", session.nonce)
|
||||
}
|
||||
|
||||
if session.codeVerifier != "" {
|
||||
t.Errorf("Expected code verifier to be cleared, got '%s'", session.codeVerifier)
|
||||
}
|
||||
|
||||
if session.incomingPath != "" {
|
||||
t.Errorf("Expected incoming path to be cleared, got '%s'", session.incomingPath)
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_SuccessDefaultRedirect tests successful callback with default redirect
|
||||
func TestOAuthHandler_HandleCallback_SuccessDefaultRedirect(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{
|
||||
csrf: "test-state",
|
||||
nonce: "test-nonce",
|
||||
incomingPath: "", // No incoming path, should default to "/"
|
||||
}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenResponse := &TokenResponse{IDToken: "valid-token", AccessToken: "access-token"}
|
||||
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{"email": "test@example.com", "nonce": "test-nonce"}, nil
|
||||
}
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
t.Errorf("Unexpected error sent: %s (code: %d)", msg, code)
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
// Check redirect to default path
|
||||
if rw.Code != http.StatusFound {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusFound, rw.Code)
|
||||
}
|
||||
|
||||
location := rw.Header().Get("Location")
|
||||
if location != "/" {
|
||||
t.Errorf("Expected redirect to '/', got '%s'", location)
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_RedirectURLPathExcluded tests incoming path same as redirect URL
|
||||
func TestOAuthHandler_HandleCallback_RedirectURLPathExcluded(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{
|
||||
csrf: "test-state",
|
||||
nonce: "test-nonce",
|
||||
incomingPath: "/callback", // Same as redirect URL path
|
||||
}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenResponse := &TokenResponse{IDToken: "valid-token", AccessToken: "access-token"}
|
||||
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{"email": "test@example.com", "nonce": "test-nonce"}, nil
|
||||
}
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
t.Errorf("Unexpected error sent: %s (code: %d)", msg, code)
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
// Should redirect to default path when incoming path is same as callback path
|
||||
location := rw.Header().Get("Location")
|
||||
if location != "/" {
|
||||
t.Errorf("Expected redirect to '/', got '%s'", location)
|
||||
}
|
||||
}
|
||||
@@ -1,454 +0,0 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestURLHelper_NewURLHelper tests the URLHelper constructor
|
||||
func TestURLHelper_NewURLHelper(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
helper := NewURLHelper(logger)
|
||||
|
||||
if helper == nil {
|
||||
t.Fatal("Expected URLHelper to be created, got nil")
|
||||
}
|
||||
|
||||
if helper.logger != logger {
|
||||
t.Error("Logger not set correctly")
|
||||
}
|
||||
}
|
||||
|
||||
// TestURLHelper_DetermineExcludedURL tests URL exclusion checking
|
||||
func TestURLHelper_DetermineExcludedURL(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
helper := NewURLHelper(logger)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
currentURL string
|
||||
excludedURLs map[string]struct{}
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "Exact match",
|
||||
currentURL: "/health",
|
||||
excludedURLs: map[string]struct{}{
|
||||
"/health": {},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Prefix match",
|
||||
currentURL: "/health/status",
|
||||
excludedURLs: map[string]struct{}{
|
||||
"/health": {},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "No match",
|
||||
currentURL: "/api/users",
|
||||
excludedURLs: map[string]struct{}{
|
||||
"/health": {},
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Multiple exclusions - first match",
|
||||
currentURL: "/api/health",
|
||||
excludedURLs: map[string]struct{}{
|
||||
"/api": {},
|
||||
"/health": {},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Multiple exclusions - second match",
|
||||
currentURL: "/health/check",
|
||||
excludedURLs: map[string]struct{}{
|
||||
"/api": {},
|
||||
"/health": {},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Empty excluded URLs",
|
||||
currentURL: "/api/users",
|
||||
excludedURLs: map[string]struct{}{},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Root path exclusion",
|
||||
currentURL: "/anything",
|
||||
excludedURLs: map[string]struct{}{
|
||||
"/": {},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Case sensitive matching",
|
||||
currentURL: "/API/users",
|
||||
excludedURLs: map[string]struct{}{
|
||||
"/api": {},
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Partial substring but not prefix",
|
||||
currentURL: "/user/api/test",
|
||||
excludedURLs: map[string]struct{}{
|
||||
"/api": {},
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Empty current URL",
|
||||
currentURL: "",
|
||||
excludedURLs: map[string]struct{}{
|
||||
"/health": {},
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "URL with query parameters",
|
||||
currentURL: "/health?status=ok",
|
||||
excludedURLs: map[string]struct{}{
|
||||
"/health": {},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := helper.DetermineExcludedURL(tt.currentURL, tt.excludedURLs)
|
||||
if result != tt.expected {
|
||||
t.Errorf("DetermineExcludedURL() = %v, expected %v", result, tt.expected)
|
||||
}
|
||||
|
||||
// Verify debug logging for excluded URLs
|
||||
if result && len(logger.debugMessages) > 0 {
|
||||
// Should have logged a debug message for excluded URL
|
||||
found := false
|
||||
for _, msg := range logger.debugMessages {
|
||||
if msg == "URL is excluded - got %s / excluded hit: %s" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("Expected debug message for excluded URL")
|
||||
}
|
||||
}
|
||||
|
||||
// Reset logger messages for next test
|
||||
logger.debugMessages = nil
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestURLHelper_DetermineScheme tests scheme determination
|
||||
func TestURLHelper_DetermineScheme(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
helper := NewURLHelper(logger)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupRequest func() *http.Request
|
||||
expectedScheme string
|
||||
}{
|
||||
{
|
||||
name: "X-Forwarded-Proto header present - https",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
return req
|
||||
},
|
||||
expectedScheme: "https",
|
||||
},
|
||||
{
|
||||
name: "X-Forwarded-Proto header present - http",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
req.Header.Set("X-Forwarded-Proto", "http")
|
||||
return req
|
||||
},
|
||||
expectedScheme: "http",
|
||||
},
|
||||
{
|
||||
name: "TLS connection without X-Forwarded-Proto",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "https://example.com", nil)
|
||||
req.TLS = &tls.ConnectionState{} // Simulate TLS connection
|
||||
return req
|
||||
},
|
||||
expectedScheme: "https",
|
||||
},
|
||||
{
|
||||
name: "No TLS and no X-Forwarded-Proto",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
return req
|
||||
},
|
||||
expectedScheme: "http",
|
||||
},
|
||||
{
|
||||
name: "X-Forwarded-Proto takes precedence over TLS",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "https://example.com", nil)
|
||||
req.TLS = &tls.ConnectionState{} // Simulate TLS connection
|
||||
req.Header.Set("X-Forwarded-Proto", "http")
|
||||
return req
|
||||
},
|
||||
expectedScheme: "http",
|
||||
},
|
||||
{
|
||||
name: "Empty X-Forwarded-Proto falls back to TLS",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "https://example.com", nil)
|
||||
req.TLS = &tls.ConnectionState{} // Simulate TLS connection
|
||||
req.Header.Set("X-Forwarded-Proto", "")
|
||||
return req
|
||||
},
|
||||
expectedScheme: "https",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := tt.setupRequest()
|
||||
result := helper.DetermineScheme(req)
|
||||
if result != tt.expectedScheme {
|
||||
t.Errorf("DetermineScheme() = %v, expected %v", result, tt.expectedScheme)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestURLHelper_DetermineHost tests host determination
|
||||
func TestURLHelper_DetermineHost(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
helper := NewURLHelper(logger)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupRequest func() *http.Request
|
||||
expectedHost string
|
||||
}{
|
||||
{
|
||||
name: "X-Forwarded-Host header present",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
req.Host = "internal.example.com"
|
||||
req.Header.Set("X-Forwarded-Host", "public.example.com")
|
||||
return req
|
||||
},
|
||||
expectedHost: "public.example.com",
|
||||
},
|
||||
{
|
||||
name: "No X-Forwarded-Host, use req.Host",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
req.Host = "direct.example.com"
|
||||
return req
|
||||
},
|
||||
expectedHost: "direct.example.com",
|
||||
},
|
||||
{
|
||||
name: "Empty X-Forwarded-Host falls back to req.Host",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
req.Host = "fallback.example.com"
|
||||
req.Header.Set("X-Forwarded-Host", "")
|
||||
return req
|
||||
},
|
||||
expectedHost: "fallback.example.com",
|
||||
},
|
||||
{
|
||||
name: "X-Forwarded-Host with port",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
req.Host = "internal.example.com:8080"
|
||||
req.Header.Set("X-Forwarded-Host", "public.example.com:443")
|
||||
return req
|
||||
},
|
||||
expectedHost: "public.example.com:443",
|
||||
},
|
||||
{
|
||||
name: "req.Host with port",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "http://example.com:8080", nil)
|
||||
req.Host = "example.com:8080"
|
||||
return req
|
||||
},
|
||||
expectedHost: "example.com:8080",
|
||||
},
|
||||
{
|
||||
name: "Multiple X-Forwarded-Host values (first one used)",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
req.Host = "internal.example.com"
|
||||
req.Header.Set("X-Forwarded-Host", "first.example.com, second.example.com")
|
||||
return req
|
||||
},
|
||||
expectedHost: "first.example.com, second.example.com", // Header value as-is
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := tt.setupRequest()
|
||||
result := helper.DetermineHost(req)
|
||||
if result != tt.expectedHost {
|
||||
t.Errorf("DetermineHost() = %v, expected %v", result, tt.expectedHost)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestURLHelper_DetermineSchemeAndHost_Integration tests scheme and host working together
|
||||
func TestURLHelper_DetermineSchemeAndHost_Integration(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
helper := NewURLHelper(logger)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupRequest func() *http.Request
|
||||
expectedScheme string
|
||||
expectedHost string
|
||||
}{
|
||||
{
|
||||
name: "Both headers present",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "http://internal.example.com", nil)
|
||||
req.Host = "internal.example.com"
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
req.Header.Set("X-Forwarded-Host", "public.example.com")
|
||||
return req
|
||||
},
|
||||
expectedScheme: "https",
|
||||
expectedHost: "public.example.com",
|
||||
},
|
||||
{
|
||||
name: "Neither header present, TLS connection",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "https://secure.example.com", nil)
|
||||
req.Host = "secure.example.com"
|
||||
req.TLS = &tls.ConnectionState{} // Simulate TLS connection
|
||||
return req
|
||||
},
|
||||
expectedScheme: "https",
|
||||
expectedHost: "secure.example.com",
|
||||
},
|
||||
{
|
||||
name: "Neither header present, no TLS",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "http://plain.example.com", nil)
|
||||
req.Host = "plain.example.com"
|
||||
return req
|
||||
},
|
||||
expectedScheme: "http",
|
||||
expectedHost: "plain.example.com",
|
||||
},
|
||||
{
|
||||
name: "Mixed - only scheme header",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "http://mixed.example.com", nil)
|
||||
req.Host = "mixed.example.com"
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
return req
|
||||
},
|
||||
expectedScheme: "https",
|
||||
expectedHost: "mixed.example.com",
|
||||
},
|
||||
{
|
||||
name: "Mixed - only host header",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "http://mixed.example.com", nil)
|
||||
req.Host = "internal.example.com"
|
||||
req.Header.Set("X-Forwarded-Host", "external.example.com")
|
||||
return req
|
||||
},
|
||||
expectedScheme: "http",
|
||||
expectedHost: "external.example.com",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := tt.setupRequest()
|
||||
|
||||
scheme := helper.DetermineScheme(req)
|
||||
host := helper.DetermineHost(req)
|
||||
|
||||
if scheme != tt.expectedScheme {
|
||||
t.Errorf("DetermineScheme() = %v, expected %v", scheme, tt.expectedScheme)
|
||||
}
|
||||
|
||||
if host != tt.expectedHost {
|
||||
t.Errorf("DetermineHost() = %v, expected %v", host, tt.expectedHost)
|
||||
}
|
||||
|
||||
// Test that we can build a complete URL
|
||||
fullURL := scheme + "://" + host + "/callback"
|
||||
expectedURL := tt.expectedScheme + "://" + tt.expectedHost + "/callback"
|
||||
if fullURL != expectedURL {
|
||||
t.Errorf("Combined URL = %v, expected %v", fullURL, expectedURL)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark tests to ensure the helper methods are performant
|
||||
func BenchmarkURLHelper_DetermineExcludedURL(b *testing.B) {
|
||||
logger := &mockLogger{}
|
||||
helper := NewURLHelper(logger)
|
||||
excludedURLs := map[string]struct{}{
|
||||
"/health": {},
|
||||
"/metrics": {},
|
||||
"/status": {},
|
||||
"/api/v1": {},
|
||||
"/api/v2": {},
|
||||
"/static": {},
|
||||
"/assets": {},
|
||||
"/favicon": {},
|
||||
"/robots": {},
|
||||
"/sitemap": {},
|
||||
}
|
||||
|
||||
testURL := "/api/users"
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
helper.DetermineExcludedURL(testURL, excludedURLs)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkURLHelper_DetermineScheme(b *testing.B) {
|
||||
logger := &mockLogger{}
|
||||
helper := NewURLHelper(logger)
|
||||
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
helper.DetermineScheme(req)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkURLHelper_DetermineHost(b *testing.B) {
|
||||
logger := &mockLogger{}
|
||||
helper := NewURLHelper(logger)
|
||||
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
req.Host = "internal.example.com"
|
||||
req.Header.Set("X-Forwarded-Host", "external.example.com")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
helper.DetermineHost(req)
|
||||
}
|
||||
}
|
||||
+4
-2
@@ -13,6 +13,8 @@ import (
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/utils"
|
||||
)
|
||||
|
||||
// generateNonce creates a cryptographically secure random nonce for OIDC flows.
|
||||
@@ -349,8 +351,8 @@ func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
host := t.determineHost(req)
|
||||
scheme := t.determineScheme(req)
|
||||
host := utils.DetermineHost(req)
|
||||
scheme := utils.DetermineScheme(req, t.forceHTTPS)
|
||||
baseURL := fmt.Sprintf("%s://%s", scheme, host)
|
||||
|
||||
postLogoutRedirectURI := t.postLogoutRedirectURI
|
||||
|
||||
+52
-47
@@ -15,20 +15,21 @@ import (
|
||||
// XSS, path traversal, and other injection attacks. It validates and sanitizes
|
||||
// various input types used in OIDC authentication flows.
|
||||
type InputValidator struct {
|
||||
usernameRegex *regexp.Regexp
|
||||
tokenRegex *regexp.Regexp
|
||||
logger *Logger
|
||||
urlRegex *regexp.Regexp
|
||||
emailRegex *regexp.Regexp
|
||||
sqlInjectionPatterns []string
|
||||
pathTraversalPatterns []string
|
||||
xssPatterns []string
|
||||
maxUsernameLength int
|
||||
maxURLLength int
|
||||
maxTokenLength int
|
||||
maxEmailLength int
|
||||
maxClaimLength int
|
||||
maxHeaderLength int
|
||||
usernameRegex *regexp.Regexp
|
||||
tokenRegex *regexp.Regexp
|
||||
logger *Logger
|
||||
urlRegex *regexp.Regexp
|
||||
emailRegex *regexp.Regexp
|
||||
sqlInjectionPatterns []string
|
||||
pathTraversalPatterns []string
|
||||
xssPatterns []string
|
||||
maxUsernameLength int
|
||||
maxURLLength int
|
||||
maxTokenLength int
|
||||
maxEmailLength int
|
||||
maxClaimLength int
|
||||
maxHeaderLength int
|
||||
allowPrivateIPAddresses bool // Allow private IP addresses in URL validation
|
||||
}
|
||||
|
||||
// ValidationResult encapsulates the outcome of input validation.
|
||||
@@ -46,13 +47,14 @@ type ValidationResult struct {
|
||||
// It specifies maximum lengths for various input types and controls whether
|
||||
// strict validation mode is enabled.
|
||||
type InputValidationConfig struct {
|
||||
MaxTokenLength int `json:"max_token_length"`
|
||||
MaxURLLength int `json:"max_url_length"`
|
||||
MaxHeaderLength int `json:"max_header_length"`
|
||||
MaxClaimLength int `json:"max_claim_length"`
|
||||
MaxEmailLength int `json:"max_email_length"`
|
||||
MaxUsernameLength int `json:"max_username_length"`
|
||||
StrictMode bool `json:"strict_mode"`
|
||||
MaxTokenLength int `json:"max_token_length"`
|
||||
MaxURLLength int `json:"max_url_length"`
|
||||
MaxHeaderLength int `json:"max_header_length"`
|
||||
MaxClaimLength int `json:"max_claim_length"`
|
||||
MaxEmailLength int `json:"max_email_length"`
|
||||
MaxUsernameLength int `json:"max_username_length"`
|
||||
StrictMode bool `json:"strict_mode"`
|
||||
AllowPrivateIPAddresses bool `json:"allow_private_ip_addresses"` // Allow private IP addresses in URL validation
|
||||
}
|
||||
|
||||
// DefaultInputValidationConfig returns a secure default configuration
|
||||
@@ -103,16 +105,17 @@ func NewInputValidator(config InputValidationConfig, logger *Logger) (*InputVali
|
||||
}
|
||||
|
||||
return &InputValidator{
|
||||
maxTokenLength: config.MaxTokenLength,
|
||||
maxURLLength: config.MaxURLLength,
|
||||
maxHeaderLength: config.MaxHeaderLength,
|
||||
maxClaimLength: config.MaxClaimLength,
|
||||
maxEmailLength: config.MaxEmailLength,
|
||||
maxUsernameLength: config.MaxUsernameLength,
|
||||
emailRegex: emailRegex,
|
||||
urlRegex: urlRegex,
|
||||
tokenRegex: tokenRegex,
|
||||
usernameRegex: usernameRegex,
|
||||
maxTokenLength: config.MaxTokenLength,
|
||||
maxURLLength: config.MaxURLLength,
|
||||
maxHeaderLength: config.MaxHeaderLength,
|
||||
maxClaimLength: config.MaxClaimLength,
|
||||
maxEmailLength: config.MaxEmailLength,
|
||||
maxUsernameLength: config.MaxUsernameLength,
|
||||
allowPrivateIPAddresses: config.AllowPrivateIPAddresses,
|
||||
emailRegex: emailRegex,
|
||||
urlRegex: urlRegex,
|
||||
tokenRegex: tokenRegex,
|
||||
usernameRegex: usernameRegex,
|
||||
sqlInjectionPatterns: []string{
|
||||
"'", "\"", ";", "--", "/*", "*/", "xp_", "sp_",
|
||||
"union", "select", "insert", "update", "delete", "drop",
|
||||
@@ -335,24 +338,26 @@ func (iv *InputValidator) ValidateURL(urlStr string) ValidationResult {
|
||||
}
|
||||
}
|
||||
|
||||
// Check for private IP ranges (RFC 1918)
|
||||
if strings.HasPrefix(hostname, "10.") ||
|
||||
strings.HasPrefix(hostname, "192.168.") ||
|
||||
strings.HasPrefix(hostname, "172.") {
|
||||
// For 172.x check if it's in the 172.16.0.0/12 range
|
||||
if strings.HasPrefix(hostname, "172.") {
|
||||
parts := strings.Split(hostname, ".")
|
||||
if len(parts) >= 2 {
|
||||
if second, err := strconv.Atoi(parts[1]); err == nil && second >= 16 && second <= 31 {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, "private IP URLs are not allowed for security")
|
||||
return result
|
||||
// Check for private IP ranges (RFC 1918) - skip if allowPrivateIPAddresses is enabled
|
||||
if !iv.allowPrivateIPAddresses {
|
||||
if strings.HasPrefix(hostname, "10.") ||
|
||||
strings.HasPrefix(hostname, "192.168.") ||
|
||||
strings.HasPrefix(hostname, "172.") {
|
||||
// For 172.x check if it's in the 172.16.0.0/12 range
|
||||
if strings.HasPrefix(hostname, "172.") {
|
||||
parts := strings.Split(hostname, ".")
|
||||
if len(parts) >= 2 {
|
||||
if second, err := strconv.Atoi(parts[1]); err == nil && second >= 16 && second <= 31 {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, "private IP URLs are not allowed for security")
|
||||
return result
|
||||
}
|
||||
}
|
||||
} else {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, "private IP URLs are not allowed for security")
|
||||
return result
|
||||
}
|
||||
} else {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, "private IP URLs are not allowed for security")
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Vendored
+2
-2
@@ -76,7 +76,7 @@ func NewRedisBackend(config *Config) (*RedisBackend, error) {
|
||||
|
||||
// Test connectivity
|
||||
if err := backend.Ping(context.Background()); err != nil {
|
||||
pool.Close()
|
||||
_ = pool.Close()
|
||||
return nil, fmt.Errorf("failed to ping Redis: %w", err)
|
||||
}
|
||||
|
||||
@@ -263,7 +263,7 @@ func (r *RedisBackend) Clear(ctx context.Context) error {
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
conn.Do("DEL", key) // Best effort, ignore errors
|
||||
_, _ = conn.Do("DEL", key) // Best effort, ignore errors
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
+23
-22
@@ -82,7 +82,7 @@ func (p *ConnectionPool) Get(ctx context.Context) (*RedisConn, error) {
|
||||
// Reuse existing connection - validate if health check enabled
|
||||
if p.config.EnableHealthCheck && !p.isConnectionHealthy(conn) {
|
||||
// Connection is stale, close it and try again
|
||||
conn.Close()
|
||||
_ = conn.Close()
|
||||
p.totalConns.Add(-1)
|
||||
continue
|
||||
}
|
||||
@@ -94,6 +94,7 @@ func (p *ConnectionPool) Get(ctx context.Context) (*RedisConn, error) {
|
||||
|
||||
default:
|
||||
// No available connection, create new one if under limit
|
||||
// #nosec G115 -- MaxConnections is a small config value that fits in int32
|
||||
if p.totalConns.Load() < int32(p.config.MaxConnections) {
|
||||
conn, err = p.createConnection()
|
||||
if err != nil {
|
||||
@@ -115,7 +116,7 @@ func (p *ConnectionPool) Get(ctx context.Context) (*RedisConn, error) {
|
||||
case conn = <-p.connections:
|
||||
// Validate connection if health check enabled
|
||||
if p.config.EnableHealthCheck && !p.isConnectionHealthy(conn) {
|
||||
conn.Close()
|
||||
_ = conn.Close()
|
||||
p.totalConns.Add(-1)
|
||||
continue
|
||||
}
|
||||
@@ -144,7 +145,7 @@ func (p *ConnectionPool) Put(conn *RedisConn) {
|
||||
p.activeConns.Add(-1)
|
||||
|
||||
if p.closed.Load() || conn.closed.Load() {
|
||||
conn.Close()
|
||||
_ = conn.Close()
|
||||
p.totalConns.Add(-1)
|
||||
return
|
||||
}
|
||||
@@ -155,7 +156,7 @@ func (p *ConnectionPool) Put(conn *RedisConn) {
|
||||
// Successfully returned to pool
|
||||
default:
|
||||
// Pool full, close connection
|
||||
conn.Close()
|
||||
_ = conn.Close()
|
||||
p.totalConns.Add(-1)
|
||||
}
|
||||
}
|
||||
@@ -173,7 +174,7 @@ func (p *ConnectionPool) Close() error {
|
||||
|
||||
// Close all pooled connections
|
||||
for conn := range p.connections {
|
||||
conn.Close()
|
||||
_ = conn.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -212,7 +213,7 @@ func (p *ConnectionPool) createConnection() (*RedisConn, error) {
|
||||
// Authenticate if password is provided
|
||||
if p.config.Password != "" {
|
||||
if _, err := redisConn.Do("AUTH", p.config.Password); err != nil {
|
||||
redisConn.Close()
|
||||
_ = redisConn.Close()
|
||||
return nil, fmt.Errorf("authentication failed: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -220,7 +221,7 @@ func (p *ConnectionPool) createConnection() (*RedisConn, error) {
|
||||
// Select database
|
||||
if p.config.DB != 0 {
|
||||
if _, err := redisConn.Do("SELECT", fmt.Sprintf("%d", p.config.DB)); err != nil {
|
||||
redisConn.Close()
|
||||
_ = redisConn.Close()
|
||||
return nil, fmt.Errorf("failed to select database: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -246,15 +247,15 @@ func (c *RedisConn) Do(command string, args ...string) (interface{}, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
// Build command arguments
|
||||
// Check for overflow: ensure len(args)+1 doesn't cause allocation overflow
|
||||
// Limit to a safe value that prevents integer overflow in allocation size calculation
|
||||
// (capacity * sizeof(string) must fit in int/size_t)
|
||||
argsLen := len(args)
|
||||
const maxSafeArgs = (1 << 20) - 1 // 1M args is already absurdly large for Redis commands
|
||||
if argsLen < 0 || argsLen > maxSafeArgs {
|
||||
return nil, errors.New("too many arguments")
|
||||
// Validate argument count to prevent integer overflow in slice operations
|
||||
// maxSafeArgs is set to (1<<20)-1 = 1,048,575 which is more than any reasonable Redis command
|
||||
const maxSafeArgs = (1 << 20) - 1
|
||||
if len(args) > maxSafeArgs {
|
||||
return nil, errors.New("too many arguments: exceeds maximum safe count")
|
||||
}
|
||||
|
||||
// Build command arguments
|
||||
// Validate total argument size to prevent memory exhaustion
|
||||
const maxTotalArgBytes = 64 << 20 // 64 MiB max total size
|
||||
totalBytes := len(command)
|
||||
for _, s := range args {
|
||||
@@ -267,13 +268,13 @@ func (c *RedisConn) Do(command string, args ...string) (interface{}, error) {
|
||||
return nil, errors.New("total argument size exceeds maximum allowed")
|
||||
}
|
||||
}
|
||||
cmdArgs := make([]string, 0, argsLen+1)
|
||||
cmdArgs = append(cmdArgs, command)
|
||||
cmdArgs = append(cmdArgs, args...)
|
||||
// Build command slice: prepend command to args
|
||||
// Using append avoids arithmetic on potentially large len(args)
|
||||
cmdArgs := append([]string{command}, args...)
|
||||
|
||||
// Set write timeout
|
||||
if c.writeTimeout > 0 {
|
||||
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
|
||||
_ = c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
|
||||
}
|
||||
|
||||
// Write command (using pooled writer for memory efficiency)
|
||||
@@ -287,7 +288,7 @@ func (c *RedisConn) Do(command string, args ...string) (interface{}, error) {
|
||||
|
||||
// Set read timeout
|
||||
if c.readTimeout > 0 {
|
||||
c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
|
||||
_ = c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
|
||||
}
|
||||
|
||||
// Read response (using pooled reader for memory efficiency)
|
||||
@@ -328,8 +329,8 @@ func (p *ConnectionPool) isConnectionHealthy(conn *RedisConn) bool {
|
||||
|
||||
// Set a read deadline for the ping
|
||||
if conn.conn != nil {
|
||||
conn.conn.SetReadDeadline(time.Now().Add(1 * time.Second))
|
||||
defer conn.conn.SetReadDeadline(time.Time{}) // Clear deadline
|
||||
_ = conn.conn.SetReadDeadline(time.Now().Add(1 * time.Second))
|
||||
defer func() { _ = conn.conn.SetReadDeadline(time.Time{}) }() // Clear deadline
|
||||
}
|
||||
|
||||
_, err := conn.Do("PING")
|
||||
|
||||
@@ -158,6 +158,7 @@ func (cb *CircuitBreaker) AllowRequest() bool {
|
||||
case StateHalfOpen:
|
||||
// Allow limited requests in half-open state
|
||||
current := cb.halfOpenRequests.Add(1)
|
||||
// #nosec G115 -- HalfOpenMaxRequests is a small config value that fits in int32
|
||||
return current <= int32(cb.config.HalfOpenMaxRequests)
|
||||
|
||||
default:
|
||||
@@ -181,6 +182,7 @@ func (cb *CircuitBreaker) RecordSuccess() {
|
||||
case StateHalfOpen:
|
||||
// If we've had enough successful requests, close the circuit
|
||||
successfulRequests := cb.halfOpenRequests.Load()
|
||||
// #nosec G115 -- HalfOpenMaxRequests is a small config value that fits in int32
|
||||
if successfulRequests >= int32(cb.config.HalfOpenMaxRequests) {
|
||||
cb.setState(StateClosed)
|
||||
cb.consecutiveFailures.Store(0)
|
||||
@@ -203,6 +205,7 @@ func (cb *CircuitBreaker) RecordFailure() {
|
||||
switch state {
|
||||
case StateClosed:
|
||||
// Check if we should open the circuit
|
||||
// #nosec G115 -- MaxFailures is a small config value that fits in int32
|
||||
if failures >= int32(cb.config.MaxFailures) {
|
||||
cb.openCircuit()
|
||||
} else if cb.config.FailureThreshold > 0 {
|
||||
|
||||
+2
@@ -217,6 +217,7 @@ func (hc *HealthChecker) recordSuccess(latency time.Duration) {
|
||||
newStatus := currentStatus
|
||||
|
||||
// Check if we should become healthy
|
||||
// #nosec G115 -- HealthyThreshold is a small config value that fits in int32
|
||||
if successes >= int32(hc.config.HealthyThreshold) {
|
||||
if latency > hc.config.DegradedThreshold {
|
||||
newStatus = HealthDegraded
|
||||
@@ -241,6 +242,7 @@ func (hc *HealthChecker) recordFailure() {
|
||||
hc.timeMu.Unlock()
|
||||
|
||||
// Check if we should become unhealthy
|
||||
// #nosec G115 -- UnhealthyThreshold is a small config value that fits in int32
|
||||
if failures >= int32(hc.config.UnhealthyThreshold) {
|
||||
hc.setStatus(HealthUnhealthy)
|
||||
}
|
||||
|
||||
@@ -150,6 +150,7 @@ func (h *HealthCheckBackend) IsHealthy() bool {
|
||||
|
||||
// recordResult records the result of an operation for health tracking
|
||||
func (h *HealthCheckBackend) recordResult(success bool) {
|
||||
// #nosec G115 -- threshold config values are small integers that fit in int32
|
||||
if success {
|
||||
fails := h.consecutiveFails.Swap(0)
|
||||
oks := h.consecutiveOK.Add(1)
|
||||
|
||||
@@ -1,219 +0,0 @@
|
||||
// Package errors provides unified error handling for OIDC operations
|
||||
package errors
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// ErrorCode represents specific error types
|
||||
type ErrorCode string
|
||||
|
||||
const (
|
||||
// Authentication errors
|
||||
ErrCodeAuthenticationFailed ErrorCode = "AUTH_FAILED"
|
||||
ErrCodeTokenExpired ErrorCode = "TOKEN_EXPIRED"
|
||||
ErrCodeTokenInvalid ErrorCode = "TOKEN_INVALID"
|
||||
ErrCodeSessionExpired ErrorCode = "SESSION_EXPIRED"
|
||||
ErrCodeCSRFMismatch ErrorCode = "CSRF_MISMATCH"
|
||||
ErrCodeNonceMismatch ErrorCode = "NONCE_MISMATCH"
|
||||
|
||||
// Configuration errors
|
||||
ErrCodeConfigInvalid ErrorCode = "CONFIG_INVALID"
|
||||
ErrCodeProviderUnreachable ErrorCode = "PROVIDER_UNREACHABLE"
|
||||
ErrCodeMetadataFailed ErrorCode = "METADATA_FAILED"
|
||||
|
||||
// Network errors
|
||||
ErrCodeNetworkTimeout ErrorCode = "NETWORK_TIMEOUT"
|
||||
ErrCodeRateLimited ErrorCode = "RATE_LIMITED"
|
||||
ErrCodeServiceUnavailable ErrorCode = "SERVICE_UNAVAILABLE"
|
||||
|
||||
// Validation errors
|
||||
ErrCodeValidationFailed ErrorCode = "VALIDATION_FAILED"
|
||||
ErrCodeDomainNotAllowed ErrorCode = "DOMAIN_NOT_ALLOWED"
|
||||
ErrCodeUserNotAllowed ErrorCode = "USER_NOT_ALLOWED"
|
||||
ErrCodeRoleNotAllowed ErrorCode = "ROLE_NOT_ALLOWED"
|
||||
)
|
||||
|
||||
// OIDCError represents a structured error with context
|
||||
type OIDCError struct {
|
||||
Code ErrorCode `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Details string `json:"details,omitempty"`
|
||||
HTTPStatus int `json:"http_status"`
|
||||
Internal error `json:"-"` // Internal error, not exposed
|
||||
}
|
||||
|
||||
// Error implements the error interface
|
||||
func (e *OIDCError) Error() string {
|
||||
if e.Details != "" {
|
||||
return fmt.Sprintf("%s: %s (%s)", e.Code, e.Message, e.Details)
|
||||
}
|
||||
return fmt.Sprintf("%s: %s", e.Code, e.Message)
|
||||
}
|
||||
|
||||
// Unwrap returns the internal error for error wrapping
|
||||
func (e *OIDCError) Unwrap() error {
|
||||
return e.Internal
|
||||
}
|
||||
|
||||
// IsRetryable indicates if the error is temporary and can be retried
|
||||
func (e *OIDCError) IsRetryable() bool {
|
||||
return e.Code == ErrCodeNetworkTimeout ||
|
||||
e.Code == ErrCodeServiceUnavailable ||
|
||||
e.Code == ErrCodeProviderUnreachable
|
||||
}
|
||||
|
||||
// IsAuthenticationError indicates if this is an authentication-related error
|
||||
func (e *OIDCError) IsAuthenticationError() bool {
|
||||
return e.Code == ErrCodeAuthenticationFailed ||
|
||||
e.Code == ErrCodeTokenExpired ||
|
||||
e.Code == ErrCodeTokenInvalid ||
|
||||
e.Code == ErrCodeSessionExpired ||
|
||||
e.Code == ErrCodeCSRFMismatch ||
|
||||
e.Code == ErrCodeNonceMismatch
|
||||
}
|
||||
|
||||
// IsAuthorizationError indicates if this is an authorization-related error
|
||||
func (e *OIDCError) IsAuthorizationError() bool {
|
||||
return e.Code == ErrCodeDomainNotAllowed ||
|
||||
e.Code == ErrCodeUserNotAllowed ||
|
||||
e.Code == ErrCodeRoleNotAllowed
|
||||
}
|
||||
|
||||
// ToJSON converts the error to a JSON response
|
||||
func (e *OIDCError) ToJSON() map[string]any {
|
||||
result := map[string]any{
|
||||
"error": map[string]any{
|
||||
"code": string(e.Code),
|
||||
"message": e.Message,
|
||||
},
|
||||
}
|
||||
|
||||
if e.Details != "" {
|
||||
errorMap, _ := result["error"].(map[string]any) // Safe to ignore: type assertion from known type
|
||||
errorMap["details"] = e.Details
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// Error constructors for common scenarios
|
||||
|
||||
// NewAuthenticationError creates an authentication-related error
|
||||
func NewAuthenticationError(code ErrorCode, message string, internal error) *OIDCError {
|
||||
status := http.StatusUnauthorized
|
||||
if code == ErrCodeSessionExpired {
|
||||
status = http.StatusForbidden
|
||||
}
|
||||
|
||||
return &OIDCError{
|
||||
Code: code,
|
||||
Message: message,
|
||||
HTTPStatus: status,
|
||||
Internal: internal,
|
||||
}
|
||||
}
|
||||
|
||||
// NewAuthorizationError creates an authorization-related error
|
||||
func NewAuthorizationError(code ErrorCode, message string, details string) *OIDCError {
|
||||
return &OIDCError{
|
||||
Code: code,
|
||||
Message: message,
|
||||
Details: details,
|
||||
HTTPStatus: http.StatusForbidden,
|
||||
}
|
||||
}
|
||||
|
||||
// NewConfigurationError creates a configuration-related error
|
||||
func NewConfigurationError(code ErrorCode, message string, internal error) *OIDCError {
|
||||
return &OIDCError{
|
||||
Code: code,
|
||||
Message: message,
|
||||
HTTPStatus: http.StatusInternalServerError,
|
||||
Internal: internal,
|
||||
}
|
||||
}
|
||||
|
||||
// NewNetworkError creates a network-related error
|
||||
func NewNetworkError(code ErrorCode, message string, internal error) *OIDCError {
|
||||
status := http.StatusServiceUnavailable
|
||||
if code == ErrCodeRateLimited {
|
||||
status = http.StatusTooManyRequests
|
||||
}
|
||||
|
||||
return &OIDCError{
|
||||
Code: code,
|
||||
Message: message,
|
||||
HTTPStatus: status,
|
||||
Internal: internal,
|
||||
}
|
||||
}
|
||||
|
||||
// NewValidationError creates a validation-related error
|
||||
func NewValidationError(code ErrorCode, message string, details string) *OIDCError {
|
||||
return &OIDCError{
|
||||
Code: code,
|
||||
Message: message,
|
||||
Details: details,
|
||||
HTTPStatus: http.StatusBadRequest,
|
||||
}
|
||||
}
|
||||
|
||||
// Convenience functions for common error patterns
|
||||
|
||||
// WrapAuthenticationError wraps an existing error as an authentication error
|
||||
func WrapAuthenticationError(err error, message string) *OIDCError {
|
||||
return NewAuthenticationError(ErrCodeAuthenticationFailed, message, err)
|
||||
}
|
||||
|
||||
// WrapTokenError wraps a token-related error
|
||||
func WrapTokenError(err error, tokenType string) *OIDCError {
|
||||
message := fmt.Sprintf("Token validation failed: %s", tokenType)
|
||||
return NewAuthenticationError(ErrCodeTokenInvalid, message, err)
|
||||
}
|
||||
|
||||
// WrapProviderError wraps a provider communication error
|
||||
func WrapProviderError(err error, providerURL string) *OIDCError {
|
||||
message := fmt.Sprintf("Provider communication failed: %s", providerURL)
|
||||
return NewNetworkError(ErrCodeProviderUnreachable, message, err)
|
||||
}
|
||||
|
||||
// IsOIDCError checks if an error is an OIDCError
|
||||
func IsOIDCError(err error) (*OIDCError, bool) {
|
||||
oidcErr, ok := err.(*OIDCError)
|
||||
return oidcErr, ok
|
||||
}
|
||||
|
||||
// GetHTTPStatus extracts HTTP status from error, defaulting to 500
|
||||
func GetHTTPStatus(err error) int {
|
||||
if oidcErr, ok := IsOIDCError(err); ok {
|
||||
return oidcErr.HTTPStatus
|
||||
}
|
||||
return http.StatusInternalServerError
|
||||
}
|
||||
|
||||
// FormatUserMessage creates a user-friendly error message
|
||||
func FormatUserMessage(err error) string {
|
||||
if oidcErr, ok := IsOIDCError(err); ok {
|
||||
switch oidcErr.Code {
|
||||
case ErrCodeDomainNotAllowed:
|
||||
return "Your email domain is not authorized for this application"
|
||||
case ErrCodeUserNotAllowed:
|
||||
return "Your account is not authorized for this application"
|
||||
case ErrCodeRoleNotAllowed:
|
||||
return "You do not have the required permissions for this application"
|
||||
case ErrCodeSessionExpired:
|
||||
return "Your session has expired. Please log in again"
|
||||
case ErrCodeTokenExpired:
|
||||
return "Your authentication has expired. Please log in again"
|
||||
case ErrCodeProviderUnreachable:
|
||||
return "Authentication service is temporarily unavailable. Please try again later"
|
||||
case ErrCodeRateLimited:
|
||||
return "Too many requests. Please wait a moment and try again"
|
||||
default:
|
||||
return "Authentication failed. Please try again"
|
||||
}
|
||||
}
|
||||
return "An unexpected error occurred. Please try again"
|
||||
}
|
||||
@@ -1,529 +0,0 @@
|
||||
package errors
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestOIDCError_Error(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
oidcErr *OIDCError
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Error with details",
|
||||
oidcErr: &OIDCError{
|
||||
Code: ErrCodeTokenInvalid,
|
||||
Message: "Token validation failed",
|
||||
Details: "JWT signature invalid",
|
||||
},
|
||||
expected: "TOKEN_INVALID: Token validation failed (JWT signature invalid)",
|
||||
},
|
||||
{
|
||||
name: "Error without details",
|
||||
oidcErr: &OIDCError{
|
||||
Code: ErrCodeAuthenticationFailed,
|
||||
Message: "Authentication failed",
|
||||
},
|
||||
expected: "AUTH_FAILED: Authentication failed",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := tt.oidcErr.Error()
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected '%s', got '%s'", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOIDCError_Unwrap(t *testing.T) {
|
||||
internalErr := errors.New("internal error")
|
||||
oidcErr := &OIDCError{
|
||||
Code: ErrCodeTokenInvalid,
|
||||
Message: "Token validation failed",
|
||||
Internal: internalErr,
|
||||
}
|
||||
|
||||
unwrapped := oidcErr.Unwrap()
|
||||
if unwrapped != internalErr {
|
||||
t.Errorf("Expected internal error, got %v", unwrapped)
|
||||
}
|
||||
|
||||
// Test with nil internal error
|
||||
oidcErrNoInternal := &OIDCError{
|
||||
Code: ErrCodeTokenInvalid,
|
||||
Message: "Token validation failed",
|
||||
}
|
||||
|
||||
unwrappedNil := oidcErrNoInternal.Unwrap()
|
||||
if unwrappedNil != nil {
|
||||
t.Errorf("Expected nil, got %v", unwrappedNil)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOIDCError_IsRetryable(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
code ErrorCode
|
||||
expected bool
|
||||
}{
|
||||
{"Network timeout", ErrCodeNetworkTimeout, true},
|
||||
{"Service unavailable", ErrCodeServiceUnavailable, true},
|
||||
{"Provider unreachable", ErrCodeProviderUnreachable, true},
|
||||
{"Authentication failed", ErrCodeAuthenticationFailed, false},
|
||||
{"Token invalid", ErrCodeTokenInvalid, false},
|
||||
{"Rate limited", ErrCodeRateLimited, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
oidcErr := &OIDCError{Code: tt.code}
|
||||
result := oidcErr.IsRetryable()
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected %v, got %v for code %s", tt.expected, result, tt.code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOIDCError_IsAuthenticationError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
code ErrorCode
|
||||
expected bool
|
||||
}{
|
||||
{"Authentication failed", ErrCodeAuthenticationFailed, true},
|
||||
{"Token expired", ErrCodeTokenExpired, true},
|
||||
{"Token invalid", ErrCodeTokenInvalid, true},
|
||||
{"Session expired", ErrCodeSessionExpired, true},
|
||||
{"CSRF mismatch", ErrCodeCSRFMismatch, true},
|
||||
{"Nonce mismatch", ErrCodeNonceMismatch, true},
|
||||
{"Config invalid", ErrCodeConfigInvalid, false},
|
||||
{"Domain not allowed", ErrCodeDomainNotAllowed, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
oidcErr := &OIDCError{Code: tt.code}
|
||||
result := oidcErr.IsAuthenticationError()
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected %v, got %v for code %s", tt.expected, result, tt.code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOIDCError_IsAuthorizationError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
code ErrorCode
|
||||
expected bool
|
||||
}{
|
||||
{"Domain not allowed", ErrCodeDomainNotAllowed, true},
|
||||
{"User not allowed", ErrCodeUserNotAllowed, true},
|
||||
{"Role not allowed", ErrCodeRoleNotAllowed, true},
|
||||
{"Authentication failed", ErrCodeAuthenticationFailed, false},
|
||||
{"Token expired", ErrCodeTokenExpired, false},
|
||||
{"Config invalid", ErrCodeConfigInvalid, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
oidcErr := &OIDCError{Code: tt.code}
|
||||
result := oidcErr.IsAuthorizationError()
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected %v, got %v for code %s", tt.expected, result, tt.code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOIDCError_ToJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
oidcErr *OIDCError
|
||||
expected map[string]any
|
||||
}{
|
||||
{
|
||||
name: "Error with details",
|
||||
oidcErr: &OIDCError{
|
||||
Code: ErrCodeTokenInvalid,
|
||||
Message: "Token validation failed",
|
||||
Details: "JWT signature invalid",
|
||||
},
|
||||
expected: map[string]any{
|
||||
"error": map[string]any{
|
||||
"code": "TOKEN_INVALID",
|
||||
"message": "Token validation failed",
|
||||
"details": "JWT signature invalid",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Error without details",
|
||||
oidcErr: &OIDCError{
|
||||
Code: ErrCodeAuthenticationFailed,
|
||||
Message: "Authentication failed",
|
||||
},
|
||||
expected: map[string]any{
|
||||
"error": map[string]any{
|
||||
"code": "AUTH_FAILED",
|
||||
"message": "Authentication failed",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := tt.oidcErr.ToJSON()
|
||||
if !reflect.DeepEqual(result, tt.expected) {
|
||||
t.Errorf("Expected %+v, got %+v", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAuthenticationError(t *testing.T) {
|
||||
internalErr := errors.New("internal error")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
code ErrorCode
|
||||
message string
|
||||
internal error
|
||||
expectedHTTP int
|
||||
}{
|
||||
{
|
||||
name: "Regular auth error",
|
||||
code: ErrCodeAuthenticationFailed,
|
||||
message: "Auth failed",
|
||||
internal: internalErr,
|
||||
expectedHTTP: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "Session expired error",
|
||||
code: ErrCodeSessionExpired,
|
||||
message: "Session expired",
|
||||
internal: internalErr,
|
||||
expectedHTTP: http.StatusForbidden,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := NewAuthenticationError(tt.code, tt.message, tt.internal)
|
||||
|
||||
if err.Code != tt.code {
|
||||
t.Errorf("Expected code %s, got %s", tt.code, err.Code)
|
||||
}
|
||||
if err.Message != tt.message {
|
||||
t.Errorf("Expected message '%s', got '%s'", tt.message, err.Message)
|
||||
}
|
||||
if err.Internal != tt.internal {
|
||||
t.Errorf("Expected internal error %v, got %v", tt.internal, err.Internal)
|
||||
}
|
||||
if err.HTTPStatus != tt.expectedHTTP {
|
||||
t.Errorf("Expected HTTP status %d, got %d", tt.expectedHTTP, err.HTTPStatus)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAuthorizationError(t *testing.T) {
|
||||
err := NewAuthorizationError(ErrCodeDomainNotAllowed, "Domain not allowed", "example.com not in whitelist")
|
||||
|
||||
if err.Code != ErrCodeDomainNotAllowed {
|
||||
t.Errorf("Expected code %s, got %s", ErrCodeDomainNotAllowed, err.Code)
|
||||
}
|
||||
if err.Message != "Domain not allowed" {
|
||||
t.Errorf("Expected message 'Domain not allowed', got '%s'", err.Message)
|
||||
}
|
||||
if err.Details != "example.com not in whitelist" {
|
||||
t.Errorf("Expected details 'example.com not in whitelist', got '%s'", err.Details)
|
||||
}
|
||||
if err.HTTPStatus != http.StatusForbidden {
|
||||
t.Errorf("Expected HTTP status %d, got %d", http.StatusForbidden, err.HTTPStatus)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewConfigurationError(t *testing.T) {
|
||||
internalErr := errors.New("config parse error")
|
||||
err := NewConfigurationError(ErrCodeConfigInvalid, "Invalid config", internalErr)
|
||||
|
||||
if err.Code != ErrCodeConfigInvalid {
|
||||
t.Errorf("Expected code %s, got %s", ErrCodeConfigInvalid, err.Code)
|
||||
}
|
||||
if err.HTTPStatus != http.StatusInternalServerError {
|
||||
t.Errorf("Expected HTTP status %d, got %d", http.StatusInternalServerError, err.HTTPStatus)
|
||||
}
|
||||
if err.Internal != internalErr {
|
||||
t.Errorf("Expected internal error %v, got %v", internalErr, err.Internal)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewNetworkError(t *testing.T) {
|
||||
internalErr := errors.New("network error")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
code ErrorCode
|
||||
expectedHTTP int
|
||||
}{
|
||||
{
|
||||
name: "Rate limited",
|
||||
code: ErrCodeRateLimited,
|
||||
expectedHTTP: http.StatusTooManyRequests,
|
||||
},
|
||||
{
|
||||
name: "Service unavailable",
|
||||
code: ErrCodeServiceUnavailable,
|
||||
expectedHTTP: http.StatusServiceUnavailable,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := NewNetworkError(tt.code, "Network error", internalErr)
|
||||
|
||||
if err.Code != tt.code {
|
||||
t.Errorf("Expected code %s, got %s", tt.code, err.Code)
|
||||
}
|
||||
if err.HTTPStatus != tt.expectedHTTP {
|
||||
t.Errorf("Expected HTTP status %d, got %d", tt.expectedHTTP, err.HTTPStatus)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewValidationError(t *testing.T) {
|
||||
err := NewValidationError(ErrCodeValidationFailed, "Validation failed", "field 'email' is required")
|
||||
|
||||
if err.Code != ErrCodeValidationFailed {
|
||||
t.Errorf("Expected code %s, got %s", ErrCodeValidationFailed, err.Code)
|
||||
}
|
||||
if err.HTTPStatus != http.StatusBadRequest {
|
||||
t.Errorf("Expected HTTP status %d, got %d", http.StatusBadRequest, err.HTTPStatus)
|
||||
}
|
||||
if err.Details != "field 'email' is required" {
|
||||
t.Errorf("Expected details 'field 'email' is required', got '%s'", err.Details)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWrapAuthenticationError(t *testing.T) {
|
||||
internalErr := errors.New("original error")
|
||||
err := WrapAuthenticationError(internalErr, "Custom auth message")
|
||||
|
||||
if err.Code != ErrCodeAuthenticationFailed {
|
||||
t.Errorf("Expected code %s, got %s", ErrCodeAuthenticationFailed, err.Code)
|
||||
}
|
||||
if err.Message != "Custom auth message" {
|
||||
t.Errorf("Expected message 'Custom auth message', got '%s'", err.Message)
|
||||
}
|
||||
if err.Internal != internalErr {
|
||||
t.Errorf("Expected internal error %v, got %v", internalErr, err.Internal)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWrapTokenError(t *testing.T) {
|
||||
internalErr := errors.New("token error")
|
||||
err := WrapTokenError(internalErr, "ID token")
|
||||
|
||||
if err.Code != ErrCodeTokenInvalid {
|
||||
t.Errorf("Expected code %s, got %s", ErrCodeTokenInvalid, err.Code)
|
||||
}
|
||||
if err.Message != "Token validation failed: ID token" {
|
||||
t.Errorf("Expected message 'Token validation failed: ID token', got '%s'", err.Message)
|
||||
}
|
||||
if err.Internal != internalErr {
|
||||
t.Errorf("Expected internal error %v, got %v", internalErr, err.Internal)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWrapProviderError(t *testing.T) {
|
||||
internalErr := errors.New("provider error")
|
||||
err := WrapProviderError(internalErr, "https://provider.example.com")
|
||||
|
||||
if err.Code != ErrCodeProviderUnreachable {
|
||||
t.Errorf("Expected code %s, got %s", ErrCodeProviderUnreachable, err.Code)
|
||||
}
|
||||
if err.Message != "Provider communication failed: https://provider.example.com" {
|
||||
t.Errorf("Expected specific message, got '%s'", err.Message)
|
||||
}
|
||||
if err.Internal != internalErr {
|
||||
t.Errorf("Expected internal error %v, got %v", internalErr, err.Internal)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsOIDCError(t *testing.T) {
|
||||
// Test with OIDCError
|
||||
oidcErr := &OIDCError{Code: ErrCodeTokenInvalid, Message: "test"}
|
||||
result, ok := IsOIDCError(oidcErr)
|
||||
if !ok {
|
||||
t.Error("Expected IsOIDCError to return true for OIDCError")
|
||||
}
|
||||
if result != oidcErr {
|
||||
t.Error("Expected to get the same OIDCError back")
|
||||
}
|
||||
|
||||
// Test with regular error
|
||||
regularErr := errors.New("regular error")
|
||||
result, ok = IsOIDCError(regularErr)
|
||||
if ok {
|
||||
t.Error("Expected IsOIDCError to return false for regular error")
|
||||
}
|
||||
if result != nil {
|
||||
t.Error("Expected nil result for regular error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetHTTPStatus(t *testing.T) {
|
||||
// Test with OIDCError
|
||||
oidcErr := &OIDCError{
|
||||
Code: ErrCodeTokenInvalid,
|
||||
HTTPStatus: http.StatusUnauthorized,
|
||||
}
|
||||
status := GetHTTPStatus(oidcErr)
|
||||
if status != http.StatusUnauthorized {
|
||||
t.Errorf("Expected %d, got %d", http.StatusUnauthorized, status)
|
||||
}
|
||||
|
||||
// Test with regular error
|
||||
regularErr := errors.New("regular error")
|
||||
status = GetHTTPStatus(regularErr)
|
||||
if status != http.StatusInternalServerError {
|
||||
t.Errorf("Expected %d, got %d", http.StatusInternalServerError, status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatUserMessage(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Domain not allowed",
|
||||
err: &OIDCError{Code: ErrCodeDomainNotAllowed},
|
||||
expected: "Your email domain is not authorized for this application",
|
||||
},
|
||||
{
|
||||
name: "User not allowed",
|
||||
err: &OIDCError{Code: ErrCodeUserNotAllowed},
|
||||
expected: "Your account is not authorized for this application",
|
||||
},
|
||||
{
|
||||
name: "Role not allowed",
|
||||
err: &OIDCError{Code: ErrCodeRoleNotAllowed},
|
||||
expected: "You do not have the required permissions for this application",
|
||||
},
|
||||
{
|
||||
name: "Session expired",
|
||||
err: &OIDCError{Code: ErrCodeSessionExpired},
|
||||
expected: "Your session has expired. Please log in again",
|
||||
},
|
||||
{
|
||||
name: "Token expired",
|
||||
err: &OIDCError{Code: ErrCodeTokenExpired},
|
||||
expected: "Your authentication has expired. Please log in again",
|
||||
},
|
||||
{
|
||||
name: "Provider unreachable",
|
||||
err: &OIDCError{Code: ErrCodeProviderUnreachable},
|
||||
expected: "Authentication service is temporarily unavailable. Please try again later",
|
||||
},
|
||||
{
|
||||
name: "Rate limited",
|
||||
err: &OIDCError{Code: ErrCodeRateLimited},
|
||||
expected: "Too many requests. Please wait a moment and try again",
|
||||
},
|
||||
{
|
||||
name: "Unknown OIDC error",
|
||||
err: &OIDCError{Code: ErrCodeConfigInvalid},
|
||||
expected: "Authentication failed. Please try again",
|
||||
},
|
||||
{
|
||||
name: "Regular error",
|
||||
err: errors.New("regular error"),
|
||||
expected: "An unexpected error occurred. Please try again",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := FormatUserMessage(tt.err)
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected '%s', got '%s'", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorCodes(t *testing.T) {
|
||||
// Test that all error codes are defined correctly
|
||||
codes := []ErrorCode{
|
||||
ErrCodeAuthenticationFailed,
|
||||
ErrCodeTokenExpired,
|
||||
ErrCodeTokenInvalid,
|
||||
ErrCodeSessionExpired,
|
||||
ErrCodeCSRFMismatch,
|
||||
ErrCodeNonceMismatch,
|
||||
ErrCodeConfigInvalid,
|
||||
ErrCodeProviderUnreachable,
|
||||
ErrCodeMetadataFailed,
|
||||
ErrCodeNetworkTimeout,
|
||||
ErrCodeRateLimited,
|
||||
ErrCodeServiceUnavailable,
|
||||
ErrCodeValidationFailed,
|
||||
ErrCodeDomainNotAllowed,
|
||||
ErrCodeUserNotAllowed,
|
||||
ErrCodeRoleNotAllowed,
|
||||
}
|
||||
|
||||
for _, code := range codes {
|
||||
if string(code) == "" {
|
||||
t.Errorf("Error code %v is empty", code)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorConstructorCompleteness(t *testing.T) {
|
||||
// Test each constructor function to ensure they set all required fields
|
||||
internalErr := errors.New("test error")
|
||||
|
||||
// Test NewAuthenticationError
|
||||
authErr := NewAuthenticationError(ErrCodeAuthenticationFailed, "auth message", internalErr)
|
||||
if authErr.Code == "" || authErr.Message == "" || authErr.HTTPStatus == 0 {
|
||||
t.Error("NewAuthenticationError did not set all required fields")
|
||||
}
|
||||
|
||||
// Test NewAuthorizationError
|
||||
authzErr := NewAuthorizationError(ErrCodeDomainNotAllowed, "authz message", "details")
|
||||
if authzErr.Code == "" || authzErr.Message == "" || authzErr.HTTPStatus == 0 {
|
||||
t.Error("NewAuthorizationError did not set all required fields")
|
||||
}
|
||||
|
||||
// Test NewConfigurationError
|
||||
configErr := NewConfigurationError(ErrCodeConfigInvalid, "config message", internalErr)
|
||||
if configErr.Code == "" || configErr.Message == "" || configErr.HTTPStatus == 0 {
|
||||
t.Error("NewConfigurationError did not set all required fields")
|
||||
}
|
||||
|
||||
// Test NewNetworkError
|
||||
netErr := NewNetworkError(ErrCodeNetworkTimeout, "network message", internalErr)
|
||||
if netErr.Code == "" || netErr.Message == "" || netErr.HTTPStatus == 0 {
|
||||
t.Error("NewNetworkError did not set all required fields")
|
||||
}
|
||||
|
||||
// Test NewValidationError
|
||||
validErr := NewValidationError(ErrCodeValidationFailed, "validation message", "details")
|
||||
if validErr.Code == "" || validErr.Message == "" || validErr.HTTPStatus == 0 {
|
||||
t.Error("NewValidationError did not set all required fields")
|
||||
}
|
||||
}
|
||||
@@ -1,224 +0,0 @@
|
||||
// Package handlers provides authentication flow management
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// AuthFlowHandler manages the complete OIDC authentication flow
|
||||
type AuthFlowHandler struct {
|
||||
sessionHandler *SessionHandler
|
||||
tokenHandler TokenHandler
|
||||
logger Logger
|
||||
excludedURLs map[string]struct{}
|
||||
initComplete chan struct{}
|
||||
issuerURL string
|
||||
}
|
||||
|
||||
// TokenHandler interface for token operations
|
||||
type TokenHandler interface {
|
||||
VerifyToken(token string) error
|
||||
RefreshToken(refreshToken string) (*TokenResponse, error)
|
||||
}
|
||||
|
||||
// TokenResponse represents token exchange response
|
||||
type TokenResponse struct {
|
||||
IDToken string `json:"id_token"`
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
}
|
||||
|
||||
// AuthFlowResult represents the result of authentication flow processing
|
||||
type AuthFlowResult struct {
|
||||
Authenticated bool
|
||||
RequiresAuth bool
|
||||
RequiresRefresh bool
|
||||
Error error
|
||||
RedirectURL string
|
||||
StatusCode int
|
||||
}
|
||||
|
||||
// NewAuthFlowHandler creates a new authentication flow handler
|
||||
func NewAuthFlowHandler(sessionHandler *SessionHandler, tokenHandler TokenHandler, logger Logger, excludedURLs map[string]struct{}, initComplete chan struct{}, issuerURL string) *AuthFlowHandler {
|
||||
return &AuthFlowHandler{
|
||||
sessionHandler: sessionHandler,
|
||||
tokenHandler: tokenHandler,
|
||||
logger: logger,
|
||||
excludedURLs: excludedURLs,
|
||||
initComplete: initComplete,
|
||||
issuerURL: issuerURL,
|
||||
}
|
||||
}
|
||||
|
||||
// ProcessRequest handles the main authentication flow
|
||||
func (h *AuthFlowHandler) ProcessRequest(rw http.ResponseWriter, req *http.Request) AuthFlowResult {
|
||||
// Check if URL should be excluded
|
||||
if h.shouldExcludeURL(req.URL.Path) {
|
||||
h.logger.Debugf("Request path %s excluded by configuration, bypassing OIDC", req.URL.Path)
|
||||
return AuthFlowResult{Authenticated: true}
|
||||
}
|
||||
|
||||
// Check for streaming requests
|
||||
if h.isStreamingRequest(req) {
|
||||
h.logger.Debugf("Streaming request detected, bypassing OIDC")
|
||||
return AuthFlowResult{Authenticated: true}
|
||||
}
|
||||
|
||||
// Wait for initialization
|
||||
if !h.waitForInitialization(req) {
|
||||
return AuthFlowResult{
|
||||
Error: ErrInitializationTimeout,
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
}
|
||||
}
|
||||
|
||||
// Get and validate session
|
||||
session, err := h.sessionHandler.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
h.logger.Errorf("Error getting session: %v", err)
|
||||
return AuthFlowResult{
|
||||
RequiresAuth: true,
|
||||
Error: err,
|
||||
}
|
||||
}
|
||||
defer session.ReturnToPoolSafely()
|
||||
|
||||
// Clean up old cookies
|
||||
h.sessionHandler.sessionManager.CleanupOldCookies(rw, req)
|
||||
|
||||
// Validate session
|
||||
validationResult := h.sessionHandler.ValidateSession(session)
|
||||
if !validationResult.Valid {
|
||||
if validationResult.NeedsAuth {
|
||||
return AuthFlowResult{RequiresAuth: true}
|
||||
}
|
||||
return AuthFlowResult{
|
||||
Error: ErrSessionInvalid,
|
||||
StatusCode: http.StatusUnauthorized,
|
||||
}
|
||||
}
|
||||
|
||||
// Check token validity and refresh if needed
|
||||
return h.validateAndRefreshTokens(session, req, rw)
|
||||
}
|
||||
|
||||
// shouldExcludeURL checks if a URL should bypass authentication
|
||||
func (h *AuthFlowHandler) shouldExcludeURL(path string) bool {
|
||||
for excludedURL := range h.excludedURLs {
|
||||
if len(path) >= len(excludedURL) && path[:len(excludedURL)] == excludedURL {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isStreamingRequest checks if request is a streaming request that should bypass auth
|
||||
func (h *AuthFlowHandler) isStreamingRequest(req *http.Request) bool {
|
||||
acceptHeader := req.Header.Get("Accept")
|
||||
return acceptHeader == "text/event-stream"
|
||||
}
|
||||
|
||||
// waitForInitialization waits for OIDC provider initialization
|
||||
func (h *AuthFlowHandler) waitForInitialization(req *http.Request) bool {
|
||||
select {
|
||||
case <-h.initComplete:
|
||||
if h.issuerURL == "" {
|
||||
h.logger.Error("OIDC provider metadata initialization failed")
|
||||
return false
|
||||
}
|
||||
return true
|
||||
case <-req.Context().Done():
|
||||
h.logger.Debug("Request canceled while waiting for OIDC initialization")
|
||||
return false
|
||||
case <-time.After(30 * time.Second):
|
||||
h.logger.Error("Timeout waiting for OIDC initialization")
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// validateAndRefreshTokens handles token validation and refresh logic
|
||||
func (h *AuthFlowHandler) validateAndRefreshTokens(session Session, req *http.Request, rw http.ResponseWriter) AuthFlowResult {
|
||||
// Check access token if present
|
||||
if accessToken := session.GetAccessToken(); accessToken != "" {
|
||||
if err := h.tokenHandler.VerifyToken(accessToken); err != nil {
|
||||
h.logger.Errorf("Access token validation failed: %v", err)
|
||||
|
||||
// Try refresh if refresh token is available
|
||||
if refreshToken := session.GetRefreshToken(); refreshToken != "" {
|
||||
return h.attemptTokenRefresh(session, req, rw)
|
||||
}
|
||||
|
||||
return AuthFlowResult{RequiresAuth: true}
|
||||
}
|
||||
}
|
||||
|
||||
// Check ID token
|
||||
if idToken := session.GetIDToken(); idToken != "" {
|
||||
if err := h.tokenHandler.VerifyToken(idToken); err != nil {
|
||||
h.logger.Errorf("ID token validation failed: %v", err)
|
||||
|
||||
// Try refresh if refresh token is available
|
||||
if refreshToken := session.GetRefreshToken(); refreshToken != "" {
|
||||
return h.attemptTokenRefresh(session, req, rw)
|
||||
}
|
||||
|
||||
return AuthFlowResult{RequiresAuth: true}
|
||||
}
|
||||
}
|
||||
|
||||
return AuthFlowResult{Authenticated: true}
|
||||
}
|
||||
|
||||
// attemptTokenRefresh tries to refresh tokens
|
||||
func (h *AuthFlowHandler) attemptTokenRefresh(session Session, req *http.Request, rw http.ResponseWriter) AuthFlowResult {
|
||||
refreshToken := session.GetRefreshToken()
|
||||
if refreshToken == "" {
|
||||
return AuthFlowResult{RequiresAuth: true}
|
||||
}
|
||||
|
||||
// Check if this is an AJAX request
|
||||
if h.sessionHandler.IsAjaxRequest(req) {
|
||||
return AuthFlowResult{
|
||||
Error: ErrSessionExpiredAjax,
|
||||
StatusCode: http.StatusUnauthorized,
|
||||
}
|
||||
}
|
||||
|
||||
_, err := h.tokenHandler.RefreshToken(refreshToken)
|
||||
if err != nil {
|
||||
h.logger.Errorf("Token refresh failed: %v", err)
|
||||
return AuthFlowResult{RequiresAuth: true}
|
||||
}
|
||||
|
||||
// Update session with new tokens would be handled here
|
||||
// Implementation depends on the actual session interface
|
||||
|
||||
if err := session.Save(req, rw); err != nil {
|
||||
h.logger.Errorf("Failed to save refreshed session: %v", err)
|
||||
return AuthFlowResult{
|
||||
Error: err,
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
}
|
||||
}
|
||||
|
||||
return AuthFlowResult{Authenticated: true}
|
||||
}
|
||||
|
||||
// Common errors
|
||||
var (
|
||||
ErrInitializationTimeout = &AuthFlowError{Code: "INIT_TIMEOUT", Message: "OIDC initialization timeout"}
|
||||
ErrSessionInvalid = &AuthFlowError{Code: "SESSION_INVALID", Message: "Invalid session"}
|
||||
ErrSessionExpiredAjax = &AuthFlowError{Code: "SESSION_EXPIRED_AJAX", Message: "Session expired for AJAX request"}
|
||||
)
|
||||
|
||||
// AuthFlowError represents authentication flow errors
|
||||
type AuthFlowError struct {
|
||||
Code string
|
||||
Message string
|
||||
}
|
||||
|
||||
func (e *AuthFlowError) Error() string {
|
||||
return e.Message
|
||||
}
|
||||
@@ -1,588 +0,0 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Mock implementations that embed SessionHandler
|
||||
type MockSessionHandlerWrapper struct {
|
||||
*SessionHandler
|
||||
}
|
||||
|
||||
func NewMockSessionHandlerWrapper() *MockSessionHandlerWrapper {
|
||||
sessionManager := &MockSessionManager{}
|
||||
logger := &MockLogger{}
|
||||
|
||||
sessionHandler := NewSessionHandler(
|
||||
sessionManager,
|
||||
logger,
|
||||
"/logout",
|
||||
"https://example.com/post-logout",
|
||||
"https://provider.example.com/logout",
|
||||
"test-client-id",
|
||||
)
|
||||
|
||||
return &MockSessionHandlerWrapper{
|
||||
SessionHandler: sessionHandler,
|
||||
}
|
||||
}
|
||||
|
||||
type MockSessionManager struct {
|
||||
session Session
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *MockSessionManager) GetSession(req *http.Request) (Session, error) {
|
||||
return m.session, m.err
|
||||
}
|
||||
|
||||
func (m *MockSessionManager) CleanupOldCookies(rw http.ResponseWriter, req *http.Request) {
|
||||
// Mock implementation
|
||||
}
|
||||
|
||||
type MockSession struct {
|
||||
authenticated bool
|
||||
email string
|
||||
idToken string
|
||||
accessToken string
|
||||
refreshToken string
|
||||
saveError error
|
||||
clearError error
|
||||
}
|
||||
|
||||
func (m *MockSession) GetAuthenticated() bool { return m.authenticated }
|
||||
func (m *MockSession) SetAuthenticated(auth bool) error { m.authenticated = auth; return nil }
|
||||
func (m *MockSession) GetEmail() string { return m.email }
|
||||
func (m *MockSession) SetEmail(email string) { m.email = email }
|
||||
func (m *MockSession) GetIDToken() string { return m.idToken }
|
||||
func (m *MockSession) GetAccessToken() string { return m.accessToken }
|
||||
func (m *MockSession) GetRefreshToken() string { return m.refreshToken }
|
||||
func (m *MockSession) SetRefreshToken(token string) { m.refreshToken = token }
|
||||
func (m *MockSession) Clear(req *http.Request, rw http.ResponseWriter) error { return m.clearError }
|
||||
func (m *MockSession) Save(req *http.Request, rw http.ResponseWriter) error { return m.saveError }
|
||||
func (m *MockSession) ReturnToPoolSafely() {}
|
||||
|
||||
type MockTokenHandler struct {
|
||||
verifyError error
|
||||
refreshError error
|
||||
tokenResponse *TokenResponse
|
||||
}
|
||||
|
||||
func (m *MockTokenHandler) VerifyToken(token string) error {
|
||||
return m.verifyError
|
||||
}
|
||||
|
||||
func (m *MockTokenHandler) RefreshToken(refreshToken string) (*TokenResponse, error) {
|
||||
return m.tokenResponse, m.refreshError
|
||||
}
|
||||
|
||||
type MockLogger struct {
|
||||
debugMessages []string
|
||||
errorMessages []string
|
||||
}
|
||||
|
||||
func (m *MockLogger) Debug(msg string) {
|
||||
m.debugMessages = append(m.debugMessages, msg)
|
||||
}
|
||||
|
||||
func (m *MockLogger) Debugf(format string, args ...interface{}) {
|
||||
m.debugMessages = append(m.debugMessages, format)
|
||||
}
|
||||
|
||||
func (m *MockLogger) Info(msg string) {}
|
||||
|
||||
func (m *MockLogger) Infof(format string, args ...interface{}) {}
|
||||
|
||||
func (m *MockLogger) Error(msg string) {
|
||||
m.errorMessages = append(m.errorMessages, msg)
|
||||
}
|
||||
|
||||
func (m *MockLogger) Errorf(format string, args ...interface{}) {
|
||||
m.errorMessages = append(m.errorMessages, format)
|
||||
}
|
||||
|
||||
func TestNewAuthFlowHandler(t *testing.T) {
|
||||
sessionHandler := NewMockSessionHandlerWrapper()
|
||||
tokenHandler := &MockTokenHandler{}
|
||||
logger := &MockLogger{}
|
||||
excludedURLs := map[string]struct{}{"/health": {}}
|
||||
initComplete := make(chan struct{})
|
||||
issuerURL := "https://issuer.example.com"
|
||||
|
||||
handler := NewAuthFlowHandler(sessionHandler.SessionHandler, tokenHandler, logger, excludedURLs, initComplete, issuerURL)
|
||||
|
||||
if handler == nil {
|
||||
t.Fatal("NewAuthFlowHandler returned nil")
|
||||
}
|
||||
|
||||
if handler.sessionHandler == nil {
|
||||
t.Error("SessionHandler not set correctly")
|
||||
}
|
||||
|
||||
if handler.tokenHandler != tokenHandler {
|
||||
t.Error("TokenHandler not set correctly")
|
||||
}
|
||||
|
||||
if handler.logger != logger {
|
||||
t.Error("Logger not set correctly")
|
||||
}
|
||||
|
||||
if handler.issuerURL != issuerURL {
|
||||
t.Error("IssuerURL not set correctly")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthFlowHandler_shouldExcludeURL(t *testing.T) {
|
||||
excludedURLs := map[string]struct{}{
|
||||
"/health": {},
|
||||
"/metrics": {},
|
||||
"/api/public": {},
|
||||
}
|
||||
|
||||
handler := &AuthFlowHandler{excludedURLs: excludedURLs}
|
||||
|
||||
tests := []struct {
|
||||
path string
|
||||
expected bool
|
||||
}{
|
||||
{"/health", true},
|
||||
{"/health/check", true},
|
||||
{"/metrics", true},
|
||||
{"/metrics/prometheus", true},
|
||||
{"/api/public", true},
|
||||
{"/api/public/endpoint", true},
|
||||
{"/api/private", false},
|
||||
{"/login", false},
|
||||
{"/dashboard", false},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
result := handler.shouldExcludeURL(test.path)
|
||||
if result != test.expected {
|
||||
t.Errorf("For path '%s': expected %v, got %v", test.path, test.expected, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthFlowHandler_isStreamingRequest(t *testing.T) {
|
||||
handler := &AuthFlowHandler{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
accept string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "SSE request",
|
||||
accept: "text/event-stream",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Regular HTML request",
|
||||
accept: "text/html,application/xhtml+xml",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "JSON request",
|
||||
accept: "application/json",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Empty accept header",
|
||||
accept: "",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.Header.Set("Accept", test.accept)
|
||||
|
||||
result := handler.isStreamingRequest(req)
|
||||
if result != test.expected {
|
||||
t.Errorf("Expected %v, got %v", test.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthFlowHandler_waitForInitialization(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupHandler func() (*AuthFlowHandler, context.CancelFunc)
|
||||
expectedResult bool
|
||||
}{
|
||||
{
|
||||
name: "Initialization complete successfully",
|
||||
setupHandler: func() (*AuthFlowHandler, context.CancelFunc) {
|
||||
initComplete := make(chan struct{})
|
||||
close(initComplete) // Already complete
|
||||
handler := &AuthFlowHandler{
|
||||
initComplete: initComplete,
|
||||
issuerURL: "https://issuer.example.com",
|
||||
}
|
||||
return handler, nil
|
||||
},
|
||||
expectedResult: true,
|
||||
},
|
||||
{
|
||||
name: "Initialization complete but no issuer URL",
|
||||
setupHandler: func() (*AuthFlowHandler, context.CancelFunc) {
|
||||
initComplete := make(chan struct{})
|
||||
close(initComplete)
|
||||
handler := &AuthFlowHandler{
|
||||
initComplete: initComplete,
|
||||
issuerURL: "",
|
||||
logger: &MockLogger{},
|
||||
}
|
||||
return handler, nil
|
||||
},
|
||||
expectedResult: false,
|
||||
},
|
||||
{
|
||||
name: "Request canceled",
|
||||
setupHandler: func() (*AuthFlowHandler, context.CancelFunc) {
|
||||
initComplete := make(chan struct{})
|
||||
handler := &AuthFlowHandler{
|
||||
initComplete: initComplete,
|
||||
logger: &MockLogger{},
|
||||
}
|
||||
_, cancel := context.WithCancel(context.Background())
|
||||
return handler, cancel
|
||||
},
|
||||
expectedResult: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
handler, cancelFunc := test.setupHandler()
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
if cancelFunc != nil {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
req = req.WithContext(ctx)
|
||||
cancel() // Cancel immediately
|
||||
}
|
||||
|
||||
result := handler.waitForInitialization(req)
|
||||
if result != test.expectedResult {
|
||||
t.Errorf("Expected %v, got %v", test.expectedResult, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthFlowHandler_ProcessRequest(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupRequest func() *http.Request
|
||||
setupHandler func() *AuthFlowHandler
|
||||
expectedResult AuthFlowResult
|
||||
}{
|
||||
{
|
||||
name: "Excluded URL bypasses authentication",
|
||||
setupRequest: func() *http.Request {
|
||||
return httptest.NewRequest("GET", "/health", nil)
|
||||
},
|
||||
setupHandler: func() *AuthFlowHandler {
|
||||
return &AuthFlowHandler{
|
||||
excludedURLs: map[string]struct{}{"/health": {}},
|
||||
logger: &MockLogger{},
|
||||
}
|
||||
},
|
||||
expectedResult: AuthFlowResult{Authenticated: true},
|
||||
},
|
||||
{
|
||||
name: "Streaming request bypasses authentication",
|
||||
setupRequest: func() *http.Request {
|
||||
req := httptest.NewRequest("GET", "/events", nil)
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
return req
|
||||
},
|
||||
setupHandler: func() *AuthFlowHandler {
|
||||
return &AuthFlowHandler{
|
||||
excludedURLs: map[string]struct{}{},
|
||||
logger: &MockLogger{},
|
||||
}
|
||||
},
|
||||
expectedResult: AuthFlowResult{Authenticated: true},
|
||||
},
|
||||
{
|
||||
name: "Initialization timeout",
|
||||
setupRequest: func() *http.Request {
|
||||
return httptest.NewRequest("GET", "/dashboard", nil)
|
||||
},
|
||||
setupHandler: func() *AuthFlowHandler {
|
||||
return &AuthFlowHandler{
|
||||
excludedURLs: map[string]struct{}{},
|
||||
initComplete: make(chan struct{}), // Never closes
|
||||
logger: &MockLogger{},
|
||||
}
|
||||
},
|
||||
expectedResult: AuthFlowResult{
|
||||
Error: ErrInitializationTimeout,
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
req := test.setupRequest()
|
||||
handler := test.setupHandler()
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
// For timeout test, use context with timeout
|
||||
if test.name == "Initialization timeout" {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
|
||||
defer cancel()
|
||||
req = req.WithContext(ctx)
|
||||
}
|
||||
|
||||
result := handler.ProcessRequest(rw, req)
|
||||
|
||||
if result.Authenticated != test.expectedResult.Authenticated {
|
||||
t.Errorf("Expected Authenticated %v, got %v", test.expectedResult.Authenticated, result.Authenticated)
|
||||
}
|
||||
|
||||
if result.StatusCode != test.expectedResult.StatusCode {
|
||||
t.Errorf("Expected StatusCode %d, got %d", test.expectedResult.StatusCode, result.StatusCode)
|
||||
}
|
||||
|
||||
if test.expectedResult.Error != nil && result.Error == nil {
|
||||
t.Error("Expected error but got nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthFlowHandler_validateAndRefreshTokens(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
session *MockSession
|
||||
tokenHandler *MockTokenHandler
|
||||
expectedResult AuthFlowResult
|
||||
}{
|
||||
{
|
||||
name: "Valid access token",
|
||||
session: &MockSession{
|
||||
authenticated: true,
|
||||
accessToken: "valid-access-token",
|
||||
},
|
||||
tokenHandler: &MockTokenHandler{
|
||||
verifyError: nil,
|
||||
},
|
||||
expectedResult: AuthFlowResult{Authenticated: true},
|
||||
},
|
||||
{
|
||||
name: "Invalid access token, successful refresh",
|
||||
session: &MockSession{
|
||||
authenticated: true,
|
||||
accessToken: "invalid-access-token",
|
||||
refreshToken: "valid-refresh-token",
|
||||
},
|
||||
tokenHandler: &MockTokenHandler{
|
||||
verifyError: errors.New("token expired"),
|
||||
refreshError: nil,
|
||||
tokenResponse: &TokenResponse{
|
||||
IDToken: "new-id-token",
|
||||
AccessToken: "new-access-token",
|
||||
},
|
||||
},
|
||||
expectedResult: AuthFlowResult{Authenticated: true},
|
||||
},
|
||||
{
|
||||
name: "Invalid access token, no refresh token",
|
||||
session: &MockSession{
|
||||
authenticated: true,
|
||||
accessToken: "invalid-access-token",
|
||||
refreshToken: "",
|
||||
},
|
||||
tokenHandler: &MockTokenHandler{
|
||||
verifyError: errors.New("token expired"),
|
||||
},
|
||||
expectedResult: AuthFlowResult{RequiresAuth: true},
|
||||
},
|
||||
{
|
||||
name: "Valid ID token only",
|
||||
session: &MockSession{
|
||||
authenticated: true,
|
||||
idToken: "valid-id-token",
|
||||
},
|
||||
tokenHandler: &MockTokenHandler{
|
||||
verifyError: nil,
|
||||
},
|
||||
expectedResult: AuthFlowResult{Authenticated: true},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
handler := &AuthFlowHandler{
|
||||
tokenHandler: test.tokenHandler,
|
||||
logger: &MockLogger{},
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
result := handler.validateAndRefreshTokens(test.session, req, rw)
|
||||
|
||||
if result.Authenticated != test.expectedResult.Authenticated {
|
||||
t.Errorf("Expected Authenticated %v, got %v", test.expectedResult.Authenticated, result.Authenticated)
|
||||
}
|
||||
|
||||
if result.RequiresAuth != test.expectedResult.RequiresAuth {
|
||||
t.Errorf("Expected RequiresAuth %v, got %v", test.expectedResult.RequiresAuth, result.RequiresAuth)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthFlowHandler_attemptTokenRefresh(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
session *MockSession
|
||||
tokenHandler *MockTokenHandler
|
||||
isAjax bool
|
||||
expectedResult AuthFlowResult
|
||||
}{
|
||||
{
|
||||
name: "No refresh token",
|
||||
session: &MockSession{
|
||||
refreshToken: "",
|
||||
},
|
||||
tokenHandler: &MockTokenHandler{},
|
||||
expectedResult: AuthFlowResult{RequiresAuth: true},
|
||||
},
|
||||
{
|
||||
name: "AJAX request with expired session",
|
||||
session: &MockSession{
|
||||
refreshToken: "refresh-token",
|
||||
},
|
||||
tokenHandler: &MockTokenHandler{},
|
||||
isAjax: true,
|
||||
expectedResult: AuthFlowResult{
|
||||
Error: ErrSessionExpiredAjax,
|
||||
StatusCode: http.StatusUnauthorized,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Successful token refresh",
|
||||
session: &MockSession{
|
||||
refreshToken: "valid-refresh-token",
|
||||
},
|
||||
tokenHandler: &MockTokenHandler{
|
||||
refreshError: nil,
|
||||
tokenResponse: &TokenResponse{
|
||||
IDToken: "new-id-token",
|
||||
AccessToken: "new-access-token",
|
||||
},
|
||||
},
|
||||
expectedResult: AuthFlowResult{Authenticated: true},
|
||||
},
|
||||
{
|
||||
name: "Failed token refresh",
|
||||
session: &MockSession{
|
||||
refreshToken: "invalid-refresh-token",
|
||||
},
|
||||
tokenHandler: &MockTokenHandler{
|
||||
refreshError: errors.New("refresh failed"),
|
||||
},
|
||||
expectedResult: AuthFlowResult{RequiresAuth: true},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
sessionHandlerWrapper := NewMockSessionHandlerWrapper()
|
||||
handler := &AuthFlowHandler{
|
||||
sessionHandler: sessionHandlerWrapper.SessionHandler,
|
||||
tokenHandler: test.tokenHandler,
|
||||
logger: &MockLogger{},
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
if test.isAjax {
|
||||
req.Header.Set("X-Requested-With", "XMLHttpRequest")
|
||||
}
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
result := handler.attemptTokenRefresh(test.session, req, rw)
|
||||
|
||||
if result.Authenticated != test.expectedResult.Authenticated {
|
||||
t.Errorf("Expected Authenticated %v, got %v", test.expectedResult.Authenticated, result.Authenticated)
|
||||
}
|
||||
|
||||
if result.RequiresAuth != test.expectedResult.RequiresAuth {
|
||||
t.Errorf("Expected RequiresAuth %v, got %v", test.expectedResult.RequiresAuth, result.RequiresAuth)
|
||||
}
|
||||
|
||||
if result.StatusCode != test.expectedResult.StatusCode {
|
||||
t.Errorf("Expected StatusCode %d, got %d", test.expectedResult.StatusCode, result.StatusCode)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthFlowError_Error(t *testing.T) {
|
||||
err := &AuthFlowError{
|
||||
Code: "TEST_ERROR",
|
||||
Message: "This is a test error",
|
||||
}
|
||||
|
||||
expected := "This is a test error"
|
||||
result := err.Error()
|
||||
|
||||
if result != expected {
|
||||
t.Errorf("Expected '%s', got '%s'", expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthFlowResult(t *testing.T) {
|
||||
// Test AuthFlowResult struct
|
||||
result := AuthFlowResult{
|
||||
Authenticated: true,
|
||||
RequiresAuth: false,
|
||||
RequiresRefresh: false,
|
||||
Error: nil,
|
||||
RedirectURL: "https://example.com",
|
||||
StatusCode: 200,
|
||||
}
|
||||
|
||||
if !result.Authenticated {
|
||||
t.Error("Expected Authenticated to be true")
|
||||
}
|
||||
|
||||
if result.RequiresAuth {
|
||||
t.Error("Expected RequiresAuth to be false")
|
||||
}
|
||||
|
||||
if result.StatusCode != 200 {
|
||||
t.Errorf("Expected StatusCode 200, got %d", result.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenResponse(t *testing.T) {
|
||||
response := &TokenResponse{
|
||||
IDToken: "id-token-value",
|
||||
AccessToken: "access-token-value",
|
||||
RefreshToken: "refresh-token-value",
|
||||
ExpiresIn: 3600,
|
||||
}
|
||||
|
||||
if response.IDToken != "id-token-value" {
|
||||
t.Errorf("Expected IDToken 'id-token-value', got '%s'", response.IDToken)
|
||||
}
|
||||
|
||||
if response.ExpiresIn != 3600 {
|
||||
t.Errorf("Expected ExpiresIn 3600, got %d", response.ExpiresIn)
|
||||
}
|
||||
}
|
||||
@@ -1,247 +0,0 @@
|
||||
// Package handlers provides HTTP request handlers for OIDC operations
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// SessionHandler manages session-related HTTP operations
|
||||
type SessionHandler struct {
|
||||
sessionManager SessionManager
|
||||
logger Logger
|
||||
logoutURLPath string
|
||||
postLogoutRedirectURI string
|
||||
endSessionURL string
|
||||
clientID string
|
||||
}
|
||||
|
||||
// SessionManager interface for session operations
|
||||
type SessionManager interface {
|
||||
GetSession(req *http.Request) (Session, error)
|
||||
CleanupOldCookies(rw http.ResponseWriter, req *http.Request)
|
||||
}
|
||||
|
||||
// Session interface for session data
|
||||
type Session interface {
|
||||
GetAuthenticated() bool
|
||||
SetAuthenticated(bool) error
|
||||
GetEmail() string
|
||||
SetEmail(string)
|
||||
GetIDToken() string
|
||||
GetAccessToken() string
|
||||
GetRefreshToken() string
|
||||
SetRefreshToken(string)
|
||||
Clear(req *http.Request, rw http.ResponseWriter) error
|
||||
Save(req *http.Request, rw http.ResponseWriter) error
|
||||
ReturnToPoolSafely()
|
||||
}
|
||||
|
||||
// Logger interface for logging operations
|
||||
type Logger interface {
|
||||
Debug(msg string)
|
||||
Debugf(format string, args ...interface{})
|
||||
Info(msg string)
|
||||
Infof(format string, args ...interface{})
|
||||
Error(msg string)
|
||||
Errorf(format string, args ...interface{})
|
||||
}
|
||||
|
||||
// NewSessionHandler creates a new session handler
|
||||
func NewSessionHandler(sessionManager SessionManager, logger Logger, logoutURLPath, postLogoutRedirectURI, endSessionURL, clientID string) *SessionHandler {
|
||||
return &SessionHandler{
|
||||
sessionManager: sessionManager,
|
||||
logger: logger,
|
||||
logoutURLPath: logoutURLPath,
|
||||
postLogoutRedirectURI: postLogoutRedirectURI,
|
||||
endSessionURL: endSessionURL,
|
||||
clientID: clientID,
|
||||
}
|
||||
}
|
||||
|
||||
// HandleLogout processes logout requests
|
||||
func (h *SessionHandler) HandleLogout(rw http.ResponseWriter, req *http.Request) {
|
||||
h.logger.Debug("Processing logout request")
|
||||
|
||||
session, err := h.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
h.logger.Errorf("Error getting session during logout: %v", err)
|
||||
// Continue with logout even if session retrieval fails
|
||||
}
|
||||
|
||||
var idToken string
|
||||
if session != nil {
|
||||
defer session.ReturnToPoolSafely()
|
||||
idToken = session.GetIDToken()
|
||||
|
||||
// Clear the session
|
||||
if err := session.Clear(req, rw); err != nil {
|
||||
h.logger.Errorf("Error clearing session during logout: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Build logout URL
|
||||
logoutURL := h.buildLogoutURL(idToken)
|
||||
|
||||
h.logger.Debugf("Redirecting to logout URL: %s", logoutURL)
|
||||
http.Redirect(rw, req, logoutURL, http.StatusFound)
|
||||
}
|
||||
|
||||
// buildLogoutURL constructs the provider logout URL
|
||||
func (h *SessionHandler) buildLogoutURL(idToken string) string {
|
||||
if h.endSessionURL == "" {
|
||||
// If no end session URL, redirect to post-logout redirect URI
|
||||
return h.postLogoutRedirectURI
|
||||
}
|
||||
|
||||
logoutURL := h.endSessionURL
|
||||
|
||||
// Add query parameters
|
||||
params := make([]string, 0, 3)
|
||||
|
||||
if idToken != "" {
|
||||
params = append(params, fmt.Sprintf("id_token_hint=%s", idToken))
|
||||
}
|
||||
|
||||
if h.postLogoutRedirectURI != "" {
|
||||
params = append(params, fmt.Sprintf("post_logout_redirect_uri=%s", h.postLogoutRedirectURI))
|
||||
}
|
||||
|
||||
if h.clientID != "" {
|
||||
params = append(params, fmt.Sprintf("client_id=%s", h.clientID))
|
||||
}
|
||||
|
||||
if len(params) > 0 {
|
||||
separator := "?"
|
||||
if strings.Contains(logoutURL, "?") {
|
||||
separator = "&"
|
||||
}
|
||||
logoutURL += separator + strings.Join(params, "&")
|
||||
}
|
||||
|
||||
return logoutURL
|
||||
}
|
||||
|
||||
// ValidateSession checks if a session is valid and authenticated
|
||||
func (h *SessionHandler) ValidateSession(session Session) SessionValidationResult {
|
||||
if session == nil {
|
||||
return SessionValidationResult{
|
||||
Valid: false,
|
||||
NeedsAuth: true,
|
||||
ErrorMessage: "session is nil",
|
||||
}
|
||||
}
|
||||
|
||||
if !session.GetAuthenticated() {
|
||||
return SessionValidationResult{
|
||||
Valid: false,
|
||||
NeedsAuth: true,
|
||||
ErrorMessage: "session not authenticated",
|
||||
}
|
||||
}
|
||||
|
||||
email := session.GetEmail()
|
||||
if email == "" {
|
||||
return SessionValidationResult{
|
||||
Valid: false,
|
||||
NeedsAuth: true,
|
||||
ErrorMessage: "no email in session",
|
||||
}
|
||||
}
|
||||
|
||||
return SessionValidationResult{
|
||||
Valid: true,
|
||||
NeedsAuth: false,
|
||||
}
|
||||
}
|
||||
|
||||
// SessionValidationResult represents the result of session validation
|
||||
type SessionValidationResult struct {
|
||||
Valid bool
|
||||
NeedsAuth bool
|
||||
ErrorMessage string
|
||||
}
|
||||
|
||||
// CleanupExpiredSession clears an expired session
|
||||
func (h *SessionHandler) CleanupExpiredSession(rw http.ResponseWriter, req *http.Request, session Session) error {
|
||||
h.logger.Debug("Cleaning up expired session")
|
||||
|
||||
if session == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Clear all session data
|
||||
if err := session.SetAuthenticated(false); err != nil {
|
||||
h.logger.Errorf("Failed to set authenticated to false: %v", err)
|
||||
}
|
||||
|
||||
session.SetEmail("")
|
||||
session.SetRefreshToken("")
|
||||
|
||||
// Save the cleared session
|
||||
if err := session.Save(req, rw); err != nil {
|
||||
h.logger.Errorf("Failed to save cleared session: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsAjaxRequest determines if the request is an AJAX/XHR request
|
||||
func (h *SessionHandler) IsAjaxRequest(req *http.Request) bool {
|
||||
// Check X-Requested-With header (commonly used by jQuery and other libraries)
|
||||
if req.Header.Get("X-Requested-With") == "XMLHttpRequest" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check Accept header for JSON preference
|
||||
accept := req.Header.Get("Accept")
|
||||
if strings.Contains(accept, "application/json") && !strings.Contains(accept, "text/html") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for fetch API indication
|
||||
if req.Header.Get("Sec-Fetch-Mode") == "cors" {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// SendErrorResponse sends an appropriate error response based on request type
|
||||
func (h *SessionHandler) SendErrorResponse(rw http.ResponseWriter, req *http.Request, message string, statusCode int) {
|
||||
if h.IsAjaxRequest(req) {
|
||||
// For AJAX requests, send JSON response
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
rw.WriteHeader(statusCode)
|
||||
_, _ = fmt.Fprintf(rw, `{"error": "%s"}`, message) // Safe to ignore: writing error response
|
||||
} else {
|
||||
// For browser requests, send HTML response
|
||||
rw.Header().Set("Content-Type", "text/html")
|
||||
rw.WriteHeader(statusCode)
|
||||
_, _ = fmt.Fprintf(rw, `<html><body><h1>Error %d</h1><p>%s</p></body></html>`, statusCode, message) // Safe to ignore: writing error response
|
||||
}
|
||||
}
|
||||
|
||||
// SetSecurityHeaders sets standard security headers
|
||||
func (h *SessionHandler) SetSecurityHeaders(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.Header().Set("X-Frame-Options", "DENY")
|
||||
rw.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
rw.Header().Set("X-XSS-Protection", "1; mode=block")
|
||||
rw.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
|
||||
|
||||
// Handle CORS for AJAX requests
|
||||
origin := req.Header.Get("Origin")
|
||||
if origin != "" {
|
||||
rw.Header().Set("Access-Control-Allow-Origin", origin)
|
||||
rw.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
rw.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
|
||||
rw.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
|
||||
|
||||
if req.Method == "OPTIONS" {
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,587 +0,0 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewSessionHandler(t *testing.T) {
|
||||
sessionManager := &MockSessionManager{}
|
||||
logger := &MockLogger{}
|
||||
logoutURLPath := "/logout"
|
||||
postLogoutRedirectURI := "https://example.com/post-logout"
|
||||
endSessionURL := "https://provider.example.com/logout"
|
||||
clientID := "test-client-id"
|
||||
|
||||
handler := NewSessionHandler(
|
||||
sessionManager,
|
||||
logger,
|
||||
logoutURLPath,
|
||||
postLogoutRedirectURI,
|
||||
endSessionURL,
|
||||
clientID,
|
||||
)
|
||||
|
||||
if handler == nil {
|
||||
t.Fatal("NewSessionHandler returned nil")
|
||||
}
|
||||
|
||||
if handler.sessionManager != sessionManager {
|
||||
t.Error("SessionManager not set correctly")
|
||||
}
|
||||
|
||||
if handler.logger != logger {
|
||||
t.Error("Logger not set correctly")
|
||||
}
|
||||
|
||||
if handler.logoutURLPath != logoutURLPath {
|
||||
t.Error("LogoutURLPath not set correctly")
|
||||
}
|
||||
|
||||
if handler.postLogoutRedirectURI != postLogoutRedirectURI {
|
||||
t.Error("PostLogoutRedirectURI not set correctly")
|
||||
}
|
||||
|
||||
if handler.endSessionURL != endSessionURL {
|
||||
t.Error("EndSessionURL not set correctly")
|
||||
}
|
||||
|
||||
if handler.clientID != clientID {
|
||||
t.Error("ClientID not set correctly")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionHandler_HandleLogout(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupSession func() *MockSession
|
||||
setupManager func() *MockSessionManager
|
||||
expectedCode int
|
||||
expectedURL string
|
||||
}{
|
||||
{
|
||||
name: "Successful logout with ID token",
|
||||
setupSession: func() *MockSession {
|
||||
return &MockSession{
|
||||
authenticated: true,
|
||||
idToken: "test-id-token",
|
||||
}
|
||||
},
|
||||
setupManager: func() *MockSessionManager {
|
||||
return &MockSessionManager{
|
||||
session: &MockSession{
|
||||
authenticated: true,
|
||||
idToken: "test-id-token",
|
||||
},
|
||||
}
|
||||
},
|
||||
expectedCode: http.StatusFound,
|
||||
expectedURL: "https://provider.example.com/logout?id_token_hint=test-id-token&post_logout_redirect_uri=https://example.com/post-logout&client_id=test-client-id",
|
||||
},
|
||||
{
|
||||
name: "Logout without ID token",
|
||||
setupSession: func() *MockSession {
|
||||
return &MockSession{
|
||||
authenticated: true,
|
||||
idToken: "",
|
||||
}
|
||||
},
|
||||
setupManager: func() *MockSessionManager {
|
||||
return &MockSessionManager{
|
||||
session: &MockSession{
|
||||
authenticated: true,
|
||||
idToken: "",
|
||||
},
|
||||
}
|
||||
},
|
||||
expectedCode: http.StatusFound,
|
||||
expectedURL: "https://provider.example.com/logout?post_logout_redirect_uri=https://example.com/post-logout&client_id=test-client-id",
|
||||
},
|
||||
{
|
||||
name: "Session retrieval error",
|
||||
setupSession: func() *MockSession { return nil },
|
||||
setupManager: func() *MockSessionManager {
|
||||
return &MockSessionManager{
|
||||
err: fmt.Errorf("session error"),
|
||||
}
|
||||
},
|
||||
expectedCode: http.StatusFound,
|
||||
expectedURL: "https://provider.example.com/logout?post_logout_redirect_uri=https://example.com/post-logout&client_id=test-client-id",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
handler := &SessionHandler{
|
||||
sessionManager: test.setupManager(),
|
||||
logger: &MockLogger{},
|
||||
logoutURLPath: "/logout",
|
||||
postLogoutRedirectURI: "https://example.com/post-logout",
|
||||
endSessionURL: "https://provider.example.com/logout",
|
||||
clientID: "test-client-id",
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/logout", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleLogout(rw, req)
|
||||
|
||||
if rw.Code != test.expectedCode {
|
||||
t.Errorf("Expected status code %d, got %d", test.expectedCode, rw.Code)
|
||||
}
|
||||
|
||||
location := rw.Header().Get("Location")
|
||||
if location != test.expectedURL {
|
||||
t.Errorf("Expected location '%s', got '%s'", test.expectedURL, location)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionHandler_buildLogoutURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
endSessionURL string
|
||||
postLogoutRedirectURI string
|
||||
clientID string
|
||||
idToken string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Complete logout URL with all parameters",
|
||||
endSessionURL: "https://provider.example.com/logout",
|
||||
postLogoutRedirectURI: "https://example.com/post-logout",
|
||||
clientID: "test-client-id",
|
||||
idToken: "test-id-token",
|
||||
expected: "https://provider.example.com/logout?id_token_hint=test-id-token&post_logout_redirect_uri=https://example.com/post-logout&client_id=test-client-id",
|
||||
},
|
||||
{
|
||||
name: "Logout URL without ID token",
|
||||
endSessionURL: "https://provider.example.com/logout",
|
||||
postLogoutRedirectURI: "https://example.com/post-logout",
|
||||
clientID: "test-client-id",
|
||||
idToken: "",
|
||||
expected: "https://provider.example.com/logout?post_logout_redirect_uri=https://example.com/post-logout&client_id=test-client-id",
|
||||
},
|
||||
{
|
||||
name: "No end session URL",
|
||||
endSessionURL: "",
|
||||
postLogoutRedirectURI: "https://example.com/post-logout",
|
||||
clientID: "test-client-id",
|
||||
idToken: "test-id-token",
|
||||
expected: "https://example.com/post-logout",
|
||||
},
|
||||
{
|
||||
name: "End session URL with existing query parameters",
|
||||
endSessionURL: "https://provider.example.com/logout?foo=bar",
|
||||
postLogoutRedirectURI: "https://example.com/post-logout",
|
||||
clientID: "test-client-id",
|
||||
idToken: "",
|
||||
expected: "https://provider.example.com/logout?foo=bar&post_logout_redirect_uri=https://example.com/post-logout&client_id=test-client-id",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
handler := &SessionHandler{
|
||||
endSessionURL: test.endSessionURL,
|
||||
postLogoutRedirectURI: test.postLogoutRedirectURI,
|
||||
clientID: test.clientID,
|
||||
}
|
||||
|
||||
result := handler.buildLogoutURL(test.idToken)
|
||||
if result != test.expected {
|
||||
t.Errorf("Expected '%s', got '%s'", test.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionHandler_ValidateSession(t *testing.T) {
|
||||
handler := &SessionHandler{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
session Session
|
||||
expectedValid bool
|
||||
expectedAuth bool
|
||||
expectedMessage string
|
||||
}{
|
||||
{
|
||||
name: "Nil session",
|
||||
session: nil,
|
||||
expectedValid: false,
|
||||
expectedAuth: true,
|
||||
expectedMessage: "session is nil",
|
||||
},
|
||||
{
|
||||
name: "Not authenticated session",
|
||||
session: &MockSession{
|
||||
authenticated: false,
|
||||
},
|
||||
expectedValid: false,
|
||||
expectedAuth: true,
|
||||
expectedMessage: "session not authenticated",
|
||||
},
|
||||
{
|
||||
name: "Authenticated session without email",
|
||||
session: &MockSession{
|
||||
authenticated: true,
|
||||
email: "",
|
||||
},
|
||||
expectedValid: false,
|
||||
expectedAuth: true,
|
||||
expectedMessage: "no email in session",
|
||||
},
|
||||
{
|
||||
name: "Valid authenticated session with email",
|
||||
session: &MockSession{
|
||||
authenticated: true,
|
||||
email: "user@example.com",
|
||||
},
|
||||
expectedValid: true,
|
||||
expectedAuth: false,
|
||||
expectedMessage: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
result := handler.ValidateSession(test.session)
|
||||
|
||||
if result.Valid != test.expectedValid {
|
||||
t.Errorf("Expected Valid %v, got %v", test.expectedValid, result.Valid)
|
||||
}
|
||||
|
||||
if result.NeedsAuth != test.expectedAuth {
|
||||
t.Errorf("Expected NeedsAuth %v, got %v", test.expectedAuth, result.NeedsAuth)
|
||||
}
|
||||
|
||||
if result.ErrorMessage != test.expectedMessage {
|
||||
t.Errorf("Expected ErrorMessage '%s', got '%s'", test.expectedMessage, result.ErrorMessage)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionHandler_CleanupExpiredSession(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
session *MockSession
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "Successful cleanup",
|
||||
session: &MockSession{
|
||||
authenticated: true,
|
||||
email: "user@example.com",
|
||||
refreshToken: "refresh-token",
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Save error during cleanup",
|
||||
session: &MockSession{
|
||||
authenticated: true,
|
||||
email: "user@example.com",
|
||||
refreshToken: "refresh-token",
|
||||
saveError: fmt.Errorf("save failed"),
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
handler := &SessionHandler{
|
||||
logger: &MockLogger{},
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
err := handler.CleanupExpiredSession(rw, req, test.session)
|
||||
|
||||
if test.expectError && err == nil {
|
||||
t.Error("Expected error but got nil")
|
||||
}
|
||||
|
||||
if !test.expectError && err != nil {
|
||||
t.Errorf("Expected no error but got: %v", err)
|
||||
}
|
||||
|
||||
if test.session != nil && !test.expectError {
|
||||
if test.session.authenticated {
|
||||
t.Error("Expected session authenticated to be false after cleanup")
|
||||
}
|
||||
|
||||
if test.session.email != "" {
|
||||
t.Error("Expected session email to be empty after cleanup")
|
||||
}
|
||||
|
||||
if test.session.refreshToken != "" {
|
||||
t.Error("Expected session refresh token to be empty after cleanup")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Test nil session separately
|
||||
t.Run("Nil session", func(t *testing.T) {
|
||||
handler := &SessionHandler{
|
||||
logger: &MockLogger{},
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
var nilSession Session = nil
|
||||
err := handler.CleanupExpiredSession(rw, req, nilSession)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error for nil session, got: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSessionHandler_IsAjaxRequest(t *testing.T) {
|
||||
handler := &SessionHandler{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
headers map[string]string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "XMLHttpRequest header",
|
||||
headers: map[string]string{
|
||||
"X-Requested-With": "XMLHttpRequest",
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "JSON Accept header without HTML",
|
||||
headers: map[string]string{
|
||||
"Accept": "application/json",
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "JSON Accept header with HTML",
|
||||
headers: map[string]string{
|
||||
"Accept": "application/json, text/html",
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Fetch API CORS mode",
|
||||
headers: map[string]string{
|
||||
"Sec-Fetch-Mode": "cors",
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Regular browser request",
|
||||
headers: map[string]string{
|
||||
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "No special headers",
|
||||
headers: map[string]string{},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
for key, value := range test.headers {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
|
||||
result := handler.IsAjaxRequest(req)
|
||||
if result != test.expected {
|
||||
t.Errorf("Expected %v, got %v", test.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionHandler_SendErrorResponse(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
isAjax bool
|
||||
message string
|
||||
statusCode int
|
||||
expectedContentType string
|
||||
expectedBodyContains string
|
||||
}{
|
||||
{
|
||||
name: "AJAX error response",
|
||||
isAjax: true,
|
||||
message: "Authentication failed",
|
||||
statusCode: http.StatusUnauthorized,
|
||||
expectedContentType: "application/json",
|
||||
expectedBodyContains: `{"error": "Authentication failed"}`,
|
||||
},
|
||||
{
|
||||
name: "Browser error response",
|
||||
isAjax: false,
|
||||
message: "Session expired",
|
||||
statusCode: http.StatusForbidden,
|
||||
expectedContentType: "text/html",
|
||||
expectedBodyContains: "<h1>Error 403</h1>",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
handler := &SessionHandler{}
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
if test.isAjax {
|
||||
req.Header.Set("X-Requested-With", "XMLHttpRequest")
|
||||
}
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.SendErrorResponse(rw, req, test.message, test.statusCode)
|
||||
|
||||
if rw.Code != test.statusCode {
|
||||
t.Errorf("Expected status code %d, got %d", test.statusCode, rw.Code)
|
||||
}
|
||||
|
||||
contentType := rw.Header().Get("Content-Type")
|
||||
if contentType != test.expectedContentType {
|
||||
t.Errorf("Expected Content-Type '%s', got '%s'", test.expectedContentType, contentType)
|
||||
}
|
||||
|
||||
body := rw.Body.String()
|
||||
if !strings.Contains(body, test.expectedBodyContains) {
|
||||
t.Errorf("Expected body to contain '%s', got '%s'", test.expectedBodyContains, body)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionHandler_SetSecurityHeaders(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
origin string
|
||||
expectedCORS bool
|
||||
expectedStatus int
|
||||
}{
|
||||
{
|
||||
name: "Regular request without CORS",
|
||||
method: "GET",
|
||||
origin: "",
|
||||
expectedCORS: false,
|
||||
expectedStatus: 0, // No status written
|
||||
},
|
||||
{
|
||||
name: "CORS request with origin",
|
||||
method: "GET",
|
||||
origin: "https://example.com",
|
||||
expectedCORS: true,
|
||||
expectedStatus: 0,
|
||||
},
|
||||
{
|
||||
name: "OPTIONS preflight request",
|
||||
method: "OPTIONS",
|
||||
origin: "https://example.com",
|
||||
expectedCORS: true,
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
handler := &SessionHandler{}
|
||||
|
||||
req := httptest.NewRequest(test.method, "/", nil)
|
||||
if test.origin != "" {
|
||||
req.Header.Set("Origin", test.origin)
|
||||
}
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.SetSecurityHeaders(rw, req)
|
||||
|
||||
// Check standard security headers
|
||||
expectedSecurityHeaders := map[string]string{
|
||||
"X-Frame-Options": "DENY",
|
||||
"X-Content-Type-Options": "nosniff",
|
||||
"X-XSS-Protection": "1; mode=block",
|
||||
"Referrer-Policy": "strict-origin-when-cross-origin",
|
||||
}
|
||||
|
||||
for header, expectedValue := range expectedSecurityHeaders {
|
||||
actualValue := rw.Header().Get(header)
|
||||
if actualValue != expectedValue {
|
||||
t.Errorf("Expected %s header '%s', got '%s'", header, expectedValue, actualValue)
|
||||
}
|
||||
}
|
||||
|
||||
// Check CORS headers
|
||||
if test.expectedCORS {
|
||||
corsOrigin := rw.Header().Get("Access-Control-Allow-Origin")
|
||||
if corsOrigin != test.origin {
|
||||
t.Errorf("Expected CORS origin '%s', got '%s'", test.origin, corsOrigin)
|
||||
}
|
||||
|
||||
corsCredentials := rw.Header().Get("Access-Control-Allow-Credentials")
|
||||
if corsCredentials != "true" {
|
||||
t.Errorf("Expected CORS credentials 'true', got '%s'", corsCredentials)
|
||||
}
|
||||
|
||||
corsMethods := rw.Header().Get("Access-Control-Allow-Methods")
|
||||
if corsMethods != "GET, POST, OPTIONS" {
|
||||
t.Errorf("Expected CORS methods 'GET, POST, OPTIONS', got '%s'", corsMethods)
|
||||
}
|
||||
|
||||
corsHeaders := rw.Header().Get("Access-Control-Allow-Headers")
|
||||
if corsHeaders != "Authorization, Content-Type" {
|
||||
t.Errorf("Expected CORS headers 'Authorization, Content-Type', got '%s'", corsHeaders)
|
||||
}
|
||||
} else {
|
||||
corsOrigin := rw.Header().Get("Access-Control-Allow-Origin")
|
||||
if corsOrigin != "" {
|
||||
t.Errorf("Expected no CORS origin header, got '%s'", corsOrigin)
|
||||
}
|
||||
}
|
||||
|
||||
// Check status code for OPTIONS requests
|
||||
if test.expectedStatus > 0 {
|
||||
if rw.Code != test.expectedStatus {
|
||||
t.Errorf("Expected status code %d, got %d", test.expectedStatus, rw.Code)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionValidationResult(t *testing.T) {
|
||||
result := SessionValidationResult{
|
||||
Valid: true,
|
||||
NeedsAuth: false,
|
||||
ErrorMessage: "test message",
|
||||
}
|
||||
|
||||
if !result.Valid {
|
||||
t.Error("Expected Valid to be true")
|
||||
}
|
||||
|
||||
if result.NeedsAuth {
|
||||
t.Error("Expected NeedsAuth to be false")
|
||||
}
|
||||
|
||||
if result.ErrorMessage != "test message" {
|
||||
t.Errorf("Expected ErrorMessage 'test message', got '%s'", result.ErrorMessage)
|
||||
}
|
||||
}
|
||||
@@ -1,545 +0,0 @@
|
||||
package httpclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/cookiejar"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Config provides configuration for creating HTTP clients
|
||||
type Config struct {
|
||||
// Timeout for the entire request
|
||||
Timeout time.Duration
|
||||
// MaxRedirects allowed (0 means follow Go's default of 10)
|
||||
MaxRedirects int
|
||||
// UseCookieJar enables cookie jar for the client
|
||||
UseCookieJar bool
|
||||
// Connection settings
|
||||
DialTimeout time.Duration
|
||||
KeepAlive time.Duration
|
||||
TLSHandshakeTimeout time.Duration
|
||||
ResponseHeaderTimeout time.Duration
|
||||
ExpectContinueTimeout time.Duration
|
||||
IdleConnTimeout time.Duration
|
||||
// Connection pool settings
|
||||
MaxIdleConns int
|
||||
MaxIdleConnsPerHost int
|
||||
MaxConnsPerHost int
|
||||
// Buffer settings
|
||||
WriteBufferSize int
|
||||
ReadBufferSize int
|
||||
// Feature flags
|
||||
ForceHTTP2 bool
|
||||
DisableKeepAlives bool
|
||||
DisableCompression bool
|
||||
// TLS configuration
|
||||
TLSConfig *tls.Config
|
||||
}
|
||||
|
||||
// ClientType defines the type of HTTP client for optimized behavior
|
||||
type ClientType string
|
||||
|
||||
const (
|
||||
ClientTypeDefault ClientType = "default"
|
||||
ClientTypeToken ClientType = "token"
|
||||
ClientTypeAPI ClientType = "api"
|
||||
ClientTypeProxy ClientType = "proxy"
|
||||
)
|
||||
|
||||
// PresetConfigs provides pre-configured settings for different client types
|
||||
var PresetConfigs = map[ClientType]Config{
|
||||
ClientTypeDefault: {
|
||||
Timeout: 10 * time.Second, // Reduced from 30s to prevent slowloris attacks
|
||||
MaxRedirects: 5, // Reduced from 10 to prevent redirect loops
|
||||
UseCookieJar: false,
|
||||
DialTimeout: 3 * time.Second,
|
||||
KeepAlive: 15 * time.Second,
|
||||
TLSHandshakeTimeout: 2 * time.Second,
|
||||
ResponseHeaderTimeout: 3 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
IdleConnTimeout: 5 * time.Second,
|
||||
MaxIdleConns: 20, // Reduced from 100 to limit resource usage
|
||||
MaxIdleConnsPerHost: 2, // Reduced from 10 to prevent connection exhaustion
|
||||
MaxConnsPerHost: 5, // Reduced from 10 to limit concurrent connections
|
||||
WriteBufferSize: 4096,
|
||||
ReadBufferSize: 4096,
|
||||
ForceHTTP2: true,
|
||||
DisableKeepAlives: false,
|
||||
DisableCompression: false,
|
||||
},
|
||||
ClientTypeToken: {
|
||||
Timeout: 10 * time.Second,
|
||||
MaxRedirects: 50, // Token endpoints may redirect more
|
||||
UseCookieJar: true,
|
||||
DialTimeout: 3 * time.Second,
|
||||
KeepAlive: 15 * time.Second,
|
||||
TLSHandshakeTimeout: 2 * time.Second,
|
||||
ResponseHeaderTimeout: 3 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
IdleConnTimeout: 5 * time.Second,
|
||||
MaxIdleConns: 10,
|
||||
MaxIdleConnsPerHost: 2,
|
||||
MaxConnsPerHost: 5,
|
||||
WriteBufferSize: 4096,
|
||||
ReadBufferSize: 4096,
|
||||
ForceHTTP2: true,
|
||||
DisableKeepAlives: false,
|
||||
DisableCompression: false,
|
||||
},
|
||||
ClientTypeAPI: {
|
||||
Timeout: 30 * time.Second, // Longer for API operations
|
||||
MaxRedirects: 10,
|
||||
UseCookieJar: false,
|
||||
DialTimeout: 5 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
TLSHandshakeTimeout: 5 * time.Second,
|
||||
ResponseHeaderTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
MaxIdleConns: 50,
|
||||
MaxIdleConnsPerHost: 5,
|
||||
MaxConnsPerHost: 10,
|
||||
WriteBufferSize: 8192,
|
||||
ReadBufferSize: 8192,
|
||||
ForceHTTP2: true,
|
||||
DisableKeepAlives: false,
|
||||
DisableCompression: false,
|
||||
},
|
||||
ClientTypeProxy: {
|
||||
Timeout: 60 * time.Second, // Proxy needs longer timeouts
|
||||
MaxRedirects: 0, // Proxy should not follow redirects
|
||||
UseCookieJar: false,
|
||||
DialTimeout: 10 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
TLSHandshakeTimeout: 5 * time.Second,
|
||||
ResponseHeaderTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
MaxConnsPerHost: 20,
|
||||
WriteBufferSize: 16384,
|
||||
ReadBufferSize: 16384,
|
||||
ForceHTTP2: true,
|
||||
DisableKeepAlives: false,
|
||||
DisableCompression: true, // Proxy should not modify content
|
||||
},
|
||||
}
|
||||
|
||||
// Factory provides methods for creating configured HTTP clients
|
||||
type Factory struct {
|
||||
pool *TransportPool
|
||||
logger Logger
|
||||
}
|
||||
|
||||
// Logger interface for HTTP client operations
|
||||
type Logger interface {
|
||||
Debug(msg string)
|
||||
Debugf(format string, args ...interface{})
|
||||
Info(msg string)
|
||||
Infof(format string, args ...interface{})
|
||||
Error(msg string)
|
||||
Errorf(format string, args ...interface{})
|
||||
}
|
||||
|
||||
var (
|
||||
globalFactory *Factory
|
||||
globalFactoryOnce sync.Once
|
||||
)
|
||||
|
||||
// GetGlobalFactory returns the singleton HTTP client factory
|
||||
func GetGlobalFactory(logger Logger) *Factory {
|
||||
globalFactoryOnce.Do(func() {
|
||||
globalFactory = NewFactory(logger)
|
||||
})
|
||||
return globalFactory
|
||||
}
|
||||
|
||||
// NewFactory creates a new HTTP client factory
|
||||
func NewFactory(logger Logger) *Factory {
|
||||
if logger == nil {
|
||||
logger = &noOpLogger{}
|
||||
}
|
||||
return &Factory{
|
||||
pool: GetGlobalTransportPool(),
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateClient creates an HTTP client with the specified configuration
|
||||
func (f *Factory) CreateClient(config Config) (*http.Client, error) {
|
||||
// Validate configuration
|
||||
if err := f.ValidateConfig(&config); err != nil {
|
||||
return nil, fmt.Errorf("invalid configuration: %w", err)
|
||||
}
|
||||
|
||||
// Apply TLS configuration if not provided
|
||||
if config.TLSConfig == nil {
|
||||
config.TLSConfig = f.createSecureTLSConfig()
|
||||
}
|
||||
|
||||
// Get or create transport from pool
|
||||
transport := f.pool.GetOrCreateTransport(config)
|
||||
if transport == nil {
|
||||
return nil, fmt.Errorf("failed to create transport: client limit exceeded")
|
||||
}
|
||||
|
||||
// Create HTTP client
|
||||
client := &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: config.Timeout,
|
||||
}
|
||||
|
||||
// Configure redirect policy
|
||||
if config.MaxRedirects > 0 {
|
||||
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||
if len(via) >= config.MaxRedirects {
|
||||
return fmt.Errorf("stopped after %d redirects", config.MaxRedirects)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Add cookie jar if requested
|
||||
if config.UseCookieJar {
|
||||
jar, err := cookiejar.New(nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create cookie jar: %w", err)
|
||||
}
|
||||
client.Jar = jar
|
||||
}
|
||||
|
||||
f.logger.Debugf("Created HTTP client with config: timeout=%v, maxRedirects=%d", config.Timeout, config.MaxRedirects)
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// CreateClientWithPreset creates an HTTP client using a preset configuration
|
||||
func (f *Factory) CreateClientWithPreset(clientType ClientType) (*http.Client, error) {
|
||||
config, ok := PresetConfigs[clientType]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown client type: %s", clientType)
|
||||
}
|
||||
return f.CreateClient(config)
|
||||
}
|
||||
|
||||
// CreateDefault creates a default HTTP client
|
||||
func (f *Factory) CreateDefault() (*http.Client, error) {
|
||||
return f.CreateClientWithPreset(ClientTypeDefault)
|
||||
}
|
||||
|
||||
// CreateToken creates an HTTP client optimized for token operations
|
||||
func (f *Factory) CreateToken() (*http.Client, error) {
|
||||
return f.CreateClientWithPreset(ClientTypeToken)
|
||||
}
|
||||
|
||||
// CreateAPI creates an HTTP client optimized for API operations
|
||||
func (f *Factory) CreateAPI() (*http.Client, error) {
|
||||
return f.CreateClientWithPreset(ClientTypeAPI)
|
||||
}
|
||||
|
||||
// CreateProxy creates an HTTP client optimized for proxy operations
|
||||
func (f *Factory) CreateProxy() (*http.Client, error) {
|
||||
return f.CreateClientWithPreset(ClientTypeProxy)
|
||||
}
|
||||
|
||||
// ValidateConfig validates HTTP client configuration parameters
|
||||
func (f *Factory) ValidateConfig(config *Config) error {
|
||||
// Validate connection pool limits
|
||||
if config.MaxIdleConns < 0 {
|
||||
return fmt.Errorf("MaxIdleConns cannot be negative: %d", config.MaxIdleConns)
|
||||
}
|
||||
if config.MaxIdleConns > 1000 {
|
||||
return fmt.Errorf("MaxIdleConns too high (max 1000): %d", config.MaxIdleConns)
|
||||
}
|
||||
|
||||
if config.MaxIdleConnsPerHost < 0 {
|
||||
return fmt.Errorf("MaxIdleConnsPerHost cannot be negative: %d", config.MaxIdleConnsPerHost)
|
||||
}
|
||||
if config.MaxIdleConnsPerHost > 100 {
|
||||
return fmt.Errorf("MaxIdleConnsPerHost too high (max 100): %d", config.MaxIdleConnsPerHost)
|
||||
}
|
||||
|
||||
if config.MaxConnsPerHost < 0 {
|
||||
return fmt.Errorf("MaxConnsPerHost cannot be negative: %d", config.MaxConnsPerHost)
|
||||
}
|
||||
if config.MaxConnsPerHost > 200 {
|
||||
return fmt.Errorf("MaxConnsPerHost too high (max 200): %d", config.MaxConnsPerHost)
|
||||
}
|
||||
|
||||
// Validate timeouts
|
||||
if config.Timeout < 0 {
|
||||
return fmt.Errorf("timeout cannot be negative")
|
||||
}
|
||||
if config.Timeout > 5*time.Minute {
|
||||
return fmt.Errorf("timeout too long (max 5 minutes): %v", config.Timeout)
|
||||
}
|
||||
|
||||
// Validate buffer sizes
|
||||
if config.WriteBufferSize < 0 || config.ReadBufferSize < 0 {
|
||||
return fmt.Errorf("buffer sizes cannot be negative")
|
||||
}
|
||||
if config.WriteBufferSize > 1024*1024 || config.ReadBufferSize > 1024*1024 {
|
||||
return fmt.Errorf("buffer sizes too large (max 1MB)")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// createSecureTLSConfig creates a secure TLS configuration
|
||||
func (f *Factory) createSecureTLSConfig() *tls.Config {
|
||||
return &tls.Config{
|
||||
MinVersion: tls.VersionTLS12, // SECURITY: Enforce TLS 1.2 minimum
|
||||
MaxVersion: tls.VersionTLS13, // Support up to TLS 1.3
|
||||
CipherSuites: []uint16{
|
||||
// TLS 1.3 cipher suites (automatically selected when TLS 1.3 is negotiated)
|
||||
// TLS 1.2 secure cipher suites
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
||||
},
|
||||
InsecureSkipVerify: false, // SECURITY: Always verify certificates
|
||||
PreferServerCipherSuites: false, // Let client choose best cipher
|
||||
}
|
||||
}
|
||||
|
||||
// TransportPool manages a pool of shared HTTP transports
|
||||
type TransportPool struct {
|
||||
mu sync.RWMutex
|
||||
transports map[string]*sharedTransport
|
||||
maxConns int
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
// Resource limits
|
||||
clientCount int32 // Track total HTTP clients
|
||||
maxClients int32 // Limit total clients
|
||||
}
|
||||
|
||||
type sharedTransport struct {
|
||||
transport *http.Transport
|
||||
refCount int32
|
||||
lastUsed time.Time
|
||||
config Config
|
||||
}
|
||||
|
||||
var (
|
||||
globalTransportPool *TransportPool
|
||||
globalTransportPoolOnce sync.Once
|
||||
)
|
||||
|
||||
// GetGlobalTransportPool returns the singleton transport pool instance
|
||||
func GetGlobalTransportPool() *TransportPool {
|
||||
globalTransportPoolOnce.Do(func() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
globalTransportPool = &TransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20, // Reduced from 100 to prevent resource exhaustion
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
clientCount: 0,
|
||||
maxClients: 5, // Maximum 5 HTTP clients
|
||||
}
|
||||
// Start cleanup goroutine with context cancellation
|
||||
go globalTransportPool.cleanupIdleTransports(ctx)
|
||||
})
|
||||
return globalTransportPool
|
||||
}
|
||||
|
||||
// GetOrCreateTransport gets or creates a shared transport with the given config
|
||||
func (p *TransportPool) GetOrCreateTransport(config Config) *http.Transport {
|
||||
// Check client limit before creating new transport
|
||||
if atomic.LoadInt32(&p.clientCount) >= p.maxClients {
|
||||
// Try to return existing transport if limit reached
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
for _, shared := range p.transports {
|
||||
if shared != nil && shared.transport != nil {
|
||||
atomic.AddInt32(&shared.refCount, 1)
|
||||
shared.lastUsed = time.Now()
|
||||
return shared.transport
|
||||
}
|
||||
}
|
||||
// If no transport available, return nil
|
||||
return nil
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
key := p.configKey(config)
|
||||
|
||||
if shared, exists := p.transports[key]; exists {
|
||||
atomic.AddInt32(&shared.refCount, 1)
|
||||
shared.lastUsed = time.Now()
|
||||
return shared.transport
|
||||
}
|
||||
|
||||
// Create new transport
|
||||
transport := p.createTransport(config)
|
||||
|
||||
p.transports[key] = &sharedTransport{
|
||||
transport: transport,
|
||||
refCount: 1,
|
||||
lastUsed: time.Now(),
|
||||
config: config,
|
||||
}
|
||||
|
||||
atomic.AddInt32(&p.clientCount, 1)
|
||||
return transport
|
||||
}
|
||||
|
||||
// createTransport creates a new HTTP transport with the given configuration
|
||||
func (p *TransportPool) createTransport(config Config) *http.Transport {
|
||||
// Create secure TLS config if not provided
|
||||
tlsConfig := config.TLSConfig
|
||||
if tlsConfig == nil {
|
||||
tlsConfig = &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
MaxVersion: tls.VersionTLS13,
|
||||
}
|
||||
}
|
||||
|
||||
return &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: config.DialTimeout,
|
||||
KeepAlive: config.KeepAlive,
|
||||
}).DialContext,
|
||||
TLSClientConfig: tlsConfig,
|
||||
TLSHandshakeTimeout: config.TLSHandshakeTimeout,
|
||||
ResponseHeaderTimeout: config.ResponseHeaderTimeout,
|
||||
ExpectContinueTimeout: config.ExpectContinueTimeout,
|
||||
IdleConnTimeout: config.IdleConnTimeout,
|
||||
MaxIdleConns: config.MaxIdleConns,
|
||||
MaxIdleConnsPerHost: config.MaxIdleConnsPerHost,
|
||||
MaxConnsPerHost: config.MaxConnsPerHost,
|
||||
WriteBufferSize: config.WriteBufferSize,
|
||||
ReadBufferSize: config.ReadBufferSize,
|
||||
ForceAttemptHTTP2: config.ForceHTTP2,
|
||||
DisableKeepAlives: config.DisableKeepAlives,
|
||||
DisableCompression: config.DisableCompression,
|
||||
}
|
||||
}
|
||||
|
||||
// configKey generates a unique key for the configuration
|
||||
func (p *TransportPool) configKey(config Config) string {
|
||||
return fmt.Sprintf("%v-%d-%d-%d-%d-%v-%v-%v",
|
||||
config.Timeout,
|
||||
config.MaxIdleConns,
|
||||
config.MaxIdleConnsPerHost,
|
||||
config.MaxConnsPerHost,
|
||||
config.MaxRedirects,
|
||||
config.ForceHTTP2,
|
||||
config.DisableKeepAlives,
|
||||
config.DisableCompression,
|
||||
)
|
||||
}
|
||||
|
||||
// cleanupIdleTransports periodically cleans up idle transports
|
||||
func (p *TransportPool) cleanupIdleTransports(ctx context.Context) {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
p.cleanupIdle()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupIdle removes idle transports with zero references
|
||||
func (p *TransportPool) cleanupIdle() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
var toRemove []string
|
||||
|
||||
for key, shared := range p.transports {
|
||||
if atomic.LoadInt32(&shared.refCount) == 0 && now.Sub(shared.lastUsed) > 10*time.Minute {
|
||||
if shared.transport != nil {
|
||||
shared.transport.CloseIdleConnections()
|
||||
}
|
||||
toRemove = append(toRemove, key)
|
||||
}
|
||||
}
|
||||
|
||||
for _, key := range toRemove {
|
||||
delete(p.transports, key)
|
||||
atomic.AddInt32(&p.clientCount, -1)
|
||||
}
|
||||
}
|
||||
|
||||
// Release decrements the reference count for a transport
|
||||
func (p *TransportPool) Release(transport *http.Transport) {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
|
||||
for _, shared := range p.transports {
|
||||
if shared.transport == transport {
|
||||
atomic.AddInt32(&shared.refCount, -1)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close shuts down the transport pool
|
||||
func (p *TransportPool) Close() error {
|
||||
p.cancel()
|
||||
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
for key, shared := range p.transports {
|
||||
if shared.transport != nil {
|
||||
shared.transport.CloseIdleConnections()
|
||||
}
|
||||
delete(p.transports, key)
|
||||
}
|
||||
|
||||
atomic.StoreInt32(&p.clientCount, 0)
|
||||
return nil
|
||||
}
|
||||
|
||||
// noOpLogger provides a no-op logger implementation
|
||||
type noOpLogger struct{}
|
||||
|
||||
func (l *noOpLogger) Debug(msg string) {}
|
||||
func (l *noOpLogger) Debugf(format string, args ...interface{}) {}
|
||||
func (l *noOpLogger) Info(msg string) {}
|
||||
func (l *noOpLogger) Infof(format string, args ...interface{}) {}
|
||||
func (l *noOpLogger) Error(msg string) {}
|
||||
func (l *noOpLogger) Errorf(format string, args ...interface{}) {}
|
||||
|
||||
// Compatibility functions for backward compatibility
|
||||
|
||||
// CreateDefaultHTTPClient creates a default HTTP client
|
||||
func CreateDefaultHTTPClient() *http.Client {
|
||||
factory := GetGlobalFactory(nil)
|
||||
client, _ := factory.CreateDefault()
|
||||
return client
|
||||
}
|
||||
|
||||
// CreateTokenHTTPClient creates an HTTP client optimized for token operations
|
||||
func CreateTokenHTTPClient() *http.Client {
|
||||
factory := GetGlobalFactory(nil)
|
||||
client, _ := factory.CreateToken()
|
||||
return client
|
||||
}
|
||||
|
||||
// CreateHTTPClientWithConfig creates an HTTP client with custom configuration
|
||||
func CreateHTTPClientWithConfig(config Config) *http.Client {
|
||||
factory := GetGlobalFactory(nil)
|
||||
client, _ := factory.CreateClient(config)
|
||||
return client
|
||||
}
|
||||
@@ -1,408 +0,0 @@
|
||||
package httpclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestCreateProxy tests the CreateProxy method
|
||||
func TestCreateProxy(t *testing.T) {
|
||||
factory := NewFactory(nil)
|
||||
client, err := factory.CreateProxy()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create proxy client: %v", err)
|
||||
}
|
||||
if client == nil {
|
||||
t.Fatal("Expected non-nil proxy client")
|
||||
}
|
||||
|
||||
// Verify proxy configuration specifics
|
||||
if client.Timeout != 60*time.Second {
|
||||
t.Errorf("Expected proxy timeout to be 60s, got %v", client.Timeout)
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateConfigEdgeCases tests additional validation scenarios
|
||||
func TestValidateConfigEdgeCases(t *testing.T) {
|
||||
factory := NewFactory(nil)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
config Config
|
||||
shouldFail bool
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "Negative MaxIdleConnsPerHost",
|
||||
config: Config{
|
||||
MaxIdleConnsPerHost: -1,
|
||||
},
|
||||
shouldFail: true,
|
||||
errorMsg: "MaxIdleConnsPerHost cannot be negative",
|
||||
},
|
||||
{
|
||||
name: "Excessive MaxIdleConnsPerHost",
|
||||
config: Config{
|
||||
MaxIdleConnsPerHost: 200,
|
||||
},
|
||||
shouldFail: true,
|
||||
errorMsg: "MaxIdleConnsPerHost too high",
|
||||
},
|
||||
{
|
||||
name: "Negative MaxConnsPerHost",
|
||||
config: Config{
|
||||
MaxConnsPerHost: -1,
|
||||
},
|
||||
shouldFail: true,
|
||||
errorMsg: "MaxConnsPerHost cannot be negative",
|
||||
},
|
||||
{
|
||||
name: "Excessive MaxConnsPerHost",
|
||||
config: Config{
|
||||
MaxConnsPerHost: 300,
|
||||
},
|
||||
shouldFail: true,
|
||||
errorMsg: "MaxConnsPerHost too high",
|
||||
},
|
||||
{
|
||||
name: "Negative WriteBufferSize",
|
||||
config: Config{
|
||||
WriteBufferSize: -1,
|
||||
},
|
||||
shouldFail: true,
|
||||
errorMsg: "buffer sizes cannot be negative",
|
||||
},
|
||||
{
|
||||
name: "Negative ReadBufferSize",
|
||||
config: Config{
|
||||
ReadBufferSize: -1,
|
||||
},
|
||||
shouldFail: true,
|
||||
errorMsg: "buffer sizes cannot be negative",
|
||||
},
|
||||
{
|
||||
name: "Excessive WriteBufferSize",
|
||||
config: Config{
|
||||
WriteBufferSize: 2 * 1024 * 1024,
|
||||
},
|
||||
shouldFail: true,
|
||||
errorMsg: "buffer sizes too large",
|
||||
},
|
||||
{
|
||||
name: "Excessive ReadBufferSize",
|
||||
config: Config{
|
||||
ReadBufferSize: 2 * 1024 * 1024,
|
||||
},
|
||||
shouldFail: true,
|
||||
errorMsg: "buffer sizes too large",
|
||||
},
|
||||
{
|
||||
name: "Valid edge values",
|
||||
config: Config{
|
||||
MaxIdleConns: 1000,
|
||||
MaxIdleConnsPerHost: 100,
|
||||
MaxConnsPerHost: 200,
|
||||
Timeout: 5 * time.Minute,
|
||||
WriteBufferSize: 1024 * 1024,
|
||||
ReadBufferSize: 1024 * 1024,
|
||||
},
|
||||
shouldFail: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := factory.ValidateConfig(&tc.config)
|
||||
if tc.shouldFail {
|
||||
if err == nil {
|
||||
t.Fatalf("Expected validation to fail with message containing: %s", tc.errorMsg)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected validation error: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestTransportPoolClose tests the Close method of TransportPool
|
||||
func TestTransportPoolClose(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
pool := &TransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
clientCount: 0,
|
||||
maxClients: 5,
|
||||
}
|
||||
|
||||
// Create some transports
|
||||
config := PresetConfigs[ClientTypeDefault]
|
||||
transport1 := pool.GetOrCreateTransport(config)
|
||||
if transport1 == nil {
|
||||
t.Fatal("Failed to create transport")
|
||||
}
|
||||
|
||||
// Modify config slightly to create a different transport
|
||||
config.Timeout = 20 * time.Second
|
||||
transport2 := pool.GetOrCreateTransport(config)
|
||||
if transport2 == nil {
|
||||
t.Fatal("Failed to create second transport")
|
||||
}
|
||||
|
||||
// Verify transports were created
|
||||
pool.mu.RLock()
|
||||
initialCount := len(pool.transports)
|
||||
pool.mu.RUnlock()
|
||||
if initialCount == 0 {
|
||||
t.Fatal("Expected transports to be created")
|
||||
}
|
||||
|
||||
// Close the pool
|
||||
err := pool.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to close pool: %v", err)
|
||||
}
|
||||
|
||||
// Verify all transports were removed
|
||||
pool.mu.RLock()
|
||||
finalCount := len(pool.transports)
|
||||
pool.mu.RUnlock()
|
||||
if finalCount != 0 {
|
||||
t.Fatalf("Expected 0 transports after close, got %d", finalCount)
|
||||
}
|
||||
|
||||
// Verify client count was reset
|
||||
if pool.clientCount != 0 {
|
||||
t.Fatalf("Expected client count to be 0 after close, got %d", pool.clientCount)
|
||||
}
|
||||
}
|
||||
|
||||
// TestNoOpLogger tests the no-op logger implementation
|
||||
func TestNoOpLogger(t *testing.T) {
|
||||
logger := &noOpLogger{}
|
||||
|
||||
// These should not panic or cause any issues
|
||||
logger.Debug("test debug")
|
||||
logger.Debugf("test debug %s", "formatted")
|
||||
logger.Info("test info")
|
||||
logger.Infof("test info %s", "formatted")
|
||||
logger.Error("test error")
|
||||
logger.Errorf("test error %s", "formatted")
|
||||
|
||||
// Test using logger with factory
|
||||
factory := NewFactory(logger)
|
||||
client, err := factory.CreateDefault()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create client with no-op logger: %v", err)
|
||||
}
|
||||
if client == nil {
|
||||
t.Fatal("Expected non-nil client")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCreateClientWithCustomTLS tests creating client with custom TLS config
|
||||
func TestCreateClientWithCustomTLS(t *testing.T) {
|
||||
factory := NewFactory(nil)
|
||||
|
||||
customTLS := &tls.Config{
|
||||
MinVersion: tls.VersionTLS13,
|
||||
MaxVersion: tls.VersionTLS13,
|
||||
}
|
||||
|
||||
config := Config{
|
||||
Timeout: 10 * time.Second,
|
||||
MaxIdleConns: 10,
|
||||
MaxIdleConnsPerHost: 2,
|
||||
MaxConnsPerHost: 5,
|
||||
TLSConfig: customTLS,
|
||||
}
|
||||
|
||||
client, err := factory.CreateClient(config)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create client with custom TLS: %v", err)
|
||||
}
|
||||
if client == nil {
|
||||
t.Fatal("Expected non-nil client")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCreateClientWithMaxRedirects tests redirect limiting
|
||||
func TestCreateClientWithMaxRedirects(t *testing.T) {
|
||||
redirectCount := 0
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
redirectCount++
|
||||
if redirectCount <= 3 {
|
||||
http.Redirect(w, r, "/redirect", http.StatusFound)
|
||||
} else {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("final"))
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
factory := NewFactory(nil)
|
||||
|
||||
// Test with max redirects = 2 (should fail)
|
||||
config := Config{
|
||||
Timeout: 10 * time.Second,
|
||||
MaxRedirects: 2,
|
||||
MaxIdleConns: 10,
|
||||
MaxIdleConnsPerHost: 2,
|
||||
MaxConnsPerHost: 5,
|
||||
}
|
||||
|
||||
client, err := factory.CreateClient(config)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create client: %v", err)
|
||||
}
|
||||
|
||||
redirectCount = 0
|
||||
_, err = client.Get(server.URL)
|
||||
if err == nil {
|
||||
t.Fatal("Expected redirect limit error")
|
||||
}
|
||||
|
||||
// Test with max redirects = 5 (should succeed)
|
||||
config.MaxRedirects = 5
|
||||
client, err = factory.CreateClient(config)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create client: %v", err)
|
||||
}
|
||||
|
||||
redirectCount = 0
|
||||
resp, err := client.Get(server.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("Request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
// TestTransportPoolMaxClientsLimit tests the max clients limitation
|
||||
func TestTransportPoolMaxClientsLimit(t *testing.T) {
|
||||
pool := &TransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
clientCount: 0,
|
||||
maxClients: 2, // Set low limit for testing
|
||||
}
|
||||
|
||||
// Create transports up to the limit
|
||||
configs := []Config{
|
||||
{Timeout: 10 * time.Second},
|
||||
{Timeout: 20 * time.Second},
|
||||
{Timeout: 30 * time.Second}, // This should not create a new transport
|
||||
}
|
||||
|
||||
for i, config := range configs {
|
||||
transport := pool.GetOrCreateTransport(config)
|
||||
if i < 2 {
|
||||
if transport == nil {
|
||||
t.Fatalf("Expected transport %d to be created", i)
|
||||
}
|
||||
// Transport created successfully within limit
|
||||
} else {
|
||||
// When limit is reached, should return existing transport or nil
|
||||
if transport == nil {
|
||||
// This is acceptable - nil when limit reached
|
||||
t.Log("Transport creation blocked due to client limit")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Verify client count doesn't exceed limit
|
||||
if pool.clientCount > pool.maxClients {
|
||||
t.Fatalf("Client count %d exceeds max %d", pool.clientCount, pool.maxClients)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCleanupIdleTransportsContext tests cleanup goroutine with context
|
||||
func TestCleanupIdleTransportsContext(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
pool := &TransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
clientCount: 0,
|
||||
maxClients: 5,
|
||||
}
|
||||
|
||||
// Start cleanup goroutine
|
||||
done := make(chan bool)
|
||||
go func() {
|
||||
pool.cleanupIdleTransports(ctx)
|
||||
done <- true
|
||||
}()
|
||||
|
||||
// Give it a moment to start
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Cancel context to stop cleanup
|
||||
cancel()
|
||||
|
||||
// Wait for goroutine to exit
|
||||
select {
|
||||
case <-done:
|
||||
// Success - goroutine exited
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Fatal("Cleanup goroutine did not exit after context cancellation")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFactoryWithLogger tests factory creation with custom logger
|
||||
func TestFactoryWithLogger(t *testing.T) {
|
||||
// Create a mock logger that implements the Logger interface
|
||||
logger := &MockLogger{}
|
||||
|
||||
factory := NewFactory(logger)
|
||||
if factory.logger == nil {
|
||||
t.Fatal("Expected logger to be set")
|
||||
}
|
||||
}
|
||||
|
||||
// MockLogger for testing
|
||||
type MockLogger struct {
|
||||
debugCalled bool
|
||||
debugfCalled bool
|
||||
infoCalled bool
|
||||
infofCalled bool
|
||||
errorCalled bool
|
||||
errorfCalled bool
|
||||
}
|
||||
|
||||
func (m *MockLogger) Debug(msg string) { m.debugCalled = true }
|
||||
func (m *MockLogger) Debugf(format string, args ...interface{}) { m.debugfCalled = true }
|
||||
func (m *MockLogger) Info(msg string) { m.infoCalled = true }
|
||||
func (m *MockLogger) Infof(format string, args ...interface{}) { m.infofCalled = true }
|
||||
func (m *MockLogger) Error(msg string) { m.errorCalled = true }
|
||||
func (m *MockLogger) Errorf(format string, args ...interface{}) { m.errorfCalled = true }
|
||||
|
||||
// TestCreateClientLogging tests that logger is called during client creation
|
||||
func TestCreateClientLogging(t *testing.T) {
|
||||
logger := &MockLogger{}
|
||||
factory := NewFactory(logger)
|
||||
|
||||
client, err := factory.CreateDefault()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create client: %v", err)
|
||||
}
|
||||
if client == nil {
|
||||
t.Fatal("Expected non-nil client")
|
||||
}
|
||||
|
||||
// Verify logger was called
|
||||
if !logger.debugfCalled {
|
||||
t.Error("Expected Debugf to be called during client creation")
|
||||
}
|
||||
}
|
||||
@@ -1,299 +0,0 @@
|
||||
package httpclient
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestFactoryCreateClient(t *testing.T) {
|
||||
factory := NewFactory(nil)
|
||||
|
||||
// Test creating default client
|
||||
client, err := factory.CreateDefault()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create default client: %v", err)
|
||||
}
|
||||
if client == nil {
|
||||
t.Fatal("Expected non-nil client")
|
||||
}
|
||||
|
||||
// Test creating token client
|
||||
tokenClient, err := factory.CreateToken()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create token client: %v", err)
|
||||
}
|
||||
if tokenClient == nil {
|
||||
t.Fatal("Expected non-nil token client")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFactoryCreateClientWithPreset(t *testing.T) {
|
||||
factory := NewFactory(nil)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
clientType ClientType
|
||||
shouldFail bool
|
||||
}{
|
||||
{"Default", ClientTypeDefault, false},
|
||||
{"Token", ClientTypeToken, false},
|
||||
{"API", ClientTypeAPI, false},
|
||||
{"Proxy", ClientTypeProxy, false},
|
||||
{"Invalid", ClientType("invalid"), true},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
client, err := factory.CreateClientWithPreset(tc.clientType)
|
||||
if tc.shouldFail {
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for invalid client type")
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create %s client: %v", tc.clientType, err)
|
||||
}
|
||||
if client == nil {
|
||||
t.Fatal("Expected non-nil client")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFactoryValidateConfig(t *testing.T) {
|
||||
factory := NewFactory(nil)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
config Config
|
||||
shouldFail bool
|
||||
}{
|
||||
{
|
||||
name: "Valid config",
|
||||
config: PresetConfigs[ClientTypeDefault],
|
||||
shouldFail: false,
|
||||
},
|
||||
{
|
||||
name: "Negative MaxIdleConns",
|
||||
config: Config{
|
||||
MaxIdleConns: -1,
|
||||
},
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
name: "Excessive MaxIdleConns",
|
||||
config: Config{
|
||||
MaxIdleConns: 2000,
|
||||
},
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
name: "Negative timeout",
|
||||
config: Config{
|
||||
Timeout: -1 * time.Second,
|
||||
},
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
name: "Excessive timeout",
|
||||
config: Config{
|
||||
Timeout: 10 * time.Minute,
|
||||
},
|
||||
shouldFail: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := factory.ValidateConfig(&tc.config)
|
||||
if tc.shouldFail && err == nil {
|
||||
t.Fatal("Expected validation to fail")
|
||||
}
|
||||
if !tc.shouldFail && err != nil {
|
||||
t.Fatalf("Unexpected validation error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransportPoolConcurrency(t *testing.T) {
|
||||
pool := &TransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
clientCount: 0,
|
||||
maxClients: 5,
|
||||
}
|
||||
|
||||
config := PresetConfigs[ClientTypeDefault]
|
||||
|
||||
var wg sync.WaitGroup
|
||||
numGoroutines := 10
|
||||
|
||||
// Test concurrent transport creation
|
||||
wg.Add(numGoroutines)
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
transport := pool.GetOrCreateTransport(config)
|
||||
if transport != nil {
|
||||
// Simulate usage
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
pool.Release(transport)
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// Verify client count is within limits
|
||||
clientCount := atomic.LoadInt32(&pool.clientCount)
|
||||
if clientCount > pool.maxClients {
|
||||
t.Fatalf("Client count %d exceeds max %d", clientCount, pool.maxClients)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPClientRequests(t *testing.T) {
|
||||
// Create test server
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("test response"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
factory := NewFactory(nil)
|
||||
client, err := factory.CreateDefault()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create client: %v", err)
|
||||
}
|
||||
|
||||
// Make request
|
||||
resp, err := client.Get(server.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("Request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientWithCookieJar(t *testing.T) {
|
||||
config := PresetConfigs[ClientTypeToken]
|
||||
if !config.UseCookieJar {
|
||||
t.Skip("Token client should have cookie jar enabled")
|
||||
}
|
||||
|
||||
factory := NewFactory(nil)
|
||||
client, err := factory.CreateToken()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create token client: %v", err)
|
||||
}
|
||||
|
||||
if client.Jar == nil {
|
||||
t.Fatal("Expected cookie jar to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransportPoolCleanup(t *testing.T) {
|
||||
pool := &TransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
clientCount: 0,
|
||||
maxClients: 5,
|
||||
}
|
||||
|
||||
config := PresetConfigs[ClientTypeDefault]
|
||||
|
||||
// Create transport
|
||||
transport := pool.GetOrCreateTransport(config)
|
||||
if transport == nil {
|
||||
t.Fatal("Failed to create transport")
|
||||
}
|
||||
|
||||
// Release transport
|
||||
pool.Release(transport)
|
||||
|
||||
// Simulate idle time
|
||||
pool.mu.Lock()
|
||||
for _, shared := range pool.transports {
|
||||
shared.lastUsed = time.Now().Add(-11 * time.Minute)
|
||||
atomic.StoreInt32(&shared.refCount, 0)
|
||||
}
|
||||
pool.mu.Unlock()
|
||||
|
||||
// Run cleanup
|
||||
pool.cleanupIdle()
|
||||
|
||||
// Verify transport was removed
|
||||
pool.mu.RLock()
|
||||
count := len(pool.transports)
|
||||
pool.mu.RUnlock()
|
||||
|
||||
if count != 0 {
|
||||
t.Fatalf("Expected 0 transports after cleanup, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGlobalFactorySingleton(t *testing.T) {
|
||||
factory1 := GetGlobalFactory(nil)
|
||||
factory2 := GetGlobalFactory(nil)
|
||||
|
||||
if factory1 != factory2 {
|
||||
t.Fatal("Expected singleton factory instances to be the same")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompatibilityFunctions(t *testing.T) {
|
||||
// Test CreateDefaultHTTPClient
|
||||
defaultClient := CreateDefaultHTTPClient()
|
||||
if defaultClient == nil {
|
||||
t.Fatal("Expected non-nil default client")
|
||||
}
|
||||
|
||||
// Test CreateTokenHTTPClient
|
||||
tokenClient := CreateTokenHTTPClient()
|
||||
if tokenClient == nil {
|
||||
t.Fatal("Expected non-nil token client")
|
||||
}
|
||||
|
||||
// Test CreateHTTPClientWithConfig
|
||||
config := PresetConfigs[ClientTypeAPI]
|
||||
apiClient := CreateHTTPClientWithConfig(config)
|
||||
if apiClient == nil {
|
||||
t.Fatal("Expected non-nil API client")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkFactoryCreateClient(b *testing.B) {
|
||||
factory := NewFactory(nil)
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
client, err := factory.CreateDefault()
|
||||
if err != nil || client == nil {
|
||||
b.Fatal("Failed to create client")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkTransportPoolGetOrCreate(b *testing.B) {
|
||||
pool := GetGlobalTransportPool()
|
||||
config := PresetConfigs[ClientTypeDefault]
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
transport := pool.GetOrCreateTransport(config)
|
||||
if transport != nil {
|
||||
pool.Release(transport)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1,83 +0,0 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
)
|
||||
|
||||
// LegacyLoggerAdapter wraps the old Logger struct from the main package
|
||||
// to implement the new unified Logger interface. This allows for gradual
|
||||
// migration of the codebase to the new logger interface.
|
||||
type LegacyLoggerAdapter struct {
|
||||
logError *log.Logger
|
||||
logInfo *log.Logger
|
||||
logDebug *log.Logger
|
||||
}
|
||||
|
||||
// NewLegacyAdapter creates a new adapter from the old logger components
|
||||
func NewLegacyAdapter(logError, logInfo, logDebug *log.Logger) Logger {
|
||||
if logError == nil || logInfo == nil || logDebug == nil {
|
||||
return GetNoOpLogger()
|
||||
}
|
||||
return &LegacyLoggerAdapter{
|
||||
logError: logError,
|
||||
logInfo: logInfo,
|
||||
logDebug: logDebug,
|
||||
}
|
||||
}
|
||||
|
||||
// Debug logs a debug message
|
||||
func (l *LegacyLoggerAdapter) Debug(msg string) {
|
||||
l.logDebug.Print(msg)
|
||||
}
|
||||
|
||||
// Debugf logs a formatted debug message
|
||||
func (l *LegacyLoggerAdapter) Debugf(format string, args ...interface{}) {
|
||||
l.logDebug.Printf(format, args...)
|
||||
}
|
||||
|
||||
// Info logs an info message
|
||||
func (l *LegacyLoggerAdapter) Info(msg string) {
|
||||
l.logInfo.Print(msg)
|
||||
}
|
||||
|
||||
// Infof logs a formatted info message
|
||||
func (l *LegacyLoggerAdapter) Infof(format string, args ...interface{}) {
|
||||
l.logInfo.Printf(format, args...)
|
||||
}
|
||||
|
||||
// Error logs an error message
|
||||
func (l *LegacyLoggerAdapter) Error(msg string) {
|
||||
l.logError.Print(msg)
|
||||
}
|
||||
|
||||
// Errorf logs a formatted error message
|
||||
func (l *LegacyLoggerAdapter) Errorf(format string, args ...interface{}) {
|
||||
l.logError.Printf(format, args...)
|
||||
}
|
||||
|
||||
// Printf logs a formatted message at info level
|
||||
func (l *LegacyLoggerAdapter) Printf(format string, args ...interface{}) {
|
||||
l.logInfo.Printf(format, args...)
|
||||
}
|
||||
|
||||
// Println logs a message at info level
|
||||
func (l *LegacyLoggerAdapter) Println(args ...interface{}) {
|
||||
l.logInfo.Print(args...)
|
||||
}
|
||||
|
||||
// Fatalf logs a formatted error message and panics
|
||||
func (l *LegacyLoggerAdapter) Fatalf(format string, args ...interface{}) {
|
||||
l.logError.Printf(format, args...)
|
||||
panic(fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
// WithField returns the same logger (no structured logging support in legacy adapter)
|
||||
func (l *LegacyLoggerAdapter) WithField(key string, value interface{}) Logger {
|
||||
return l
|
||||
}
|
||||
|
||||
// WithFields returns the same logger (no structured logging support in legacy adapter)
|
||||
func (l *LegacyLoggerAdapter) WithFields(fields map[string]interface{}) Logger {
|
||||
return l
|
||||
}
|
||||
@@ -1,182 +0,0 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Factory creates and manages logger instances with singleton support
|
||||
// for common logger types to reduce memory allocation.
|
||||
type Factory struct {
|
||||
mu sync.RWMutex
|
||||
defaultLogger Logger
|
||||
noOpLogger Logger
|
||||
loggers map[string]Logger
|
||||
defaultLogLevel string
|
||||
}
|
||||
|
||||
var (
|
||||
// globalFactory is the singleton factory instance
|
||||
globalFactory *Factory
|
||||
// factoryOnce ensures the factory is created only once
|
||||
factoryOnce sync.Once
|
||||
)
|
||||
|
||||
// GetFactory returns the global logger factory instance
|
||||
func GetFactory() *Factory {
|
||||
factoryOnce.Do(func() {
|
||||
globalFactory = &Factory{
|
||||
loggers: make(map[string]Logger),
|
||||
defaultLogLevel: "info",
|
||||
}
|
||||
})
|
||||
return globalFactory
|
||||
}
|
||||
|
||||
// SetDefaultLogLevel sets the default log level for new loggers
|
||||
func (f *Factory) SetDefaultLogLevel(level string) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.defaultLogLevel = level
|
||||
}
|
||||
|
||||
// GetLogger returns a logger for the given name, creating one if it doesn't exist
|
||||
func (f *Factory) GetLogger(name string) Logger {
|
||||
f.mu.RLock()
|
||||
if logger, exists := f.loggers[name]; exists {
|
||||
f.mu.RUnlock()
|
||||
return logger
|
||||
}
|
||||
f.mu.RUnlock()
|
||||
|
||||
// Create new logger
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
|
||||
// Double check after acquiring write lock
|
||||
if logger, exists := f.loggers[name]; exists {
|
||||
return logger
|
||||
}
|
||||
|
||||
logger := f.createLogger(name)
|
||||
f.loggers[name] = logger
|
||||
return logger
|
||||
}
|
||||
|
||||
// createLogger creates a new logger instance
|
||||
func (f *Factory) createLogger(name string) Logger {
|
||||
if name == "noop" || name == "no-op" || name == "discard" {
|
||||
return GetNoOpLogger()
|
||||
}
|
||||
|
||||
// Create logger with appropriate outputs based on environment
|
||||
var errorOut, infoOut, debugOut io.Writer
|
||||
|
||||
if os.Getenv("OIDC_LOG_TO_FILE") == "true" {
|
||||
// Log to files if configured
|
||||
errorOut = getOrCreateLogFile("error.log")
|
||||
infoOut = getOrCreateLogFile("info.log")
|
||||
debugOut = getOrCreateLogFile("debug.log")
|
||||
} else {
|
||||
// Default to stdout/stderr
|
||||
errorOut = os.Stderr
|
||||
infoOut = os.Stdout
|
||||
debugOut = os.Stdout
|
||||
}
|
||||
|
||||
return NewStandardLogger(f.defaultLogLevel, errorOut, infoOut, debugOut)
|
||||
}
|
||||
|
||||
// GetDefaultLogger returns the default logger instance
|
||||
func (f *Factory) GetDefaultLogger() Logger {
|
||||
f.mu.RLock()
|
||||
if f.defaultLogger != nil {
|
||||
f.mu.RUnlock()
|
||||
return f.defaultLogger
|
||||
}
|
||||
f.mu.RUnlock()
|
||||
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
|
||||
if f.defaultLogger == nil {
|
||||
f.defaultLogger = f.createLogger("default")
|
||||
}
|
||||
|
||||
return f.defaultLogger
|
||||
}
|
||||
|
||||
// GetNoOpLogger returns the singleton no-op logger
|
||||
func (f *Factory) GetNoOpLogger() Logger {
|
||||
f.mu.RLock()
|
||||
if f.noOpLogger != nil {
|
||||
f.mu.RUnlock()
|
||||
return f.noOpLogger
|
||||
}
|
||||
f.mu.RUnlock()
|
||||
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
|
||||
if f.noOpLogger == nil {
|
||||
f.noOpLogger = GetNoOpLogger()
|
||||
}
|
||||
|
||||
return f.noOpLogger
|
||||
}
|
||||
|
||||
// Clear removes all cached loggers (useful for testing)
|
||||
func (f *Factory) Clear() {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
|
||||
f.loggers = make(map[string]Logger)
|
||||
f.defaultLogger = nil
|
||||
// Don't clear noOpLogger as it's a singleton
|
||||
}
|
||||
|
||||
// getOrCreateLogFile returns a file writer for the given log file
|
||||
func getOrCreateLogFile(filename string) io.Writer {
|
||||
logDir := os.Getenv("OIDC_LOG_DIR")
|
||||
if logDir == "" {
|
||||
logDir = "/var/log/traefik-oidc"
|
||||
}
|
||||
|
||||
// Ensure log directory exists
|
||||
if err := os.MkdirAll(logDir, 0755); err != nil {
|
||||
// Fall back to stderr if we can't create the directory
|
||||
return os.Stderr
|
||||
}
|
||||
|
||||
filepath := logDir + "/" + filename
|
||||
file, err := os.OpenFile(filepath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
|
||||
if err != nil {
|
||||
// Fall back to stderr if we can't open the file
|
||||
return os.Stderr
|
||||
}
|
||||
|
||||
return file
|
||||
}
|
||||
|
||||
// Global convenience functions
|
||||
|
||||
// New creates a new logger with the specified level
|
||||
func New(level string) Logger {
|
||||
return GetFactory().GetLogger(level)
|
||||
}
|
||||
|
||||
// Default returns the default logger
|
||||
func Default() Logger {
|
||||
return GetFactory().GetDefaultLogger()
|
||||
}
|
||||
|
||||
// NoOp returns a no-op logger
|
||||
func NoOp() Logger {
|
||||
return GetFactory().GetNoOpLogger()
|
||||
}
|
||||
|
||||
// WithLevel creates a new logger with the specified level
|
||||
func WithLevel(level string) Logger {
|
||||
return NewStandardLogger(level, os.Stderr, os.Stdout, os.Stdout)
|
||||
}
|
||||
@@ -1,312 +0,0 @@
|
||||
// Package logger provides a unified logging interface for the entire application.
|
||||
// It consolidates all the duplicate logger interfaces into a single, comprehensive
|
||||
// interface that supports different log levels and structured logging.
|
||||
package logger
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Logger is the unified interface for all logging operations in the application.
|
||||
// It combines all the methods from the various logger interfaces that were
|
||||
// previously scattered across different packages.
|
||||
type Logger interface {
|
||||
// Basic logging methods
|
||||
Debug(msg string)
|
||||
Debugf(format string, args ...interface{})
|
||||
Info(msg string)
|
||||
Infof(format string, args ...interface{})
|
||||
Error(msg string)
|
||||
Errorf(format string, args ...interface{})
|
||||
|
||||
// Additional methods for compatibility with existing code
|
||||
Printf(format string, args ...interface{})
|
||||
Println(args ...interface{})
|
||||
Fatalf(format string, args ...interface{})
|
||||
|
||||
// Structured logging support
|
||||
WithField(key string, value interface{}) Logger
|
||||
WithFields(fields map[string]interface{}) Logger
|
||||
}
|
||||
|
||||
// StandardLogger implements the Logger interface using Go's standard log package.
|
||||
// It provides thread-safe logging with different output streams for different log levels.
|
||||
type StandardLogger struct {
|
||||
mu sync.RWMutex
|
||||
logError *log.Logger
|
||||
logInfo *log.Logger
|
||||
logDebug *log.Logger
|
||||
fields map[string]interface{}
|
||||
level LogLevel
|
||||
}
|
||||
|
||||
// LogLevel represents the logging level
|
||||
type LogLevel int
|
||||
|
||||
const (
|
||||
// LogLevelDebug enables all log messages
|
||||
LogLevelDebug LogLevel = iota
|
||||
// LogLevelInfo enables info and error messages
|
||||
LogLevelInfo
|
||||
// LogLevelError enables only error messages
|
||||
LogLevelError
|
||||
// LogLevelNone disables all logging
|
||||
LogLevelNone
|
||||
)
|
||||
|
||||
// ParseLogLevel converts a string log level to LogLevel
|
||||
func ParseLogLevel(level string) LogLevel {
|
||||
switch level {
|
||||
case "debug", "DEBUG":
|
||||
return LogLevelDebug
|
||||
case "info", "INFO":
|
||||
return LogLevelInfo
|
||||
case "error", "ERROR":
|
||||
return LogLevelError
|
||||
case "none", "NONE":
|
||||
return LogLevelNone
|
||||
default:
|
||||
return LogLevelInfo
|
||||
}
|
||||
}
|
||||
|
||||
// NewStandardLogger creates a new StandardLogger with the specified log level
|
||||
func NewStandardLogger(level string, errorOutput, infoOutput, debugOutput io.Writer) *StandardLogger {
|
||||
logLevel := ParseLogLevel(level)
|
||||
|
||||
if errorOutput == nil {
|
||||
errorOutput = io.Discard
|
||||
}
|
||||
if infoOutput == nil {
|
||||
infoOutput = io.Discard
|
||||
}
|
||||
if debugOutput == nil {
|
||||
debugOutput = io.Discard
|
||||
}
|
||||
|
||||
return &StandardLogger{
|
||||
logError: log.New(errorOutput, "ERROR: ", log.Ldate|log.Ltime|log.Lshortfile),
|
||||
logInfo: log.New(infoOutput, "INFO: ", log.Ldate|log.Ltime),
|
||||
logDebug: log.New(debugOutput, "DEBUG: ", log.Ldate|log.Ltime|log.Lshortfile),
|
||||
fields: make(map[string]interface{}),
|
||||
level: logLevel,
|
||||
}
|
||||
}
|
||||
|
||||
// Debug logs a debug message
|
||||
func (l *StandardLogger) Debug(msg string) {
|
||||
if l.level <= LogLevelDebug {
|
||||
l.mu.RLock()
|
||||
defer l.mu.RUnlock()
|
||||
if len(l.fields) > 0 {
|
||||
msg = l.formatWithFields(msg)
|
||||
}
|
||||
l.logDebug.Print(msg)
|
||||
}
|
||||
}
|
||||
|
||||
// Debugf logs a formatted debug message
|
||||
func (l *StandardLogger) Debugf(format string, args ...interface{}) {
|
||||
if l.level <= LogLevelDebug {
|
||||
l.mu.RLock()
|
||||
defer l.mu.RUnlock()
|
||||
msg := fmt.Sprintf(format, args...)
|
||||
if len(l.fields) > 0 {
|
||||
msg = l.formatWithFields(msg)
|
||||
}
|
||||
l.logDebug.Print(msg)
|
||||
}
|
||||
}
|
||||
|
||||
// Info logs an info message
|
||||
func (l *StandardLogger) Info(msg string) {
|
||||
if l.level <= LogLevelInfo {
|
||||
l.mu.RLock()
|
||||
defer l.mu.RUnlock()
|
||||
if len(l.fields) > 0 {
|
||||
msg = l.formatWithFields(msg)
|
||||
}
|
||||
l.logInfo.Print(msg)
|
||||
}
|
||||
}
|
||||
|
||||
// Infof logs a formatted info message
|
||||
func (l *StandardLogger) Infof(format string, args ...interface{}) {
|
||||
if l.level <= LogLevelInfo {
|
||||
l.mu.RLock()
|
||||
defer l.mu.RUnlock()
|
||||
msg := fmt.Sprintf(format, args...)
|
||||
if len(l.fields) > 0 {
|
||||
msg = l.formatWithFields(msg)
|
||||
}
|
||||
l.logInfo.Print(msg)
|
||||
}
|
||||
}
|
||||
|
||||
// Error logs an error message
|
||||
func (l *StandardLogger) Error(msg string) {
|
||||
if l.level <= LogLevelError {
|
||||
l.mu.RLock()
|
||||
defer l.mu.RUnlock()
|
||||
if len(l.fields) > 0 {
|
||||
msg = l.formatWithFields(msg)
|
||||
}
|
||||
l.logError.Print(msg)
|
||||
}
|
||||
}
|
||||
|
||||
// Errorf logs a formatted error message
|
||||
func (l *StandardLogger) Errorf(format string, args ...interface{}) {
|
||||
if l.level <= LogLevelError {
|
||||
l.mu.RLock()
|
||||
defer l.mu.RUnlock()
|
||||
msg := fmt.Sprintf(format, args...)
|
||||
if len(l.fields) > 0 {
|
||||
msg = l.formatWithFields(msg)
|
||||
}
|
||||
l.logError.Print(msg)
|
||||
}
|
||||
}
|
||||
|
||||
// Printf logs a formatted message at info level
|
||||
func (l *StandardLogger) Printf(format string, args ...interface{}) {
|
||||
l.Infof(format, args...)
|
||||
}
|
||||
|
||||
// Println logs a message at info level
|
||||
func (l *StandardLogger) Println(args ...interface{}) {
|
||||
l.Info(fmt.Sprint(args...))
|
||||
}
|
||||
|
||||
// Fatalf logs a formatted error message and exits the program
|
||||
func (l *StandardLogger) Fatalf(format string, args ...interface{}) {
|
||||
l.Errorf(format, args...)
|
||||
panic(fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
// WithField returns a new logger with an additional field
|
||||
func (l *StandardLogger) WithField(key string, value interface{}) Logger {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
newLogger := &StandardLogger{
|
||||
logError: l.logError,
|
||||
logInfo: l.logInfo,
|
||||
logDebug: l.logDebug,
|
||||
fields: make(map[string]interface{}, len(l.fields)+1),
|
||||
level: l.level,
|
||||
}
|
||||
|
||||
for k, v := range l.fields {
|
||||
newLogger.fields[k] = v
|
||||
}
|
||||
newLogger.fields[key] = value
|
||||
|
||||
return newLogger
|
||||
}
|
||||
|
||||
// WithFields returns a new logger with additional fields
|
||||
func (l *StandardLogger) WithFields(fields map[string]interface{}) Logger {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
newLogger := &StandardLogger{
|
||||
logError: l.logError,
|
||||
logInfo: l.logInfo,
|
||||
logDebug: l.logDebug,
|
||||
fields: make(map[string]interface{}, len(l.fields)+len(fields)),
|
||||
level: l.level,
|
||||
}
|
||||
|
||||
for k, v := range l.fields {
|
||||
newLogger.fields[k] = v
|
||||
}
|
||||
for k, v := range fields {
|
||||
newLogger.fields[k] = v
|
||||
}
|
||||
|
||||
return newLogger
|
||||
}
|
||||
|
||||
// formatWithFields formats a message with structured fields
|
||||
func (l *StandardLogger) formatWithFields(msg string) string {
|
||||
if len(l.fields) == 0 {
|
||||
return msg
|
||||
}
|
||||
|
||||
fieldsStr := ""
|
||||
for k, v := range l.fields {
|
||||
if fieldsStr != "" {
|
||||
fieldsStr += " "
|
||||
}
|
||||
fieldsStr += fmt.Sprintf("%s=%v", k, v)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s [%s]", msg, fieldsStr)
|
||||
}
|
||||
|
||||
// NoOpLogger is a logger that discards all output.
|
||||
// It's useful for testing and for cases where logging should be disabled.
|
||||
type NoOpLogger struct{}
|
||||
|
||||
// Debug discards the message
|
||||
func (n *NoOpLogger) Debug(msg string) {}
|
||||
|
||||
// Debugf discards the formatted message
|
||||
func (n *NoOpLogger) Debugf(format string, args ...interface{}) {}
|
||||
|
||||
// Info discards the message
|
||||
func (n *NoOpLogger) Info(msg string) {}
|
||||
|
||||
// Infof discards the formatted message
|
||||
func (n *NoOpLogger) Infof(format string, args ...interface{}) {}
|
||||
|
||||
// Error discards the message
|
||||
func (n *NoOpLogger) Error(msg string) {}
|
||||
|
||||
// Errorf discards the formatted message
|
||||
func (n *NoOpLogger) Errorf(format string, args ...interface{}) {}
|
||||
|
||||
// Printf discards the formatted message
|
||||
func (n *NoOpLogger) Printf(format string, args ...interface{}) {}
|
||||
|
||||
// Println discards the message
|
||||
func (n *NoOpLogger) Println(args ...interface{}) {}
|
||||
|
||||
// Fatalf discards the message and does not exit
|
||||
func (n *NoOpLogger) Fatalf(format string, args ...interface{}) {}
|
||||
|
||||
// WithField returns the same NoOpLogger
|
||||
func (n *NoOpLogger) WithField(key string, value interface{}) Logger {
|
||||
return n
|
||||
}
|
||||
|
||||
// WithFields returns the same NoOpLogger
|
||||
func (n *NoOpLogger) WithFields(fields map[string]interface{}) Logger {
|
||||
return n
|
||||
}
|
||||
|
||||
var (
|
||||
// singletonNoOpLogger is the global instance of the no-op logger
|
||||
singletonNoOpLogger *NoOpLogger
|
||||
// noOpLoggerOnce ensures the singleton is created only once
|
||||
noOpLoggerOnce sync.Once
|
||||
)
|
||||
|
||||
// GetNoOpLogger returns the singleton no-op logger instance.
|
||||
// This reduces memory allocation by reusing the same no-op logger
|
||||
// instance across the entire application.
|
||||
func GetNoOpLogger() Logger {
|
||||
noOpLoggerOnce.Do(func() {
|
||||
singletonNoOpLogger = &NoOpLogger{}
|
||||
})
|
||||
return singletonNoOpLogger
|
||||
}
|
||||
|
||||
// DefaultLogger creates a default logger based on the provided configuration
|
||||
func DefaultLogger(level string) Logger {
|
||||
return NewStandardLogger(level, log.Writer(), log.Writer(), log.Writer())
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,122 +0,0 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// RequestContext holds request processing context
|
||||
type RequestContext struct {
|
||||
Writer http.ResponseWriter
|
||||
Request *http.Request
|
||||
RedirectURL string
|
||||
Scheme string
|
||||
Host string
|
||||
}
|
||||
|
||||
// RequestProcessor handles common request processing operations
|
||||
type RequestProcessor struct {
|
||||
logger Logger
|
||||
}
|
||||
|
||||
// Logger interface for logging operations
|
||||
type Logger interface {
|
||||
Debug(msg string)
|
||||
Debugf(format string, args ...interface{})
|
||||
Error(msg string)
|
||||
Errorf(format string, args ...interface{})
|
||||
Info(msg string)
|
||||
Infof(format string, args ...interface{})
|
||||
}
|
||||
|
||||
// NewRequestProcessor creates a new request processor
|
||||
func NewRequestProcessor(logger Logger) *RequestProcessor {
|
||||
return &RequestProcessor{
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// BuildRequestContext creates a request context with scheme and host detection
|
||||
func (rp *RequestProcessor) BuildRequestContext(rw http.ResponseWriter, req *http.Request, redirectPath string) *RequestContext {
|
||||
scheme := rp.determineScheme(req)
|
||||
host := rp.determineHost(req)
|
||||
redirectURL := buildFullURL(scheme, host, redirectPath)
|
||||
|
||||
return &RequestContext{
|
||||
Writer: rw,
|
||||
Request: req,
|
||||
RedirectURL: redirectURL,
|
||||
Scheme: scheme,
|
||||
Host: host,
|
||||
}
|
||||
}
|
||||
|
||||
// IsHealthCheckRequest checks if request is a health check
|
||||
func (rp *RequestProcessor) IsHealthCheckRequest(req *http.Request) bool {
|
||||
return strings.HasPrefix(req.URL.Path, "/health")
|
||||
}
|
||||
|
||||
// IsEventStreamRequest checks if request expects event stream
|
||||
func (rp *RequestProcessor) IsEventStreamRequest(req *http.Request) bool {
|
||||
acceptHeader := req.Header.Get("Accept")
|
||||
return strings.Contains(acceptHeader, "text/event-stream")
|
||||
}
|
||||
|
||||
// IsAjaxRequest determines if this is an AJAX request
|
||||
func (rp *RequestProcessor) IsAjaxRequest(req *http.Request) bool {
|
||||
xhr := req.Header.Get("X-Requested-With")
|
||||
contentType := req.Header.Get("Content-Type")
|
||||
accept := req.Header.Get("Accept")
|
||||
|
||||
return xhr == "XMLHttpRequest" ||
|
||||
strings.Contains(contentType, "application/json") ||
|
||||
strings.Contains(accept, "application/json")
|
||||
}
|
||||
|
||||
// WaitForInitialization waits for OIDC provider initialization with timeout
|
||||
func (rp *RequestProcessor) WaitForInitialization(req *http.Request, initComplete <-chan struct{}) error {
|
||||
select {
|
||||
case <-initComplete:
|
||||
return nil
|
||||
case <-req.Context().Done():
|
||||
rp.logger.Debug("Request canceled while waiting for OIDC initialization")
|
||||
return fmt.Errorf("request canceled")
|
||||
case <-time.After(30 * time.Second):
|
||||
rp.logger.Error("Timeout waiting for OIDC initialization")
|
||||
return fmt.Errorf("timeout waiting for OIDC provider initialization")
|
||||
}
|
||||
}
|
||||
|
||||
// determineScheme determines the URL scheme for building redirect URLs
|
||||
func (rp *RequestProcessor) determineScheme(req *http.Request) string {
|
||||
if scheme := req.Header.Get("X-Forwarded-Proto"); scheme != "" {
|
||||
return scheme
|
||||
}
|
||||
if req.TLS != nil {
|
||||
return "https"
|
||||
}
|
||||
return "http"
|
||||
}
|
||||
|
||||
// determineHost determines the host for building redirect URLs
|
||||
func (rp *RequestProcessor) determineHost(req *http.Request) string {
|
||||
if host := req.Header.Get("X-Forwarded-Host"); host != "" {
|
||||
return host
|
||||
}
|
||||
return req.Host
|
||||
}
|
||||
|
||||
// buildFullURL constructs a complete URL from scheme, host, and path components
|
||||
func buildFullURL(scheme, host, path string) string {
|
||||
if strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://") {
|
||||
return path
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(path, "/") {
|
||||
path = "/" + path
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s://%s%s", scheme, host, path)
|
||||
}
|
||||
@@ -1,655 +0,0 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// MockLogger implements the Logger interface for testing
|
||||
type MockLogger struct {
|
||||
DebugCalls []string
|
||||
DebugfCalls []string
|
||||
ErrorCalls []string
|
||||
ErrorfCalls []string
|
||||
InfoCalls []string
|
||||
InfofCalls []string
|
||||
}
|
||||
|
||||
func (m *MockLogger) Debug(msg string) {
|
||||
m.DebugCalls = append(m.DebugCalls, msg)
|
||||
}
|
||||
|
||||
func (m *MockLogger) Debugf(format string, args ...interface{}) {
|
||||
m.DebugfCalls = append(m.DebugfCalls, format)
|
||||
}
|
||||
|
||||
func (m *MockLogger) Error(msg string) {
|
||||
m.ErrorCalls = append(m.ErrorCalls, msg)
|
||||
}
|
||||
|
||||
func (m *MockLogger) Errorf(format string, args ...interface{}) {
|
||||
m.ErrorfCalls = append(m.ErrorfCalls, format)
|
||||
}
|
||||
|
||||
func (m *MockLogger) Info(msg string) {
|
||||
m.InfoCalls = append(m.InfoCalls, msg)
|
||||
}
|
||||
|
||||
func (m *MockLogger) Infof(format string, args ...interface{}) {
|
||||
m.InfofCalls = append(m.InfofCalls, format)
|
||||
}
|
||||
|
||||
// TestNewRequestProcessor tests the constructor
|
||||
func TestNewRequestProcessor(t *testing.T) {
|
||||
logger := &MockLogger{}
|
||||
processor := NewRequestProcessor(logger)
|
||||
|
||||
if processor == nil {
|
||||
t.Error("Expected NewRequestProcessor to return non-nil processor")
|
||||
return
|
||||
}
|
||||
|
||||
if processor.logger != logger {
|
||||
t.Error("Expected processor to use provided logger")
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildRequestContext tests request context building
|
||||
func TestBuildRequestContext(t *testing.T) {
|
||||
logger := &MockLogger{}
|
||||
processor := NewRequestProcessor(logger)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupRequest func() (*http.Request, http.ResponseWriter)
|
||||
redirectPath string
|
||||
expectedURL string
|
||||
expectedHost string
|
||||
}{
|
||||
{
|
||||
name: "Basic HTTP request",
|
||||
setupRequest: func() (*http.Request, http.ResponseWriter) {
|
||||
req := httptest.NewRequest("GET", "http://example.com/test", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
return req, rw
|
||||
},
|
||||
redirectPath: "/callback",
|
||||
expectedURL: "http://example.com/callback",
|
||||
expectedHost: "example.com",
|
||||
},
|
||||
{
|
||||
name: "HTTPS request with TLS",
|
||||
setupRequest: func() (*http.Request, http.ResponseWriter) {
|
||||
req := httptest.NewRequest("GET", "https://secure.com/test", nil)
|
||||
req.TLS = &tls.ConnectionState{} // Simulate TLS
|
||||
rw := httptest.NewRecorder()
|
||||
return req, rw
|
||||
},
|
||||
redirectPath: "/auth",
|
||||
expectedURL: "https://secure.com/auth",
|
||||
expectedHost: "secure.com",
|
||||
},
|
||||
{
|
||||
name: "Request with X-Forwarded-Proto header",
|
||||
setupRequest: func() (*http.Request, http.ResponseWriter) {
|
||||
req := httptest.NewRequest("GET", "http://internal.com/test", nil)
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
rw := httptest.NewRecorder()
|
||||
return req, rw
|
||||
},
|
||||
redirectPath: "/callback",
|
||||
expectedURL: "https://internal.com/callback",
|
||||
expectedHost: "internal.com",
|
||||
},
|
||||
{
|
||||
name: "Request with X-Forwarded-Host header",
|
||||
setupRequest: func() (*http.Request, http.ResponseWriter) {
|
||||
req := httptest.NewRequest("GET", "http://internal.com/test", nil)
|
||||
req.Header.Set("X-Forwarded-Host", "public.com")
|
||||
rw := httptest.NewRecorder()
|
||||
return req, rw
|
||||
},
|
||||
redirectPath: "/callback",
|
||||
expectedURL: "http://public.com/callback",
|
||||
expectedHost: "public.com",
|
||||
},
|
||||
{
|
||||
name: "Request with both forwarded headers",
|
||||
setupRequest: func() (*http.Request, http.ResponseWriter) {
|
||||
req := httptest.NewRequest("GET", "http://internal.com/test", nil)
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
req.Header.Set("X-Forwarded-Host", "public.com")
|
||||
rw := httptest.NewRecorder()
|
||||
return req, rw
|
||||
},
|
||||
redirectPath: "/auth",
|
||||
expectedURL: "https://public.com/auth",
|
||||
expectedHost: "public.com",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req, rw := tt.setupRequest()
|
||||
ctx := processor.BuildRequestContext(rw, req, tt.redirectPath)
|
||||
|
||||
if ctx == nil {
|
||||
t.Error("Expected BuildRequestContext to return non-nil context")
|
||||
return
|
||||
}
|
||||
|
||||
if ctx.Writer != rw {
|
||||
t.Error("Expected context writer to match provided writer")
|
||||
}
|
||||
|
||||
if ctx.Request != req {
|
||||
t.Error("Expected context request to match provided request")
|
||||
}
|
||||
|
||||
if ctx.RedirectURL != tt.expectedURL {
|
||||
t.Errorf("Expected redirect URL '%s', got '%s'", tt.expectedURL, ctx.RedirectURL)
|
||||
}
|
||||
|
||||
if ctx.Host != tt.expectedHost {
|
||||
t.Errorf("Expected host '%s', got '%s'", tt.expectedHost, ctx.Host)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsHealthCheckRequest tests health check detection
|
||||
func TestIsHealthCheckRequest(t *testing.T) {
|
||||
logger := &MockLogger{}
|
||||
processor := NewRequestProcessor(logger)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "Health check path",
|
||||
path: "/health",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Health check subpath",
|
||||
path: "/health/status",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Health check with query params",
|
||||
path: "/health?check=db",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Not a health check",
|
||||
path: "/api/users",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Health-related path (matches prefix)",
|
||||
path: "/healthiness",
|
||||
expected: true, // HasPrefix behavior - this actually matches
|
||||
},
|
||||
{
|
||||
name: "Root path",
|
||||
path: "/",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "http://example.com"+tt.path, nil)
|
||||
result := processor.IsHealthCheckRequest(req)
|
||||
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected IsHealthCheckRequest to return %v for path '%s', got %v", tt.expected, tt.path, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsEventStreamRequest tests event stream detection
|
||||
func TestIsEventStreamRequest(t *testing.T) {
|
||||
logger := &MockLogger{}
|
||||
processor := NewRequestProcessor(logger)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
acceptHeader string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "Event stream accept header",
|
||||
acceptHeader: "text/event-stream",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Event stream with other types",
|
||||
acceptHeader: "text/html, text/event-stream, application/json",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "JSON accept header",
|
||||
acceptHeader: "application/json",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "HTML accept header",
|
||||
acceptHeader: "text/html,application/xhtml+xml",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Empty accept header",
|
||||
acceptHeader: "",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Similar but not event stream",
|
||||
acceptHeader: "text/event-source",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "http://example.com/test", nil)
|
||||
if tt.acceptHeader != "" {
|
||||
req.Header.Set("Accept", tt.acceptHeader)
|
||||
}
|
||||
|
||||
result := processor.IsEventStreamRequest(req)
|
||||
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected IsEventStreamRequest to return %v for accept header '%s', got %v", tt.expected, tt.acceptHeader, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsAjaxRequest tests AJAX request detection
|
||||
func TestIsAjaxRequest(t *testing.T) {
|
||||
logger := &MockLogger{}
|
||||
processor := NewRequestProcessor(logger)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupHeader func(*http.Request)
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "XMLHttpRequest header",
|
||||
setupHeader: func(req *http.Request) {
|
||||
req.Header.Set("X-Requested-With", "XMLHttpRequest")
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "JSON content type",
|
||||
setupHeader: func(req *http.Request) {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "JSON content type with charset",
|
||||
setupHeader: func(req *http.Request) {
|
||||
req.Header.Set("Content-Type", "application/json; charset=utf-8")
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "JSON accept header",
|
||||
setupHeader: func(req *http.Request) {
|
||||
req.Header.Set("Accept", "application/json")
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "JSON accept with other types",
|
||||
setupHeader: func(req *http.Request) {
|
||||
req.Header.Set("Accept", "text/html, application/json, application/xml")
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Multiple AJAX indicators",
|
||||
setupHeader: func(req *http.Request) {
|
||||
req.Header.Set("X-Requested-With", "XMLHttpRequest")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Regular HTML request",
|
||||
setupHeader: func(req *http.Request) {
|
||||
req.Header.Set("Accept", "text/html,application/xhtml+xml")
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Form submission",
|
||||
setupHeader: func(req *http.Request) {
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "No special headers",
|
||||
setupHeader: func(req *http.Request) {},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("POST", "http://example.com/api", nil)
|
||||
tt.setupHeader(req)
|
||||
|
||||
result := processor.IsAjaxRequest(req)
|
||||
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected IsAjaxRequest to return %v, got %v", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestWaitForInitialization tests initialization waiting
|
||||
func TestWaitForInitialization(t *testing.T) {
|
||||
logger := &MockLogger{}
|
||||
processor := NewRequestProcessor(logger)
|
||||
|
||||
t.Run("Initialization completes successfully", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "http://example.com/test", nil)
|
||||
initComplete := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
close(initComplete)
|
||||
}()
|
||||
|
||||
err := processor.WaitForInitialization(req, initComplete)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error when initialization completes, got: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Request context canceled", func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
req := httptest.NewRequest("GET", "http://example.com/test", nil)
|
||||
req = req.WithContext(ctx)
|
||||
initComplete := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
cancel()
|
||||
}()
|
||||
|
||||
err := processor.WaitForInitialization(req, initComplete)
|
||||
if err == nil {
|
||||
t.Error("Expected error when request context is canceled")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "request canceled") {
|
||||
t.Errorf("Expected 'request canceled' error, got: %v", err)
|
||||
}
|
||||
|
||||
if len(logger.DebugCalls) == 0 {
|
||||
t.Error("Expected debug log when request is canceled")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Initialization timeout", func(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping timeout test in short mode")
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com/test", nil)
|
||||
initComplete := make(chan struct{}) // Never closes
|
||||
|
||||
// Note: This test takes 30 seconds due to hardcoded timeout in implementation
|
||||
start := time.Now()
|
||||
err := processor.WaitForInitialization(req, initComplete)
|
||||
duration := time.Since(start)
|
||||
|
||||
if err == nil {
|
||||
t.Error("Expected timeout error")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "timeout") {
|
||||
t.Errorf("Expected timeout error, got: %v", err)
|
||||
}
|
||||
|
||||
// The timeout should be around 30 seconds, allow some variance
|
||||
if duration < 29*time.Second || duration > 31*time.Second {
|
||||
t.Errorf("Expected timeout after ~30 seconds, but got %v", duration)
|
||||
}
|
||||
|
||||
if len(logger.ErrorCalls) == 0 {
|
||||
t.Error("Expected error log when timeout occurs")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestDetermineScheme tests scheme determination
|
||||
func TestDetermineScheme(t *testing.T) {
|
||||
logger := &MockLogger{}
|
||||
processor := NewRequestProcessor(logger)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func(*http.Request)
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "X-Forwarded-Proto HTTPS",
|
||||
setup: func(req *http.Request) {
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
},
|
||||
expected: "https",
|
||||
},
|
||||
{
|
||||
name: "X-Forwarded-Proto HTTP",
|
||||
setup: func(req *http.Request) {
|
||||
req.Header.Set("X-Forwarded-Proto", "http")
|
||||
},
|
||||
expected: "http",
|
||||
},
|
||||
{
|
||||
name: "TLS connection without header",
|
||||
setup: func(req *http.Request) {
|
||||
req.TLS = &tls.ConnectionState{}
|
||||
},
|
||||
expected: "https",
|
||||
},
|
||||
{
|
||||
name: "No TLS, no header",
|
||||
setup: func(req *http.Request) {
|
||||
// No special setup
|
||||
},
|
||||
expected: "http",
|
||||
},
|
||||
{
|
||||
name: "X-Forwarded-Proto takes precedence over TLS",
|
||||
setup: func(req *http.Request) {
|
||||
req.Header.Set("X-Forwarded-Proto", "http")
|
||||
req.TLS = &tls.ConnectionState{}
|
||||
},
|
||||
expected: "http",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "http://example.com/test", nil)
|
||||
tt.setup(req)
|
||||
|
||||
result := processor.determineScheme(req)
|
||||
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected scheme '%s', got '%s'", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDetermineHost tests host determination
|
||||
func TestDetermineHost(t *testing.T) {
|
||||
logger := &MockLogger{}
|
||||
processor := NewRequestProcessor(logger)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func(*http.Request)
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "X-Forwarded-Host header present",
|
||||
setup: func(req *http.Request) {
|
||||
req.Header.Set("X-Forwarded-Host", "public.example.com")
|
||||
},
|
||||
expected: "public.example.com",
|
||||
},
|
||||
{
|
||||
name: "No X-Forwarded-Host, use req.Host",
|
||||
setup: func(req *http.Request) {
|
||||
// No special setup, will use req.Host
|
||||
},
|
||||
expected: "example.com",
|
||||
},
|
||||
{
|
||||
name: "Empty X-Forwarded-Host, fallback to req.Host",
|
||||
setup: func(req *http.Request) {
|
||||
req.Header.Set("X-Forwarded-Host", "")
|
||||
},
|
||||
expected: "example.com",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "http://example.com/test", nil)
|
||||
tt.setup(req)
|
||||
|
||||
result := processor.determineHost(req)
|
||||
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected host '%s', got '%s'", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildFullURL tests URL building
|
||||
func TestBuildFullURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
scheme string
|
||||
host string
|
||||
path string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Basic URL construction",
|
||||
scheme: "https",
|
||||
host: "example.com",
|
||||
path: "/callback",
|
||||
expected: "https://example.com/callback",
|
||||
},
|
||||
{
|
||||
name: "Path without leading slash",
|
||||
scheme: "http",
|
||||
host: "test.com",
|
||||
path: "auth",
|
||||
expected: "http://test.com/auth",
|
||||
},
|
||||
{
|
||||
name: "Absolute HTTP URL in path",
|
||||
scheme: "https",
|
||||
host: "example.com",
|
||||
path: "http://other.com/callback",
|
||||
expected: "http://other.com/callback",
|
||||
},
|
||||
{
|
||||
name: "Absolute HTTPS URL in path",
|
||||
scheme: "http",
|
||||
host: "example.com",
|
||||
path: "https://secure.com/auth",
|
||||
expected: "https://secure.com/auth",
|
||||
},
|
||||
{
|
||||
name: "Root path",
|
||||
scheme: "https",
|
||||
host: "example.com:8080",
|
||||
path: "/",
|
||||
expected: "https://example.com:8080/",
|
||||
},
|
||||
{
|
||||
name: "Empty path",
|
||||
scheme: "https",
|
||||
host: "example.com",
|
||||
path: "",
|
||||
expected: "https://example.com/",
|
||||
},
|
||||
{
|
||||
name: "Path with query parameters",
|
||||
scheme: "https",
|
||||
host: "example.com",
|
||||
path: "/callback?state=abc123",
|
||||
expected: "https://example.com/callback?state=abc123",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := buildFullURL(tt.scheme, tt.host, tt.path)
|
||||
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected URL '%s', got '%s'", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRequestContext tests the RequestContext struct
|
||||
func TestRequestContext(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "http://example.com/test", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
ctx := &RequestContext{
|
||||
Writer: rw,
|
||||
Request: req,
|
||||
RedirectURL: "https://example.com/callback",
|
||||
Scheme: "https",
|
||||
Host: "example.com",
|
||||
}
|
||||
|
||||
if ctx.Writer != rw {
|
||||
t.Error("Expected Writer to be set correctly")
|
||||
}
|
||||
|
||||
if ctx.Request != req {
|
||||
t.Error("Expected Request to be set correctly")
|
||||
}
|
||||
|
||||
if ctx.RedirectURL != "https://example.com/callback" {
|
||||
t.Error("Expected RedirectURL to be set correctly")
|
||||
}
|
||||
|
||||
if ctx.Scheme != "https" {
|
||||
t.Error("Expected Scheme to be set correctly")
|
||||
}
|
||||
|
||||
if ctx.Host != "example.com" {
|
||||
t.Error("Expected Host to be set correctly")
|
||||
}
|
||||
}
|
||||
@@ -1,309 +0,0 @@
|
||||
// Package patterns provides cached compiled regex patterns for performance optimization
|
||||
package patterns
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// RegexCache manages compiled regex patterns with thread-safe access
|
||||
type RegexCache struct {
|
||||
patterns map[string]*regexp.Regexp
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewRegexCache creates a new regex cache instance
|
||||
func NewRegexCache() *RegexCache {
|
||||
return &RegexCache{
|
||||
patterns: make(map[string]*regexp.Regexp),
|
||||
}
|
||||
}
|
||||
|
||||
// Get retrieves a compiled regex pattern, compiling and caching it if not present
|
||||
func (c *RegexCache) Get(pattern string) (*regexp.Regexp, error) {
|
||||
// First try read lock for existing pattern
|
||||
c.mu.RLock()
|
||||
if regex, exists := c.patterns[pattern]; exists {
|
||||
c.mu.RUnlock()
|
||||
return regex, nil
|
||||
}
|
||||
c.mu.RUnlock()
|
||||
|
||||
// Pattern not found, acquire write lock to compile and cache
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
// Double-check in case another goroutine compiled it while we waited
|
||||
if regex, exists := c.patterns[pattern]; exists {
|
||||
return regex, nil
|
||||
}
|
||||
|
||||
// Compile the pattern
|
||||
regex, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Cache the compiled pattern
|
||||
c.patterns[pattern] = regex
|
||||
return regex, nil
|
||||
}
|
||||
|
||||
// MustGet is like Get but panics if the pattern cannot be compiled
|
||||
func (c *RegexCache) MustGet(pattern string) *regexp.Regexp {
|
||||
regex, err := c.Get(pattern)
|
||||
if err != nil {
|
||||
panic("regex compilation failed for pattern '" + pattern + "': " + err.Error())
|
||||
}
|
||||
return regex
|
||||
}
|
||||
|
||||
// Precompile compiles and caches multiple patterns at once
|
||||
func (c *RegexCache) Precompile(patterns []string) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
for _, pattern := range patterns {
|
||||
if _, exists := c.patterns[pattern]; !exists {
|
||||
regex, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.patterns[pattern] = regex
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Size returns the number of cached patterns
|
||||
func (c *RegexCache) Size() int {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return len(c.patterns)
|
||||
}
|
||||
|
||||
// Clear removes all cached patterns
|
||||
func (c *RegexCache) Clear() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.patterns = make(map[string]*regexp.Regexp)
|
||||
}
|
||||
|
||||
// Global regex cache instance
|
||||
var globalCache = NewRegexCache()
|
||||
|
||||
// Common regex patterns used throughout the OIDC implementation
|
||||
const (
|
||||
// Email validation pattern (RFC 5322 compliant)
|
||||
EmailPattern = `^[a-zA-Z0-9.!#$%&'*+/=?^_` + "`" + `{|}~-]+@[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?(?:\.[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)*$`
|
||||
|
||||
// Domain validation pattern
|
||||
DomainPattern = `^[a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?(\.[a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)*$`
|
||||
|
||||
// URL validation pattern (http/https)
|
||||
URLPattern = `^https?://[a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?(\.[a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)*(/.*)?$`
|
||||
|
||||
// JWT token pattern (three base64url parts separated by dots)
|
||||
JWTPattern = `^[A-Za-z0-9_-]+\.[A-Za-z0-9_-]+\.[A-Za-z0-9_-]+$`
|
||||
|
||||
// Bearer token pattern (Authorization header)
|
||||
BearerTokenPattern = `^Bearer\s+([A-Za-z0-9._~+/-]+=*)$`
|
||||
|
||||
// Client ID pattern (alphanumeric with common separators)
|
||||
ClientIDPattern = `^[a-zA-Z0-9._-]+$`
|
||||
|
||||
// Scope pattern (space-separated alphanumeric with underscores)
|
||||
ScopePattern = `^[a-zA-Z0-9_]+(\s+[a-zA-Z0-9_]+)*$`
|
||||
|
||||
// Session ID pattern (hexadecimal)
|
||||
SessionIDPattern = `^[a-fA-F0-9]{32,128}$`
|
||||
|
||||
// CSRF token pattern (base64url)
|
||||
CSRFTokenPattern = `^[A-Za-z0-9_-]+$`
|
||||
|
||||
// Nonce pattern (base64url)
|
||||
NoncePattern = `^[A-Za-z0-9_-]+$`
|
||||
|
||||
// Code verifier pattern for PKCE (base64url, 43-128 chars)
|
||||
CodeVerifierPattern = `^[A-Za-z0-9_-]{43,128}$`
|
||||
|
||||
// Authorization code pattern (base64url)
|
||||
AuthCodePattern = `^[A-Za-z0-9._~+/-]+=*$`
|
||||
|
||||
// Redirect URI validation (must be absolute HTTP/HTTPS URL)
|
||||
RedirectURIPattern = `^https?://[^\s/$.?#].[^\s]*$`
|
||||
|
||||
// User-Agent pattern for bot detection
|
||||
BotUserAgentPattern = `(?i)(bot|crawler|spider|scraper|curl|wget|python|java|go-http)`
|
||||
|
||||
// IP address pattern (IPv4)
|
||||
IPv4Pattern = `^(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$`
|
||||
|
||||
// Tenant ID pattern (UUID format for Azure, etc.)
|
||||
TenantIDPattern = `^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$`
|
||||
)
|
||||
|
||||
// Precompiled common patterns for immediate use
|
||||
var (
|
||||
EmailRegex *regexp.Regexp
|
||||
DomainRegex *regexp.Regexp
|
||||
URLRegex *regexp.Regexp
|
||||
JWTRegex *regexp.Regexp
|
||||
BearerTokenRegex *regexp.Regexp
|
||||
ClientIDRegex *regexp.Regexp
|
||||
ScopeRegex *regexp.Regexp
|
||||
SessionIDRegex *regexp.Regexp
|
||||
CSRFTokenRegex *regexp.Regexp
|
||||
NonceRegex *regexp.Regexp
|
||||
CodeVerifierRegex *regexp.Regexp
|
||||
AuthCodeRegex *regexp.Regexp
|
||||
RedirectURIRegex *regexp.Regexp
|
||||
BotUserAgentRegex *regexp.Regexp
|
||||
IPv4Regex *regexp.Regexp
|
||||
TenantIDRegex *regexp.Regexp
|
||||
)
|
||||
|
||||
// Initialize precompiled patterns
|
||||
func init() {
|
||||
commonPatterns := []string{
|
||||
EmailPattern,
|
||||
DomainPattern,
|
||||
URLPattern,
|
||||
JWTPattern,
|
||||
BearerTokenPattern,
|
||||
ClientIDPattern,
|
||||
ScopePattern,
|
||||
SessionIDPattern,
|
||||
CSRFTokenPattern,
|
||||
NoncePattern,
|
||||
CodeVerifierPattern,
|
||||
AuthCodePattern,
|
||||
RedirectURIPattern,
|
||||
BotUserAgentPattern,
|
||||
IPv4Pattern,
|
||||
TenantIDPattern,
|
||||
}
|
||||
|
||||
if err := globalCache.Precompile(commonPatterns); err != nil {
|
||||
panic("Failed to precompile common regex patterns: " + err.Error())
|
||||
}
|
||||
|
||||
// Assign precompiled patterns to global variables for easy access
|
||||
EmailRegex = globalCache.MustGet(EmailPattern)
|
||||
DomainRegex = globalCache.MustGet(DomainPattern)
|
||||
URLRegex = globalCache.MustGet(URLPattern)
|
||||
JWTRegex = globalCache.MustGet(JWTPattern)
|
||||
BearerTokenRegex = globalCache.MustGet(BearerTokenPattern)
|
||||
ClientIDRegex = globalCache.MustGet(ClientIDPattern)
|
||||
ScopeRegex = globalCache.MustGet(ScopePattern)
|
||||
SessionIDRegex = globalCache.MustGet(SessionIDPattern)
|
||||
CSRFTokenRegex = globalCache.MustGet(CSRFTokenPattern)
|
||||
NonceRegex = globalCache.MustGet(NoncePattern)
|
||||
CodeVerifierRegex = globalCache.MustGet(CodeVerifierPattern)
|
||||
AuthCodeRegex = globalCache.MustGet(AuthCodePattern)
|
||||
RedirectURIRegex = globalCache.MustGet(RedirectURIPattern)
|
||||
BotUserAgentRegex = globalCache.MustGet(BotUserAgentPattern)
|
||||
IPv4Regex = globalCache.MustGet(IPv4Pattern)
|
||||
TenantIDRegex = globalCache.MustGet(TenantIDPattern)
|
||||
}
|
||||
|
||||
// Global helper functions for common validations
|
||||
|
||||
// ValidateEmail checks if an email address is valid
|
||||
func ValidateEmail(email string) bool {
|
||||
return EmailRegex.MatchString(email)
|
||||
}
|
||||
|
||||
// ValidateDomain checks if a domain name is valid
|
||||
func ValidateDomain(domain string) bool {
|
||||
return DomainRegex.MatchString(domain)
|
||||
}
|
||||
|
||||
// ValidateURL checks if a URL is valid (http/https)
|
||||
func ValidateURL(url string) bool {
|
||||
return URLRegex.MatchString(url)
|
||||
}
|
||||
|
||||
// ValidateJWT checks if a token has valid JWT format
|
||||
func ValidateJWT(token string) bool {
|
||||
return JWTRegex.MatchString(token)
|
||||
}
|
||||
|
||||
// ExtractBearerToken extracts the token from a Bearer authorization header
|
||||
func ExtractBearerToken(authHeader string) (string, bool) {
|
||||
matches := BearerTokenRegex.FindStringSubmatch(authHeader)
|
||||
if len(matches) == 2 {
|
||||
return matches[1], true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
// ValidateClientID checks if a client ID has valid format
|
||||
func ValidateClientID(clientID string) bool {
|
||||
return ClientIDRegex.MatchString(clientID)
|
||||
}
|
||||
|
||||
// ValidateScopes checks if scopes string has valid format
|
||||
func ValidateScopes(scopes string) bool {
|
||||
return ScopeRegex.MatchString(scopes)
|
||||
}
|
||||
|
||||
// ValidateSessionID checks if a session ID has valid format
|
||||
func ValidateSessionID(sessionID string) bool {
|
||||
return SessionIDRegex.MatchString(sessionID)
|
||||
}
|
||||
|
||||
// ValidateCSRFToken checks if a CSRF token has valid format
|
||||
func ValidateCSRFToken(token string) bool {
|
||||
return CSRFTokenRegex.MatchString(token)
|
||||
}
|
||||
|
||||
// ValidateNonce checks if a nonce has valid format
|
||||
func ValidateNonce(nonce string) bool {
|
||||
return NonceRegex.MatchString(nonce)
|
||||
}
|
||||
|
||||
// ValidateCodeVerifier checks if a PKCE code verifier has valid format
|
||||
func ValidateCodeVerifier(verifier string) bool {
|
||||
return CodeVerifierRegex.MatchString(verifier)
|
||||
}
|
||||
|
||||
// ValidateAuthCode checks if an authorization code has valid format
|
||||
func ValidateAuthCode(code string) bool {
|
||||
return AuthCodeRegex.MatchString(code)
|
||||
}
|
||||
|
||||
// ValidateRedirectURI checks if a redirect URI is valid
|
||||
func ValidateRedirectURI(uri string) bool {
|
||||
return RedirectURIRegex.MatchString(uri)
|
||||
}
|
||||
|
||||
// IsBotUserAgent checks if a User-Agent suggests an automated client
|
||||
func IsBotUserAgent(userAgent string) bool {
|
||||
return BotUserAgentRegex.MatchString(userAgent)
|
||||
}
|
||||
|
||||
// ValidateIPv4 checks if an IP address is valid IPv4
|
||||
func ValidateIPv4(ip string) bool {
|
||||
return IPv4Regex.MatchString(ip)
|
||||
}
|
||||
|
||||
// ValidateTenantID checks if a tenant ID has valid UUID format
|
||||
func ValidateTenantID(tenantID string) bool {
|
||||
return TenantIDRegex.MatchString(tenantID)
|
||||
}
|
||||
|
||||
// GetGlobalCache returns the global regex cache instance
|
||||
func GetGlobalCache() *RegexCache {
|
||||
return globalCache
|
||||
}
|
||||
|
||||
// CompilePattern compiles a pattern using the global cache
|
||||
func CompilePattern(pattern string) (*regexp.Regexp, error) {
|
||||
return globalCache.Get(pattern)
|
||||
}
|
||||
|
||||
// MustCompilePattern compiles a pattern using the global cache, panicking on error
|
||||
func MustCompilePattern(pattern string) *regexp.Regexp {
|
||||
return globalCache.MustGet(pattern)
|
||||
}
|
||||
@@ -1,484 +0,0 @@
|
||||
package patterns
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRegexCache_Get(t *testing.T) {
|
||||
cache := NewRegexCache()
|
||||
|
||||
pattern := `^test\d+$`
|
||||
|
||||
// First call should compile and cache
|
||||
regex1, err := cache.Get(pattern)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get regex: %v", err)
|
||||
}
|
||||
|
||||
// Second call should return cached version
|
||||
regex2, err := cache.Get(pattern)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get cached regex: %v", err)
|
||||
}
|
||||
|
||||
// Should be the same instance
|
||||
if regex1 != regex2 {
|
||||
t.Error("Expected same regex instance from cache")
|
||||
}
|
||||
|
||||
// Test the regex works
|
||||
if !regex1.MatchString("test123") {
|
||||
t.Error("Regex should match 'test123'")
|
||||
}
|
||||
|
||||
if regex1.MatchString("test") {
|
||||
t.Error("Regex should not match 'test'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegexCache_ConcurrentAccess(t *testing.T) {
|
||||
cache := NewRegexCache()
|
||||
pattern := `^concurrent\d+$`
|
||||
|
||||
var wg sync.WaitGroup
|
||||
results := make([]*regexp.Regexp, 10)
|
||||
errors := make([]error, 10)
|
||||
|
||||
// Launch multiple goroutines to access the same pattern
|
||||
for i := 0; i < 10; i++ {
|
||||
wg.Add(1)
|
||||
go func(index int) {
|
||||
defer wg.Done()
|
||||
regex, err := cache.Get(pattern)
|
||||
results[index] = regex
|
||||
errors[index] = err
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Check all succeeded
|
||||
for i, err := range errors {
|
||||
if err != nil {
|
||||
t.Fatalf("Goroutine %d failed: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// All should return the same instance
|
||||
first := results[0]
|
||||
for i, regex := range results[1:] {
|
||||
if regex != first {
|
||||
t.Errorf("Goroutine %d got different regex instance", i+1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegexCache_InvalidPattern(t *testing.T) {
|
||||
cache := NewRegexCache()
|
||||
|
||||
_, err := cache.Get(`[invalid`)
|
||||
if err == nil {
|
||||
t.Error("Expected error for invalid regex pattern")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegexCache_Precompile(t *testing.T) {
|
||||
cache := NewRegexCache()
|
||||
|
||||
patterns := []string{
|
||||
`^test1$`,
|
||||
`^test2$`,
|
||||
`^test3$`,
|
||||
}
|
||||
|
||||
err := cache.Precompile(patterns)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to precompile patterns: %v", err)
|
||||
}
|
||||
|
||||
if cache.Size() != 3 {
|
||||
t.Errorf("Expected cache size 3, got %d", cache.Size())
|
||||
}
|
||||
|
||||
// Should be able to get precompiled patterns without error
|
||||
for _, pattern := range patterns {
|
||||
_, err := cache.Get(pattern)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to get precompiled pattern %s: %v", pattern, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidationFunctions(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
function func(string) bool
|
||||
valid []string
|
||||
invalid []string
|
||||
}{
|
||||
{
|
||||
name: "ValidateEmail",
|
||||
function: ValidateEmail,
|
||||
valid: []string{"test@example.com", "user.name@domain.org", "admin+tag@company.co.uk"},
|
||||
invalid: []string{"invalid-email", "@domain.com", "user@", ""},
|
||||
},
|
||||
{
|
||||
name: "ValidateDomain",
|
||||
function: ValidateDomain,
|
||||
valid: []string{"example.com", "sub.domain.org", "test.co.uk"},
|
||||
invalid: []string{"", "invalid..domain", ".example.com", "domain."},
|
||||
},
|
||||
{
|
||||
name: "ValidateJWT",
|
||||
function: ValidateJWT,
|
||||
valid: []string{"eyJ0.eyJ1.sig", "a.b.c"},
|
||||
invalid: []string{"invalid", "a.b", "a.b.c.d", ""},
|
||||
},
|
||||
{
|
||||
name: "ValidateClientID",
|
||||
function: ValidateClientID,
|
||||
valid: []string{"client123", "my-client_id", "123.456"},
|
||||
invalid: []string{"", "client with spaces", "client@invalid"},
|
||||
},
|
||||
{
|
||||
name: "ValidateURL",
|
||||
function: ValidateURL,
|
||||
valid: []string{"https://example.com", "https://sub.domain.org/path", "http://localhost", "https://example.com/path?query=value", "http://192.168.1.1"},
|
||||
invalid: []string{"", "ftp://example.com", "not-a-url", "https://", "example.com", "http://localhost:8080"},
|
||||
},
|
||||
{
|
||||
name: "ValidateScopes",
|
||||
function: ValidateScopes,
|
||||
valid: []string{"openid", "openid profile", "read write admin", "user_info"},
|
||||
invalid: []string{"", "scope-with-dash", "scope@invalid", "scope with.dot", " "},
|
||||
},
|
||||
{
|
||||
name: "ValidateSessionID",
|
||||
function: ValidateSessionID,
|
||||
valid: []string{"a1b2c3d4e5f6789012345678901234567890abcdef", "ABCDEF1234567890abcdef1234567890", "0123456789abcdef0123456789abcdef"},
|
||||
invalid: []string{"", "too-short", "contains-invalid-chars!", "g123456789abcdef0123456789abcdef", "1234567890abcdef1234567890abcde"},
|
||||
},
|
||||
{
|
||||
name: "ValidateCSRFToken",
|
||||
function: ValidateCSRFToken,
|
||||
valid: []string{"abc123", "ABC_123-xyz", "token-value_123", "_valid-token_"},
|
||||
invalid: []string{"", "token with spaces", "token@invalid", "token.with.dots!", "token/with/slash"},
|
||||
},
|
||||
{
|
||||
name: "ValidateNonce",
|
||||
function: ValidateNonce,
|
||||
valid: []string{"abc123", "ABC_123-xyz", "nonce-value_123", "_valid-nonce_"},
|
||||
invalid: []string{"", "nonce with spaces", "nonce@invalid", "nonce.with.dots!", "nonce/with/slash"},
|
||||
},
|
||||
{
|
||||
name: "ValidateCodeVerifier",
|
||||
function: ValidateCodeVerifier,
|
||||
valid: []string{"dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk", "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_-"},
|
||||
invalid: []string{"", "too-short", "short", "verifier with spaces", "verifier@invalid", "a"},
|
||||
},
|
||||
{
|
||||
name: "ValidateAuthCode",
|
||||
function: ValidateAuthCode,
|
||||
valid: []string{"auth_code_123", "ABC.123-xyz/code+value=", "simple-code"},
|
||||
invalid: []string{"", "code with spaces", "code@invalid"},
|
||||
},
|
||||
{
|
||||
name: "ValidateRedirectURI",
|
||||
function: ValidateRedirectURI,
|
||||
valid: []string{"https://example.com/callback", "http://localhost:8080/auth", "https://app.example.org/oauth/callback", "http://127.0.0.1:3000"},
|
||||
invalid: []string{"", "ftp://example.com", "not-a-url", "example.com/callback", "https://"},
|
||||
},
|
||||
{
|
||||
name: "ValidateIPv4",
|
||||
function: ValidateIPv4,
|
||||
valid: []string{"192.168.1.1", "10.0.0.1", "127.0.0.1", "255.255.255.255", "0.0.0.0"},
|
||||
invalid: []string{"", "256.1.1.1", "192.168.1", "192.168.1.1.1", "not-an-ip"},
|
||||
},
|
||||
{
|
||||
name: "ValidateTenantID",
|
||||
function: ValidateTenantID,
|
||||
valid: []string{"12345678-1234-1234-1234-123456789abc", "ABCDEF12-3456-7890-ABCD-EF1234567890"},
|
||||
invalid: []string{"", "not-a-uuid", "12345678-1234-1234-1234", "12345678-1234-1234-1234-123456789abcd", "123456781234123412341234567890ab"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
for _, valid := range tt.valid {
|
||||
if !tt.function(valid) {
|
||||
t.Errorf("%s should be valid: %s", tt.name, valid)
|
||||
}
|
||||
}
|
||||
|
||||
for _, invalid := range tt.invalid {
|
||||
if tt.function(invalid) {
|
||||
t.Errorf("%s should be invalid: %s", tt.name, invalid)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractBearerToken(t *testing.T) {
|
||||
tests := []struct {
|
||||
header string
|
||||
expected string
|
||||
valid bool
|
||||
}{
|
||||
{"Bearer abc123", "abc123", true},
|
||||
{"Bearer eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9", "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9", true},
|
||||
{"bearer token123", "", false}, // case sensitive
|
||||
{"Basic abc123", "", false},
|
||||
{"Bearer", "", false},
|
||||
{"", "", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
token, valid := ExtractBearerToken(tt.header)
|
||||
if valid != tt.valid {
|
||||
t.Errorf("ExtractBearerToken(%q) valid = %v, want %v", tt.header, valid, tt.valid)
|
||||
}
|
||||
if token != tt.expected {
|
||||
t.Errorf("ExtractBearerToken(%q) token = %q, want %q", tt.header, token, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkRegexCache_Get(b *testing.B) {
|
||||
cache := NewRegexCache()
|
||||
pattern := `^benchmark\d+$`
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
_, err := cache.Get(pattern)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkRegexCache_Validation(b *testing.B) {
|
||||
email := "test@example.com"
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
ValidateEmail(email)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkRegex_DirectCompile(b *testing.B) {
|
||||
pattern := `^benchmark\d+$`
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegexCache_Clear(t *testing.T) {
|
||||
cache := NewRegexCache()
|
||||
|
||||
// Add some patterns to the cache
|
||||
patterns := []string{`^test1$`, `^test2$`, `^test3$`}
|
||||
for _, pattern := range patterns {
|
||||
_, err := cache.Get(pattern)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add pattern %s: %v", pattern, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify cache has patterns
|
||||
if cache.Size() != 3 {
|
||||
t.Errorf("Expected cache size 3, got %d", cache.Size())
|
||||
}
|
||||
|
||||
// Clear the cache
|
||||
cache.Clear()
|
||||
|
||||
// Verify cache is empty
|
||||
if cache.Size() != 0 {
|
||||
t.Errorf("Expected cache size 0 after clear, got %d", cache.Size())
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsBotUserAgent(t *testing.T) {
|
||||
tests := []struct {
|
||||
userAgent string
|
||||
isBot bool
|
||||
}{
|
||||
{"Mozilla/5.0 (compatible; Googlebot/2.1; +http://www.google.com/bot.html)", true},
|
||||
{"Mozilla/5.0 (compatible; bingbot/2.0; +http://www.bing.com/bingbot.htm)", true},
|
||||
{"facebookexternalhit/1.1 (+http://www.facebook.com/externalhit_uatext.php)", false},
|
||||
{"crawler-bot/1.0", true},
|
||||
{"spider-agent/2.0", true},
|
||||
{"curl/7.68.0", true},
|
||||
{"wget/1.20.3", true},
|
||||
{"python-requests/2.25.1", true},
|
||||
{"Go-http-client/1.1", true},
|
||||
{"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36", false},
|
||||
{"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36", false},
|
||||
{"", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.userAgent, func(t *testing.T) {
|
||||
result := IsBotUserAgent(tt.userAgent)
|
||||
if result != tt.isBot {
|
||||
t.Errorf("IsBotUserAgent(%q) = %v, want %v", tt.userAgent, result, tt.isBot)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetGlobalCache(t *testing.T) {
|
||||
cache := GetGlobalCache()
|
||||
if cache == nil {
|
||||
t.Error("GetGlobalCache() should not return nil")
|
||||
}
|
||||
|
||||
// Should return the same instance
|
||||
cache2 := GetGlobalCache()
|
||||
if cache != cache2 {
|
||||
t.Error("GetGlobalCache() should return the same instance")
|
||||
}
|
||||
|
||||
// Should have precompiled patterns
|
||||
if cache.Size() == 0 {
|
||||
t.Error("Global cache should have precompiled patterns")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompilePattern(t *testing.T) {
|
||||
pattern := `^test_compile\d+$`
|
||||
|
||||
regex, err := CompilePattern(pattern)
|
||||
if err != nil {
|
||||
t.Fatalf("CompilePattern failed: %v", err)
|
||||
}
|
||||
|
||||
if !regex.MatchString("test_compile123") {
|
||||
t.Error("Compiled pattern should match 'test_compile123'")
|
||||
}
|
||||
|
||||
if regex.MatchString("test_compile") {
|
||||
t.Error("Compiled pattern should not match 'test_compile'")
|
||||
}
|
||||
|
||||
// Test invalid pattern
|
||||
_, err = CompilePattern(`[invalid`)
|
||||
if err == nil {
|
||||
t.Error("Expected error for invalid pattern")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMustCompilePattern(t *testing.T) {
|
||||
pattern := `^test_must_compile\d+$`
|
||||
|
||||
regex := MustCompilePattern(pattern)
|
||||
if regex == nil {
|
||||
t.Fatal("MustCompilePattern should not return nil")
|
||||
}
|
||||
|
||||
if !regex.MatchString("test_must_compile456") {
|
||||
t.Error("Compiled pattern should match 'test_must_compile456'")
|
||||
}
|
||||
|
||||
// Test that it panics with invalid pattern
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Error("MustCompilePattern should panic with invalid pattern")
|
||||
}
|
||||
}()
|
||||
MustCompilePattern(`[invalid`)
|
||||
}
|
||||
|
||||
func TestAdditionalValidationEdgeCases(t *testing.T) {
|
||||
// Test edge cases for ValidateURL
|
||||
t.Run("ValidateURL_EdgeCases", func(t *testing.T) {
|
||||
edgeCases := []struct {
|
||||
url string
|
||||
valid bool
|
||||
}{
|
||||
{"https://a.b", true},
|
||||
{"http://localhost", true},
|
||||
{"https://example.com/path?query=value#fragment", true},
|
||||
{"http://192.168.0.1:8080/api", false},
|
||||
{"https://", false},
|
||||
{"http://", false},
|
||||
{"https://example", true},
|
||||
}
|
||||
|
||||
for _, tc := range edgeCases {
|
||||
result := ValidateURL(tc.url)
|
||||
if result != tc.valid {
|
||||
t.Errorf("ValidateURL(%q) = %v, want %v", tc.url, result, tc.valid)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Test edge cases for ValidateScopes
|
||||
t.Run("ValidateScopes_EdgeCases", func(t *testing.T) {
|
||||
edgeCases := []struct {
|
||||
scopes string
|
||||
valid bool
|
||||
}{
|
||||
{"a", true},
|
||||
{"a b", true},
|
||||
{"openid profile email", true},
|
||||
{"user_profile", true},
|
||||
{"read_all write_all", true},
|
||||
{"scope-with-dash", false},
|
||||
{"scope.with.dot", false},
|
||||
{"scope@email", false},
|
||||
{" scope", false},
|
||||
{"scope ", false},
|
||||
{"a b", true}, // pattern allows multiple spaces
|
||||
}
|
||||
|
||||
for _, tc := range edgeCases {
|
||||
result := ValidateScopes(tc.scopes)
|
||||
if result != tc.valid {
|
||||
t.Errorf("ValidateScopes(%q) = %v, want %v", tc.scopes, result, tc.valid)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Test edge cases for ValidateSessionID
|
||||
t.Run("ValidateSessionID_EdgeCases", func(t *testing.T) {
|
||||
edgeCases := []struct {
|
||||
sessionID string
|
||||
valid bool
|
||||
}{
|
||||
{"12345678901234567890123456789012", true}, // 32 chars (min)
|
||||
{"1234567890123456789012345678901", false}, // 31 chars (too short)
|
||||
{string(make([]byte, 128)), false}, // 128 non-hex chars
|
||||
{"abcdef1234567890ABCDEF1234567890" + string(make([]byte, 96)), false}, // 128+ chars with non-hex
|
||||
}
|
||||
|
||||
// Generate valid 128-char hex string (max length)
|
||||
validLongHex := ""
|
||||
for i := 0; i < 128; i++ {
|
||||
validLongHex += "a"
|
||||
}
|
||||
edgeCases = append(edgeCases, struct {
|
||||
sessionID string
|
||||
valid bool
|
||||
}{validLongHex, true})
|
||||
|
||||
for _, tc := range edgeCases {
|
||||
result := ValidateSessionID(tc.sessionID)
|
||||
if result != tc.valid {
|
||||
t.Errorf("ValidateSessionID(%q) = %v, want %v", tc.sessionID, result, tc.valid)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -202,8 +202,10 @@ func (p *TransportPool) createTransport(config TransportConfig) *http.Transport
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
||||
},
|
||||
// #nosec G402 -- PreferServerCipherSuites is deprecated in Go 1.17+ but setting it is harmless
|
||||
PreferServerCipherSuites: true,
|
||||
InsecureSkipVerify: config.InsecureSkipVerify,
|
||||
// #nosec G402 -- InsecureSkipVerify is configurable for testing/dev environments
|
||||
InsecureSkipVerify: config.InsecureSkipVerify,
|
||||
}
|
||||
|
||||
return &http.Transport{
|
||||
|
||||
@@ -148,6 +148,7 @@ func (cb *CircuitBreaker) allowRequest() bool {
|
||||
// allowHalfOpenRequest checks if a request is allowed in half-open state
|
||||
func (cb *CircuitBreaker) allowHalfOpenRequest() bool {
|
||||
current := atomic.AddInt32(&cb.halfOpenRequests, 1)
|
||||
// #nosec G115 -- MaxRequests is a small config value that fits in int32
|
||||
if current <= int32(cb.config.MaxRequests) {
|
||||
return true
|
||||
}
|
||||
@@ -164,6 +165,7 @@ func (cb *CircuitBreaker) recordFailure() {
|
||||
|
||||
state := CircuitBreakerState(atomic.LoadInt32(&cb.state))
|
||||
|
||||
// #nosec G115 -- FailureThreshold is a small config value that fits in int32
|
||||
if state == CircuitBreakerClosed && failures >= int32(cb.config.FailureThreshold) {
|
||||
cb.transitionToOpen()
|
||||
} else if state == CircuitBreakerHalfOpen {
|
||||
@@ -180,6 +182,7 @@ func (cb *CircuitBreaker) recordSuccess() {
|
||||
|
||||
state := CircuitBreakerState(atomic.LoadInt32(&cb.state))
|
||||
|
||||
// #nosec G115 -- SuccessThreshold is a small config value that fits in int32
|
||||
if state == CircuitBreakerHalfOpen && successes >= int32(cb.config.SuccessThreshold) {
|
||||
cb.transitionToClosed()
|
||||
}
|
||||
|
||||
@@ -191,6 +191,7 @@ func (re *RetryExecutor) calculateDelay(attempt int) time.Duration {
|
||||
}
|
||||
|
||||
// Add jitter
|
||||
// #nosec G404 -- math/rand is acceptable for jitter timing, not security-sensitive
|
||||
if re.config.RandomizationFactor > 0 {
|
||||
jitter := delay * re.config.RandomizationFactor
|
||||
minDelay := delay - jitter
|
||||
|
||||
@@ -0,0 +1,796 @@
|
||||
//go:build !yaegi
|
||||
|
||||
package recovery
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// =============================================================================
|
||||
// RETRY CONFIG TESTS
|
||||
// =============================================================================
|
||||
|
||||
func TestDefaultRetryConfig(t *testing.T) {
|
||||
config := DefaultRetryConfig()
|
||||
|
||||
if config.MaxAttempts != 3 {
|
||||
t.Errorf("Expected MaxAttempts to be 3, got %d", config.MaxAttempts)
|
||||
}
|
||||
|
||||
if config.InitialDelay != 100*time.Millisecond {
|
||||
t.Errorf("Expected InitialDelay to be 100ms, got %v", config.InitialDelay)
|
||||
}
|
||||
|
||||
if config.MaxDelay != 30*time.Second {
|
||||
t.Errorf("Expected MaxDelay to be 30s, got %v", config.MaxDelay)
|
||||
}
|
||||
|
||||
if config.Multiplier != 2.0 {
|
||||
t.Errorf("Expected Multiplier to be 2.0, got %f", config.Multiplier)
|
||||
}
|
||||
|
||||
if config.RandomizationFactor != 0.1 {
|
||||
t.Errorf("Expected RandomizationFactor to be 0.1, got %f", config.RandomizationFactor)
|
||||
}
|
||||
|
||||
if len(config.RetryableErrors) != 3 {
|
||||
t.Errorf("Expected 3 retryable errors, got %d", len(config.RetryableErrors))
|
||||
}
|
||||
|
||||
if len(config.RetryableStatusCodes) != 6 {
|
||||
t.Errorf("Expected 6 retryable status codes, got %d", len(config.RetryableStatusCodes))
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// RETRY EXECUTOR TESTS
|
||||
// =============================================================================
|
||||
|
||||
func TestNewRetryExecutor(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
config := DefaultRetryConfig()
|
||||
|
||||
executor := NewRetryExecutor(config, logger)
|
||||
|
||||
if executor == nil {
|
||||
t.Fatal("Expected NewRetryExecutor to return non-nil")
|
||||
}
|
||||
|
||||
if executor.config.MaxAttempts != 3 {
|
||||
t.Errorf("Expected MaxAttempts to be 3, got %d", executor.config.MaxAttempts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewRetryExecutor_InvalidConfig(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
|
||||
// Test with invalid MaxAttempts
|
||||
config := RetryConfig{
|
||||
MaxAttempts: 0, // Invalid
|
||||
Multiplier: 0, // Invalid
|
||||
}
|
||||
|
||||
executor := NewRetryExecutor(config, logger)
|
||||
|
||||
if executor.config.MaxAttempts != 1 {
|
||||
t.Errorf("Expected MaxAttempts to be corrected to 1, got %d", executor.config.MaxAttempts)
|
||||
}
|
||||
|
||||
if executor.config.Multiplier != 1.0 {
|
||||
t.Errorf("Expected Multiplier to be corrected to 1.0, got %f", executor.config.Multiplier)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryExecutor_ExecuteWithContext_Success(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
config := DefaultRetryConfig()
|
||||
executor := NewRetryExecutor(config, logger)
|
||||
|
||||
callCount := 0
|
||||
err := executor.ExecuteWithContext(context.Background(), func() error {
|
||||
callCount++
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if callCount != 1 {
|
||||
t.Errorf("Expected function to be called once, got %d", callCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryExecutor_ExecuteWithContext_Retry(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
config := RetryConfig{
|
||||
MaxAttempts: 3,
|
||||
InitialDelay: 1 * time.Millisecond,
|
||||
MaxDelay: 10 * time.Millisecond,
|
||||
Multiplier: 2.0,
|
||||
RetryableErrors: []string{"connection refused"},
|
||||
}
|
||||
executor := NewRetryExecutor(config, logger)
|
||||
|
||||
callCount := 0
|
||||
err := executor.ExecuteWithContext(context.Background(), func() error {
|
||||
callCount++
|
||||
if callCount < 3 {
|
||||
return errors.New("connection refused")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Expected success after retries, got %v", err)
|
||||
}
|
||||
|
||||
if callCount != 3 {
|
||||
t.Errorf("Expected function to be called 3 times, got %d", callCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryExecutor_ExecuteWithContext_MaxRetriesExhausted(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
config := RetryConfig{
|
||||
MaxAttempts: 3,
|
||||
InitialDelay: 1 * time.Millisecond,
|
||||
MaxDelay: 10 * time.Millisecond,
|
||||
Multiplier: 2.0,
|
||||
RetryableErrors: []string{"timeout"},
|
||||
}
|
||||
executor := NewRetryExecutor(config, logger)
|
||||
|
||||
callCount := 0
|
||||
err := executor.ExecuteWithContext(context.Background(), func() error {
|
||||
callCount++
|
||||
return errors.New("timeout")
|
||||
})
|
||||
|
||||
if err == nil {
|
||||
t.Error("Expected error after max retries exhausted")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "all retry attempts failed") {
|
||||
t.Errorf("Expected 'all retry attempts failed' error, got %v", err)
|
||||
}
|
||||
|
||||
if callCount != 3 {
|
||||
t.Errorf("Expected function to be called 3 times, got %d", callCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryExecutor_ExecuteWithContext_NonRetryableError(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
config := RetryConfig{
|
||||
MaxAttempts: 3,
|
||||
InitialDelay: 1 * time.Millisecond,
|
||||
RetryableErrors: []string{"timeout"},
|
||||
}
|
||||
executor := NewRetryExecutor(config, logger)
|
||||
|
||||
callCount := 0
|
||||
err := executor.ExecuteWithContext(context.Background(), func() error {
|
||||
callCount++
|
||||
return errors.New("permanent error")
|
||||
})
|
||||
|
||||
if err == nil {
|
||||
t.Error("Expected error for non-retryable error")
|
||||
}
|
||||
|
||||
if callCount != 1 {
|
||||
t.Errorf("Expected function to be called once (non-retryable), got %d", callCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryExecutor_ExecuteWithContext_ContextCancelled(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
config := RetryConfig{
|
||||
MaxAttempts: 3,
|
||||
InitialDelay: 1 * time.Second, // Long delay
|
||||
RetryableErrors: []string{"timeout"},
|
||||
}
|
||||
executor := NewRetryExecutor(config, logger)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
callCount := 0
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
|
||||
var execErr error
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
execErr = executor.ExecuteWithContext(ctx, func() error {
|
||||
callCount++
|
||||
return errors.New("timeout")
|
||||
})
|
||||
}()
|
||||
|
||||
// Cancel after a short delay
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if execErr == nil {
|
||||
t.Error("Expected error when context is cancelled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryExecutor_ExecuteWithContext_ContextCancelledBeforeStart(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
config := DefaultRetryConfig()
|
||||
executor := NewRetryExecutor(config, logger)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // Cancel immediately
|
||||
|
||||
err := executor.ExecuteWithContext(ctx, func() error {
|
||||
return nil
|
||||
})
|
||||
|
||||
if err == nil {
|
||||
t.Error("Expected error when context is already cancelled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryExecutor_Execute(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
config := DefaultRetryConfig()
|
||||
executor := NewRetryExecutor(config, logger)
|
||||
|
||||
called := false
|
||||
err := executor.Execute(context.Background(), func() error {
|
||||
called = true
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if !called {
|
||||
t.Error("Expected function to be called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryExecutor_isRetryableError(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
config := RetryConfig{
|
||||
RetryableErrors: []string{"connection refused", "timeout"},
|
||||
RetryableStatusCodes: []int{500, 503},
|
||||
}
|
||||
executor := NewRetryExecutor(config, logger)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
expected bool
|
||||
}{
|
||||
{"nil error", nil, false},
|
||||
{"connection refused", errors.New("connection refused"), true},
|
||||
{"timeout", errors.New("TIMEOUT"), true}, // case insensitive
|
||||
{"EOF", errors.New("EOF"), false},
|
||||
{"random error", errors.New("something else"), false},
|
||||
{"context cancelled", context.Canceled, false},
|
||||
{"context deadline exceeded", context.DeadlineExceeded, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := executor.isRetryableError(tt.err)
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected %v, got %v", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryExecutor_isRetryableError_HTTPError(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
config := RetryConfig{
|
||||
RetryableStatusCodes: []int{500, 503},
|
||||
}
|
||||
executor := NewRetryExecutor(config, logger)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
expected bool
|
||||
}{
|
||||
{"500 error", 500, true},
|
||||
{"503 error", 503, true},
|
||||
{"502 error (5xx)", 502, true},
|
||||
{"400 error", 400, false},
|
||||
{"401 error", 401, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
httpErr := &HTTPError{StatusCode: tt.statusCode}
|
||||
result := executor.isRetryableError(httpErr)
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected %v, got %v", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryExecutor_isRetryableError_OIDCError(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
config := DefaultRetryConfig()
|
||||
executor := NewRetryExecutor(config, logger)
|
||||
|
||||
// Test retryable OIDC error
|
||||
retryableErr := &OIDCError{Code: "temporarily_unavailable", Description: "Server busy"}
|
||||
if !executor.isRetryableError(retryableErr) {
|
||||
t.Error("Expected temporarily_unavailable to be retryable")
|
||||
}
|
||||
|
||||
// Test non-retryable OIDC error
|
||||
nonRetryableErr := &OIDCError{Code: "invalid_token", Description: "Token expired"}
|
||||
if executor.isRetryableError(nonRetryableErr) {
|
||||
t.Error("Expected invalid_token to not be retryable")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryExecutor_calculateDelay(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
config := RetryConfig{
|
||||
InitialDelay: 100 * time.Millisecond,
|
||||
MaxDelay: 1 * time.Second,
|
||||
Multiplier: 2.0,
|
||||
RandomizationFactor: 0.0, // No jitter for predictable tests
|
||||
}
|
||||
executor := NewRetryExecutor(config, logger)
|
||||
|
||||
// Test exponential backoff without jitter
|
||||
delay1 := executor.calculateDelay(1)
|
||||
if delay1 != 100*time.Millisecond {
|
||||
t.Errorf("Expected 100ms for attempt 1, got %v", delay1)
|
||||
}
|
||||
|
||||
delay2 := executor.calculateDelay(2)
|
||||
if delay2 != 200*time.Millisecond {
|
||||
t.Errorf("Expected 200ms for attempt 2, got %v", delay2)
|
||||
}
|
||||
|
||||
delay3 := executor.calculateDelay(3)
|
||||
if delay3 != 400*time.Millisecond {
|
||||
t.Errorf("Expected 400ms for attempt 3, got %v", delay3)
|
||||
}
|
||||
|
||||
// Test max delay cap
|
||||
delay10 := executor.calculateDelay(10)
|
||||
if delay10 > 1*time.Second {
|
||||
t.Errorf("Expected delay capped at 1s, got %v", delay10)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryExecutor_calculateDelay_WithJitter(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
config := RetryConfig{
|
||||
InitialDelay: 100 * time.Millisecond,
|
||||
MaxDelay: 1 * time.Second,
|
||||
Multiplier: 2.0,
|
||||
RandomizationFactor: 0.5, // 50% jitter
|
||||
}
|
||||
executor := NewRetryExecutor(config, logger)
|
||||
|
||||
// With jitter, delay should be within range
|
||||
baseDelay := 100 * time.Millisecond
|
||||
minExpected := time.Duration(float64(baseDelay) * 0.5)
|
||||
maxExpected := time.Duration(float64(baseDelay) * 1.5)
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
delay := executor.calculateDelay(1)
|
||||
if delay < minExpected || delay > maxExpected {
|
||||
t.Errorf("Delay %v outside expected range [%v, %v]", delay, minExpected, maxExpected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryExecutor_Reset(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
config := DefaultRetryConfig()
|
||||
executor := NewRetryExecutor(config, logger)
|
||||
|
||||
// Generate some metrics
|
||||
executor.ExecuteWithContext(context.Background(), func() error { return nil })
|
||||
executor.ExecuteWithContext(context.Background(), func() error { return nil })
|
||||
|
||||
// Reset
|
||||
executor.Reset()
|
||||
|
||||
if atomic.LoadInt64(&executor.totalRetries) != 0 {
|
||||
t.Error("Expected totalRetries to be 0 after reset")
|
||||
}
|
||||
|
||||
if atomic.LoadInt64(&executor.maxRetriesHit) != 0 {
|
||||
t.Error("Expected maxRetriesHit to be 0 after reset")
|
||||
}
|
||||
|
||||
if atomic.LoadInt64(&executor.totalRequests) != 0 {
|
||||
t.Error("Expected totalRequests to be 0 after reset")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryExecutor_IsAvailable(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
config := DefaultRetryConfig()
|
||||
executor := NewRetryExecutor(config, logger)
|
||||
|
||||
if !executor.IsAvailable() {
|
||||
t.Error("Expected IsAvailable to return true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryExecutor_GetMetrics(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
config := DefaultRetryConfig()
|
||||
executor := NewRetryExecutor(config, logger)
|
||||
|
||||
// Generate some metrics
|
||||
executor.ExecuteWithContext(context.Background(), func() error { return nil })
|
||||
|
||||
metrics := executor.GetMetrics()
|
||||
|
||||
// Check required fields
|
||||
if _, ok := metrics["totalRetries"]; !ok {
|
||||
t.Error("Expected 'totalRetries' in metrics")
|
||||
}
|
||||
|
||||
if _, ok := metrics["maxRetriesHit"]; !ok {
|
||||
t.Error("Expected 'maxRetriesHit' in metrics")
|
||||
}
|
||||
|
||||
if _, ok := metrics["config"]; !ok {
|
||||
t.Error("Expected 'config' in metrics")
|
||||
}
|
||||
|
||||
if _, ok := metrics["lastRetryTime"]; !ok {
|
||||
t.Error("Expected 'lastRetryTime' in metrics")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryExecutor_GetMetrics_WithRetries(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
config := RetryConfig{
|
||||
MaxAttempts: 3,
|
||||
InitialDelay: 1 * time.Millisecond,
|
||||
MaxDelay: 10 * time.Millisecond,
|
||||
Multiplier: 2.0,
|
||||
RetryableErrors: []string{"retry"},
|
||||
}
|
||||
executor := NewRetryExecutor(config, logger)
|
||||
|
||||
// Generate retries
|
||||
callCount := 0
|
||||
executor.ExecuteWithContext(context.Background(), func() error {
|
||||
callCount++
|
||||
if callCount < 2 {
|
||||
return errors.New("retry me")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
metrics := executor.GetMetrics()
|
||||
|
||||
totalRetries := metrics["totalRetries"].(int64)
|
||||
if totalRetries < 1 {
|
||||
t.Errorf("Expected at least 1 retry, got %d", totalRetries)
|
||||
}
|
||||
|
||||
// Check for average retries calculation
|
||||
if _, ok := metrics["averageRetriesPerRequest"]; !ok {
|
||||
t.Error("Expected 'averageRetriesPerRequest' in metrics")
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// RECOVERY METRICS TESTS
|
||||
// =============================================================================
|
||||
|
||||
func TestNewRecoveryMetrics(t *testing.T) {
|
||||
rm := NewRecoveryMetrics()
|
||||
|
||||
if rm == nil {
|
||||
t.Fatal("Expected NewRecoveryMetrics to return non-nil")
|
||||
}
|
||||
|
||||
if rm.mechanisms == nil {
|
||||
t.Error("Expected mechanisms map to be initialized")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecoveryMetrics_RegisterMechanism(t *testing.T) {
|
||||
rm := NewRecoveryMetrics()
|
||||
logger := &mockLogger{}
|
||||
|
||||
cb := NewCircuitBreaker(DefaultCircuitBreakerConfig(), logger)
|
||||
rm.RegisterMechanism("circuit_breaker", cb)
|
||||
|
||||
rm.mu.RLock()
|
||||
defer rm.mu.RUnlock()
|
||||
|
||||
if _, exists := rm.mechanisms["circuit_breaker"]; !exists {
|
||||
t.Error("Expected mechanism to be registered")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecoveryMetrics_UnregisterMechanism(t *testing.T) {
|
||||
rm := NewRecoveryMetrics()
|
||||
logger := &mockLogger{}
|
||||
|
||||
cb := NewCircuitBreaker(DefaultCircuitBreakerConfig(), logger)
|
||||
rm.RegisterMechanism("circuit_breaker", cb)
|
||||
rm.UnregisterMechanism("circuit_breaker")
|
||||
|
||||
rm.mu.RLock()
|
||||
defer rm.mu.RUnlock()
|
||||
|
||||
if _, exists := rm.mechanisms["circuit_breaker"]; exists {
|
||||
t.Error("Expected mechanism to be unregistered")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecoveryMetrics_GetAllMetrics(t *testing.T) {
|
||||
rm := NewRecoveryMetrics()
|
||||
logger := &mockLogger{}
|
||||
|
||||
cb := NewCircuitBreaker(DefaultCircuitBreakerConfig(), logger)
|
||||
rm.RegisterMechanism("circuit_breaker", cb)
|
||||
|
||||
re := NewRetryExecutor(DefaultRetryConfig(), logger)
|
||||
rm.RegisterMechanism("retry_executor", re)
|
||||
|
||||
metrics := rm.GetAllMetrics()
|
||||
|
||||
if _, ok := metrics["circuit_breaker"]; !ok {
|
||||
t.Error("Expected 'circuit_breaker' in metrics")
|
||||
}
|
||||
|
||||
if _, ok := metrics["retry_executor"]; !ok {
|
||||
t.Error("Expected 'retry_executor' in metrics")
|
||||
}
|
||||
|
||||
if _, ok := metrics["summary"]; !ok {
|
||||
t.Error("Expected 'summary' in metrics")
|
||||
}
|
||||
|
||||
summary := metrics["summary"].(map[string]interface{})
|
||||
if summary["totalMechanisms"] != 2 {
|
||||
t.Errorf("Expected 2 mechanisms, got %v", summary["totalMechanisms"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecoveryMetrics_GetAllMetrics_WithActivity(t *testing.T) {
|
||||
rm := NewRecoveryMetrics()
|
||||
logger := &mockLogger{}
|
||||
|
||||
cb := NewCircuitBreaker(DefaultCircuitBreakerConfig(), logger)
|
||||
rm.RegisterMechanism("circuit_breaker", cb)
|
||||
|
||||
// Generate some activity
|
||||
cb.Execute(func() error { return nil })
|
||||
cb.Execute(func() error { return nil })
|
||||
|
||||
metrics := rm.GetAllMetrics()
|
||||
summary := metrics["summary"].(map[string]interface{})
|
||||
|
||||
// Should have success rate calculated
|
||||
if _, ok := summary["overallSuccessRate"]; !ok {
|
||||
t.Error("Expected 'overallSuccessRate' in summary")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecoveryMetrics_GetMechanismMetrics(t *testing.T) {
|
||||
rm := NewRecoveryMetrics()
|
||||
logger := &mockLogger{}
|
||||
|
||||
cb := NewCircuitBreaker(DefaultCircuitBreakerConfig(), logger)
|
||||
rm.RegisterMechanism("circuit_breaker", cb)
|
||||
|
||||
// Test existing mechanism
|
||||
metrics, ok := rm.GetMechanismMetrics("circuit_breaker")
|
||||
if !ok {
|
||||
t.Error("Expected to find circuit_breaker mechanism")
|
||||
}
|
||||
if metrics == nil {
|
||||
t.Error("Expected metrics to be non-nil")
|
||||
}
|
||||
|
||||
// Test non-existing mechanism
|
||||
_, ok = rm.GetMechanismMetrics("non_existent")
|
||||
if ok {
|
||||
t.Error("Expected to not find non_existent mechanism")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecoveryMetrics_HealthCheck(t *testing.T) {
|
||||
rm := NewRecoveryMetrics()
|
||||
logger := &mockLogger{}
|
||||
|
||||
// Test with healthy mechanism
|
||||
cb := NewCircuitBreaker(DefaultCircuitBreakerConfig(), logger)
|
||||
rm.RegisterMechanism("circuit_breaker", cb)
|
||||
|
||||
health := rm.HealthCheck()
|
||||
|
||||
if health["status"] != "healthy" {
|
||||
t.Errorf("Expected status 'healthy', got %v", health["status"])
|
||||
}
|
||||
|
||||
mechanisms := health["mechanisms"].(map[string]interface{})
|
||||
if mechanisms["circuit_breaker"] != "healthy" {
|
||||
t.Errorf("Expected circuit_breaker to be 'healthy', got %v", mechanisms["circuit_breaker"])
|
||||
}
|
||||
|
||||
if health["healthy"] != 1 {
|
||||
t.Errorf("Expected 1 healthy, got %v", health["healthy"])
|
||||
}
|
||||
|
||||
if health["unhealthy"] != 0 {
|
||||
t.Errorf("Expected 0 unhealthy, got %v", health["unhealthy"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecoveryMetrics_HealthCheck_Degraded(t *testing.T) {
|
||||
rm := NewRecoveryMetrics()
|
||||
logger := &mockLogger{}
|
||||
|
||||
// Add a healthy mechanism
|
||||
cb1 := NewCircuitBreaker(DefaultCircuitBreakerConfig(), logger)
|
||||
rm.RegisterMechanism("healthy_cb", cb1)
|
||||
|
||||
// Add an unhealthy mechanism (trip the circuit breaker)
|
||||
config := CircuitBreakerConfig{
|
||||
FailureThreshold: 1,
|
||||
SuccessThreshold: 10,
|
||||
Timeout: 1 * time.Hour,
|
||||
MaxRequests: 1,
|
||||
}
|
||||
cb2 := NewCircuitBreaker(config, logger)
|
||||
cb2.Execute(func() error { return errors.New("fail") })
|
||||
rm.RegisterMechanism("unhealthy_cb", cb2)
|
||||
|
||||
health := rm.HealthCheck()
|
||||
|
||||
if health["status"] != "degraded" {
|
||||
t.Errorf("Expected status 'degraded', got %v", health["status"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecoveryMetrics_HealthCheck_Unhealthy(t *testing.T) {
|
||||
rm := NewRecoveryMetrics()
|
||||
logger := &mockLogger{}
|
||||
|
||||
// Add only an unhealthy mechanism
|
||||
config := CircuitBreakerConfig{
|
||||
FailureThreshold: 1,
|
||||
SuccessThreshold: 10,
|
||||
Timeout: 1 * time.Hour,
|
||||
MaxRequests: 1,
|
||||
}
|
||||
cb := NewCircuitBreaker(config, logger)
|
||||
cb.Execute(func() error { return errors.New("fail") })
|
||||
rm.RegisterMechanism("unhealthy_cb", cb)
|
||||
|
||||
health := rm.HealthCheck()
|
||||
|
||||
if health["status"] != "unhealthy" {
|
||||
t.Errorf("Expected status 'unhealthy', got %v", health["status"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecoveryMetrics_HTTPMetricsHandler(t *testing.T) {
|
||||
rm := NewRecoveryMetrics()
|
||||
logger := &mockLogger{}
|
||||
|
||||
cb := NewCircuitBreaker(DefaultCircuitBreakerConfig(), logger)
|
||||
rm.RegisterMechanism("circuit_breaker", cb)
|
||||
|
||||
handler := rm.HTTPMetricsHandler()
|
||||
|
||||
req := httptest.NewRequest("GET", "/metrics", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
contentType := w.Header().Get("Content-Type")
|
||||
if contentType != "application/json" {
|
||||
t.Errorf("Expected Content-Type 'application/json', got %s", contentType)
|
||||
}
|
||||
|
||||
body := w.Body.String()
|
||||
if body == "" {
|
||||
t.Error("Expected non-empty response body")
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// CONCURRENT ACCESS TESTS
|
||||
// =============================================================================
|
||||
|
||||
func TestRecoveryMetrics_ConcurrentAccess(t *testing.T) {
|
||||
rm := NewRecoveryMetrics()
|
||||
logger := &mockLogger{}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Concurrent registrations
|
||||
for i := 0; i < 10; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
cb := NewCircuitBreaker(DefaultCircuitBreakerConfig(), logger)
|
||||
rm.RegisterMechanism(string(rune('a'+idx)), cb)
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Concurrent reads
|
||||
for i := 0; i < 10; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_ = rm.GetAllMetrics()
|
||||
_ = rm.HealthCheck()
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Verify no race conditions
|
||||
health := rm.HealthCheck()
|
||||
if health == nil {
|
||||
t.Error("Expected HealthCheck to return non-nil after concurrent access")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryExecutor_ConcurrentExecution(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
config := RetryConfig{
|
||||
MaxAttempts: 3,
|
||||
InitialDelay: 1 * time.Millisecond,
|
||||
MaxDelay: 10 * time.Millisecond,
|
||||
Multiplier: 2.0,
|
||||
RetryableErrors: []string{"retry"},
|
||||
}
|
||||
executor := NewRetryExecutor(config, logger)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
successCount := int64(0)
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
err := executor.ExecuteWithContext(context.Background(), func() error {
|
||||
return nil
|
||||
})
|
||||
if err == nil {
|
||||
atomic.AddInt64(&successCount, 1)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if successCount != 100 {
|
||||
t.Errorf("Expected 100 successes, got %d", successCount)
|
||||
}
|
||||
}
|
||||
@@ -1,403 +0,0 @@
|
||||
// Package security provides security-related middleware and utilities
|
||||
package security
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SecurityHeadersConfig configures security headers
|
||||
type SecurityHeadersConfig struct {
|
||||
// Content Security Policy
|
||||
ContentSecurityPolicy string
|
||||
|
||||
// HSTS settings
|
||||
StrictTransportSecurity string
|
||||
StrictTransportSecurityMaxAge int // seconds
|
||||
StrictTransportSecuritySubdomains bool
|
||||
StrictTransportSecurityPreload bool
|
||||
|
||||
// Frame options
|
||||
FrameOptions string // DENY, SAMEORIGIN, or ALLOW-FROM uri
|
||||
|
||||
// Content type options
|
||||
ContentTypeOptions string // nosniff
|
||||
|
||||
// XSS protection
|
||||
XSSProtection string // 1; mode=block
|
||||
|
||||
// Referrer policy
|
||||
ReferrerPolicy string
|
||||
|
||||
// Permissions policy
|
||||
PermissionsPolicy string
|
||||
|
||||
// Cross-origin settings
|
||||
CrossOriginEmbedderPolicy string
|
||||
CrossOriginOpenerPolicy string
|
||||
CrossOriginResourcePolicy string
|
||||
|
||||
// CORS settings
|
||||
CORSEnabled bool
|
||||
CORSAllowedOrigins []string
|
||||
CORSAllowedMethods []string
|
||||
CORSAllowedHeaders []string
|
||||
CORSAllowCredentials bool
|
||||
CORSMaxAge int // seconds
|
||||
|
||||
// Custom headers
|
||||
CustomHeaders map[string]string
|
||||
|
||||
// Security features
|
||||
DisableServerHeader bool
|
||||
DisablePoweredByHeader bool
|
||||
|
||||
// Development mode (less strict for local development)
|
||||
DevelopmentMode bool
|
||||
}
|
||||
|
||||
// DefaultSecurityConfig returns a secure default configuration
|
||||
func DefaultSecurityConfig() *SecurityHeadersConfig {
|
||||
return &SecurityHeadersConfig{
|
||||
ContentSecurityPolicy: "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'; img-src 'self' data: https:; font-src 'self' data:; connect-src 'self'; frame-ancestors 'none';",
|
||||
|
||||
StrictTransportSecurityMaxAge: 31536000, // 1 year
|
||||
StrictTransportSecuritySubdomains: true,
|
||||
StrictTransportSecurityPreload: true,
|
||||
|
||||
FrameOptions: "DENY",
|
||||
ContentTypeOptions: "nosniff",
|
||||
XSSProtection: "1; mode=block",
|
||||
ReferrerPolicy: "strict-origin-when-cross-origin",
|
||||
|
||||
PermissionsPolicy: "geolocation=(), microphone=(), camera=(), payment=(), usb=(), magnetometer=(), gyroscope=(), speaker=()",
|
||||
|
||||
CrossOriginEmbedderPolicy: "require-corp",
|
||||
CrossOriginOpenerPolicy: "same-origin",
|
||||
CrossOriginResourcePolicy: "same-origin",
|
||||
|
||||
CORSEnabled: false,
|
||||
CORSAllowedMethods: []string{"GET", "POST", "OPTIONS"},
|
||||
CORSAllowedHeaders: []string{"Authorization", "Content-Type", "X-Requested-With"},
|
||||
CORSMaxAge: 86400, // 24 hours
|
||||
|
||||
DisableServerHeader: true,
|
||||
DisablePoweredByHeader: true,
|
||||
|
||||
DevelopmentMode: false,
|
||||
}
|
||||
}
|
||||
|
||||
// DevelopmentSecurityConfig returns a configuration suitable for development
|
||||
func DevelopmentSecurityConfig() *SecurityHeadersConfig {
|
||||
config := DefaultSecurityConfig()
|
||||
|
||||
// Relax CSP for development
|
||||
config.ContentSecurityPolicy = "default-src 'self' 'unsafe-inline' 'unsafe-eval'; img-src 'self' data: https: http:; connect-src 'self' ws: wss:;"
|
||||
|
||||
// Allow framing for development tools
|
||||
config.FrameOptions = "SAMEORIGIN"
|
||||
|
||||
// Enable CORS for local development
|
||||
config.CORSEnabled = true
|
||||
config.CORSAllowedOrigins = []string{"http://localhost:*", "http://127.0.0.1:*"}
|
||||
config.CORSAllowCredentials = true
|
||||
|
||||
// Relax cross-origin policies
|
||||
config.CrossOriginEmbedderPolicy = ""
|
||||
config.CrossOriginOpenerPolicy = "unsafe-none"
|
||||
config.CrossOriginResourcePolicy = "cross-origin"
|
||||
|
||||
config.DevelopmentMode = true
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
// SecurityHeadersMiddleware applies security headers to HTTP responses
|
||||
type SecurityHeadersMiddleware struct {
|
||||
config *SecurityHeadersConfig
|
||||
}
|
||||
|
||||
// NewSecurityHeadersMiddleware creates a new security headers middleware
|
||||
func NewSecurityHeadersMiddleware(config *SecurityHeadersConfig) *SecurityHeadersMiddleware {
|
||||
if config == nil {
|
||||
config = DefaultSecurityConfig()
|
||||
}
|
||||
|
||||
return &SecurityHeadersMiddleware{
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
// Apply applies security headers to the response
|
||||
func (m *SecurityHeadersMiddleware) Apply(rw http.ResponseWriter, req *http.Request) {
|
||||
headers := rw.Header()
|
||||
|
||||
// Content Security Policy
|
||||
if m.config.ContentSecurityPolicy != "" {
|
||||
headers.Set("Content-Security-Policy", m.config.ContentSecurityPolicy)
|
||||
}
|
||||
|
||||
// HSTS (only for HTTPS)
|
||||
if req.TLS != nil || req.Header.Get("X-Forwarded-Proto") == "https" {
|
||||
hstsValue := m.buildHSTSHeader()
|
||||
if hstsValue != "" {
|
||||
headers.Set("Strict-Transport-Security", hstsValue)
|
||||
}
|
||||
}
|
||||
|
||||
// Frame options
|
||||
if m.config.FrameOptions != "" {
|
||||
headers.Set("X-Frame-Options", m.config.FrameOptions)
|
||||
}
|
||||
|
||||
// Content type options
|
||||
if m.config.ContentTypeOptions != "" {
|
||||
headers.Set("X-Content-Type-Options", m.config.ContentTypeOptions)
|
||||
}
|
||||
|
||||
// XSS protection
|
||||
if m.config.XSSProtection != "" {
|
||||
headers.Set("X-XSS-Protection", m.config.XSSProtection)
|
||||
}
|
||||
|
||||
// Referrer policy
|
||||
if m.config.ReferrerPolicy != "" {
|
||||
headers.Set("Referrer-Policy", m.config.ReferrerPolicy)
|
||||
}
|
||||
|
||||
// Permissions policy
|
||||
if m.config.PermissionsPolicy != "" {
|
||||
headers.Set("Permissions-Policy", m.config.PermissionsPolicy)
|
||||
}
|
||||
|
||||
// Cross-origin policies
|
||||
if m.config.CrossOriginEmbedderPolicy != "" {
|
||||
headers.Set("Cross-Origin-Embedder-Policy", m.config.CrossOriginEmbedderPolicy)
|
||||
}
|
||||
|
||||
if m.config.CrossOriginOpenerPolicy != "" {
|
||||
headers.Set("Cross-Origin-Opener-Policy", m.config.CrossOriginOpenerPolicy)
|
||||
}
|
||||
|
||||
if m.config.CrossOriginResourcePolicy != "" {
|
||||
headers.Set("Cross-Origin-Resource-Policy", m.config.CrossOriginResourcePolicy)
|
||||
}
|
||||
|
||||
// CORS headers
|
||||
if m.config.CORSEnabled {
|
||||
m.applyCORSHeaders(rw, req)
|
||||
}
|
||||
|
||||
// Custom headers
|
||||
for name, value := range m.config.CustomHeaders {
|
||||
headers.Set(name, value)
|
||||
}
|
||||
|
||||
// Remove server identification headers
|
||||
if m.config.DisableServerHeader {
|
||||
headers.Del("Server")
|
||||
}
|
||||
|
||||
if m.config.DisablePoweredByHeader {
|
||||
headers.Del("X-Powered-By")
|
||||
}
|
||||
|
||||
// Add security timestamp for debugging
|
||||
if m.config.DevelopmentMode {
|
||||
headers.Set("X-Security-Headers-Applied", time.Now().UTC().Format(time.RFC3339))
|
||||
}
|
||||
}
|
||||
|
||||
// buildHSTSHeader constructs the HSTS header value
|
||||
func (m *SecurityHeadersMiddleware) buildHSTSHeader() string {
|
||||
if m.config.StrictTransportSecurityMaxAge <= 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
parts := []string{
|
||||
"max-age=" + string(rune(m.config.StrictTransportSecurityMaxAge)),
|
||||
}
|
||||
|
||||
if m.config.StrictTransportSecuritySubdomains {
|
||||
parts = append(parts, "includeSubDomains")
|
||||
}
|
||||
|
||||
if m.config.StrictTransportSecurityPreload {
|
||||
parts = append(parts, "preload")
|
||||
}
|
||||
|
||||
return strings.Join(parts, "; ")
|
||||
}
|
||||
|
||||
// applyCORSHeaders applies CORS headers based on the request
|
||||
func (m *SecurityHeadersMiddleware) applyCORSHeaders(rw http.ResponseWriter, req *http.Request) {
|
||||
headers := rw.Header()
|
||||
origin := req.Header.Get("Origin")
|
||||
|
||||
// Check if origin is allowed
|
||||
if origin != "" && m.isOriginAllowed(origin) {
|
||||
headers.Set("Access-Control-Allow-Origin", origin)
|
||||
} else if len(m.config.CORSAllowedOrigins) == 1 && m.config.CORSAllowedOrigins[0] == "*" {
|
||||
headers.Set("Access-Control-Allow-Origin", "*")
|
||||
}
|
||||
|
||||
// Set other CORS headers
|
||||
if len(m.config.CORSAllowedMethods) > 0 {
|
||||
headers.Set("Access-Control-Allow-Methods", strings.Join(m.config.CORSAllowedMethods, ", "))
|
||||
}
|
||||
|
||||
if len(m.config.CORSAllowedHeaders) > 0 {
|
||||
headers.Set("Access-Control-Allow-Headers", strings.Join(m.config.CORSAllowedHeaders, ", "))
|
||||
}
|
||||
|
||||
if m.config.CORSAllowCredentials {
|
||||
headers.Set("Access-Control-Allow-Credentials", "true")
|
||||
}
|
||||
|
||||
if m.config.CORSMaxAge > 0 {
|
||||
headers.Set("Access-Control-Max-Age", string(rune(m.config.CORSMaxAge)))
|
||||
}
|
||||
|
||||
// Handle preflight requests
|
||||
if req.Method == "OPTIONS" {
|
||||
headers.Set("Access-Control-Allow-Methods", strings.Join(m.config.CORSAllowedMethods, ", "))
|
||||
headers.Set("Access-Control-Allow-Headers", strings.Join(m.config.CORSAllowedHeaders, ", "))
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
// isOriginAllowed checks if the origin is in the allowed list
|
||||
func (m *SecurityHeadersMiddleware) isOriginAllowed(origin string) bool {
|
||||
for _, allowed := range m.config.CORSAllowedOrigins {
|
||||
if m.matchOrigin(origin, allowed) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// matchOrigin checks if an origin matches an allowed pattern
|
||||
func (m *SecurityHeadersMiddleware) matchOrigin(origin, pattern string) bool {
|
||||
// Exact match
|
||||
if origin == pattern {
|
||||
return true
|
||||
}
|
||||
|
||||
// Wildcard subdomain match (e.g., "https://*.example.com")
|
||||
if strings.Contains(pattern, "*") {
|
||||
// Simple wildcard matching for subdomains
|
||||
if strings.HasPrefix(pattern, "https://*.") {
|
||||
domain := strings.TrimPrefix(pattern, "https://*.")
|
||||
if strings.HasSuffix(origin, "."+domain) || origin == "https://"+domain {
|
||||
return true
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(pattern, "http://*.") {
|
||||
domain := strings.TrimPrefix(pattern, "http://*.")
|
||||
if strings.HasSuffix(origin, "."+domain) || origin == "http://"+domain {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Port wildcard match (e.g., "http://localhost:*")
|
||||
if strings.HasSuffix(pattern, ":*") {
|
||||
prefix := strings.TrimSuffix(pattern, ":*")
|
||||
if strings.HasPrefix(origin, prefix+":") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Wrap wraps an HTTP handler with security headers
|
||||
func (m *SecurityHeadersMiddleware) Wrap(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
m.Apply(rw, req)
|
||||
next.ServeHTTP(rw, req)
|
||||
})
|
||||
}
|
||||
|
||||
// SecurityHeadersHandler is a convenience function that creates and applies security headers
|
||||
func SecurityHeadersHandler(config *SecurityHeadersConfig) func(http.ResponseWriter, *http.Request) {
|
||||
middleware := NewSecurityHeadersMiddleware(config)
|
||||
return middleware.Apply
|
||||
}
|
||||
|
||||
// Common security header presets
|
||||
|
||||
// StrictSecurityConfig returns a very strict security configuration
|
||||
func StrictSecurityConfig() *SecurityHeadersConfig {
|
||||
config := DefaultSecurityConfig()
|
||||
|
||||
// Very strict CSP
|
||||
config.ContentSecurityPolicy = "default-src 'none'; script-src 'self'; style-src 'self'; img-src 'self'; font-src 'self'; connect-src 'self'; frame-ancestors 'none'; base-uri 'self'; form-action 'self';"
|
||||
|
||||
// Stricter frame options
|
||||
config.FrameOptions = "DENY"
|
||||
|
||||
// Disable CORS entirely
|
||||
config.CORSEnabled = false
|
||||
|
||||
// Very strict cross-origin policies
|
||||
config.CrossOriginEmbedderPolicy = "require-corp"
|
||||
config.CrossOriginOpenerPolicy = "same-origin"
|
||||
config.CrossOriginResourcePolicy = "same-site"
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
// APISecurityConfig returns a configuration suitable for APIs
|
||||
func APISecurityConfig() *SecurityHeadersConfig {
|
||||
config := DefaultSecurityConfig()
|
||||
|
||||
// API-friendly CSP
|
||||
config.ContentSecurityPolicy = "default-src 'none'; frame-ancestors 'none';"
|
||||
|
||||
// Enable CORS for APIs
|
||||
config.CORSEnabled = true
|
||||
config.CORSAllowedMethods = []string{"GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"}
|
||||
config.CORSAllowedHeaders = []string{"Authorization", "Content-Type", "X-Requested-With", "X-API-Key"}
|
||||
|
||||
// API-appropriate policies
|
||||
config.CrossOriginResourcePolicy = "cross-origin"
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
// ValidateConfig validates the security configuration
|
||||
func (c *SecurityHeadersConfig) Validate() error {
|
||||
// Validate HSTS max age
|
||||
if c.StrictTransportSecurityMaxAge < 0 {
|
||||
c.StrictTransportSecurityMaxAge = 0
|
||||
}
|
||||
|
||||
// Validate CORS max age
|
||||
if c.CORSMaxAge < 0 {
|
||||
c.CORSMaxAge = 0
|
||||
}
|
||||
|
||||
// Validate frame options
|
||||
validFrameOptions := []string{"DENY", "SAMEORIGIN", ""}
|
||||
isValidFrameOption := false
|
||||
for _, valid := range validFrameOptions {
|
||||
if c.FrameOptions == valid || strings.HasPrefix(c.FrameOptions, "ALLOW-FROM ") {
|
||||
isValidFrameOption = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !isValidFrameOption {
|
||||
c.FrameOptions = "DENY"
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ApplyToResponseWriter is a helper function to quickly apply security headers
|
||||
func ApplySecurityHeaders(rw http.ResponseWriter, req *http.Request, config *SecurityHeadersConfig) {
|
||||
middleware := NewSecurityHeadersMiddleware(config)
|
||||
middleware.Apply(rw, req)
|
||||
}
|
||||
@@ -1,350 +0,0 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDefaultSecurityConfig(t *testing.T) {
|
||||
config := DefaultSecurityConfig()
|
||||
|
||||
if config.ContentSecurityPolicy == "" {
|
||||
t.Error("Expected default CSP to be set")
|
||||
}
|
||||
|
||||
if config.FrameOptions != "DENY" {
|
||||
t.Errorf("Expected frame options to be DENY, got %s", config.FrameOptions)
|
||||
}
|
||||
|
||||
if !config.DisableServerHeader {
|
||||
t.Error("Expected server header to be disabled by default")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecurityHeadersMiddleware_Apply(t *testing.T) {
|
||||
config := DefaultSecurityConfig()
|
||||
middleware := NewSecurityHeadersMiddleware(config)
|
||||
|
||||
// Create a mock request (HTTPS)
|
||||
req := httptest.NewRequest("GET", "https://example.com/test", nil)
|
||||
req.TLS = &tls.ConnectionState{} // Mock TLS
|
||||
|
||||
// Create a response recorder
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
// Apply security headers
|
||||
middleware.Apply(rr, req)
|
||||
|
||||
headers := rr.Header()
|
||||
|
||||
// Check that security headers are set
|
||||
if headers.Get("Content-Security-Policy") == "" {
|
||||
t.Error("Expected CSP header to be set")
|
||||
}
|
||||
|
||||
if headers.Get("X-Frame-Options") != "DENY" {
|
||||
t.Errorf("Expected X-Frame-Options to be DENY, got %s", headers.Get("X-Frame-Options"))
|
||||
}
|
||||
|
||||
if headers.Get("X-Content-Type-Options") != "nosniff" {
|
||||
t.Errorf("Expected X-Content-Type-Options to be nosniff, got %s", headers.Get("X-Content-Type-Options"))
|
||||
}
|
||||
|
||||
if headers.Get("X-XSS-Protection") != "1; mode=block" {
|
||||
t.Errorf("Expected X-XSS-Protection to be '1; mode=block', got %s", headers.Get("X-XSS-Protection"))
|
||||
}
|
||||
|
||||
// Check HSTS for HTTPS requests
|
||||
hsts := headers.Get("Strict-Transport-Security")
|
||||
if hsts == "" {
|
||||
t.Error("Expected HSTS header for HTTPS request")
|
||||
}
|
||||
|
||||
if !strings.Contains(hsts, "max-age=") {
|
||||
t.Error("Expected HSTS header to contain max-age")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecurityHeadersMiddleware_HTTPSOnly(t *testing.T) {
|
||||
config := DefaultSecurityConfig()
|
||||
middleware := NewSecurityHeadersMiddleware(config)
|
||||
|
||||
// Test HTTP request (no HSTS)
|
||||
req := httptest.NewRequest("GET", "http://example.com/test", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
middleware.Apply(rr, req)
|
||||
|
||||
if rr.Header().Get("Strict-Transport-Security") != "" {
|
||||
t.Error("Expected no HSTS header for HTTP request")
|
||||
}
|
||||
|
||||
// Test HTTPS request (with HSTS)
|
||||
req = httptest.NewRequest("GET", "https://example.com/test", nil)
|
||||
req.TLS = &tls.ConnectionState{}
|
||||
rr = httptest.NewRecorder()
|
||||
|
||||
middleware.Apply(rr, req)
|
||||
|
||||
if rr.Header().Get("Strict-Transport-Security") == "" {
|
||||
t.Error("Expected HSTS header for HTTPS request")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCORSHeaders(t *testing.T) {
|
||||
config := DefaultSecurityConfig()
|
||||
config.CORSEnabled = true
|
||||
config.CORSAllowedOrigins = []string{"https://example.com", "https://*.test.com"}
|
||||
config.CORSAllowCredentials = true
|
||||
|
||||
middleware := NewSecurityHeadersMiddleware(config)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
origin string
|
||||
expectedOrigin string
|
||||
}{
|
||||
{
|
||||
name: "exact match",
|
||||
origin: "https://example.com",
|
||||
expectedOrigin: "https://example.com",
|
||||
},
|
||||
{
|
||||
name: "wildcard subdomain match",
|
||||
origin: "https://api.test.com",
|
||||
expectedOrigin: "https://api.test.com",
|
||||
},
|
||||
{
|
||||
name: "no match",
|
||||
origin: "https://malicious.com",
|
||||
expectedOrigin: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "https://example.com/test", nil)
|
||||
if tt.origin != "" {
|
||||
req.Header.Set("Origin", tt.origin)
|
||||
}
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
middleware.Apply(rr, req)
|
||||
|
||||
actualOrigin := rr.Header().Get("Access-Control-Allow-Origin")
|
||||
if actualOrigin != tt.expectedOrigin {
|
||||
t.Errorf("Expected origin %s, got %s", tt.expectedOrigin, actualOrigin)
|
||||
}
|
||||
|
||||
if tt.expectedOrigin != "" {
|
||||
// Should have credentials header
|
||||
if rr.Header().Get("Access-Control-Allow-Credentials") != "true" {
|
||||
t.Error("Expected credentials header for allowed origin")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCORSPreflight(t *testing.T) {
|
||||
config := DefaultSecurityConfig()
|
||||
config.CORSEnabled = true
|
||||
config.CORSAllowedOrigins = []string{"*"}
|
||||
config.CORSAllowedMethods = []string{"GET", "POST", "OPTIONS"}
|
||||
|
||||
middleware := NewSecurityHeadersMiddleware(config)
|
||||
|
||||
req := httptest.NewRequest("OPTIONS", "https://example.com/test", nil)
|
||||
req.Header.Set("Origin", "https://other.com")
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
middleware.Apply(rr, req)
|
||||
|
||||
if rr.Header().Get("Access-Control-Allow-Origin") != "*" {
|
||||
t.Error("Expected wildcard origin for preflight request")
|
||||
}
|
||||
|
||||
if rr.Header().Get("Access-Control-Allow-Methods") == "" {
|
||||
t.Error("Expected methods header for preflight request")
|
||||
}
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200 for preflight, got %d", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOriginMatching(t *testing.T) {
|
||||
config := &SecurityHeadersConfig{
|
||||
CORSEnabled: true,
|
||||
CORSAllowedOrigins: []string{
|
||||
"https://example.com",
|
||||
"https://*.example.com",
|
||||
"http://localhost:*",
|
||||
},
|
||||
}
|
||||
|
||||
middleware := NewSecurityHeadersMiddleware(config)
|
||||
|
||||
tests := []struct {
|
||||
origin string
|
||||
expected bool
|
||||
}{
|
||||
{"https://example.com", true},
|
||||
{"https://api.example.com", true},
|
||||
{"https://sub.api.example.com", true},
|
||||
{"http://localhost:3000", true},
|
||||
{"http://localhost:8080", true},
|
||||
{"https://malicious.com", false},
|
||||
{"http://example.com", false}, // Different scheme
|
||||
{"https://example.com.evil.com", false}, // Domain suffix attack
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.origin, func(t *testing.T) {
|
||||
result := middleware.isOriginAllowed(tt.origin)
|
||||
if result != tt.expected {
|
||||
t.Errorf("Origin %s: expected %v, got %v", tt.origin, tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDevelopmentMode(t *testing.T) {
|
||||
config := DevelopmentSecurityConfig()
|
||||
|
||||
if !config.DevelopmentMode {
|
||||
t.Error("Expected development mode to be enabled")
|
||||
}
|
||||
|
||||
if !config.CORSEnabled {
|
||||
t.Error("Expected CORS to be enabled in development mode")
|
||||
}
|
||||
|
||||
if config.FrameOptions != "SAMEORIGIN" {
|
||||
t.Errorf("Expected frame options to be SAMEORIGIN in dev mode, got %s", config.FrameOptions)
|
||||
}
|
||||
|
||||
// Should be less strict CSP
|
||||
if strings.Contains(config.ContentSecurityPolicy, "'none'") {
|
||||
t.Error("Expected less strict CSP in development mode")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStrictSecurityConfig(t *testing.T) {
|
||||
config := StrictSecurityConfig()
|
||||
|
||||
if !strings.Contains(config.ContentSecurityPolicy, "'none'") {
|
||||
t.Error("Expected very strict CSP with 'none' defaults")
|
||||
}
|
||||
|
||||
if config.CORSEnabled {
|
||||
t.Error("Expected CORS to be disabled in strict mode")
|
||||
}
|
||||
|
||||
if config.FrameOptions != "DENY" {
|
||||
t.Error("Expected frame options to be DENY in strict mode")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPISecurityConfig(t *testing.T) {
|
||||
config := APISecurityConfig()
|
||||
|
||||
if !config.CORSEnabled {
|
||||
t.Error("Expected CORS to be enabled for API config")
|
||||
}
|
||||
|
||||
methods := config.CORSAllowedMethods
|
||||
expectedMethods := []string{"GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"}
|
||||
|
||||
for _, method := range expectedMethods {
|
||||
found := false
|
||||
for _, allowed := range methods {
|
||||
if allowed == method {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("Expected method %s to be allowed in API config", method)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddlewareWrap(t *testing.T) {
|
||||
config := DefaultSecurityConfig()
|
||||
middleware := NewSecurityHeadersMiddleware(config)
|
||||
|
||||
// Create a simple handler
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("OK"))
|
||||
})
|
||||
|
||||
// Wrap with security middleware
|
||||
wrappedHandler := middleware.Wrap(handler)
|
||||
|
||||
req := httptest.NewRequest("GET", "https://example.com/test", nil)
|
||||
req.TLS = &tls.ConnectionState{}
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
wrappedHandler.ServeHTTP(rr, req)
|
||||
|
||||
// Check response
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", rr.Code)
|
||||
}
|
||||
|
||||
if rr.Body.String() != "OK" {
|
||||
t.Errorf("Expected body 'OK', got %s", rr.Body.String())
|
||||
}
|
||||
|
||||
// Check security headers were applied
|
||||
if rr.Header().Get("X-Frame-Options") == "" {
|
||||
t.Error("Expected security headers to be applied by wrapper")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigValidation(t *testing.T) {
|
||||
config := &SecurityHeadersConfig{
|
||||
StrictTransportSecurityMaxAge: -1,
|
||||
CORSMaxAge: -1,
|
||||
FrameOptions: "INVALID",
|
||||
}
|
||||
|
||||
err := config.Validate()
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected validation error: %v", err)
|
||||
}
|
||||
|
||||
// Should fix invalid values
|
||||
if config.StrictTransportSecurityMaxAge != 0 {
|
||||
t.Error("Expected negative HSTS max age to be reset to 0")
|
||||
}
|
||||
|
||||
if config.CORSMaxAge != 0 {
|
||||
t.Error("Expected negative CORS max age to be reset to 0")
|
||||
}
|
||||
|
||||
if config.FrameOptions != "DENY" {
|
||||
t.Error("Expected invalid frame options to be reset to DENY")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSecurityHeadersApply(b *testing.B) {
|
||||
config := DefaultSecurityConfig()
|
||||
middleware := NewSecurityHeadersMiddleware(config)
|
||||
|
||||
req := httptest.NewRequest("GET", "https://example.com/test", nil)
|
||||
req.TLS = &tls.ConnectionState{}
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
rr := httptest.NewRecorder()
|
||||
middleware.Apply(rr, req)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1,394 +0,0 @@
|
||||
// Package singleton provides a centralized, thread-safe singleton management system
|
||||
// that consolidates all singleton patterns used throughout the application.
|
||||
// It ensures proper initialization, lifecycle management, and graceful shutdown.
|
||||
package singleton
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// Registry is the centralized singleton registry that manages all singleton instances
|
||||
// in the application. It provides thread-safe initialization, access, and cleanup.
|
||||
type Registry struct {
|
||||
mu sync.RWMutex
|
||||
instances map[string]*Instance
|
||||
groups map[string]*Group
|
||||
shutdown int32
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// Instance represents a singleton instance with lifecycle management
|
||||
type Instance struct {
|
||||
name string
|
||||
value interface{}
|
||||
initializer func() interface{}
|
||||
finalizer func(interface{})
|
||||
once sync.Once
|
||||
refCount int32
|
||||
}
|
||||
|
||||
// Group represents a group of related singletons
|
||||
type Group struct {
|
||||
name string
|
||||
instances map[string]*Instance
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
var (
|
||||
// globalRegistry is the singleton registry instance
|
||||
globalRegistry *Registry
|
||||
// registryOnce ensures single initialization
|
||||
registryOnce sync.Once
|
||||
)
|
||||
|
||||
// Get returns the global singleton registry
|
||||
func Get() *Registry {
|
||||
registryOnce.Do(func() {
|
||||
globalRegistry = &Registry{
|
||||
instances: make(map[string]*Instance),
|
||||
groups: make(map[string]*Group),
|
||||
}
|
||||
})
|
||||
return globalRegistry
|
||||
}
|
||||
|
||||
// Register registers a new singleton with its initializer and optional finalizer
|
||||
func (r *Registry) Register(name string, initializer func() interface{}, finalizer func(interface{})) error {
|
||||
if atomic.LoadInt32(&r.shutdown) == 1 {
|
||||
return fmt.Errorf("registry is shutting down")
|
||||
}
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if _, exists := r.instances[name]; exists {
|
||||
return fmt.Errorf("singleton %s already registered", name)
|
||||
}
|
||||
|
||||
r.instances[name] = &Instance{
|
||||
name: name,
|
||||
initializer: initializer,
|
||||
finalizer: finalizer,
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetInstance retrieves or initializes a singleton instance
|
||||
func (r *Registry) GetInstance(name string) (interface{}, error) {
|
||||
if atomic.LoadInt32(&r.shutdown) == 1 {
|
||||
return nil, fmt.Errorf("registry is shutting down")
|
||||
}
|
||||
|
||||
r.mu.RLock()
|
||||
instance, exists := r.instances[name]
|
||||
r.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("singleton %s not registered", name)
|
||||
}
|
||||
|
||||
// Initialize the singleton if needed
|
||||
instance.once.Do(func() {
|
||||
if instance.initializer != nil {
|
||||
instance.value = instance.initializer()
|
||||
atomic.AddInt32(&instance.refCount, 1)
|
||||
}
|
||||
})
|
||||
|
||||
return instance.value, nil
|
||||
}
|
||||
|
||||
// MustGet retrieves a singleton instance, panicking if not found
|
||||
func (r *Registry) MustGet(name string) interface{} {
|
||||
val, err := r.GetInstance(name)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("singleton %s: %v", name, err))
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
||||
// RegisterGroup creates a new singleton group
|
||||
func (r *Registry) RegisterGroup(name string) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if _, exists := r.groups[name]; exists {
|
||||
return fmt.Errorf("group %s already exists", name)
|
||||
}
|
||||
|
||||
r.groups[name] = &Group{
|
||||
name: name,
|
||||
instances: make(map[string]*Instance),
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddToGroup adds a singleton to a group
|
||||
func (r *Registry) AddToGroup(groupName, singletonName string) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
group, groupExists := r.groups[groupName]
|
||||
if !groupExists {
|
||||
return fmt.Errorf("group %s does not exist", groupName)
|
||||
}
|
||||
|
||||
instance, instanceExists := r.instances[singletonName]
|
||||
if !instanceExists {
|
||||
return fmt.Errorf("singleton %s not registered", singletonName)
|
||||
}
|
||||
|
||||
group.mu.Lock()
|
||||
defer group.mu.Unlock()
|
||||
|
||||
group.instances[singletonName] = instance
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetGroup retrieves all singletons in a group
|
||||
func (r *Registry) GetGroup(name string) (map[string]interface{}, error) {
|
||||
r.mu.RLock()
|
||||
group, exists := r.groups[name]
|
||||
r.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("group %s does not exist", name)
|
||||
}
|
||||
|
||||
group.mu.RLock()
|
||||
defer group.mu.RUnlock()
|
||||
|
||||
result := make(map[string]interface{})
|
||||
for name, instance := range group.instances {
|
||||
if instance.value != nil {
|
||||
result[name] = instance.value
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// AddReference increments the reference count for a singleton
|
||||
func (r *Registry) AddReference(name string) error {
|
||||
r.mu.RLock()
|
||||
instance, exists := r.instances[name]
|
||||
r.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return fmt.Errorf("singleton %s not registered", name)
|
||||
}
|
||||
|
||||
atomic.AddInt32(&instance.refCount, 1)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReleaseReference decrements the reference count for a singleton
|
||||
func (r *Registry) ReleaseReference(name string) error {
|
||||
r.mu.RLock()
|
||||
instance, exists := r.instances[name]
|
||||
r.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return fmt.Errorf("singleton %s not registered", name)
|
||||
}
|
||||
|
||||
count := atomic.AddInt32(&instance.refCount, -1)
|
||||
if count == 0 && instance.finalizer != nil && instance.value != nil {
|
||||
// Run finalizer when last reference is released
|
||||
go instance.finalizer(instance.value)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetReferenceCount returns the reference count for a singleton
|
||||
func (r *Registry) GetReferenceCount(name string) (int32, error) {
|
||||
r.mu.RLock()
|
||||
instance, exists := r.instances[name]
|
||||
r.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return 0, fmt.Errorf("singleton %s not registered", name)
|
||||
}
|
||||
|
||||
return atomic.LoadInt32(&instance.refCount), nil
|
||||
}
|
||||
|
||||
// Shutdown gracefully shuts down all singletons
|
||||
func (r *Registry) Shutdown(ctx context.Context) error {
|
||||
if !atomic.CompareAndSwapInt32(&r.shutdown, 0, 1) {
|
||||
return fmt.Errorf("registry already shutting down")
|
||||
}
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
// Create error channel for collecting shutdown errors
|
||||
errChan := make(chan error, len(r.instances))
|
||||
|
||||
// Run finalizers for all initialized singletons
|
||||
for name, instance := range r.instances {
|
||||
if instance.value != nil && instance.finalizer != nil {
|
||||
r.wg.Add(1)
|
||||
go func(n string, i *Instance) {
|
||||
defer r.wg.Done()
|
||||
|
||||
// Run finalizer with panic recovery
|
||||
func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
errChan <- fmt.Errorf("finalizer for %s panicked: %v", n, r)
|
||||
}
|
||||
}()
|
||||
i.finalizer(i.value)
|
||||
}()
|
||||
}(name, instance)
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for all finalizers to complete or timeout
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
r.wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// All finalizers completed
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("shutdown timeout: %w", ctx.Err())
|
||||
}
|
||||
|
||||
// Collect any errors
|
||||
close(errChan)
|
||||
var errs []error
|
||||
for err := range errChan {
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Clear all instances
|
||||
r.instances = make(map[string]*Instance)
|
||||
r.groups = make(map[string]*Group)
|
||||
|
||||
if len(errs) > 0 {
|
||||
return fmt.Errorf("shutdown errors: %v", errs)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reset resets the registry (mainly for testing)
|
||||
func (r *Registry) Reset() {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
r.instances = make(map[string]*Instance)
|
||||
r.groups = make(map[string]*Group)
|
||||
atomic.StoreInt32(&r.shutdown, 0)
|
||||
}
|
||||
|
||||
// Stats returns statistics about the registry
|
||||
type Stats struct {
|
||||
TotalRegistered int
|
||||
TotalInitialized int
|
||||
TotalGroups int
|
||||
TotalReferences int32
|
||||
}
|
||||
|
||||
// GetStats returns current registry statistics
|
||||
func (r *Registry) GetStats() Stats {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
stats := Stats{
|
||||
TotalRegistered: len(r.instances),
|
||||
TotalGroups: len(r.groups),
|
||||
}
|
||||
|
||||
for _, instance := range r.instances {
|
||||
if instance.value != nil {
|
||||
stats.TotalInitialized++
|
||||
}
|
||||
stats.TotalReferences += atomic.LoadInt32(&instance.refCount)
|
||||
}
|
||||
|
||||
return stats
|
||||
}
|
||||
|
||||
// Builder provides a fluent interface for registering singletons
|
||||
type Builder struct {
|
||||
registry *Registry
|
||||
name string
|
||||
initializer func() interface{}
|
||||
finalizer func(interface{})
|
||||
group string
|
||||
}
|
||||
|
||||
// NewBuilder creates a new singleton builder
|
||||
func NewBuilder(name string) *Builder {
|
||||
return &Builder{
|
||||
registry: Get(),
|
||||
name: name,
|
||||
}
|
||||
}
|
||||
|
||||
// WithInitializer sets the initializer function
|
||||
func (b *Builder) WithInitializer(init func() interface{}) *Builder {
|
||||
b.initializer = init
|
||||
return b
|
||||
}
|
||||
|
||||
// WithFinalizer sets the finalizer function
|
||||
func (b *Builder) WithFinalizer(final func(interface{})) *Builder {
|
||||
b.finalizer = final
|
||||
return b
|
||||
}
|
||||
|
||||
// InGroup adds the singleton to a group
|
||||
func (b *Builder) InGroup(group string) *Builder {
|
||||
b.group = group
|
||||
return b
|
||||
}
|
||||
|
||||
// Register registers the singleton with the configured options
|
||||
func (b *Builder) Register() error {
|
||||
if err := b.registry.Register(b.name, b.initializer, b.finalizer); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if b.group != "" {
|
||||
// Ensure group exists
|
||||
if err := b.registry.RegisterGroup(b.group); err != nil {
|
||||
// Group might already exist, which is ok
|
||||
if !contains(err.Error(), "already exists") {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return b.registry.AddToGroup(b.group, b.name)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Helper function to check if string contains substring
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && (s == substr || len(s) > len(substr) && containsHelper(s, substr))
|
||||
}
|
||||
|
||||
func containsHelper(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -1,970 +0,0 @@
|
||||
package singleton
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestGet_Singleton tests that Get() returns the same instance
|
||||
func TestGet_Singleton(t *testing.T) {
|
||||
registry1 := Get()
|
||||
registry2 := Get()
|
||||
|
||||
if registry1 != registry2 {
|
||||
t.Error("Get() should return the same instance (singleton)")
|
||||
}
|
||||
|
||||
if registry1 == nil {
|
||||
t.Error("Get() should not return nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRegistry_Register tests singleton registration
|
||||
func TestRegistry_Register(t *testing.T) {
|
||||
registry := &Registry{
|
||||
instances: make(map[string]*Instance),
|
||||
groups: make(map[string]*Group),
|
||||
}
|
||||
|
||||
initializer := func() interface{} {
|
||||
return "test-value"
|
||||
}
|
||||
|
||||
finalizer := func(v interface{}) {
|
||||
// Mock finalizer
|
||||
}
|
||||
|
||||
// Test successful registration
|
||||
err := registry.Register("test-singleton", initializer, finalizer)
|
||||
if err != nil {
|
||||
t.Errorf("Register should succeed, got error: %v", err)
|
||||
}
|
||||
|
||||
// Verify instance was registered
|
||||
if len(registry.instances) != 1 {
|
||||
t.Error("Instance should be registered")
|
||||
}
|
||||
|
||||
instance := registry.instances["test-singleton"]
|
||||
if instance == nil {
|
||||
t.Error("Instance should not be nil")
|
||||
return
|
||||
}
|
||||
|
||||
if instance.name != "test-singleton" {
|
||||
t.Errorf("Instance name should be 'test-singleton', got '%s'", instance.name)
|
||||
}
|
||||
|
||||
if instance.initializer == nil {
|
||||
t.Error("Instance should have initializer")
|
||||
}
|
||||
|
||||
if instance.finalizer == nil {
|
||||
t.Error("Instance should have finalizer")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRegistry_Register_Duplicate tests duplicate registration
|
||||
func TestRegistry_Register_Duplicate(t *testing.T) {
|
||||
registry := &Registry{
|
||||
instances: make(map[string]*Instance),
|
||||
groups: make(map[string]*Group),
|
||||
}
|
||||
|
||||
initializer := func() interface{} {
|
||||
return "test-value"
|
||||
}
|
||||
|
||||
// Register first time
|
||||
err := registry.Register("test-singleton", initializer, nil)
|
||||
if err != nil {
|
||||
t.Errorf("First registration should succeed, got error: %v", err)
|
||||
}
|
||||
|
||||
// Register again - should fail
|
||||
err = registry.Register("test-singleton", initializer, nil)
|
||||
if err == nil {
|
||||
t.Error("Duplicate registration should fail")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "already registered") {
|
||||
t.Errorf("Error should mention already registered, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRegistry_Register_DuringShutdown tests registration during shutdown
|
||||
func TestRegistry_Register_DuringShutdown(t *testing.T) {
|
||||
registry := &Registry{
|
||||
instances: make(map[string]*Instance),
|
||||
groups: make(map[string]*Group),
|
||||
shutdown: 1, // Already shutting down
|
||||
}
|
||||
|
||||
initializer := func() interface{} {
|
||||
return "test-value"
|
||||
}
|
||||
|
||||
err := registry.Register("test-singleton", initializer, nil)
|
||||
if err == nil {
|
||||
t.Error("Registration during shutdown should fail")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "shutting down") {
|
||||
t.Errorf("Error should mention shutting down, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRegistry_GetInstance tests singleton retrieval and initialization
|
||||
func TestRegistry_GetInstance(t *testing.T) {
|
||||
registry := &Registry{
|
||||
instances: make(map[string]*Instance),
|
||||
groups: make(map[string]*Group),
|
||||
}
|
||||
|
||||
callCount := int32(0)
|
||||
testValue := "test-value"
|
||||
|
||||
initializer := func() interface{} {
|
||||
atomic.AddInt32(&callCount, 1)
|
||||
return testValue
|
||||
}
|
||||
|
||||
// Register singleton
|
||||
err := registry.Register("test-singleton", initializer, nil)
|
||||
if err != nil {
|
||||
t.Errorf("Register should succeed, got error: %v", err)
|
||||
}
|
||||
|
||||
// First get - should initialize
|
||||
value1, err := registry.GetInstance("test-singleton")
|
||||
if err != nil {
|
||||
t.Errorf("GetInstance should succeed, got error: %v", err)
|
||||
}
|
||||
|
||||
if value1 != testValue {
|
||||
t.Errorf("Value should be '%s', got '%v'", testValue, value1)
|
||||
}
|
||||
|
||||
if atomic.LoadInt32(&callCount) != 1 {
|
||||
t.Errorf("Initializer should be called once, called %d times", callCount)
|
||||
}
|
||||
|
||||
// Second get - should return same instance without calling initializer
|
||||
value2, err := registry.GetInstance("test-singleton")
|
||||
if err != nil {
|
||||
t.Errorf("GetInstance should succeed, got error: %v", err)
|
||||
}
|
||||
|
||||
if value2 != testValue {
|
||||
t.Errorf("Value should be '%s', got '%v'", testValue, value2)
|
||||
}
|
||||
|
||||
if atomic.LoadInt32(&callCount) != 1 {
|
||||
t.Errorf("Initializer should still be called only once, called %d times", callCount)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRegistry_GetInstance_NotRegistered tests getting unregistered singleton
|
||||
func TestRegistry_GetInstance_NotRegistered(t *testing.T) {
|
||||
registry := &Registry{
|
||||
instances: make(map[string]*Instance),
|
||||
groups: make(map[string]*Group),
|
||||
}
|
||||
|
||||
value, err := registry.GetInstance("non-existent")
|
||||
if err == nil {
|
||||
t.Error("GetInstance of non-existent singleton should fail")
|
||||
}
|
||||
|
||||
if value != nil {
|
||||
t.Error("Value should be nil for non-existent singleton")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "not registered") {
|
||||
t.Errorf("Error should mention not registered, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRegistry_GetInstance_DuringShutdown tests getting instance during shutdown
|
||||
func TestRegistry_GetInstance_DuringShutdown(t *testing.T) {
|
||||
registry := &Registry{
|
||||
instances: make(map[string]*Instance),
|
||||
groups: make(map[string]*Group),
|
||||
shutdown: 1, // Already shutting down
|
||||
}
|
||||
|
||||
value, err := registry.GetInstance("test-singleton")
|
||||
if err == nil {
|
||||
t.Error("GetInstance during shutdown should fail")
|
||||
}
|
||||
|
||||
if value != nil {
|
||||
t.Error("Value should be nil during shutdown")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "shutting down") {
|
||||
t.Errorf("Error should mention shutting down, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRegistry_MustGet tests MustGet method
|
||||
func TestRegistry_MustGet(t *testing.T) {
|
||||
registry := &Registry{
|
||||
instances: make(map[string]*Instance),
|
||||
groups: make(map[string]*Group),
|
||||
}
|
||||
|
||||
testValue := "test-value"
|
||||
initializer := func() interface{} {
|
||||
return testValue
|
||||
}
|
||||
|
||||
// Register singleton
|
||||
err := registry.Register("test-singleton", initializer, nil)
|
||||
if err != nil {
|
||||
t.Errorf("Register should succeed, got error: %v", err)
|
||||
}
|
||||
|
||||
// MustGet should succeed
|
||||
value := registry.MustGet("test-singleton")
|
||||
if value != testValue {
|
||||
t.Errorf("Value should be '%s', got '%v'", testValue, value)
|
||||
}
|
||||
|
||||
// MustGet non-existent should panic
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Error("MustGet of non-existent singleton should panic")
|
||||
}
|
||||
}()
|
||||
|
||||
registry.MustGet("non-existent")
|
||||
}
|
||||
|
||||
// TestRegistry_RegisterGroup tests group registration
|
||||
func TestRegistry_RegisterGroup(t *testing.T) {
|
||||
registry := &Registry{
|
||||
instances: make(map[string]*Instance),
|
||||
groups: make(map[string]*Group),
|
||||
}
|
||||
|
||||
// Test successful group registration
|
||||
err := registry.RegisterGroup("test-group")
|
||||
if err != nil {
|
||||
t.Errorf("RegisterGroup should succeed, got error: %v", err)
|
||||
}
|
||||
|
||||
// Verify group was registered
|
||||
if len(registry.groups) != 1 {
|
||||
t.Error("Group should be registered")
|
||||
}
|
||||
|
||||
group := registry.groups["test-group"]
|
||||
if group == nil {
|
||||
t.Error("Group should not be nil")
|
||||
return
|
||||
}
|
||||
|
||||
if group.name != "test-group" {
|
||||
t.Errorf("Group name should be 'test-group', got '%s'", group.name)
|
||||
}
|
||||
|
||||
// Test duplicate group registration
|
||||
err = registry.RegisterGroup("test-group")
|
||||
if err == nil {
|
||||
t.Error("Duplicate group registration should fail")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "already exists") {
|
||||
t.Errorf("Error should mention already exists, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRegistry_AddToGroup tests adding singletons to groups
|
||||
func TestRegistry_AddToGroup(t *testing.T) {
|
||||
registry := &Registry{
|
||||
instances: make(map[string]*Instance),
|
||||
groups: make(map[string]*Group),
|
||||
}
|
||||
|
||||
// Register a singleton
|
||||
initializer := func() interface{} {
|
||||
return "test-value"
|
||||
}
|
||||
|
||||
err := registry.Register("test-singleton", initializer, nil)
|
||||
if err != nil {
|
||||
t.Errorf("Register should succeed, got error: %v", err)
|
||||
}
|
||||
|
||||
// Register a group
|
||||
err = registry.RegisterGroup("test-group")
|
||||
if err != nil {
|
||||
t.Errorf("RegisterGroup should succeed, got error: %v", err)
|
||||
}
|
||||
|
||||
// Add singleton to group
|
||||
err = registry.AddToGroup("test-group", "test-singleton")
|
||||
if err != nil {
|
||||
t.Errorf("AddToGroup should succeed, got error: %v", err)
|
||||
}
|
||||
|
||||
// Verify singleton is in group
|
||||
group := registry.groups["test-group"]
|
||||
if len(group.instances) != 1 {
|
||||
t.Error("Group should contain one instance")
|
||||
}
|
||||
|
||||
if group.instances["test-singleton"] == nil {
|
||||
t.Error("Singleton should be in group")
|
||||
}
|
||||
|
||||
// Test adding to non-existent group
|
||||
err = registry.AddToGroup("non-existent-group", "test-singleton")
|
||||
if err == nil {
|
||||
t.Error("Adding to non-existent group should fail")
|
||||
}
|
||||
|
||||
// Test adding non-existent singleton to group
|
||||
err = registry.AddToGroup("test-group", "non-existent-singleton")
|
||||
if err == nil {
|
||||
t.Error("Adding non-existent singleton should fail")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRegistry_GetGroup tests retrieving group instances
|
||||
func TestRegistry_GetGroup(t *testing.T) {
|
||||
registry := &Registry{
|
||||
instances: make(map[string]*Instance),
|
||||
groups: make(map[string]*Group),
|
||||
}
|
||||
|
||||
// Register singletons
|
||||
err := registry.Register("test-singleton-1", func() interface{} {
|
||||
return "value-1"
|
||||
}, nil)
|
||||
if err != nil {
|
||||
t.Errorf("Register should succeed, got error: %v", err)
|
||||
}
|
||||
|
||||
err = registry.Register("test-singleton-2", func() interface{} {
|
||||
return "value-2"
|
||||
}, nil)
|
||||
if err != nil {
|
||||
t.Errorf("Register should succeed, got error: %v", err)
|
||||
}
|
||||
|
||||
// Register group and add singletons
|
||||
err = registry.RegisterGroup("test-group")
|
||||
if err != nil {
|
||||
t.Errorf("RegisterGroup should succeed, got error: %v", err)
|
||||
}
|
||||
|
||||
err = registry.AddToGroup("test-group", "test-singleton-1")
|
||||
if err != nil {
|
||||
t.Errorf("AddToGroup should succeed, got error: %v", err)
|
||||
}
|
||||
|
||||
err = registry.AddToGroup("test-group", "test-singleton-2")
|
||||
if err != nil {
|
||||
t.Errorf("AddToGroup should succeed, got error: %v", err)
|
||||
}
|
||||
|
||||
// Initialize singletons
|
||||
_, _ = registry.GetInstance("test-singleton-1")
|
||||
_, _ = registry.GetInstance("test-singleton-2")
|
||||
|
||||
// Get group
|
||||
groupInstances, err := registry.GetGroup("test-group")
|
||||
if err != nil {
|
||||
t.Errorf("GetGroup should succeed, got error: %v", err)
|
||||
}
|
||||
|
||||
if len(groupInstances) != 2 {
|
||||
t.Errorf("Group should contain 2 instances, got %d", len(groupInstances))
|
||||
}
|
||||
|
||||
if groupInstances["test-singleton-1"] != "value-1" {
|
||||
t.Error("Group should contain correct instance values")
|
||||
}
|
||||
|
||||
if groupInstances["test-singleton-2"] != "value-2" {
|
||||
t.Error("Group should contain correct instance values")
|
||||
}
|
||||
|
||||
// Test getting non-existent group
|
||||
_, err = registry.GetGroup("non-existent-group")
|
||||
if err == nil {
|
||||
t.Error("Getting non-existent group should fail")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRegistry_ReferenceCountingv tests reference counting
|
||||
func TestRegistry_ReferenceCountingv(t *testing.T) {
|
||||
registry := &Registry{
|
||||
instances: make(map[string]*Instance),
|
||||
groups: make(map[string]*Group),
|
||||
}
|
||||
|
||||
finalizerCalled := int32(0)
|
||||
finalizer := func(v interface{}) {
|
||||
atomic.AddInt32(&finalizerCalled, 1)
|
||||
}
|
||||
|
||||
// Register singleton
|
||||
err := registry.Register("test-singleton", func() interface{} {
|
||||
return "test-value"
|
||||
}, finalizer)
|
||||
if err != nil {
|
||||
t.Errorf("Register should succeed, got error: %v", err)
|
||||
}
|
||||
|
||||
// Initialize singleton (this adds 1 reference)
|
||||
_, err = registry.GetInstance("test-singleton")
|
||||
if err != nil {
|
||||
t.Errorf("GetInstance should succeed, got error: %v", err)
|
||||
}
|
||||
|
||||
// Check initial reference count
|
||||
count, err := registry.GetReferenceCount("test-singleton")
|
||||
if err != nil {
|
||||
t.Errorf("GetReferenceCount should succeed, got error: %v", err)
|
||||
}
|
||||
|
||||
if count != 1 {
|
||||
t.Errorf("Reference count should be 1, got %d", count)
|
||||
}
|
||||
|
||||
// Add reference
|
||||
err = registry.AddReference("test-singleton")
|
||||
if err != nil {
|
||||
t.Errorf("AddReference should succeed, got error: %v", err)
|
||||
}
|
||||
|
||||
count, _ = registry.GetReferenceCount("test-singleton")
|
||||
if count != 2 {
|
||||
t.Errorf("Reference count should be 2, got %d", count)
|
||||
}
|
||||
|
||||
// Release reference
|
||||
err = registry.ReleaseReference("test-singleton")
|
||||
if err != nil {
|
||||
t.Errorf("ReleaseReference should succeed, got error: %v", err)
|
||||
}
|
||||
|
||||
count, _ = registry.GetReferenceCount("test-singleton")
|
||||
if count != 1 {
|
||||
t.Errorf("Reference count should be 1, got %d", count)
|
||||
}
|
||||
|
||||
// Release last reference - should trigger finalizer
|
||||
err = registry.ReleaseReference("test-singleton")
|
||||
if err != nil {
|
||||
t.Errorf("ReleaseReference should succeed, got error: %v", err)
|
||||
}
|
||||
|
||||
count, _ = registry.GetReferenceCount("test-singleton")
|
||||
if count != 0 {
|
||||
t.Errorf("Reference count should be 0, got %d", count)
|
||||
}
|
||||
|
||||
// Wait for finalizer to run (it runs in goroutine)
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
if atomic.LoadInt32(&finalizerCalled) != 1 {
|
||||
t.Errorf("Finalizer should be called once, called %d times", finalizerCalled)
|
||||
}
|
||||
|
||||
// Test reference operations on non-existent singleton
|
||||
err = registry.AddReference("non-existent")
|
||||
if err == nil {
|
||||
t.Error("AddReference on non-existent singleton should fail")
|
||||
}
|
||||
|
||||
err = registry.ReleaseReference("non-existent")
|
||||
if err == nil {
|
||||
t.Error("ReleaseReference on non-existent singleton should fail")
|
||||
}
|
||||
|
||||
_, err = registry.GetReferenceCount("non-existent")
|
||||
if err == nil {
|
||||
t.Error("GetReferenceCount on non-existent singleton should fail")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRegistry_Shutdown tests graceful shutdown
|
||||
func TestRegistry_Shutdown(t *testing.T) {
|
||||
registry := &Registry{
|
||||
instances: make(map[string]*Instance),
|
||||
groups: make(map[string]*Group),
|
||||
}
|
||||
|
||||
finalizerCalled := int32(0)
|
||||
finalizer := func(v interface{}) {
|
||||
atomic.AddInt32(&finalizerCalled, 1)
|
||||
}
|
||||
|
||||
// Register and initialize singletons
|
||||
err := registry.Register("test-singleton-1", func() interface{} {
|
||||
return "value-1"
|
||||
}, finalizer)
|
||||
if err != nil {
|
||||
t.Errorf("Register should succeed, got error: %v", err)
|
||||
}
|
||||
|
||||
err = registry.Register("test-singleton-2", func() interface{} {
|
||||
return "value-2"
|
||||
}, finalizer)
|
||||
if err != nil {
|
||||
t.Errorf("Register should succeed, got error: %v", err)
|
||||
}
|
||||
|
||||
// Initialize singletons
|
||||
_, _ = registry.GetInstance("test-singleton-1")
|
||||
_, _ = registry.GetInstance("test-singleton-2")
|
||||
|
||||
// Shutdown
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err = registry.Shutdown(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("Shutdown should succeed, got error: %v", err)
|
||||
}
|
||||
|
||||
// Verify finalizers were called
|
||||
if atomic.LoadInt32(&finalizerCalled) != 2 {
|
||||
t.Errorf("Finalizers should be called 2 times, called %d times", finalizerCalled)
|
||||
}
|
||||
|
||||
// Verify registry is cleared
|
||||
if len(registry.instances) != 0 {
|
||||
t.Error("Instances should be cleared after shutdown")
|
||||
}
|
||||
|
||||
if len(registry.groups) != 0 {
|
||||
t.Error("Groups should be cleared after shutdown")
|
||||
}
|
||||
|
||||
// Verify shutdown flag is set
|
||||
if atomic.LoadInt32(®istry.shutdown) != 1 {
|
||||
t.Error("Shutdown flag should be set")
|
||||
}
|
||||
|
||||
// Test double shutdown
|
||||
err = registry.Shutdown(ctx)
|
||||
if err == nil {
|
||||
t.Error("Double shutdown should fail")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRegistry_Shutdown_Timeout tests shutdown timeout
|
||||
func TestRegistry_Shutdown_Timeout(t *testing.T) {
|
||||
registry := &Registry{
|
||||
instances: make(map[string]*Instance),
|
||||
groups: make(map[string]*Group),
|
||||
}
|
||||
|
||||
// Register singleton with slow finalizer
|
||||
slowFinalizer := func(v interface{}) {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
err := registry.Register("slow-singleton", func() interface{} {
|
||||
return "value"
|
||||
}, slowFinalizer)
|
||||
if err != nil {
|
||||
t.Errorf("Register should succeed, got error: %v", err)
|
||||
}
|
||||
|
||||
// Initialize singleton
|
||||
_, _ = registry.GetInstance("slow-singleton")
|
||||
|
||||
// Shutdown with short timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
err = registry.Shutdown(ctx)
|
||||
if err == nil {
|
||||
t.Error("Shutdown should timeout")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "timeout") {
|
||||
t.Errorf("Error should mention timeout, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRegistry_Shutdown_PanicRecovery tests panic recovery during shutdown
|
||||
func TestRegistry_Shutdown_PanicRecovery(t *testing.T) {
|
||||
registry := &Registry{
|
||||
instances: make(map[string]*Instance),
|
||||
groups: make(map[string]*Group),
|
||||
}
|
||||
|
||||
// Register singleton with panicking finalizer
|
||||
panicFinalizer := func(v interface{}) {
|
||||
panic("finalizer panic")
|
||||
}
|
||||
|
||||
err := registry.Register("panic-singleton", func() interface{} {
|
||||
return "value"
|
||||
}, panicFinalizer)
|
||||
if err != nil {
|
||||
t.Errorf("Register should succeed, got error: %v", err)
|
||||
}
|
||||
|
||||
// Initialize singleton
|
||||
_, _ = registry.GetInstance("panic-singleton")
|
||||
|
||||
// Shutdown should handle panic
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err = registry.Shutdown(ctx)
|
||||
if err == nil {
|
||||
t.Error("Shutdown should report finalizer panic")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "panicked") {
|
||||
t.Errorf("Error should mention panic, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRegistry_Reset tests registry reset
|
||||
func TestRegistry_Reset(t *testing.T) {
|
||||
registry := &Registry{
|
||||
instances: make(map[string]*Instance),
|
||||
groups: make(map[string]*Group),
|
||||
shutdown: 1,
|
||||
}
|
||||
|
||||
// Add some data
|
||||
registry.instances["test"] = &Instance{}
|
||||
registry.groups["test"] = &Group{}
|
||||
|
||||
// Reset
|
||||
registry.Reset()
|
||||
|
||||
// Verify everything is cleared
|
||||
if len(registry.instances) != 0 {
|
||||
t.Error("Instances should be cleared after reset")
|
||||
}
|
||||
|
||||
if len(registry.groups) != 0 {
|
||||
t.Error("Groups should be cleared after reset")
|
||||
}
|
||||
|
||||
if atomic.LoadInt32(®istry.shutdown) != 0 {
|
||||
t.Error("Shutdown flag should be cleared after reset")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRegistry_GetStats tests statistics
|
||||
func TestRegistry_GetStats(t *testing.T) {
|
||||
registry := &Registry{
|
||||
instances: make(map[string]*Instance),
|
||||
groups: make(map[string]*Group),
|
||||
}
|
||||
|
||||
// Register singletons
|
||||
err := registry.Register("test-singleton-1", func() interface{} {
|
||||
return "value-1"
|
||||
}, nil)
|
||||
if err != nil {
|
||||
t.Errorf("Register should succeed, got error: %v", err)
|
||||
}
|
||||
|
||||
err = registry.Register("test-singleton-2", func() interface{} {
|
||||
return "value-2"
|
||||
}, nil)
|
||||
if err != nil {
|
||||
t.Errorf("Register should succeed, got error: %v", err)
|
||||
}
|
||||
|
||||
// Register group
|
||||
err = registry.RegisterGroup("test-group")
|
||||
if err != nil {
|
||||
t.Errorf("RegisterGroup should succeed, got error: %v", err)
|
||||
}
|
||||
|
||||
// Initialize one singleton
|
||||
_, _ = registry.GetInstance("test-singleton-1")
|
||||
|
||||
// Add reference
|
||||
_ = registry.AddReference("test-singleton-1")
|
||||
|
||||
// Get stats
|
||||
stats := registry.GetStats()
|
||||
|
||||
if stats.TotalRegistered != 2 {
|
||||
t.Errorf("TotalRegistered should be 2, got %d", stats.TotalRegistered)
|
||||
}
|
||||
|
||||
if stats.TotalInitialized != 1 {
|
||||
t.Errorf("TotalInitialized should be 1, got %d", stats.TotalInitialized)
|
||||
}
|
||||
|
||||
if stats.TotalGroups != 1 {
|
||||
t.Errorf("TotalGroups should be 1, got %d", stats.TotalGroups)
|
||||
}
|
||||
|
||||
if stats.TotalReferences != 2 { // 1 from initialization + 1 from AddReference
|
||||
t.Errorf("TotalReferences should be 2, got %d", stats.TotalReferences)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuilder tests the fluent builder interface
|
||||
func TestBuilder(t *testing.T) {
|
||||
// Reset global registry for clean test
|
||||
Get().Reset()
|
||||
|
||||
testValue := "builder-test-value"
|
||||
|
||||
initializer := func() interface{} {
|
||||
return testValue
|
||||
}
|
||||
|
||||
finalizer := func(v interface{}) {
|
||||
// Mock finalizer for builder test
|
||||
}
|
||||
|
||||
// Test builder
|
||||
err := NewBuilder("builder-singleton").
|
||||
WithInitializer(initializer).
|
||||
WithFinalizer(finalizer).
|
||||
InGroup("builder-group").
|
||||
Register()
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Builder registration should succeed, got error: %v", err)
|
||||
}
|
||||
|
||||
// Verify singleton was registered
|
||||
value, err := Get().GetInstance("builder-singleton")
|
||||
if err != nil {
|
||||
t.Errorf("GetInstance should succeed, got error: %v", err)
|
||||
}
|
||||
|
||||
if value != testValue {
|
||||
t.Errorf("Value should be '%s', got '%v'", testValue, value)
|
||||
}
|
||||
|
||||
// Verify group was created and singleton added
|
||||
groupInstances, err := Get().GetGroup("builder-group")
|
||||
if err != nil {
|
||||
t.Errorf("GetGroup should succeed, got error: %v", err)
|
||||
}
|
||||
|
||||
if len(groupInstances) != 1 {
|
||||
t.Errorf("Group should contain 1 instance, got %d", len(groupInstances))
|
||||
}
|
||||
|
||||
if groupInstances["builder-singleton"] != testValue {
|
||||
t.Error("Group should contain correct instance")
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuilder_WithoutGroup tests builder without group
|
||||
func TestBuilder_WithoutGroup(t *testing.T) {
|
||||
registry := &Registry{
|
||||
instances: make(map[string]*Instance),
|
||||
groups: make(map[string]*Group),
|
||||
}
|
||||
|
||||
builder := &Builder{
|
||||
registry: registry,
|
||||
name: "no-group-singleton",
|
||||
}
|
||||
|
||||
err := builder.WithInitializer(func() interface{} {
|
||||
return "value"
|
||||
}).Register()
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Registration without group should succeed, got error: %v", err)
|
||||
}
|
||||
|
||||
// Verify singleton was registered
|
||||
if len(registry.instances) != 1 {
|
||||
t.Error("Singleton should be registered")
|
||||
}
|
||||
}
|
||||
|
||||
// TestContainsHelper tests the helper string contains function
|
||||
func TestContainsHelper(t *testing.T) {
|
||||
tests := []struct {
|
||||
s string
|
||||
substr string
|
||||
expect bool
|
||||
}{
|
||||
{"hello world", "world", true},
|
||||
{"hello world", "hello", true},
|
||||
{"hello world", "lo wo", true},
|
||||
{"hello world", "xyz", false},
|
||||
{"hello", "hello world", false},
|
||||
{"", "test", false},
|
||||
{"test", "", true},
|
||||
{"", "", true},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
result := contains(test.s, test.substr)
|
||||
if result != test.expect {
|
||||
t.Errorf("contains(%q, %q) = %v, want %v", test.s, test.substr, result, test.expect)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestRegistry_ConcurrentAccess tests concurrent access to registry
|
||||
func TestRegistry_ConcurrentAccess(t *testing.T) {
|
||||
registry := &Registry{
|
||||
instances: make(map[string]*Instance),
|
||||
groups: make(map[string]*Group),
|
||||
}
|
||||
|
||||
callCount := int32(0)
|
||||
initializer := func() interface{} {
|
||||
atomic.AddInt32(&callCount, 1)
|
||||
return "concurrent-value"
|
||||
}
|
||||
|
||||
// Register singleton
|
||||
err := registry.Register("concurrent-singleton", initializer, nil)
|
||||
if err != nil {
|
||||
t.Errorf("Register should succeed, got error: %v", err)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
numGoroutines := 50
|
||||
|
||||
// Concurrent access
|
||||
wg.Add(numGoroutines)
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
value, err := registry.GetInstance("concurrent-singleton")
|
||||
if err != nil {
|
||||
t.Errorf("GetInstance should succeed, got error: %v", err)
|
||||
return
|
||||
}
|
||||
if value != "concurrent-value" {
|
||||
t.Errorf("Value should be 'concurrent-value', got '%v'", value)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Initializer should be called only once despite concurrent access
|
||||
if atomic.LoadInt32(&callCount) != 1 {
|
||||
t.Errorf("Initializer should be called only once, called %d times", callCount)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRegistry_ConcurrentReferenceOperations tests concurrent reference operations
|
||||
func TestRegistry_ConcurrentReferenceOperations(t *testing.T) {
|
||||
registry := &Registry{
|
||||
instances: make(map[string]*Instance),
|
||||
groups: make(map[string]*Group),
|
||||
}
|
||||
|
||||
// Register singleton
|
||||
err := registry.Register("ref-singleton", func() interface{} {
|
||||
return "ref-value"
|
||||
}, nil)
|
||||
if err != nil {
|
||||
t.Errorf("Register should succeed, got error: %v", err)
|
||||
}
|
||||
|
||||
// Initialize singleton
|
||||
_, _ = registry.GetInstance("ref-singleton")
|
||||
|
||||
var wg sync.WaitGroup
|
||||
numGoroutines := 20
|
||||
|
||||
// Concurrent reference operations
|
||||
wg.Add(numGoroutines * 2)
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_ = registry.AddReference("ref-singleton")
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_ = registry.ReleaseReference("ref-singleton")
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Reference count should be consistent (initial 1 + net operations)
|
||||
count, err := registry.GetReferenceCount("ref-singleton")
|
||||
if err != nil {
|
||||
t.Errorf("GetReferenceCount should succeed, got error: %v", err)
|
||||
}
|
||||
|
||||
// Count should be >= 0 due to balanced add/release operations
|
||||
if count < 0 {
|
||||
t.Errorf("Reference count should not be negative, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark tests for performance verification
|
||||
func BenchmarkRegistry_GetInstance(b *testing.B) {
|
||||
registry := &Registry{
|
||||
instances: make(map[string]*Instance),
|
||||
groups: make(map[string]*Group),
|
||||
}
|
||||
|
||||
registry.Register("benchmark-singleton", func() interface{} {
|
||||
return "benchmark-value"
|
||||
}, nil)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
registry.GetInstance("benchmark-singleton")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkRegistry_ConcurrentGetInstance(b *testing.B) {
|
||||
registry := &Registry{
|
||||
instances: make(map[string]*Instance),
|
||||
groups: make(map[string]*Group),
|
||||
}
|
||||
|
||||
registry.Register("concurrent-benchmark", func() interface{} {
|
||||
return "concurrent-value"
|
||||
}, nil)
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
registry.GetInstance("concurrent-benchmark")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkBuilder_Register(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
registry := &Registry{
|
||||
instances: make(map[string]*Instance),
|
||||
groups: make(map[string]*Group),
|
||||
}
|
||||
|
||||
builder := &Builder{
|
||||
registry: registry,
|
||||
name: fmt.Sprintf("benchmark-%d", i),
|
||||
}
|
||||
|
||||
builder.WithInitializer(func() interface{} {
|
||||
return "value"
|
||||
}).Register()
|
||||
}
|
||||
}
|
||||
@@ -1,393 +0,0 @@
|
||||
// Package testing provides unified mock implementations for tests
|
||||
package testing
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// UnifiedMockLogger provides a standard mock logger for all tests
|
||||
type UnifiedMockLogger struct {
|
||||
LoggedMessages []string
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewUnifiedMockLogger() *UnifiedMockLogger {
|
||||
return &UnifiedMockLogger{
|
||||
LoggedMessages: make([]string, 0),
|
||||
}
|
||||
}
|
||||
|
||||
func (l *UnifiedMockLogger) Debug(msg string) {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
l.LoggedMessages = append(l.LoggedMessages, fmt.Sprintf("DEBUG: %s", msg))
|
||||
}
|
||||
|
||||
func (l *UnifiedMockLogger) Debugf(format string, args ...interface{}) {
|
||||
l.Debug(fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
func (l *UnifiedMockLogger) Info(msg string) {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
l.LoggedMessages = append(l.LoggedMessages, fmt.Sprintf("INFO: %s", msg))
|
||||
}
|
||||
|
||||
func (l *UnifiedMockLogger) Infof(format string, args ...interface{}) {
|
||||
l.Info(fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
func (l *UnifiedMockLogger) Error(msg string) {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
l.LoggedMessages = append(l.LoggedMessages, fmt.Sprintf("ERROR: %s", msg))
|
||||
}
|
||||
|
||||
func (l *UnifiedMockLogger) Errorf(format string, args ...interface{}) {
|
||||
l.Error(fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
func (l *UnifiedMockLogger) GetMessages() []string {
|
||||
l.mu.RLock()
|
||||
defer l.mu.RUnlock()
|
||||
result := make([]string, len(l.LoggedMessages))
|
||||
copy(result, l.LoggedMessages)
|
||||
return result
|
||||
}
|
||||
|
||||
func (l *UnifiedMockLogger) Clear() {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
l.LoggedMessages = l.LoggedMessages[:0]
|
||||
}
|
||||
|
||||
// UnifiedMockSession provides a standard mock session for all tests
|
||||
type UnifiedMockSession struct {
|
||||
authenticated bool
|
||||
idToken string
|
||||
accessToken string
|
||||
refreshToken string
|
||||
email string
|
||||
csrf string
|
||||
nonce string
|
||||
codeVerifier string
|
||||
incomingPath string
|
||||
redirectCount int
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewUnifiedMockSession() *UnifiedMockSession {
|
||||
return &UnifiedMockSession{}
|
||||
}
|
||||
|
||||
func (s *UnifiedMockSession) GetAuthenticated() bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.authenticated
|
||||
}
|
||||
|
||||
func (s *UnifiedMockSession) SetAuthenticated(auth bool) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.authenticated = auth
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *UnifiedMockSession) GetIDToken() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.idToken
|
||||
}
|
||||
|
||||
func (s *UnifiedMockSession) SetIDToken(token string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.idToken = token
|
||||
}
|
||||
|
||||
func (s *UnifiedMockSession) GetAccessToken() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.accessToken
|
||||
}
|
||||
|
||||
func (s *UnifiedMockSession) SetAccessToken(token string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.accessToken = token
|
||||
}
|
||||
|
||||
func (s *UnifiedMockSession) GetRefreshToken() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.refreshToken
|
||||
}
|
||||
|
||||
func (s *UnifiedMockSession) SetRefreshToken(token string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.refreshToken = token
|
||||
}
|
||||
|
||||
func (s *UnifiedMockSession) GetEmail() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.email
|
||||
}
|
||||
|
||||
func (s *UnifiedMockSession) SetEmail(email string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.email = email
|
||||
}
|
||||
|
||||
func (s *UnifiedMockSession) GetCSRF() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.csrf
|
||||
}
|
||||
|
||||
func (s *UnifiedMockSession) SetCSRF(csrf string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.csrf = csrf
|
||||
}
|
||||
|
||||
func (s *UnifiedMockSession) GetNonce() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.nonce
|
||||
}
|
||||
|
||||
func (s *UnifiedMockSession) SetNonce(nonce string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.nonce = nonce
|
||||
}
|
||||
|
||||
func (s *UnifiedMockSession) GetCodeVerifier() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.codeVerifier
|
||||
}
|
||||
|
||||
func (s *UnifiedMockSession) SetCodeVerifier(verifier string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.codeVerifier = verifier
|
||||
}
|
||||
|
||||
func (s *UnifiedMockSession) GetIncomingPath() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.incomingPath
|
||||
}
|
||||
|
||||
func (s *UnifiedMockSession) SetIncomingPath(path string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.incomingPath = path
|
||||
}
|
||||
|
||||
func (s *UnifiedMockSession) GetRedirectCount() int {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.redirectCount
|
||||
}
|
||||
|
||||
func (s *UnifiedMockSession) IncrementRedirectCount() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.redirectCount++
|
||||
}
|
||||
|
||||
func (s *UnifiedMockSession) ResetRedirectCount() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.redirectCount = 0
|
||||
}
|
||||
|
||||
func (s *UnifiedMockSession) Save(req *http.Request, rw http.ResponseWriter) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *UnifiedMockSession) Clear(req *http.Request, rw http.ResponseWriter) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.authenticated = false
|
||||
s.idToken = ""
|
||||
s.accessToken = ""
|
||||
s.refreshToken = ""
|
||||
s.email = ""
|
||||
s.csrf = ""
|
||||
s.nonce = ""
|
||||
s.codeVerifier = ""
|
||||
s.incomingPath = ""
|
||||
s.redirectCount = 0
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *UnifiedMockSession) MarkDirty() {}
|
||||
|
||||
func (s *UnifiedMockSession) IsDirty() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *UnifiedMockSession) ReturnToPoolSafely() {}
|
||||
|
||||
// UnifiedMockTokenVerifier provides a standard mock token verifier
|
||||
type UnifiedMockTokenVerifier struct {
|
||||
ShouldFail bool
|
||||
Error error
|
||||
}
|
||||
|
||||
func NewUnifiedMockTokenVerifier() *UnifiedMockTokenVerifier {
|
||||
return &UnifiedMockTokenVerifier{}
|
||||
}
|
||||
|
||||
func (v *UnifiedMockTokenVerifier) VerifyToken(token string) error {
|
||||
if v.ShouldFail {
|
||||
if v.Error != nil {
|
||||
return v.Error
|
||||
}
|
||||
return fmt.Errorf("mock verification failed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnifiedMockTokenCache provides a standard mock token cache
|
||||
type UnifiedMockTokenCache struct {
|
||||
data map[string]map[string]interface{}
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewUnifiedMockTokenCache() *UnifiedMockTokenCache {
|
||||
return &UnifiedMockTokenCache{
|
||||
data: make(map[string]map[string]interface{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *UnifiedMockTokenCache) Get(key string) (map[string]interface{}, bool) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
value, exists := c.data[key]
|
||||
return value, exists
|
||||
}
|
||||
|
||||
func (c *UnifiedMockTokenCache) Set(key string, claims map[string]interface{}, ttl time.Duration) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.data[key] = claims
|
||||
}
|
||||
|
||||
func (c *UnifiedMockTokenCache) Delete(key string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
delete(c.data, key)
|
||||
}
|
||||
|
||||
func (c *UnifiedMockTokenCache) SetMaxSize(size int) {}
|
||||
|
||||
func (c *UnifiedMockTokenCache) Size() int {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return len(c.data)
|
||||
}
|
||||
|
||||
func (c *UnifiedMockTokenCache) Clear() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.data = make(map[string]map[string]interface{})
|
||||
}
|
||||
|
||||
func (c *UnifiedMockTokenCache) Cleanup() {}
|
||||
|
||||
func (c *UnifiedMockTokenCache) Close() {}
|
||||
|
||||
func (c *UnifiedMockTokenCache) GetStats() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"size": c.Size(),
|
||||
}
|
||||
}
|
||||
|
||||
// UnifiedMockHTTPClient provides a mock HTTP client for tests
|
||||
type UnifiedMockHTTPClient struct {
|
||||
Responses map[string]*http.Response
|
||||
Errors map[string]error
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewUnifiedMockHTTPClient() *UnifiedMockHTTPClient {
|
||||
return &UnifiedMockHTTPClient{
|
||||
Responses: make(map[string]*http.Response),
|
||||
Errors: make(map[string]error),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *UnifiedMockHTTPClient) Do(req *http.Request) (*http.Response, error) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
url := req.URL.String()
|
||||
if err, exists := c.Errors[url]; exists {
|
||||
return nil, err
|
||||
}
|
||||
if resp, exists := c.Responses[url]; exists {
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// Default response
|
||||
return &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: http.NoBody,
|
||||
Header: make(http.Header),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *UnifiedMockHTTPClient) SetResponse(url string, response *http.Response) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.Responses[url] = response
|
||||
}
|
||||
|
||||
func (c *UnifiedMockHTTPClient) SetError(url string, err error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.Errors[url] = err
|
||||
}
|
||||
|
||||
// TestSuite provides a unified test setup and teardown
|
||||
type TestSuite struct {
|
||||
Logger *UnifiedMockLogger
|
||||
Session *UnifiedMockSession
|
||||
TokenVerifier *UnifiedMockTokenVerifier
|
||||
TokenCache *UnifiedMockTokenCache
|
||||
HTTPClient *UnifiedMockHTTPClient
|
||||
}
|
||||
|
||||
func NewTestSuite() *TestSuite {
|
||||
return &TestSuite{
|
||||
Logger: NewUnifiedMockLogger(),
|
||||
Session: NewUnifiedMockSession(),
|
||||
TokenVerifier: NewUnifiedMockTokenVerifier(),
|
||||
TokenCache: NewUnifiedMockTokenCache(),
|
||||
HTTPClient: NewUnifiedMockHTTPClient(),
|
||||
}
|
||||
}
|
||||
|
||||
func (ts *TestSuite) Setup() {
|
||||
// Common test setup
|
||||
ts.Logger.Clear()
|
||||
_ = ts.Session.Clear(nil, nil) // Safe to ignore: test helper function
|
||||
ts.TokenCache.Clear()
|
||||
ts.TokenVerifier.ShouldFail = false
|
||||
ts.TokenVerifier.Error = nil
|
||||
}
|
||||
|
||||
func (ts *TestSuite) Teardown() {
|
||||
// Common test teardown
|
||||
ts.TokenCache.Close()
|
||||
}
|
||||
@@ -0,0 +1,140 @@
|
||||
package testutil
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/testutil/fixtures"
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/testutil/mocks"
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/testutil/servers"
|
||||
)
|
||||
|
||||
// Re-export types for easier access from main package tests
|
||||
type (
|
||||
// Mocks
|
||||
JWKCacheMock = mocks.JWKCache
|
||||
TokenExchangerMock = mocks.TokenExchanger
|
||||
TokenVerifierMock = mocks.TokenVerifier
|
||||
JWTVerifierMock = mocks.JWTVerifier
|
||||
SessionManagerMock = mocks.SessionManager
|
||||
CacheMock = mocks.Cache
|
||||
TokenCacheMock = mocks.TokenCache
|
||||
BlacklistMock = mocks.Blacklist
|
||||
HTTPClientMock = mocks.HTTPClient
|
||||
RoundTripperMock = mocks.RoundTripper
|
||||
LoggerMock = mocks.Logger
|
||||
|
||||
// Mock types
|
||||
JWKSet = mocks.JWKSet
|
||||
JWK = mocks.JWK
|
||||
MockTokenResponse = mocks.TokenResponse
|
||||
MockSessionData = mocks.SessionData
|
||||
IntrospectionResp = mocks.IntrospectionResponse
|
||||
|
||||
// Fixtures
|
||||
TokenFixture = fixtures.TokenFixture
|
||||
|
||||
// Servers
|
||||
OIDCServer = servers.OIDCServer
|
||||
OIDCServerConfig = servers.OIDCServerConfig
|
||||
OIDCError = servers.OIDCError
|
||||
)
|
||||
|
||||
// NewJWKCacheMock creates a new JWK cache mock
|
||||
func NewJWKCacheMock() *mocks.JWKCache {
|
||||
return new(mocks.JWKCache)
|
||||
}
|
||||
|
||||
// NewTokenExchangerMock creates a new token exchanger mock
|
||||
func NewTokenExchangerMock() *mocks.TokenExchanger {
|
||||
return new(mocks.TokenExchanger)
|
||||
}
|
||||
|
||||
// NewTokenVerifierMock creates a new token verifier mock
|
||||
func NewTokenVerifierMock() *mocks.TokenVerifier {
|
||||
return new(mocks.TokenVerifier)
|
||||
}
|
||||
|
||||
// NewJWTVerifierMock creates a new JWT verifier mock
|
||||
func NewJWTVerifierMock() *mocks.JWTVerifier {
|
||||
return new(mocks.JWTVerifier)
|
||||
}
|
||||
|
||||
// NewSessionManagerMock creates a new session manager mock
|
||||
func NewSessionManagerMock() *mocks.SessionManager {
|
||||
return new(mocks.SessionManager)
|
||||
}
|
||||
|
||||
// NewCacheMock creates a new cache mock
|
||||
func NewCacheMock() *mocks.Cache {
|
||||
return new(mocks.Cache)
|
||||
}
|
||||
|
||||
// NewTokenCacheMock creates a new token cache mock
|
||||
func NewTokenCacheMock() *mocks.TokenCache {
|
||||
return new(mocks.TokenCache)
|
||||
}
|
||||
|
||||
// NewBlacklistMock creates a new blacklist mock
|
||||
func NewBlacklistMock() *mocks.Blacklist {
|
||||
return new(mocks.Blacklist)
|
||||
}
|
||||
|
||||
// NewHTTPClientMock creates a new HTTP client mock
|
||||
func NewHTTPClientMock() *mocks.HTTPClient {
|
||||
return new(mocks.HTTPClient)
|
||||
}
|
||||
|
||||
// NewRoundTripperMock creates a new round tripper mock
|
||||
func NewRoundTripperMock() *mocks.RoundTripper {
|
||||
return new(mocks.RoundTripper)
|
||||
}
|
||||
|
||||
// NewLoggerMock creates a new logger mock
|
||||
func NewLoggerMock() *mocks.Logger {
|
||||
return new(mocks.Logger)
|
||||
}
|
||||
|
||||
// NewTokenFixture creates a new token fixture
|
||||
func NewTokenFixture() (*fixtures.TokenFixture, error) {
|
||||
return fixtures.NewTokenFixture()
|
||||
}
|
||||
|
||||
// NewOIDCServer creates a new mock OIDC server
|
||||
func NewOIDCServer(config *servers.OIDCServerConfig) *servers.OIDCServer {
|
||||
return servers.NewOIDCServer(config)
|
||||
}
|
||||
|
||||
// DefaultServerConfig returns a default server configuration
|
||||
func DefaultServerConfig() *servers.OIDCServerConfig {
|
||||
return servers.DefaultConfig()
|
||||
}
|
||||
|
||||
// GoogleServerConfig returns a Google-like server configuration
|
||||
func GoogleServerConfig() *servers.OIDCServerConfig {
|
||||
return servers.GoogleConfig()
|
||||
}
|
||||
|
||||
// AzureServerConfig returns an Azure AD-like server configuration
|
||||
func AzureServerConfig() *servers.OIDCServerConfig {
|
||||
return servers.AzureConfig()
|
||||
}
|
||||
|
||||
// Auth0ServerConfig returns an Auth0-like server configuration
|
||||
func Auth0ServerConfig() *servers.OIDCServerConfig {
|
||||
return servers.Auth0Config()
|
||||
}
|
||||
|
||||
// KeycloakServerConfig returns a Keycloak-like server configuration
|
||||
func KeycloakServerConfig() *servers.OIDCServerConfig {
|
||||
return servers.KeycloakConfig()
|
||||
}
|
||||
|
||||
// SlowServerConfig returns a configuration with delays
|
||||
func SlowServerConfig(delay time.Duration) *servers.OIDCServerConfig {
|
||||
return servers.SlowServerConfig(delay)
|
||||
}
|
||||
|
||||
// RateLimitedServerConfig returns a rate-limited configuration
|
||||
func RateLimitedServerConfig(afterN int) *servers.OIDCServerConfig {
|
||||
return servers.RateLimitedConfig(afterN)
|
||||
}
|
||||
@@ -0,0 +1,330 @@
|
||||
package fixtures
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"math/big"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TokenFixture provides JWT token generation for tests
|
||||
type TokenFixture struct {
|
||||
RSAPrivateKey *rsa.PrivateKey
|
||||
RSAPublicKey *rsa.PublicKey
|
||||
ECPrivateKey *ecdsa.PrivateKey
|
||||
ECPublicKey *ecdsa.PublicKey
|
||||
KeyID string
|
||||
Issuer string
|
||||
Audience string
|
||||
ClockSkew time.Duration
|
||||
}
|
||||
|
||||
// NewTokenFixture creates a new token fixture with generated keys
|
||||
func NewTokenFixture() (*TokenFixture, error) {
|
||||
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ecKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &TokenFixture{
|
||||
RSAPrivateKey: rsaKey,
|
||||
RSAPublicKey: &rsaKey.PublicKey,
|
||||
ECPrivateKey: ecKey,
|
||||
ECPublicKey: &ecKey.PublicKey,
|
||||
KeyID: "test-key-id",
|
||||
Issuer: "https://test-issuer.com",
|
||||
Audience: "test-client-id",
|
||||
ClockSkew: 2 * time.Minute,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// DefaultClaims returns standard JWT claims
|
||||
func (f *TokenFixture) DefaultClaims() map[string]interface{} {
|
||||
now := time.Now()
|
||||
return map[string]interface{}{
|
||||
"iss": f.Issuer,
|
||||
"aud": f.Audience,
|
||||
"sub": "test-subject",
|
||||
"email": "user@example.com",
|
||||
"exp": now.Add(1 * time.Hour).Unix(),
|
||||
"iat": now.Add(-f.ClockSkew).Unix(),
|
||||
"nbf": now.Add(-f.ClockSkew).Unix(),
|
||||
"nonce": "test-nonce",
|
||||
"jti": generateJTI(),
|
||||
}
|
||||
}
|
||||
|
||||
// ValidToken creates a valid JWT token with optional claim overrides
|
||||
func (f *TokenFixture) ValidToken(claimOverrides map[string]interface{}) (string, error) {
|
||||
claims := f.DefaultClaims()
|
||||
for k, v := range claimOverrides {
|
||||
claims[k] = v
|
||||
}
|
||||
return f.createJWT(claims, "RS256", f.KeyID)
|
||||
}
|
||||
|
||||
// ExpiredToken creates an expired JWT token
|
||||
func (f *TokenFixture) ExpiredToken() (string, error) {
|
||||
claims := f.DefaultClaims()
|
||||
claims["exp"] = time.Now().Add(-1 * time.Hour).Unix()
|
||||
return f.createJWT(claims, "RS256", f.KeyID)
|
||||
}
|
||||
|
||||
// NotYetValidToken creates a token that's not valid yet
|
||||
func (f *TokenFixture) NotYetValidToken() (string, error) {
|
||||
claims := f.DefaultClaims()
|
||||
claims["nbf"] = time.Now().Add(1 * time.Hour).Unix()
|
||||
return f.createJWT(claims, "RS256", f.KeyID)
|
||||
}
|
||||
|
||||
// TokenWithSkew creates a token with a specific time offset
|
||||
func (f *TokenFixture) TokenWithSkew(skew time.Duration) (string, error) {
|
||||
claims := f.DefaultClaims()
|
||||
claims["exp"] = time.Now().Add(skew).Unix()
|
||||
return f.createJWT(claims, "RS256", f.KeyID)
|
||||
}
|
||||
|
||||
// TokenWithRoles creates a token with specific roles
|
||||
func (f *TokenFixture) TokenWithRoles(roles []string) (string, error) {
|
||||
claims := f.DefaultClaims()
|
||||
claims["roles"] = roles
|
||||
return f.createJWT(claims, "RS256", f.KeyID)
|
||||
}
|
||||
|
||||
// TokenWithGroups creates a token with specific groups
|
||||
func (f *TokenFixture) TokenWithGroups(groups []string) (string, error) {
|
||||
claims := f.DefaultClaims()
|
||||
claims["groups"] = groups
|
||||
return f.createJWT(claims, "RS256", f.KeyID)
|
||||
}
|
||||
|
||||
// TokenWithEmail creates a token with a specific email
|
||||
func (f *TokenFixture) TokenWithEmail(email string) (string, error) {
|
||||
claims := f.DefaultClaims()
|
||||
claims["email"] = email
|
||||
return f.createJWT(claims, "RS256", f.KeyID)
|
||||
}
|
||||
|
||||
// TokenWithAudience creates a token with a specific audience
|
||||
func (f *TokenFixture) TokenWithAudience(audience string) (string, error) {
|
||||
claims := f.DefaultClaims()
|
||||
claims["aud"] = audience
|
||||
return f.createJWT(claims, "RS256", f.KeyID)
|
||||
}
|
||||
|
||||
// TokenWithIssuer creates a token with a specific issuer
|
||||
func (f *TokenFixture) TokenWithIssuer(issuer string) (string, error) {
|
||||
claims := f.DefaultClaims()
|
||||
claims["iss"] = issuer
|
||||
return f.createJWT(claims, "RS256", f.KeyID)
|
||||
}
|
||||
|
||||
// TokenMissingClaim creates a token missing specified claims
|
||||
func (f *TokenFixture) TokenMissingClaim(missingClaims ...string) (string, error) {
|
||||
claims := f.DefaultClaims()
|
||||
for _, claim := range missingClaims {
|
||||
delete(claims, claim)
|
||||
}
|
||||
return f.createJWT(claims, "RS256", f.KeyID)
|
||||
}
|
||||
|
||||
// TokenWithCustomClaims creates a token with custom claims
|
||||
func (f *TokenFixture) TokenWithCustomClaims(customClaims map[string]interface{}) (string, error) {
|
||||
claims := f.DefaultClaims()
|
||||
for k, v := range customClaims {
|
||||
claims[k] = v
|
||||
}
|
||||
return f.createJWT(claims, "RS256", f.KeyID)
|
||||
}
|
||||
|
||||
// MalformedToken returns an invalid JWT string
|
||||
func (f *TokenFixture) MalformedToken() string {
|
||||
return "not.a.valid.jwt"
|
||||
}
|
||||
|
||||
// EmptyToken returns an empty string
|
||||
func (f *TokenFixture) EmptyToken() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// TokenWithWrongSignature creates a token signed with a different key
|
||||
func (f *TokenFixture) TokenWithWrongSignature() (string, error) {
|
||||
wrongKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
claims := f.DefaultClaims()
|
||||
return createJWTWithKey(claims, "RS256", f.KeyID, wrongKey)
|
||||
}
|
||||
|
||||
// TokenWithWrongAlgorithm creates a token with mismatched algorithm
|
||||
func (f *TokenFixture) TokenWithWrongAlgorithm() (string, error) {
|
||||
claims := f.DefaultClaims()
|
||||
// Create token claiming RS256 but we'll return it as-is
|
||||
// This simulates algorithm confusion attacks
|
||||
return f.createJWT(claims, "none", f.KeyID)
|
||||
}
|
||||
|
||||
// ECToken creates a token signed with EC key
|
||||
func (f *TokenFixture) ECToken() (string, error) {
|
||||
claims := f.DefaultClaims()
|
||||
return f.createECJWT(claims, "ES256", f.KeyID)
|
||||
}
|
||||
|
||||
// GetJWKS returns a JWKS containing the test public key
|
||||
func (f *TokenFixture) GetJWKS() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"keys": []map[string]interface{}{
|
||||
{
|
||||
"kty": "RSA",
|
||||
"kid": f.KeyID,
|
||||
"use": "sig",
|
||||
"alg": "RS256",
|
||||
"n": base64.RawURLEncoding.EncodeToString(f.RSAPublicKey.N.Bytes()),
|
||||
"e": base64.RawURLEncoding.EncodeToString(bigIntToBytes(big.NewInt(int64(f.RSAPublicKey.E)))),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// GetJWKSBytes returns JWKS as JSON bytes
|
||||
func (f *TokenFixture) GetJWKSBytes() ([]byte, error) {
|
||||
return json.Marshal(f.GetJWKS())
|
||||
}
|
||||
|
||||
// createJWT creates a JWT with the fixture's RSA key
|
||||
func (f *TokenFixture) createJWT(claims map[string]interface{}, alg, kid string) (string, error) {
|
||||
return createJWTWithKey(claims, alg, kid, f.RSAPrivateKey)
|
||||
}
|
||||
|
||||
// createECJWT creates a JWT with the fixture's EC key
|
||||
func (f *TokenFixture) createECJWT(claims map[string]interface{}, alg, kid string) (string, error) {
|
||||
return createECJWTWithKey(claims, alg, kid, f.ECPrivateKey)
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
func generateJTI() string {
|
||||
b := make([]byte, 16)
|
||||
_, _ = rand.Read(b) // #nosec G104 - test fixture, crypto strength not critical
|
||||
return base64.RawURLEncoding.EncodeToString(b)
|
||||
}
|
||||
|
||||
func bigIntToBytes(i *big.Int) []byte {
|
||||
return i.Bytes()
|
||||
}
|
||||
|
||||
func createJWTWithKey(claims map[string]interface{}, alg, kid string, key *rsa.PrivateKey) (string, error) {
|
||||
header := map[string]interface{}{
|
||||
"alg": alg,
|
||||
"typ": "JWT",
|
||||
"kid": kid,
|
||||
}
|
||||
|
||||
headerBytes, err := json.Marshal(header)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
claimsBytes, err := json.Marshal(claims)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
headerB64 := base64.RawURLEncoding.EncodeToString(headerBytes)
|
||||
claimsB64 := base64.RawURLEncoding.EncodeToString(claimsBytes)
|
||||
|
||||
signingInput := headerB64 + "." + claimsB64
|
||||
|
||||
// For "none" algorithm, return without signature
|
||||
if alg == "none" {
|
||||
return signingInput + ".", nil
|
||||
}
|
||||
|
||||
// Sign with RSA-SHA256
|
||||
signature, err := signRS256([]byte(signingInput), key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
signatureB64 := base64.RawURLEncoding.EncodeToString(signature)
|
||||
return signingInput + "." + signatureB64, nil
|
||||
}
|
||||
|
||||
func createECJWTWithKey(claims map[string]interface{}, alg, kid string, key *ecdsa.PrivateKey) (string, error) {
|
||||
header := map[string]interface{}{
|
||||
"alg": alg,
|
||||
"typ": "JWT",
|
||||
"kid": kid,
|
||||
}
|
||||
|
||||
headerBytes, err := json.Marshal(header)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
claimsBytes, err := json.Marshal(claims)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
headerB64 := base64.RawURLEncoding.EncodeToString(headerBytes)
|
||||
claimsB64 := base64.RawURLEncoding.EncodeToString(claimsBytes)
|
||||
|
||||
signingInput := headerB64 + "." + claimsB64
|
||||
|
||||
// Sign with ECDSA-SHA256
|
||||
signature, err := signES256([]byte(signingInput), key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
signatureB64 := base64.RawURLEncoding.EncodeToString(signature)
|
||||
return signingInput + "." + signatureB64, nil
|
||||
}
|
||||
|
||||
func signRS256(data []byte, key *rsa.PrivateKey) ([]byte, error) {
|
||||
h := hashSHA256(data)
|
||||
return rsa.SignPKCS1v15(rand.Reader, key, crypto.SHA256, h)
|
||||
}
|
||||
|
||||
func signES256(data []byte, key *ecdsa.PrivateKey) ([]byte, error) {
|
||||
h := hashSHA256(data)
|
||||
r, s, err := ecdsa.Sign(rand.Reader, key, h)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Encode r and s as fixed-size byte arrays
|
||||
curveBits := key.Curve.Params().BitSize
|
||||
keyBytes := curveBits / 8
|
||||
if curveBits%8 > 0 {
|
||||
keyBytes++
|
||||
}
|
||||
|
||||
rBytes := r.Bytes()
|
||||
sBytes := s.Bytes()
|
||||
|
||||
signature := make([]byte, 2*keyBytes)
|
||||
copy(signature[keyBytes-len(rBytes):keyBytes], rBytes)
|
||||
copy(signature[2*keyBytes-len(sBytes):], sBytes)
|
||||
|
||||
return signature, nil
|
||||
}
|
||||
|
||||
func hashSHA256(data []byte) []byte {
|
||||
h := sha256.Sum256(data)
|
||||
return h[:]
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user