mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-07 22:53:58 +00:00
Compare commits
12 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| e64fc7f730 | |||
| 5fcbd54955 | |||
| e70cd1907c | |||
| e45b06c86d | |||
| ae59a5e88a | |||
| 79e9b164f9 | |||
| 93888e56d1 | |||
| eff9bd7bd2 | |||
| bde1db1c3b | |||
| 79d34ea4c9 | |||
| c3f23cb99b | |||
| 3bbc6a1608 |
@@ -0,0 +1,38 @@
|
||||
# Code Owners for traefik-oidc
|
||||
# These owners will be automatically requested for review when someone opens a PR
|
||||
|
||||
# Default owner for everything in the repo
|
||||
* @lukaszraczylo
|
||||
|
||||
# Core authentication and middleware
|
||||
/middleware/ @lukaszraczylo
|
||||
/auth/ @lukaszraczylo
|
||||
/handlers/ @lukaszraczylo
|
||||
|
||||
# OIDC providers
|
||||
/internal/providers/ @lukaszraczylo
|
||||
|
||||
# Session management and security
|
||||
/session/ @lukaszraczylo
|
||||
/internal/security/ @lukaszraczylo
|
||||
/security/ @lukaszraczylo
|
||||
|
||||
# Token management
|
||||
/internal/token/ @lukaszraczylo
|
||||
|
||||
# Configuration
|
||||
/config/ @lukaszraczylo
|
||||
/.traefik.yml @lukaszraczylo
|
||||
|
||||
# GitHub Actions and CI/CD
|
||||
/.github/ @lukaszraczylo
|
||||
/.github/workflows/ @lukaszraczylo
|
||||
/.golangci.yml @lukaszraczylo
|
||||
|
||||
# Documentation
|
||||
/docs/ @lukaszraczylo
|
||||
README.md @lukaszraczylo
|
||||
|
||||
# Dependencies
|
||||
go.mod @lukaszraczylo
|
||||
go.sum @lukaszraczylo
|
||||
@@ -0,0 +1,123 @@
|
||||
## Description
|
||||
|
||||
<!-- Provide a brief description of the changes in this PR -->
|
||||
|
||||
## Type of Change
|
||||
|
||||
<!-- Mark the relevant option with an "x" -->
|
||||
|
||||
- [ ] Bug fix (non-breaking change which fixes an issue)
|
||||
- [ ] New feature (non-breaking change which adds functionality)
|
||||
- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
|
||||
- [ ] Documentation update
|
||||
- [ ] Performance improvement
|
||||
- [ ] Code refactoring
|
||||
- [ ] Security fix
|
||||
- [ ] Provider-specific fix/enhancement
|
||||
|
||||
## Related Issues
|
||||
|
||||
<!-- Link to related issues using #issue_number -->
|
||||
|
||||
Fixes #
|
||||
Related to #
|
||||
|
||||
## Changes Made
|
||||
|
||||
<!-- List the main changes made in this PR -->
|
||||
|
||||
-
|
||||
-
|
||||
-
|
||||
|
||||
## Provider Impact
|
||||
|
||||
<!-- If this affects specific OIDC providers, list them here -->
|
||||
|
||||
- [ ] Google
|
||||
- [ ] Azure AD
|
||||
- [ ] Auth0
|
||||
- [ ] Okta
|
||||
- [ ] Keycloak
|
||||
- [ ] AWS Cognito
|
||||
- [ ] GitLab
|
||||
- [ ] GitHub
|
||||
- [ ] Generic OIDC
|
||||
- [ ] All providers
|
||||
|
||||
## Testing Performed
|
||||
|
||||
<!-- Describe the tests you ran to verify your changes -->
|
||||
|
||||
- [ ] Unit tests pass locally
|
||||
- [ ] Integration tests pass locally
|
||||
- [ ] Race detector shows no issues
|
||||
- [ ] Memory leak tests pass
|
||||
- [ ] Manual testing performed
|
||||
|
||||
### Test Configuration
|
||||
|
||||
<!-- Provide details about your test configuration if applicable -->
|
||||
|
||||
**Provider tested:**
|
||||
**Go version:**
|
||||
**Traefik version:**
|
||||
|
||||
## Security Considerations
|
||||
|
||||
<!-- Describe any security implications of these changes -->
|
||||
|
||||
- [ ] This PR does not introduce security vulnerabilities
|
||||
- [ ] Security scanning has been performed
|
||||
- [ ] Credentials/secrets are properly handled
|
||||
- [ ] Input validation is implemented
|
||||
|
||||
## Performance Impact
|
||||
|
||||
<!-- Describe any performance implications -->
|
||||
|
||||
- [ ] No performance impact expected
|
||||
- [ ] Performance improved (describe how)
|
||||
- [ ] Performance may be affected (describe why and mitigation)
|
||||
|
||||
## Breaking Changes
|
||||
|
||||
<!-- If this is a breaking change, describe the impact and migration path -->
|
||||
|
||||
**Breaking changes:**
|
||||
|
||||
|
||||
**Migration guide:**
|
||||
|
||||
|
||||
## Checklist
|
||||
|
||||
<!-- Ensure all items are checked before requesting review -->
|
||||
|
||||
- [ ] My code follows the project's code style
|
||||
- [ ] I have performed a self-review of my code
|
||||
- [ ] I have commented my code, particularly in hard-to-understand areas
|
||||
- [ ] I have made corresponding changes to the documentation
|
||||
- [ ] My changes generate no new warnings
|
||||
- [ ] I have added tests that prove my fix is effective or that my feature works
|
||||
- [ ] New and existing unit tests pass locally with my changes
|
||||
- [ ] Any dependent changes have been merged and published
|
||||
|
||||
## Additional Context
|
||||
|
||||
<!-- Add any other context, screenshots, or information about the PR here -->
|
||||
|
||||
## Screenshots (if applicable)
|
||||
|
||||
<!-- Add screenshots to help explain your changes -->
|
||||
|
||||
---
|
||||
|
||||
**For Reviewers:**
|
||||
|
||||
Please verify:
|
||||
- [ ] Code quality and style
|
||||
- [ ] Test coverage is adequate
|
||||
- [ ] Security implications reviewed
|
||||
- [ ] Documentation is updated
|
||||
- [ ] No performance regressions
|
||||
@@ -0,0 +1,52 @@
|
||||
version: 2
|
||||
updates:
|
||||
# Maintain dependencies for GitHub Actions
|
||||
- package-ecosystem: "github-actions"
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
day: "monday"
|
||||
time: "09:00"
|
||||
open-pull-requests-limit: 5
|
||||
commit-message:
|
||||
prefix: "chore(deps)"
|
||||
include: "scope"
|
||||
labels:
|
||||
- "dependencies"
|
||||
- "github-actions"
|
||||
reviewers:
|
||||
- "lukaszraczylo"
|
||||
|
||||
# Maintain Go module dependencies
|
||||
- package-ecosystem: "gomod"
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
day: "monday"
|
||||
time: "09:00"
|
||||
open-pull-requests-limit: 10
|
||||
commit-message:
|
||||
prefix: "chore(deps)"
|
||||
include: "scope"
|
||||
labels:
|
||||
- "dependencies"
|
||||
- "go"
|
||||
reviewers:
|
||||
- "lukaszraczylo"
|
||||
# Group patch updates together
|
||||
groups:
|
||||
patch-updates:
|
||||
patterns:
|
||||
- "*"
|
||||
update-types:
|
||||
- "patch"
|
||||
minor-updates:
|
||||
patterns:
|
||||
- "*"
|
||||
update-types:
|
||||
- "minor"
|
||||
# Ignore certain dependencies if needed
|
||||
ignore:
|
||||
# Example: ignore specific versions
|
||||
# - dependency-name: "github.com/example/package"
|
||||
# versions: ["1.x", "2.x"]
|
||||
@@ -0,0 +1,9 @@
|
||||
# Ensure consistent line endings
|
||||
* text=auto eol=lf
|
||||
|
||||
# GitHub Actions files should use LF
|
||||
*.yml text eol=lf
|
||||
*.yaml text eol=lf
|
||||
|
||||
# Shell scripts should use LF
|
||||
*.sh text eol=lf
|
||||
@@ -0,0 +1,225 @@
|
||||
# GitHub Actions Workflows
|
||||
|
||||
This directory contains CI/CD workflows for the Traefik OIDC middleware.
|
||||
|
||||
## Workflows
|
||||
|
||||
### PR Validation (`pr-validation.yml`)
|
||||
|
||||
A comprehensive validation workflow that runs **all checks in parallel** for maximum speed and thorough testing.
|
||||
|
||||
**Triggered on:**
|
||||
- Pull requests to `main` branch
|
||||
- Pushes to `main` branch
|
||||
|
||||
**Parallel Jobs (20+ concurrent checks):**
|
||||
|
||||
#### Code Quality
|
||||
- **Quick Checks** - Format, go vet, go mod verify
|
||||
- **golangci-lint** - Comprehensive linting
|
||||
- **Staticcheck** - Static analysis
|
||||
|
||||
#### Security
|
||||
- **Gosec** - Security vulnerability scanning
|
||||
- **Govulncheck** - Go vulnerability database check
|
||||
- **CodeQL** - GitHub's code analysis
|
||||
|
||||
#### Testing
|
||||
- **Race Detector** - Concurrent access bug detection
|
||||
- **Coverage** - Test coverage with 75% threshold
|
||||
- **Memory Leaks** - Goroutine and memory leak detection
|
||||
- **Integration Tests** - Full integration test suite
|
||||
- **Regression Tests** - Prevent previously fixed bugs
|
||||
- **Security Edge Cases** - Security-specific scenarios
|
||||
- **Session Tests** - Session management validation
|
||||
- **Token Tests** - Token validation scenarios
|
||||
- **CSRF Tests** - CSRF protection validation
|
||||
|
||||
#### Provider Testing (Matrix)
|
||||
Tests run in parallel for each OIDC provider:
|
||||
- Google
|
||||
- Azure AD
|
||||
- Auth0
|
||||
- Okta
|
||||
- Keycloak
|
||||
- AWS Cognito
|
||||
- GitLab
|
||||
- GitHub
|
||||
- Generic OIDC
|
||||
|
||||
#### Performance & Compatibility
|
||||
- **Benchmarks** - Performance regression detection
|
||||
- **Build Matrix** - linux/darwin × amd64/arm64
|
||||
- **Go Versions** - Go 1.23 and 1.24 compatibility
|
||||
|
||||
#### Final Validation
|
||||
- **All Checks Passed** - Ensures all jobs succeeded
|
||||
|
||||
## Workflow Features
|
||||
|
||||
### 🚀 Parallel Execution
|
||||
All independent checks run simultaneously for fastest feedback (~5-10 minutes for full suite).
|
||||
|
||||
### 📊 Coverage Reporting
|
||||
- Automatic PR comments with coverage statistics
|
||||
- Per-package coverage breakdown
|
||||
- 75% coverage threshold enforcement
|
||||
|
||||
### 🔒 Security First
|
||||
- Multiple security scanners (gosec, govulncheck, CodeQL)
|
||||
- SARIF report uploads for GitHub Security tab
|
||||
- Security edge case testing
|
||||
|
||||
### 🎯 Comprehensive Testing
|
||||
- Race condition detection
|
||||
- Memory leak detection
|
||||
- Provider-specific testing
|
||||
- Integration and regression tests
|
||||
|
||||
### 📈 Performance Tracking
|
||||
- Benchmark results stored as artifacts
|
||||
- Performance regression detection
|
||||
|
||||
### ✅ Quality Gates
|
||||
All checks must pass before PR can be merged:
|
||||
- Code formatting and style
|
||||
- Security vulnerabilities
|
||||
- Test coverage threshold
|
||||
- Race conditions
|
||||
- Memory leaks
|
||||
- Build success on all platforms
|
||||
|
||||
## Local Development
|
||||
|
||||
### Run checks locally before pushing:
|
||||
|
||||
```bash
|
||||
# Format code
|
||||
gofmt -s -w .
|
||||
|
||||
# Run linter
|
||||
golangci-lint run
|
||||
|
||||
# Run tests with race detector
|
||||
go test -race -timeout=15m -count=1 ./...
|
||||
|
||||
# Check coverage
|
||||
go test -coverprofile=coverage.out ./...
|
||||
go tool cover -func=coverage.out
|
||||
|
||||
# Run specific test suites
|
||||
go test -v -run='.*Leak.*' ./... # Memory leak tests
|
||||
go test -v -run='.*Integration.*' ./... # Integration tests
|
||||
go test -v -run='.*Regression.*' ./... # Regression tests
|
||||
|
||||
# Run benchmarks
|
||||
go test -bench=. -benchmem ./...
|
||||
|
||||
# Security scan
|
||||
gosec ./...
|
||||
govulncheck ./...
|
||||
```
|
||||
|
||||
### Required Tools
|
||||
|
||||
Install these tools for local development:
|
||||
|
||||
```bash
|
||||
# golangci-lint
|
||||
go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
|
||||
|
||||
# staticcheck
|
||||
go install honnef.co/go/tools/cmd/staticcheck@latest
|
||||
|
||||
# gosec
|
||||
go install github.com/securego/gosec/v2/cmd/gosec@latest
|
||||
|
||||
# govulncheck
|
||||
go install golang.org/x/vuln/cmd/govulncheck@latest
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Workflow Fails
|
||||
|
||||
1. **Check job status** - Click on failed job for details
|
||||
2. **Review logs** - Expand failed steps to see error messages
|
||||
3. **Run locally** - Reproduce issue with local commands above
|
||||
4. **Check coverage** - Ensure test coverage meets 75% threshold
|
||||
|
||||
### Coverage Below Threshold
|
||||
|
||||
Add tests to increase coverage:
|
||||
```bash
|
||||
# See which lines aren't covered
|
||||
go test -coverprofile=coverage.out ./...
|
||||
go tool cover -html=coverage.out
|
||||
```
|
||||
|
||||
### Race Condition Detected
|
||||
|
||||
Run with race detector locally:
|
||||
```bash
|
||||
go test -race -v ./...
|
||||
```
|
||||
|
||||
### Provider Test Failure
|
||||
|
||||
Test specific provider:
|
||||
```bash
|
||||
go test -v -run='.*Azure.*' ./internal/providers/...
|
||||
```
|
||||
|
||||
## Performance Optimization
|
||||
|
||||
The workflow is optimized for speed:
|
||||
|
||||
- **Parallel execution** - All independent jobs run simultaneously
|
||||
- **Go caching** - Dependencies cached between runs
|
||||
- **Strategic ordering** - Quick checks run first for fast feedback
|
||||
- **Fail-fast disabled** - Continue running all tests even if some fail
|
||||
|
||||
## Workflow Monitoring
|
||||
|
||||
### GitHub Actions Dashboard
|
||||
Monitor workflow runs at: `https://github.com/{owner}/{repo}/actions`
|
||||
|
||||
### Status Badges
|
||||
Add to README.md:
|
||||
```markdown
|
||||

|
||||
```
|
||||
|
||||
### Notifications
|
||||
Configure in repository settings:
|
||||
- Settings → Notifications
|
||||
- Choose email or Slack notifications for workflow failures
|
||||
|
||||
## Maintenance
|
||||
|
||||
### Update Go Version
|
||||
Edit in workflow file:
|
||||
```yaml
|
||||
go-version: '1.24' # Update this
|
||||
```
|
||||
|
||||
### Adjust Coverage Threshold
|
||||
Edit in workflow file:
|
||||
```yaml
|
||||
THRESHOLD=75 # Adjust this value
|
||||
```
|
||||
|
||||
### Add New Provider
|
||||
Add to provider matrix:
|
||||
```yaml
|
||||
matrix:
|
||||
provider:
|
||||
- new_provider # Add here
|
||||
```
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- [GitHub Actions Documentation](https://docs.github.com/en/actions)
|
||||
- [golangci-lint Configuration](../.golangci.yml)
|
||||
- [Dependabot Configuration](../dependabot.yml)
|
||||
- [PR Template](../PULL_REQUEST_TEMPLATE.md)
|
||||
@@ -0,0 +1,622 @@
|
||||
name: PR Validation
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
push:
|
||||
branches: [ main ]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
checks: write
|
||||
security-events: write
|
||||
|
||||
jobs:
|
||||
# Fast feedback - format and basic checks
|
||||
quick-checks:
|
||||
name: Quick Checks
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Format check
|
||||
run: |
|
||||
# Exclude vendor directory from format checks
|
||||
UNFORMATTED=$(gofmt -s -l . | grep -v "^vendor/" || true)
|
||||
if [ -n "$UNFORMATTED" ]; then
|
||||
echo "Code is not formatted. Run: gofmt -s -w ."
|
||||
echo "Unformatted files:"
|
||||
echo "$UNFORMATTED"
|
||||
gofmt -s -d $(echo "$UNFORMATTED")
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Go vet
|
||||
run: go vet ./...
|
||||
|
||||
- name: Go mod verify
|
||||
run: go mod verify
|
||||
|
||||
- name: Go mod tidy check
|
||||
run: |
|
||||
go mod tidy
|
||||
git diff --exit-code go.mod go.sum
|
||||
|
||||
# Static analysis with golangci-lint (advisory - will not fail the build)
|
||||
golangci-lint:
|
||||
name: golangci-lint
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v8
|
||||
with:
|
||||
version: latest
|
||||
args: --timeout=10m
|
||||
continue-on-error: true # Allow pipeline to continue even with linting warnings
|
||||
|
||||
# Staticcheck analysis
|
||||
staticcheck:
|
||||
name: Staticcheck
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Install staticcheck
|
||||
run: go install honnef.co/go/tools/cmd/staticcheck@latest
|
||||
|
||||
- name: Run staticcheck
|
||||
run: staticcheck ./...
|
||||
|
||||
# Security scanning with gosec
|
||||
gosec:
|
||||
name: Gosec Security Scanner
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Run Gosec Security Scanner
|
||||
run: |
|
||||
go install github.com/securego/gosec/v2/cmd/gosec@latest
|
||||
gosec -no-fail -fmt sarif -out results.sarif ./... || echo "Gosec completed with warnings"
|
||||
continue-on-error: true
|
||||
|
||||
- name: Upload SARIF file
|
||||
if: always() && hashFiles('results.sarif') != ''
|
||||
uses: github/codeql-action/upload-sarif@v3
|
||||
with:
|
||||
sarif_file: results.sarif
|
||||
continue-on-error: true
|
||||
|
||||
# Vulnerability scanning
|
||||
govulncheck:
|
||||
name: Vulnerability Scan
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Install govulncheck
|
||||
run: go install golang.org/x/vuln/cmd/govulncheck@latest
|
||||
|
||||
- name: Run govulncheck
|
||||
run: govulncheck ./...
|
||||
|
||||
# CodeQL analysis
|
||||
codeql:
|
||||
name: CodeQL Analysis
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Initialize CodeQL
|
||||
uses: github/codeql-action/init@v3
|
||||
with:
|
||||
languages: go
|
||||
continue-on-error: true
|
||||
|
||||
- name: Autobuild
|
||||
uses: github/codeql-action/autobuild@v3
|
||||
continue-on-error: true
|
||||
|
||||
- name: Perform CodeQL Analysis
|
||||
uses: github/codeql-action/analyze@v3
|
||||
continue-on-error: true
|
||||
|
||||
# Unit tests with race detection
|
||||
test-race:
|
||||
name: Unit Tests (Race Detector)
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Run tests with race detector
|
||||
run: go test -race -timeout=15m -count=1 -v ./...
|
||||
env:
|
||||
GOMAXPROCS: 4
|
||||
|
||||
# Coverage analysis with threshold check
|
||||
test-coverage:
|
||||
name: Test Coverage
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Run tests with coverage
|
||||
run: |
|
||||
go test -coverprofile=coverage.out -covermode=atomic -timeout=15m ./...
|
||||
go tool cover -func=coverage.out -o=coverage.txt
|
||||
|
||||
- name: Calculate coverage
|
||||
id: coverage
|
||||
run: |
|
||||
COVERAGE=$(go tool cover -func=coverage.out | grep total | awk '{print $3}' | sed 's/%//')
|
||||
echo "coverage=$COVERAGE" >> $GITHUB_OUTPUT
|
||||
echo "Total Coverage: $COVERAGE%"
|
||||
|
||||
- name: Upload coverage to Codecov
|
||||
uses: codecov/codecov-action@v4
|
||||
with:
|
||||
file: ./coverage.out
|
||||
flags: unittests
|
||||
name: codecov-umbrella
|
||||
fail_ci_if_error: false
|
||||
continue-on-error: true
|
||||
|
||||
- name: Comment coverage on PR
|
||||
if: github.event_name == 'pull_request'
|
||||
uses: actions/github-script@v8
|
||||
with:
|
||||
script: |
|
||||
const coverage = '${{ steps.coverage.outputs.coverage }}';
|
||||
const threshold = 70;
|
||||
const coverageNum = parseFloat(coverage);
|
||||
const emoji = coverageNum >= threshold ? '✅' : '⚠️';
|
||||
const status = coverageNum >= threshold ? 'meets' : 'below';
|
||||
|
||||
const body = `## ${emoji} Test Coverage Report
|
||||
|
||||
| Metric | Value |
|
||||
|--------|-------|
|
||||
| **Total Coverage** | ${coverage}% |
|
||||
| **Threshold** | ${threshold}% |
|
||||
| **Status** | ${emoji} Coverage ${status} threshold |`;
|
||||
|
||||
// Find existing comment
|
||||
const { data: comments } = await github.rest.issues.listComments({
|
||||
issue_number: context.issue.number,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
});
|
||||
|
||||
const botComment = comments.find(comment =>
|
||||
comment.user.type === 'Bot' &&
|
||||
comment.body.includes('Test Coverage Report')
|
||||
);
|
||||
|
||||
if (botComment) {
|
||||
await github.rest.issues.updateComment({
|
||||
comment_id: botComment.id,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
body: body
|
||||
});
|
||||
} else {
|
||||
await github.rest.issues.createComment({
|
||||
issue_number: context.issue.number,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
body: body
|
||||
});
|
||||
}
|
||||
|
||||
- name: Check coverage threshold
|
||||
run: |
|
||||
COVERAGE=${{ steps.coverage.outputs.coverage }}
|
||||
THRESHOLD=70
|
||||
echo "Coverage: $COVERAGE%"
|
||||
echo "Threshold: $THRESHOLD%"
|
||||
if (( $(echo "$COVERAGE < $THRESHOLD" | bc -l) )); then
|
||||
echo "⚠️ Coverage $COVERAGE% is below threshold $THRESHOLD%"
|
||||
exit 1
|
||||
fi
|
||||
echo "✅ Coverage $COVERAGE% meets threshold $THRESHOLD%"
|
||||
|
||||
# Memory leak detection
|
||||
test-memory-leaks:
|
||||
name: Memory Leak Detection
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Run goroutine leak tests
|
||||
run: |
|
||||
echo "Running goroutine leak detection tests..."
|
||||
go test -v -timeout=20m -run='.*[Gg]oroutine.*[Ll]eak.*' ./... || echo "No goroutine leak tests found"
|
||||
|
||||
- name: Run memory leak tests
|
||||
run: |
|
||||
echo "Running memory leak detection tests..."
|
||||
go test -v -timeout=20m -run='.*[Mm]emory.*[Ll]eak.*' ./... || echo "No memory leak tests found"
|
||||
|
||||
- name: Run cleanup tests
|
||||
run: |
|
||||
echo "Running cleanup and resource management tests..."
|
||||
go test -v -timeout=20m -run='.*[Cc]leanup.*' ./... || echo "No cleanup tests found"
|
||||
|
||||
# Integration tests
|
||||
test-integration:
|
||||
name: Integration Tests
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Run integration tests
|
||||
run: |
|
||||
if [ -d "./integration" ]; then
|
||||
go test -v -timeout=20m ./integration/...
|
||||
else
|
||||
echo "Running integration tests from all packages..."
|
||||
go test -v -timeout=20m -run='.*[Ii]ntegration.*' ./...
|
||||
fi
|
||||
|
||||
# Regression tests
|
||||
test-regression:
|
||||
name: Regression Tests
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Run regression tests
|
||||
run: |
|
||||
echo "Running regression tests..."
|
||||
go test -v -timeout=20m -run='.*[Rr]egression.*' ./...
|
||||
|
||||
# Provider-specific tests (parallel matrix)
|
||||
test-providers:
|
||||
name: Provider Tests (${{ matrix.provider }})
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
provider:
|
||||
- google
|
||||
- azure
|
||||
- auth0
|
||||
- okta
|
||||
- keycloak
|
||||
- cognito
|
||||
- gitlab
|
||||
- github
|
||||
- generic
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Run ${{ matrix.provider }} provider tests
|
||||
run: |
|
||||
PROVIDER_CAP=$(echo "${{ matrix.provider }}" | sed 's/.*/\u&/')
|
||||
echo "Testing $PROVIDER_CAP provider..."
|
||||
go test -v -timeout=10m -run=".*$PROVIDER_CAP.*" ./internal/providers/... || true
|
||||
go test -v -timeout=10m -run=".*${{ matrix.provider }}.*" ./... || true
|
||||
|
||||
# Benchmark tests with performance tracking
|
||||
benchmark:
|
||||
name: Benchmark Tests
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Run benchmarks
|
||||
run: |
|
||||
echo "Running benchmark tests..."
|
||||
go test -bench=. -benchmem -benchtime=1s -run=^$ ./... | tee benchmark.txt
|
||||
|
||||
- name: Upload benchmark results
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: benchmark-results
|
||||
path: benchmark.txt
|
||||
retention-days: 30
|
||||
|
||||
- name: Compare benchmarks
|
||||
if: github.event_name == 'pull_request'
|
||||
continue-on-error: true
|
||||
run: |
|
||||
echo "Benchmark results available in artifacts"
|
||||
echo "To compare with main branch, download previous benchmark results"
|
||||
|
||||
# Build validation across platforms
|
||||
build:
|
||||
name: Build (${{ matrix.os }}/${{ matrix.arch }})
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [linux, darwin]
|
||||
arch: [amd64, arm64]
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Build for ${{ matrix.os }}/${{ matrix.arch }}
|
||||
env:
|
||||
GOOS: ${{ matrix.os }}
|
||||
GOARCH: ${{ matrix.arch }}
|
||||
run: |
|
||||
echo "Building for $GOOS/$GOARCH..."
|
||||
go build -v -ldflags="-s -w" ./...
|
||||
|
||||
# Security-specific edge case tests
|
||||
test-security-edge-cases:
|
||||
name: Security Edge Cases
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Run security edge case tests
|
||||
run: |
|
||||
echo "Running security edge case tests..."
|
||||
go test -v -timeout=15m -run='.*[Ss]ecurity.*' ./...
|
||||
|
||||
# Session management tests
|
||||
test-session:
|
||||
name: Session Management Tests
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Run session tests
|
||||
run: |
|
||||
echo "Running session management tests..."
|
||||
go test -v -timeout=15m -run='.*[Ss]ession.*' ./...
|
||||
|
||||
# Token validation tests
|
||||
test-token:
|
||||
name: Token Validation Tests
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Run token validation tests
|
||||
run: |
|
||||
echo "Running token validation tests..."
|
||||
go test -v -timeout=15m -run='.*[Tt]oken.*' ./...
|
||||
|
||||
# CSRF and security tests
|
||||
test-csrf:
|
||||
name: CSRF and Security Tests
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Run CSRF tests
|
||||
run: |
|
||||
echo "Running CSRF and security tests..."
|
||||
go test -v -timeout=15m -run='.*[Cc][Ss][Rr][Ff].*' ./...
|
||||
|
||||
# Multi-Go version compatibility
|
||||
test-go-versions:
|
||||
name: Go ${{ matrix.go-version }} Compatibility
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
go-version: ['1.24']
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go ${{ matrix.go-version }}
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: ${{ matrix.go-version }}
|
||||
cache: true
|
||||
|
||||
- name: Run tests on Go ${{ matrix.go-version }}
|
||||
run: go test -short -timeout=10m ./...
|
||||
|
||||
# Final validation - all checks must pass (golangci-lint is advisory)
|
||||
all-checks-passed:
|
||||
name: ✅ All Checks Passed
|
||||
runs-on: ubuntu-latest
|
||||
needs:
|
||||
- quick-checks
|
||||
- golangci-lint
|
||||
- staticcheck
|
||||
- gosec
|
||||
- govulncheck
|
||||
- codeql
|
||||
- test-race
|
||||
- test-coverage
|
||||
- test-memory-leaks
|
||||
- test-integration
|
||||
- test-regression
|
||||
- test-providers
|
||||
- benchmark
|
||||
- build
|
||||
- test-security-edge-cases
|
||||
- test-session
|
||||
- test-token
|
||||
- test-csrf
|
||||
- test-go-versions
|
||||
if: always()
|
||||
steps:
|
||||
- name: Check all jobs status
|
||||
run: |
|
||||
echo "Checking status of all jobs..."
|
||||
|
||||
# Check critical jobs (excluding golangci-lint which is advisory)
|
||||
CRITICAL_FAILURES=false
|
||||
|
||||
if [ "${{ needs.quick-checks.result }}" == "failure" ] || \
|
||||
[ "${{ needs.staticcheck.result }}" == "failure" ] || \
|
||||
[ "${{ needs.test-race.result }}" == "failure" ] || \
|
||||
[ "${{ needs.test-coverage.result }}" == "failure" ] || \
|
||||
[ "${{ needs.build.result }}" == "failure" ]; then
|
||||
CRITICAL_FAILURES=true
|
||||
fi
|
||||
|
||||
if [ "$CRITICAL_FAILURES" == "true" ]; then
|
||||
echo "❌ Critical checks failed"
|
||||
exit 1
|
||||
elif [ "${{ contains(needs.*.result, 'cancelled') }}" == "true" ]; then
|
||||
echo "⚠️ Some checks were cancelled"
|
||||
exit 1
|
||||
else
|
||||
echo "✅ All critical checks passed successfully!"
|
||||
if [ "${{ needs.golangci-lint.result }}" != "success" ]; then
|
||||
echo "ℹ️ Note: golangci-lint reported issues (advisory only)"
|
||||
fi
|
||||
fi
|
||||
|
||||
- name: Post summary
|
||||
if: always()
|
||||
run: |
|
||||
echo "# PR Validation Summary" >> $GITHUB_STEP_SUMMARY
|
||||
echo "" >> $GITHUB_STEP_SUMMARY
|
||||
echo "## Job Status" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Quick Checks: ${{ needs.quick-checks.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Linting (advisory): ${{ needs.golangci-lint.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Static Analysis: ${{ needs.staticcheck.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Security Scan (gosec): ${{ needs.gosec.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Vulnerability Scan: ${{ needs.govulncheck.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- CodeQL: ${{ needs.codeql.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Race Detection: ${{ needs.test-race.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Coverage: ${{ needs.test-coverage.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Memory Leaks: ${{ needs.test-memory-leaks.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Integration Tests: ${{ needs.test-integration.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Regression Tests: ${{ needs.test-regression.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Provider Tests: ${{ needs.test-providers.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Benchmarks: ${{ needs.benchmark.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Build: ${{ needs.build.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Security Edge Cases: ${{ needs.test-security-edge-cases.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Session Tests: ${{ needs.test-session.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Token Tests: ${{ needs.test-token.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- CSRF Tests: ${{ needs.test-csrf.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Go Version Compatibility: ${{ needs.test-go-versions.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
+192
@@ -0,0 +1,192 @@
|
||||
version: "2"
|
||||
run:
|
||||
go: "1.24"
|
||||
modules-download-mode: readonly
|
||||
tests: true
|
||||
linters:
|
||||
enable:
|
||||
- bodyclose
|
||||
- dupl
|
||||
- goconst
|
||||
- gocritic
|
||||
- gocyclo
|
||||
- goprintffuncname
|
||||
- gosec
|
||||
- misspell
|
||||
- noctx
|
||||
- nolintlint
|
||||
- prealloc
|
||||
- revive
|
||||
- rowserrcheck
|
||||
- sqlclosecheck
|
||||
- unconvert
|
||||
- unparam
|
||||
- whitespace
|
||||
disable:
|
||||
- exhaustive
|
||||
- funlen
|
||||
- gocognit
|
||||
- lll
|
||||
- mnd
|
||||
- testpackage
|
||||
- wsl
|
||||
settings:
|
||||
dupl:
|
||||
threshold: 200 # Allow intentional duplication in provider patterns and token management
|
||||
errcheck:
|
||||
check-type-assertions: true
|
||||
check-blank: false # Allow explicit blank assignments (_ = ...) to ignore errors
|
||||
exclude-functions:
|
||||
- (io.Closer).Close
|
||||
- (*database/sql.Rows).Close
|
||||
- (*database/sql.Stmt).Close
|
||||
- (io.Writer).Write
|
||||
- (*net/http.ResponseWriter).Write
|
||||
- fmt.Fprintf
|
||||
- fmt.Fprint
|
||||
- fmt.Fprintln
|
||||
goconst:
|
||||
min-len: 3
|
||||
min-occurrences: 10 # Increased to reduce noise for standard OAuth2/OIDC strings
|
||||
ignore-tests: true
|
||||
gocritic:
|
||||
# Using default enabled checks in v2
|
||||
enabled-checks:
|
||||
- appendCombine
|
||||
- boolExprSimplify
|
||||
- builtinShadow
|
||||
- commentedOutCode
|
||||
- emptyFallthrough
|
||||
- equalFold
|
||||
- hexLiteral
|
||||
- indexAlloc
|
||||
- initClause
|
||||
- methodExprCall
|
||||
- nestingReduce
|
||||
- rangeExprCopy
|
||||
- rangeValCopy
|
||||
- stringXbytes
|
||||
- typeAssertChain
|
||||
- typeUnparen
|
||||
- unlabelStmt
|
||||
- yodaStyleExpr
|
||||
gocyclo:
|
||||
min-complexity: 30 # OAuth/OIDC flows are inherently complex; set higher for Yaegi compatibility
|
||||
gosec:
|
||||
excludes:
|
||||
- G104
|
||||
- G404
|
||||
severity: medium
|
||||
confidence: medium
|
||||
govet:
|
||||
disable:
|
||||
- fieldalignment
|
||||
- shadow
|
||||
enable-all: true
|
||||
misspell:
|
||||
locale: US
|
||||
ignore-rules:
|
||||
- traefik
|
||||
- oidc
|
||||
- keycloak
|
||||
nolintlint:
|
||||
require-explanation: true
|
||||
require-specific: true
|
||||
allow-unused: false
|
||||
prealloc:
|
||||
simple: true
|
||||
range-loops: true
|
||||
for-loops: false
|
||||
revive:
|
||||
rules:
|
||||
- name: blank-imports
|
||||
- name: context-as-argument
|
||||
- name: context-keys-type
|
||||
- name: dot-imports
|
||||
- name: error-return
|
||||
- name: error-strings
|
||||
- name: error-naming
|
||||
- name: exported
|
||||
- name: if-return
|
||||
- name: increment-decrement
|
||||
- name: var-naming
|
||||
- name: var-declaration
|
||||
- name: package-comments
|
||||
- name: range
|
||||
- name: receiver-naming
|
||||
- name: time-naming
|
||||
- name: unexported-return
|
||||
- name: indent-error-flow
|
||||
- name: errorf
|
||||
- name: empty-block
|
||||
- name: superfluous-else
|
||||
- name: unused-parameter
|
||||
- name: unreachable-code
|
||||
- name: redefines-builtin-id
|
||||
unparam:
|
||||
check-exported: false
|
||||
staticcheck:
|
||||
checks:
|
||||
- all
|
||||
- -QF1001 # De Morgan's law - style preference, may affect Yaegi
|
||||
- -QF1003 # Tagged switch - style preference, may affect Yaegi
|
||||
- -QF1007 # Merge conditional assignment - style preference
|
||||
- -QF1008 # Remove embedded field - may break Yaegi compatibility
|
||||
- -QF1012 # Use fmt.Fprintf - style preference
|
||||
- -ST1003 # Package name format - allowed for test packages
|
||||
exclusions:
|
||||
generated: lax
|
||||
rules:
|
||||
- linters:
|
||||
- bodyclose
|
||||
- dupl
|
||||
- errcheck
|
||||
- goconst
|
||||
- gocyclo
|
||||
- gosec
|
||||
- noctx
|
||||
- prealloc
|
||||
- unparam
|
||||
path: _test\.go
|
||||
- linters:
|
||||
- dupl
|
||||
- gocyclo
|
||||
path: test.*\.go
|
||||
- linters:
|
||||
- gocritic
|
||||
- unused
|
||||
path: mocks.*\.go
|
||||
- linters:
|
||||
- gosec
|
||||
text: 'G404:'
|
||||
- linters:
|
||||
- all
|
||||
path: vendor/
|
||||
- linters:
|
||||
- goconst
|
||||
path: (.+)_test\.go
|
||||
- linters:
|
||||
- dupl
|
||||
path: internal/providers/(auth0|keycloak|okta|google|azure|github|gitlab|cognito|generic)\.go
|
||||
- linters:
|
||||
- dupl
|
||||
path: session\.go
|
||||
- linters:
|
||||
- dupl
|
||||
path: session_chunk_manager\.go
|
||||
text: "(extractJWTExpiration|extractJWTIssuedAt)"
|
||||
paths:
|
||||
- third_party$
|
||||
- builtin$
|
||||
- examples$
|
||||
issues:
|
||||
max-issues-per-linter: 0
|
||||
max-same-issues: 0
|
||||
uniq-by-line: true
|
||||
formatters:
|
||||
exclusions:
|
||||
generated: lax
|
||||
paths:
|
||||
- third_party$
|
||||
- builtin$
|
||||
- examples$
|
||||
+1100
-65
File diff suppressed because it is too large
Load Diff
+286
@@ -0,0 +1,286 @@
|
||||
# CI/CD Setup Guide
|
||||
|
||||
## 📋 Overview
|
||||
|
||||
This repository now has a comprehensive CI/CD pipeline that runs **20+ parallel checks** on every pull request to ensure code quality, security, and reliability.
|
||||
|
||||
## 🎯 What Was Added
|
||||
|
||||
### GitHub Actions Workflow
|
||||
- **`.github/workflows/pr-validation.yml`** - Main CI/CD pipeline (single file, all parallel)
|
||||
|
||||
### Configuration Files
|
||||
- **`.golangci.yml`** - Linter configuration with 30+ enabled checks
|
||||
- **`.github/dependabot.yml`** - Automated dependency updates
|
||||
- **`.github/CODEOWNERS`** - Automatic PR reviewer assignment
|
||||
- **`.github/PULL_REQUEST_TEMPLATE.md`** - Standardized PR descriptions
|
||||
- **`.github/workflows/README.md`** - Detailed workflow documentation
|
||||
- **`.github/workflows/.gitattributes`** - Consistent line endings
|
||||
|
||||
## ✅ What Gets Tested (All in Parallel)
|
||||
|
||||
### Code Quality (3 checks)
|
||||
- **Format & Basic Checks** - gofmt, go vet, go mod
|
||||
- **golangci-lint** - 30+ linters including style, complexity, bugs
|
||||
- **Staticcheck** - Advanced static analysis
|
||||
|
||||
### Security (3 checks)
|
||||
- **Gosec** - Security vulnerability scanning with SARIF reports
|
||||
- **Govulncheck** - Go vulnerability database scanning
|
||||
- **CodeQL** - GitHub's semantic code analysis
|
||||
|
||||
### Testing (9 test suites)
|
||||
- **Race Detector** - Concurrent access bugs
|
||||
- **Coverage** - 75% threshold with PR comments
|
||||
- **Memory Leaks** - Goroutine and memory leak detection
|
||||
- **Integration Tests** - Full integration suite
|
||||
- **Regression Tests** - Prevent old bugs from returning
|
||||
- **Security Edge Cases** - Security-specific scenarios
|
||||
- **Session Tests** - Session management
|
||||
- **Token Tests** - Token validation
|
||||
- **CSRF Tests** - CSRF protection
|
||||
|
||||
### Provider Testing (9 providers in parallel)
|
||||
- Google, Azure AD, Auth0, Okta, Keycloak, AWS Cognito, GitLab, GitHub, Generic
|
||||
|
||||
### Performance & Build (3 checks)
|
||||
- **Benchmarks** - Performance regression detection
|
||||
- **Multi-platform Build** - 4 combinations (linux/darwin × amd64/arm64)
|
||||
- **Go Version Compatibility** - Go 1.23 & 1.24
|
||||
|
||||
## 🚀 Quick Start
|
||||
|
||||
### 1. Push to GitHub
|
||||
```bash
|
||||
git add .github .golangci.yml CI_SETUP.md
|
||||
git commit -m "Add comprehensive CI/CD pipeline"
|
||||
git push origin main
|
||||
```
|
||||
|
||||
### 2. Create a Test PR
|
||||
```bash
|
||||
# Create a feature branch
|
||||
git checkout -b feature/test-ci
|
||||
echo "# Test" >> test.md
|
||||
git add test.md
|
||||
git commit -m "Test CI pipeline"
|
||||
git push origin feature/test-ci
|
||||
|
||||
# Create PR on GitHub
|
||||
# Watch all 20+ checks run in parallel! ⚡
|
||||
```
|
||||
|
||||
### 3. Monitor Results
|
||||
- Go to Actions tab: `https://github.com/{owner}/{repo}/actions`
|
||||
- Click on latest workflow run
|
||||
- See all parallel checks in action
|
||||
- Review coverage comment on PR
|
||||
|
||||
## 📊 Key Features
|
||||
|
||||
### ⚡ Maximum Speed
|
||||
- **Parallel execution** - All checks run simultaneously
|
||||
- **Smart caching** - Go modules and build cache
|
||||
- **Optimized order** - Quick checks first for fast feedback
|
||||
- **Expected runtime**: 5-10 minutes for full suite
|
||||
|
||||
### 🔒 Security First
|
||||
- **3 security scanners** - gosec, govulncheck, CodeQL
|
||||
- **SARIF integration** - Results in GitHub Security tab
|
||||
- **Dependency scanning** - Automated with Dependabot
|
||||
- **Security edge case tests**
|
||||
|
||||
### 📈 Coverage Tracking
|
||||
- **Automatic PR comments** with coverage stats
|
||||
- **Per-package breakdown** included
|
||||
- **75% threshold** enforced (configurable)
|
||||
- **Codecov integration** ready (optional)
|
||||
|
||||
### 🎨 Developer Experience
|
||||
- **Clear PR template** guides contributors
|
||||
- **Auto code owners** assignment
|
||||
- **Detailed error messages** for failures
|
||||
- **Benchmark tracking** for performance
|
||||
|
||||
## 🛠️ Local Development
|
||||
|
||||
### Install Required Tools
|
||||
```bash
|
||||
# golangci-lint (comprehensive linting)
|
||||
go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
|
||||
|
||||
# staticcheck (static analysis)
|
||||
go install honnef.co/go/tools/cmd/staticcheck@latest
|
||||
|
||||
# gosec (security scanning)
|
||||
go install github.com/securego/gosec/v2/cmd/gosec@latest
|
||||
|
||||
# govulncheck (vulnerability scanning)
|
||||
go install golang.org/x/vuln/cmd/govulncheck@latest
|
||||
```
|
||||
|
||||
### Run Checks Locally
|
||||
```bash
|
||||
# Quick validation (before committing)
|
||||
gofmt -s -w . # Format code
|
||||
go vet ./... # Basic checks
|
||||
go mod tidy # Clean dependencies
|
||||
|
||||
# Linting
|
||||
golangci-lint run # Full lint suite
|
||||
staticcheck ./... # Static analysis
|
||||
|
||||
# Testing
|
||||
go test -race -timeout=15m ./... # Tests with race detector
|
||||
go test -coverprofile=coverage.out ./... # Coverage
|
||||
go tool cover -func=coverage.out # View coverage
|
||||
|
||||
# Security
|
||||
gosec ./... # Security scan
|
||||
govulncheck ./... # Vulnerability check
|
||||
|
||||
# Benchmarks
|
||||
go test -bench=. -benchmem ./... # Performance tests
|
||||
```
|
||||
|
||||
### Pre-commit Checklist
|
||||
```bash
|
||||
# Run this before every commit
|
||||
gofmt -s -w . && \
|
||||
go mod tidy && \
|
||||
golangci-lint run && \
|
||||
go test -race -short ./... && \
|
||||
echo "✅ Ready to commit!"
|
||||
```
|
||||
|
||||
## 📝 Configuration
|
||||
|
||||
### Adjust Coverage Threshold
|
||||
Edit `.github/workflows/pr-validation.yml`:
|
||||
```yaml
|
||||
THRESHOLD=75 # Change to desired percentage
|
||||
```
|
||||
|
||||
### Modify Linter Rules
|
||||
Edit `.golangci.yml`:
|
||||
```yaml
|
||||
linters:
|
||||
enable:
|
||||
- newlinter # Add new linters here
|
||||
```
|
||||
|
||||
### Update Go Version
|
||||
Edit `.github/workflows/pr-validation.yml`:
|
||||
```yaml
|
||||
go-version: '1.24' # Update version
|
||||
```
|
||||
|
||||
## 🐛 Troubleshooting
|
||||
|
||||
### Coverage Below Threshold
|
||||
```bash
|
||||
# See uncovered lines in browser
|
||||
go test -coverprofile=coverage.out ./...
|
||||
go tool cover -html=coverage.out
|
||||
```
|
||||
|
||||
### Race Condition Found
|
||||
```bash
|
||||
# Run specific test with race detector
|
||||
go test -race -v -run=TestName ./...
|
||||
```
|
||||
|
||||
### Linter Errors
|
||||
```bash
|
||||
# See detailed lint errors
|
||||
golangci-lint run -v
|
||||
|
||||
# Auto-fix some issues
|
||||
golangci-lint run --fix
|
||||
```
|
||||
|
||||
### Provider Test Fails
|
||||
```bash
|
||||
# Test specific provider
|
||||
go test -v -run='.*Azure.*' ./internal/providers/
|
||||
```
|
||||
|
||||
## 📈 Metrics & Monitoring
|
||||
|
||||
### GitHub Actions Dashboard
|
||||
- View all runs: `Actions` tab
|
||||
- Filter by workflow, branch, status
|
||||
- Download logs and artifacts
|
||||
|
||||
### Status Badge
|
||||
Add to README.md:
|
||||
```markdown
|
||||
[](https://github.com/lukaszraczylo/traefikoidc/actions/workflows/pr-validation.yml)
|
||||
```
|
||||
|
||||
### Notifications
|
||||
- Configure in: Settings → Notifications
|
||||
- Email alerts for workflow failures
|
||||
- Slack/Discord webhooks supported
|
||||
|
||||
## 🔄 Continuous Improvement
|
||||
|
||||
### Dependabot Updates
|
||||
- Automatic weekly dependency checks (Mondays 9 AM)
|
||||
- Security updates prioritized
|
||||
- Groups patch updates together
|
||||
|
||||
### Code Owners
|
||||
- Auto-assigns reviewers based on file paths
|
||||
- Ensures expertise reviews changes
|
||||
- Speeds up PR review process
|
||||
|
||||
## 📚 Additional Resources
|
||||
|
||||
- [Workflow Documentation](.github/workflows/README.md)
|
||||
- [golangci-lint Rules](.golangci.yml)
|
||||
- [PR Template](.github/PULL_REQUEST_TEMPLATE.md)
|
||||
- [Dependabot Config](.github/dependabot.yml)
|
||||
|
||||
## 🎉 Benefits
|
||||
|
||||
### For Contributors
|
||||
- Clear expectations via PR template
|
||||
- Fast feedback (5-10 min)
|
||||
- Comprehensive local tooling
|
||||
- Detailed error messages
|
||||
|
||||
### For Maintainers
|
||||
- Automated code review
|
||||
- Security scanning
|
||||
- Performance tracking
|
||||
- Quality gates enforcement
|
||||
|
||||
### For Users
|
||||
- Higher code quality
|
||||
- Fewer bugs in production
|
||||
- Better security
|
||||
- Consistent performance
|
||||
|
||||
## 🚦 Success Criteria
|
||||
|
||||
All PRs must pass:
|
||||
- ✅ All 20+ parallel checks
|
||||
- ✅ 75% test coverage minimum
|
||||
- ✅ Zero security vulnerabilities
|
||||
- ✅ No race conditions
|
||||
- ✅ No memory leaks
|
||||
- ✅ All providers tested
|
||||
- ✅ Builds on all platforms
|
||||
|
||||
## 💡 Tips
|
||||
|
||||
1. **Run checks locally** before pushing to save CI time
|
||||
2. **Watch for PR comments** - coverage stats posted automatically
|
||||
3. **Check Security tab** for gosec/CodeQL findings
|
||||
4. **Review benchmark results** in artifacts
|
||||
5. **Use draft PRs** for work-in-progress to skip some checks
|
||||
|
||||
---
|
||||
|
||||
**Ready to go!** 🚀 Push your changes and create a PR to see it in action.
|
||||
@@ -0,0 +1,143 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestAudienceConfiguration tests the custom audience configuration feature
|
||||
func TestAudienceConfiguration(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
configAudience string
|
||||
clientID string
|
||||
expectedAudience string
|
||||
}{
|
||||
{
|
||||
name: "no custom audience - uses clientID",
|
||||
configAudience: "",
|
||||
clientID: "test-client-id",
|
||||
expectedAudience: "test-client-id",
|
||||
},
|
||||
{
|
||||
name: "custom audience specified",
|
||||
configAudience: "api://custom-audience",
|
||||
clientID: "test-client-id",
|
||||
expectedAudience: "api://custom-audience",
|
||||
},
|
||||
{
|
||||
name: "auth0 style custom audience",
|
||||
configAudience: "https://api.example.com",
|
||||
clientID: "test-client-id",
|
||||
expectedAudience: "https://api.example.com",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create config with custom audience
|
||||
config := CreateConfig()
|
||||
config.ProviderURL = "https://provider.example.com"
|
||||
config.ClientID = tt.clientID
|
||||
config.ClientSecret = "test-secret"
|
||||
config.SessionEncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
|
||||
config.CallbackURL = "/callback"
|
||||
config.Audience = tt.configAudience
|
||||
|
||||
// Create middleware instance
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
traefikOidc, err := NewWithContext(context.Background(), config, next, "test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create middleware: %v", err)
|
||||
}
|
||||
|
||||
// Verify audience is set correctly
|
||||
if traefikOidc.audience != tt.expectedAudience {
|
||||
t.Errorf("Expected audience %s, got %s", tt.expectedAudience, traefikOidc.audience)
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
_ = traefikOidc.Close()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAudienceValidation tests the audience validation in Config.Validate()
|
||||
func TestAudienceValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
audience string
|
||||
expectError bool
|
||||
errorContains string
|
||||
}{
|
||||
{
|
||||
name: "valid custom audience URL",
|
||||
audience: "https://api.example.com",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "valid azure style audience",
|
||||
audience: "api://12345678-1234-1234-1234-123456789012",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "empty audience is valid (uses clientID)",
|
||||
audience: "",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "http URL not allowed",
|
||||
audience: "http://api.example.com",
|
||||
expectError: true,
|
||||
errorContains: "audience URL must use HTTPS",
|
||||
},
|
||||
{
|
||||
name: "wildcard not allowed",
|
||||
audience: "https://*.example.com",
|
||||
expectError: true,
|
||||
errorContains: "audience must not contain wildcards",
|
||||
},
|
||||
{
|
||||
name: "too long audience",
|
||||
audience: "https://" + string(make([]byte, 250)) + ".com",
|
||||
expectError: true,
|
||||
errorContains: "audience must not exceed 256 characters",
|
||||
},
|
||||
{
|
||||
name: "invalid characters",
|
||||
audience: "api://test\ninjection",
|
||||
expectError: true,
|
||||
errorContains: "audience contains invalid characters",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config := CreateConfig()
|
||||
config.ProviderURL = "https://provider.example.com"
|
||||
config.ClientID = "test-client"
|
||||
config.ClientSecret = "test-secret"
|
||||
config.SessionEncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
|
||||
config.CallbackURL = "/callback"
|
||||
config.Audience = tt.audience
|
||||
|
||||
err := config.Validate()
|
||||
if tt.expectError {
|
||||
if err == nil {
|
||||
t.Errorf("Expected error but got none")
|
||||
} else if tt.errorContains != "" && !strings.Contains(err.Error(), tt.errorContains) {
|
||||
t.Errorf("Expected error containing '%s', got: %v", tt.errorContains, err)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,931 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// TestConfigAudienceValidation tests the Config.Validate() method for the audience field
|
||||
func TestConfigAudienceValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
audience string
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "Empty audience is valid for backward compatibility",
|
||||
audience: "",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Valid HTTPS URL audience Auth0 format",
|
||||
audience: "https://api.example.com",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Valid identifier audience",
|
||||
audience: "my-api",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Valid Azure AD Application ID URI format",
|
||||
audience: "api://12345-guid-67890",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Valid Auth0 API identifier",
|
||||
audience: "https://my-company.auth0.com/api/v2/",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "HTTP URL audience should fail",
|
||||
audience: "http://api.example.com",
|
||||
wantErr: true,
|
||||
errContains: "must use HTTPS",
|
||||
},
|
||||
{
|
||||
name: "Audience with wildcard should fail",
|
||||
audience: "https://api.*.example.com",
|
||||
wantErr: true,
|
||||
errContains: "must not contain wildcards",
|
||||
},
|
||||
{
|
||||
name: "Audience with single asterisk should fail",
|
||||
audience: "*",
|
||||
wantErr: true,
|
||||
errContains: "must not contain wildcards",
|
||||
},
|
||||
{
|
||||
name: "Audience over 256 characters should fail",
|
||||
audience: strings.Repeat("a", 257),
|
||||
wantErr: true,
|
||||
errContains: "must not exceed 256 characters",
|
||||
},
|
||||
{
|
||||
name: "Audience with newline should fail",
|
||||
audience: "my-api\ninjection",
|
||||
wantErr: true,
|
||||
errContains: "contains invalid characters",
|
||||
},
|
||||
{
|
||||
name: "Audience with carriage return should fail",
|
||||
audience: "my-api\rinjection",
|
||||
wantErr: true,
|
||||
errContains: "contains invalid characters",
|
||||
},
|
||||
{
|
||||
name: "Audience with tab should fail",
|
||||
audience: "my-api\tinjection",
|
||||
wantErr: true,
|
||||
errContains: "contains invalid characters",
|
||||
},
|
||||
{
|
||||
name: "Valid audience exactly 256 characters",
|
||||
audience: strings.Repeat("a", 256),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Valid simple identifier",
|
||||
audience: "my-service-api",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Valid URN format",
|
||||
audience: "urn:myservice:api:v1",
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config := CreateConfig()
|
||||
config.ProviderURL = "https://provider.example.com"
|
||||
config.ClientID = "test-client-id"
|
||||
config.ClientSecret = "test-client-secret"
|
||||
config.CallbackURL = "/callback"
|
||||
config.SessionEncryptionKey = strings.Repeat("a", MinSessionEncryptionKeyLength)
|
||||
config.Audience = tt.audience
|
||||
|
||||
err := config.Validate()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if err != nil && tt.errContains != "" && !strings.Contains(err.Error(), tt.errContains) {
|
||||
t.Errorf("Error message should contain %q, got: %v", tt.errContains, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestJWTAudienceVerification tests JWT verification with custom audience values
|
||||
func TestJWTAudienceVerification(t *testing.T) {
|
||||
// Create cleanup helper
|
||||
tc := newTestCleanup(t)
|
||||
|
||||
// Generate RSA key for signing JWTs
|
||||
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate RSA key: %v", err)
|
||||
}
|
||||
rsaPublicKey := &rsaPrivateKey.PublicKey
|
||||
|
||||
// Create JWK
|
||||
jwk := JWK{
|
||||
Kty: "RSA",
|
||||
Kid: "test-key-id",
|
||||
Alg: "RS256",
|
||||
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}),
|
||||
}
|
||||
jwks := &JWKSet{
|
||||
Keys: []JWK{jwk},
|
||||
}
|
||||
|
||||
mockJWKCache := &MockJWKCache{
|
||||
JWKS: jwks,
|
||||
Err: nil,
|
||||
}
|
||||
|
||||
logger := NewLogger("debug")
|
||||
tokenBlacklist := tc.addCache(NewCache())
|
||||
tokenCache := tc.addTokenCache(NewTokenCache())
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
configAudience string
|
||||
tokenAudience interface{}
|
||||
wantErr bool
|
||||
errContains string
|
||||
skipReplayCheck bool
|
||||
}{
|
||||
{
|
||||
name: "JWT with string aud matching configured audience",
|
||||
configAudience: "https://api.example.com",
|
||||
tokenAudience: "https://api.example.com",
|
||||
wantErr: false,
|
||||
skipReplayCheck: true,
|
||||
},
|
||||
{
|
||||
name: "JWT with array aud containing configured audience",
|
||||
configAudience: "https://api.example.com",
|
||||
tokenAudience: []interface{}{"https://other.com", "https://api.example.com", "https://another.com"},
|
||||
wantErr: false,
|
||||
skipReplayCheck: true,
|
||||
},
|
||||
{
|
||||
name: "JWT with string aud NOT matching configured audience",
|
||||
configAudience: "https://api.example.com",
|
||||
tokenAudience: "https://wrong-api.example.com",
|
||||
wantErr: true,
|
||||
errContains: "invalid audience",
|
||||
skipReplayCheck: true,
|
||||
},
|
||||
{
|
||||
name: "JWT with array aud NOT containing configured audience",
|
||||
configAudience: "https://api.example.com",
|
||||
tokenAudience: []interface{}{"https://other.com", "https://another.com"},
|
||||
wantErr: true,
|
||||
errContains: "invalid audience",
|
||||
skipReplayCheck: true,
|
||||
},
|
||||
{
|
||||
name: "JWT with clientID as aud when no custom audience configured",
|
||||
configAudience: "",
|
||||
tokenAudience: "test-client-id",
|
||||
wantErr: false,
|
||||
skipReplayCheck: true,
|
||||
},
|
||||
{
|
||||
name: "JWT with empty string aud",
|
||||
configAudience: "https://api.example.com",
|
||||
tokenAudience: "",
|
||||
wantErr: true,
|
||||
errContains: "invalid audience",
|
||||
skipReplayCheck: true,
|
||||
},
|
||||
{
|
||||
name: "Azure AD Application ID URI format",
|
||||
configAudience: "api://12345-app-id",
|
||||
tokenAudience: "api://12345-app-id",
|
||||
wantErr: false,
|
||||
skipReplayCheck: true,
|
||||
},
|
||||
{
|
||||
name: "Auth0 custom API audience",
|
||||
configAudience: "https://mycompany.com/api",
|
||||
tokenAudience: "https://mycompany.com/api",
|
||||
wantErr: false,
|
||||
skipReplayCheck: true,
|
||||
},
|
||||
{
|
||||
name: "Token confusion attack - audience for different service",
|
||||
configAudience: "https://service-a.example.com",
|
||||
tokenAudience: "https://service-b.example.com",
|
||||
wantErr: true,
|
||||
errContains: "invalid audience",
|
||||
skipReplayCheck: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create TraefikOidc instance
|
||||
tOidc := &TraefikOidc{
|
||||
issuerURL: "https://test-issuer.com",
|
||||
clientID: "test-client-id",
|
||||
clientSecret: "test-client-secret",
|
||||
jwkCache: mockJWKCache,
|
||||
jwksURL: "https://test-jwks-url.com",
|
||||
tokenBlacklist: tokenBlacklist,
|
||||
tokenCache: tokenCache,
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
||||
logger: logger,
|
||||
httpClient: &http.Client{},
|
||||
}
|
||||
|
||||
// Set up the token verifier and JWT verifier
|
||||
tOidc.jwtVerifier = tOidc
|
||||
tOidc.tokenVerifier = tOidc
|
||||
|
||||
// Determine the expected audience for validation
|
||||
expectedAudience := tt.configAudience
|
||||
if expectedAudience == "" {
|
||||
expectedAudience = tOidc.clientID
|
||||
}
|
||||
|
||||
// Set the audience field on the tOidc instance
|
||||
tOidc.audience = expectedAudience
|
||||
|
||||
// Create JWT with specified audience
|
||||
jti := generateRandomString(16)
|
||||
if tt.skipReplayCheck {
|
||||
// Use a unique JTI for each test to avoid replay detection
|
||||
jti = fmt.Sprintf("test-%s-%s", tt.name, jti)
|
||||
}
|
||||
|
||||
jwt, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": tt.tokenAudience,
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
||||
"sub": "test-subject",
|
||||
"email": "user@example.com",
|
||||
"jti": jti,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test JWT: %v", err)
|
||||
}
|
||||
|
||||
// Verify the token
|
||||
err = tOidc.VerifyToken(jwt)
|
||||
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("VerifyToken() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if err != nil && tt.errContains != "" && !strings.Contains(err.Error(), tt.errContains) {
|
||||
t.Errorf("Error message should contain %q, got: %v", tt.errContains, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestJWTAudienceBackwardCompatibility tests that existing behavior is preserved
|
||||
// when the Audience field is not set
|
||||
func TestJWTAudienceBackwardCompatibility(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
// Test with no custom audience configured - should use clientID
|
||||
jwt, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id", // Should match clientID
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
||||
"sub": "test-subject",
|
||||
"email": "user@example.com",
|
||||
"jti": generateRandomString(16),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test JWT: %v", err)
|
||||
}
|
||||
|
||||
err = ts.tOidc.VerifyToken(jwt)
|
||||
if err != nil {
|
||||
t.Errorf("Backward compatibility broken: VerifyToken() error = %v, expected nil", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAudienceIntegrationAuth0Scenario tests Auth0-specific use case
|
||||
func TestAudienceIntegrationAuth0Scenario(t *testing.T) {
|
||||
// Create cleanup helper
|
||||
tc := newTestCleanup(t)
|
||||
|
||||
// Simulate Auth0 scenario: custom audience for API access
|
||||
config := CreateConfig()
|
||||
config.ProviderURL = "https://mycompany.auth0.com"
|
||||
config.ClientID = "auth0-client-id"
|
||||
config.ClientSecret = "auth0-client-secret"
|
||||
config.CallbackURL = "/callback"
|
||||
config.SessionEncryptionKey = strings.Repeat("a", MinSessionEncryptionKeyLength)
|
||||
config.Audience = "https://api.mycompany.com" // Custom API audience
|
||||
|
||||
// Validate config
|
||||
if err := config.Validate(); err != nil {
|
||||
t.Fatalf("Auth0 config validation failed: %v", err)
|
||||
}
|
||||
|
||||
// Generate test keys
|
||||
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate RSA key: %v", err)
|
||||
}
|
||||
rsaPublicKey := &rsaPrivateKey.PublicKey
|
||||
|
||||
jwk := JWK{
|
||||
Kty: "RSA",
|
||||
Kid: "auth0-key-id",
|
||||
Alg: "RS256",
|
||||
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}),
|
||||
}
|
||||
jwks := &JWKSet{
|
||||
Keys: []JWK{jwk},
|
||||
}
|
||||
|
||||
mockJWKCache := &MockJWKCache{
|
||||
JWKS: jwks,
|
||||
Err: nil,
|
||||
}
|
||||
|
||||
logger := NewLogger("debug")
|
||||
tokenBlacklist := tc.addCache(NewCache())
|
||||
tokenCache := tc.addTokenCache(NewTokenCache())
|
||||
|
||||
tOidc := &TraefikOidc{
|
||||
issuerURL: config.ProviderURL,
|
||||
clientID: config.ClientID,
|
||||
clientSecret: config.ClientSecret,
|
||||
audience: config.Audience, // Set audience from config
|
||||
jwkCache: mockJWKCache,
|
||||
jwksURL: "https://mycompany.auth0.com/.well-known/jwks.json",
|
||||
tokenBlacklist: tokenBlacklist,
|
||||
tokenCache: tokenCache,
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
||||
logger: logger,
|
||||
httpClient: &http.Client{},
|
||||
}
|
||||
|
||||
// Default audience to clientID if not specified
|
||||
if tOidc.audience == "" {
|
||||
tOidc.audience = tOidc.clientID
|
||||
}
|
||||
|
||||
tOidc.jwtVerifier = tOidc
|
||||
tOidc.tokenVerifier = tOidc
|
||||
|
||||
t.Run("Valid Auth0 API access token with custom audience", func(t *testing.T) {
|
||||
jwt, err := createTestJWT(rsaPrivateKey, "RS256", "auth0-key-id", map[string]interface{}{
|
||||
"iss": config.ProviderURL,
|
||||
"aud": config.Audience, // Matches configured audience
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
||||
"sub": "auth0|123456",
|
||||
"email": "user@example.com",
|
||||
"jti": generateRandomString(16),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create Auth0 JWT: %v", err)
|
||||
}
|
||||
|
||||
err = tOidc.VerifyToken(jwt)
|
||||
if err != nil {
|
||||
t.Errorf("Auth0 token verification failed: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Auth0 ACCESS token with clientID instead of API audience should fail", func(t *testing.T) {
|
||||
jwt, err := createTestJWT(rsaPrivateKey, "RS256", "auth0-key-id", map[string]interface{}{
|
||||
"iss": config.ProviderURL,
|
||||
"aud": config.ClientID, // Using clientID instead of API audience
|
||||
"scope": "openid profile email", // Mark as access token
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
||||
"sub": "auth0|123456",
|
||||
"email": "user@example.com",
|
||||
"jti": generateRandomString(16),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create Auth0 JWT: %v", err)
|
||||
}
|
||||
|
||||
err = tOidc.VerifyToken(jwt)
|
||||
if err == nil {
|
||||
t.Error("Auth0 access token with wrong audience should have been rejected")
|
||||
} else if !strings.Contains(err.Error(), "invalid audience") {
|
||||
t.Errorf("Expected 'invalid audience' error, got: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestAudienceIntegrationAzureADScenario tests Azure AD-specific use case
|
||||
func TestAudienceIntegrationAzureADScenario(t *testing.T) {
|
||||
// Create cleanup helper
|
||||
tc := newTestCleanup(t)
|
||||
|
||||
// Simulate Azure AD scenario: Application ID URI format
|
||||
config := CreateConfig()
|
||||
config.ProviderURL = "https://login.microsoftonline.com/tenant-id/v2.0"
|
||||
config.ClientID = "azure-client-id"
|
||||
config.ClientSecret = "azure-client-secret"
|
||||
config.CallbackURL = "/callback"
|
||||
config.SessionEncryptionKey = strings.Repeat("a", MinSessionEncryptionKeyLength)
|
||||
config.Audience = "api://12345-abcd-6789-efgh" // Azure AD Application ID URI
|
||||
|
||||
// Validate config
|
||||
if err := config.Validate(); err != nil {
|
||||
t.Fatalf("Azure AD config validation failed: %v", err)
|
||||
}
|
||||
|
||||
// Generate test keys
|
||||
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate RSA key: %v", err)
|
||||
}
|
||||
rsaPublicKey := &rsaPrivateKey.PublicKey
|
||||
|
||||
jwk := JWK{
|
||||
Kty: "RSA",
|
||||
Kid: "azure-key-id",
|
||||
Alg: "RS256",
|
||||
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}),
|
||||
}
|
||||
jwks := &JWKSet{
|
||||
Keys: []JWK{jwk},
|
||||
}
|
||||
|
||||
mockJWKCache := &MockJWKCache{
|
||||
JWKS: jwks,
|
||||
Err: nil,
|
||||
}
|
||||
|
||||
logger := NewLogger("debug")
|
||||
tokenBlacklist := tc.addCache(NewCache())
|
||||
tokenCache := tc.addTokenCache(NewTokenCache())
|
||||
|
||||
tOidc := &TraefikOidc{
|
||||
issuerURL: config.ProviderURL,
|
||||
clientID: config.ClientID,
|
||||
clientSecret: config.ClientSecret,
|
||||
audience: config.Audience, // Set audience from config
|
||||
jwkCache: mockJWKCache,
|
||||
jwksURL: config.ProviderURL + "/.well-known/jwks.json",
|
||||
tokenBlacklist: tokenBlacklist,
|
||||
tokenCache: tokenCache,
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
||||
logger: logger,
|
||||
httpClient: &http.Client{},
|
||||
}
|
||||
|
||||
// Default audience to clientID if not specified
|
||||
if tOidc.audience == "" {
|
||||
tOidc.audience = tOidc.clientID
|
||||
}
|
||||
|
||||
tOidc.jwtVerifier = tOidc
|
||||
tOidc.tokenVerifier = tOidc
|
||||
|
||||
t.Run("Valid Azure AD token with Application ID URI audience", func(t *testing.T) {
|
||||
jwt, err := createTestJWT(rsaPrivateKey, "RS256", "azure-key-id", map[string]interface{}{
|
||||
"iss": config.ProviderURL,
|
||||
"aud": config.Audience, // Matches Application ID URI
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
||||
"sub": "azure-user-id",
|
||||
"email": "user@example.com",
|
||||
"oid": "object-id-12345",
|
||||
"tid": "tenant-id",
|
||||
"jti": generateRandomString(16),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create Azure AD JWT: %v", err)
|
||||
}
|
||||
|
||||
err = tOidc.VerifyToken(jwt)
|
||||
if err != nil {
|
||||
t.Errorf("Azure AD token verification failed: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Azure AD token with multiple audiences including correct one", func(t *testing.T) {
|
||||
jwt, err := createTestJWT(rsaPrivateKey, "RS256", "azure-key-id", map[string]interface{}{
|
||||
"iss": config.ProviderURL,
|
||||
"aud": []interface{}{config.ClientID, config.Audience, "https://graph.microsoft.com"},
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
||||
"sub": "azure-user-id",
|
||||
"email": "user@example.com",
|
||||
"oid": "object-id-12345",
|
||||
"tid": "tenant-id",
|
||||
"jti": generateRandomString(16),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create Azure AD JWT: %v", err)
|
||||
}
|
||||
|
||||
err = tOidc.VerifyToken(jwt)
|
||||
if err != nil {
|
||||
t.Errorf("Azure AD token with multiple audiences verification failed: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestAudienceSecurityTokenConfusionAttack tests security against token confusion attacks
|
||||
func TestAudienceSecurityTokenConfusionAttack(t *testing.T) {
|
||||
// Create cleanup helper
|
||||
tc := newTestCleanup(t)
|
||||
|
||||
// Generate test keys
|
||||
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate RSA key: %v", err)
|
||||
}
|
||||
rsaPublicKey := &rsaPrivateKey.PublicKey
|
||||
|
||||
jwk := JWK{
|
||||
Kty: "RSA",
|
||||
Kid: "test-key-id",
|
||||
Alg: "RS256",
|
||||
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}),
|
||||
}
|
||||
jwks := &JWKSet{
|
||||
Keys: []JWK{jwk},
|
||||
}
|
||||
|
||||
mockJWKCache := &MockJWKCache{
|
||||
JWKS: jwks,
|
||||
Err: nil,
|
||||
}
|
||||
|
||||
logger := NewLogger("debug")
|
||||
tokenBlacklist := tc.addCache(NewCache())
|
||||
tokenCache := tc.addTokenCache(NewTokenCache())
|
||||
|
||||
// Service A configuration
|
||||
serviceA := &TraefikOidc{
|
||||
issuerURL: "https://auth.example.com",
|
||||
clientID: "service-a-client-id",
|
||||
clientSecret: "service-a-secret",
|
||||
audience: "service-a-client-id", // Service A uses its clientID as audience
|
||||
jwkCache: mockJWKCache,
|
||||
jwksURL: "https://auth.example.com/.well-known/jwks.json",
|
||||
tokenBlacklist: tokenBlacklist,
|
||||
tokenCache: tokenCache,
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
||||
logger: logger,
|
||||
httpClient: &http.Client{},
|
||||
}
|
||||
serviceA.jwtVerifier = serviceA
|
||||
serviceA.tokenVerifier = serviceA
|
||||
|
||||
t.Run("Token confusion - Try to use service B token on service A", func(t *testing.T) {
|
||||
// Create a token intended for service B
|
||||
serviceBToken, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://auth.example.com",
|
||||
"aud": "https://service-b.example.com", // For service B
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
||||
"sub": "attacker@example.com",
|
||||
"email": "attacker@example.com",
|
||||
"jti": generateRandomString(16),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create service B token: %v", err)
|
||||
}
|
||||
|
||||
// Try to verify the service B token on service A
|
||||
err = serviceA.VerifyToken(serviceBToken)
|
||||
switch {
|
||||
case err == nil:
|
||||
t.Error("SECURITY VULNERABILITY: Token confusion attack succeeded - service B token was accepted by service A")
|
||||
case !strings.Contains(err.Error(), "invalid audience"):
|
||||
t.Errorf("Expected 'invalid audience' error for token confusion, got: %v", err)
|
||||
default:
|
||||
t.Logf("Token confusion attack correctly prevented: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestAudienceSecurityWildcardInjection tests that wildcards are rejected
|
||||
func TestAudienceSecurityWildcardInjection(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
audience string
|
||||
}{
|
||||
{
|
||||
name: "Single asterisk",
|
||||
audience: "*",
|
||||
},
|
||||
{
|
||||
name: "Wildcard in URL",
|
||||
audience: "https://*.example.com",
|
||||
},
|
||||
{
|
||||
name: "Wildcard in path",
|
||||
audience: "https://api.example.com/*",
|
||||
},
|
||||
{
|
||||
name: "Multiple wildcards",
|
||||
audience: "https://*.*.example.com",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config := CreateConfig()
|
||||
config.ProviderURL = "https://provider.example.com"
|
||||
config.ClientID = "test-client-id"
|
||||
config.ClientSecret = "test-client-secret"
|
||||
config.CallbackURL = "/callback"
|
||||
config.SessionEncryptionKey = strings.Repeat("a", MinSessionEncryptionKeyLength)
|
||||
config.Audience = tt.audience
|
||||
|
||||
err := config.Validate()
|
||||
if err == nil {
|
||||
t.Errorf("SECURITY VULNERABILITY: Wildcard audience %q was not rejected", tt.audience)
|
||||
} else if !strings.Contains(err.Error(), "must not contain wildcards") {
|
||||
t.Errorf("Expected 'must not contain wildcards' error, got: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAudienceSecurityInjectionAttempts tests various injection attempts
|
||||
func TestAudienceSecurityInjectionAttempts(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
audience string
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "Newline injection",
|
||||
audience: "api.example.com\nmalicious.com",
|
||||
errContains: "contains invalid characters",
|
||||
},
|
||||
{
|
||||
name: "Carriage return injection",
|
||||
audience: "api.example.com\rmalicious.com",
|
||||
errContains: "contains invalid characters",
|
||||
},
|
||||
{
|
||||
name: "Tab injection",
|
||||
audience: "api.example.com\tmalicious.com",
|
||||
errContains: "contains invalid characters",
|
||||
},
|
||||
{
|
||||
name: "Null byte injection",
|
||||
audience: "api.example.com\x00malicious.com",
|
||||
errContains: "contains invalid characters",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config := CreateConfig()
|
||||
config.ProviderURL = "https://provider.example.com"
|
||||
config.ClientID = "test-client-id"
|
||||
config.ClientSecret = "test-client-secret"
|
||||
config.CallbackURL = "/callback"
|
||||
config.SessionEncryptionKey = strings.Repeat("a", MinSessionEncryptionKeyLength)
|
||||
config.Audience = tt.audience
|
||||
|
||||
err := config.Validate()
|
||||
if err == nil {
|
||||
t.Errorf("SECURITY VULNERABILITY: Injection attempt with %q was not rejected", tt.name)
|
||||
} else if !strings.Contains(err.Error(), tt.errContains) {
|
||||
t.Errorf("Expected error containing %q, got: %v", tt.errContains, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAudienceWithReplayProtection tests that replay protection works correctly with custom audiences
|
||||
func TestAudienceWithReplayProtection(t *testing.T) {
|
||||
// Create cleanup helper
|
||||
tc := newTestCleanup(t)
|
||||
|
||||
// Generate test keys
|
||||
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate RSA key: %v", err)
|
||||
}
|
||||
rsaPublicKey := &rsaPrivateKey.PublicKey
|
||||
|
||||
jwk := JWK{
|
||||
Kty: "RSA",
|
||||
Kid: "test-key-id",
|
||||
Alg: "RS256",
|
||||
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}),
|
||||
}
|
||||
jwks := &JWKSet{
|
||||
Keys: []JWK{jwk},
|
||||
}
|
||||
|
||||
mockJWKCache := &MockJWKCache{
|
||||
JWKS: jwks,
|
||||
Err: nil,
|
||||
}
|
||||
|
||||
logger := NewLogger("debug")
|
||||
tokenBlacklist := tc.addCache(NewCache())
|
||||
tokenCache := tc.addTokenCache(NewTokenCache())
|
||||
|
||||
tOidc := &TraefikOidc{
|
||||
issuerURL: "https://auth.example.com",
|
||||
clientID: "test-client-id",
|
||||
clientSecret: "test-client-secret",
|
||||
jwkCache: mockJWKCache,
|
||||
jwksURL: "https://auth.example.com/.well-known/jwks.json",
|
||||
tokenBlacklist: tokenBlacklist,
|
||||
tokenCache: tokenCache,
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
||||
logger: logger,
|
||||
httpClient: &http.Client{},
|
||||
}
|
||||
tOidc.jwtVerifier = tOidc
|
||||
tOidc.tokenVerifier = tOidc
|
||||
|
||||
// Create a token with custom audience and fixed JTI
|
||||
fixedJTI := "replay-test-jti-" + generateRandomString(8)
|
||||
customAudience := "https://api.example.com"
|
||||
|
||||
// Set the audience field to match what we expect
|
||||
tOidc.audience = customAudience
|
||||
|
||||
jwt, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://auth.example.com",
|
||||
"aud": customAudience,
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
||||
"sub": "test-user",
|
||||
"email": "user@example.com",
|
||||
"jti": fixedJTI,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create JWT: %v", err)
|
||||
}
|
||||
|
||||
// First verification should succeed
|
||||
err = tOidc.VerifyToken(jwt)
|
||||
if err != nil {
|
||||
t.Fatalf("First verification failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify that the JTI was blacklisted
|
||||
if blacklisted, exists := tOidc.tokenBlacklist.Get(fixedJTI); !exists || blacklisted == nil {
|
||||
t.Logf("Note: JTI was not added to blacklist (may be due to test token prefix)")
|
||||
} else {
|
||||
t.Logf("Replay protection verified: JTI %s is correctly blacklisted", fixedJTI)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAudienceEndToEndScenario tests a complete end-to-end scenario with middleware
|
||||
func TestAudienceEndToEndScenario(t *testing.T) {
|
||||
// Create cleanup helper
|
||||
tc := newTestCleanup(t)
|
||||
|
||||
// Create a test next handler
|
||||
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("Authenticated with custom audience"))
|
||||
})
|
||||
|
||||
// Generate test keys
|
||||
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate RSA key: %v", err)
|
||||
}
|
||||
rsaPublicKey := &rsaPrivateKey.PublicKey
|
||||
|
||||
jwk := JWK{
|
||||
Kty: "RSA",
|
||||
Kid: "test-key-id",
|
||||
Alg: "RS256",
|
||||
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}),
|
||||
}
|
||||
jwks := &JWKSet{
|
||||
Keys: []JWK{jwk},
|
||||
}
|
||||
|
||||
mockJWKCache := &MockJWKCache{
|
||||
JWKS: jwks,
|
||||
Err: nil,
|
||||
}
|
||||
|
||||
logger := NewLogger("debug")
|
||||
sm, err := NewSessionManager(strings.Repeat("a", MinSessionEncryptionKeyLength), false, "", "", 0, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session manager: %v", err)
|
||||
}
|
||||
|
||||
tokenBlacklist := tc.addCache(NewCache())
|
||||
tokenCache := tc.addTokenCache(NewTokenCache())
|
||||
|
||||
customAudience := "https://api.company.com"
|
||||
|
||||
tOidc := &TraefikOidc{
|
||||
next: nextHandler,
|
||||
name: "test",
|
||||
redirURLPath: "/callback",
|
||||
logoutURLPath: "/callback/logout",
|
||||
issuerURL: "https://auth.company.com",
|
||||
clientID: "test-client-id",
|
||||
clientSecret: "test-client-secret",
|
||||
audience: customAudience, // Set custom audience
|
||||
jwkCache: mockJWKCache,
|
||||
jwksURL: "https://auth.company.com/.well-known/jwks.json",
|
||||
tokenBlacklist: tokenBlacklist,
|
||||
tokenCache: tokenCache,
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
||||
logger: logger,
|
||||
allowedUserDomains: map[string]struct{}{"company.com": {}},
|
||||
excludedURLs: map[string]struct{}{},
|
||||
httpClient: &http.Client{},
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sm,
|
||||
extractClaimsFunc: extractClaims,
|
||||
}
|
||||
tOidc.jwtVerifier = tOidc
|
||||
tOidc.tokenVerifier = tOidc
|
||||
close(tOidc.initComplete)
|
||||
|
||||
t.Run("End-to-end with correct custom audience", func(t *testing.T) {
|
||||
// Create a valid token with the custom audience
|
||||
validJWT, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://auth.company.com",
|
||||
"aud": customAudience,
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
||||
"sub": "user-123",
|
||||
"email": "user@company.com",
|
||||
"jti": generateRandomString(16),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create valid JWT: %v", err)
|
||||
}
|
||||
|
||||
// Create a request with authenticated session
|
||||
req := httptest.NewRequest("GET", "/protected", nil)
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
req.Header.Set("X-Forwarded-Host", "company.com")
|
||||
|
||||
// Create session with token
|
||||
resp := httptest.NewRecorder()
|
||||
session, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
if err := session.SetAuthenticated(true); err != nil {
|
||||
t.Fatalf("Failed to set authenticated: %v", err)
|
||||
}
|
||||
session.SetEmail("user@company.com")
|
||||
session.SetIDToken(validJWT)
|
||||
session.SetAccessToken(validJWT)
|
||||
|
||||
if err := session.Save(req, resp); err != nil {
|
||||
t.Fatalf("Failed to save session: %v", err)
|
||||
}
|
||||
|
||||
// Get cookies and add them to a new request
|
||||
cookies := resp.Result().Cookies()
|
||||
req = httptest.NewRequest("GET", "/protected", nil)
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
req.Header.Set("X-Forwarded-Host", "company.com")
|
||||
for _, cookie := range cookies {
|
||||
req.AddCookie(cookie)
|
||||
}
|
||||
|
||||
resp = httptest.NewRecorder()
|
||||
tOidc.ServeHTTP(resp, req)
|
||||
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d. Body: %s", resp.Code, resp.Body.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
+120
-71
@@ -11,17 +11,24 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// AuthHandler provides core authentication functionality for OIDC flows
|
||||
type AuthHandler struct {
|
||||
logger Logger
|
||||
enablePKCE bool
|
||||
isGoogleProv func() bool
|
||||
isAzureProv func() bool
|
||||
clientID string
|
||||
authURL string
|
||||
issuerURL string
|
||||
scopes []string
|
||||
overrideScopes bool
|
||||
// ScopeFilter interface for filtering OAuth scopes based on provider capabilities
|
||||
type ScopeFilter interface {
|
||||
FilterSupportedScopes(requestedScopes, supportedScopes []string, providerURL string) []string
|
||||
}
|
||||
|
||||
// Handler provides core authentication functionality for OIDC flows
|
||||
type Handler struct {
|
||||
logger Logger
|
||||
enablePKCE bool
|
||||
isGoogleProv func() bool
|
||||
isAzureProv func() bool
|
||||
clientID string
|
||||
authURL string
|
||||
issuerURL string
|
||||
scopes []string
|
||||
overrideScopes bool
|
||||
scopeFilter ScopeFilter // NEW
|
||||
scopesSupported []string // NEW - from provider metadata
|
||||
}
|
||||
|
||||
// Logger interface for dependency injection
|
||||
@@ -30,29 +37,31 @@ type Logger interface {
|
||||
Errorf(format string, args ...interface{})
|
||||
}
|
||||
|
||||
// NewAuthHandler creates a new AuthHandler instance
|
||||
// NewAuthHandler creates a new Handler instance
|
||||
func NewAuthHandler(logger Logger, enablePKCE bool, isGoogleProv, isAzureProv func() bool,
|
||||
clientID, authURL, issuerURL string, scopes []string, overrideScopes bool) *AuthHandler {
|
||||
return &AuthHandler{
|
||||
logger: logger,
|
||||
enablePKCE: enablePKCE,
|
||||
isGoogleProv: isGoogleProv,
|
||||
isAzureProv: isAzureProv,
|
||||
clientID: clientID,
|
||||
authURL: authURL,
|
||||
issuerURL: issuerURL,
|
||||
scopes: scopes,
|
||||
overrideScopes: overrideScopes,
|
||||
clientID, authURL, issuerURL string, scopes []string, overrideScopes bool,
|
||||
scopeFilter ScopeFilter, scopesSupported []string) *Handler {
|
||||
return &Handler{
|
||||
logger: logger,
|
||||
enablePKCE: enablePKCE,
|
||||
isGoogleProv: isGoogleProv,
|
||||
isAzureProv: isAzureProv,
|
||||
clientID: clientID,
|
||||
authURL: authURL,
|
||||
issuerURL: issuerURL,
|
||||
scopes: scopes,
|
||||
overrideScopes: overrideScopes,
|
||||
scopeFilter: scopeFilter, // NEW
|
||||
scopesSupported: scopesSupported, // NEW
|
||||
}
|
||||
}
|
||||
|
||||
// InitiateAuthentication initiates the OIDC authentication flow.
|
||||
// It generates CSRF tokens, nonce, PKCE parameters (if enabled), clears the session,
|
||||
// stores authentication state, and redirects the user to the OIDC provider.
|
||||
func (h *AuthHandler) InitiateAuthentication(rw http.ResponseWriter, req *http.Request,
|
||||
func (h *Handler) InitiateAuthentication(rw http.ResponseWriter, req *http.Request,
|
||||
session SessionData, redirectURL string,
|
||||
generateNonce, generateCodeVerifier, deriveCodeChallenge func() (string, error)) {
|
||||
|
||||
h.logger.Debugf("Initiating new OIDC authentication flow for request: %s", req.URL.RequestURI())
|
||||
|
||||
const maxRedirects = 5
|
||||
@@ -128,7 +137,7 @@ func (h *AuthHandler) InitiateAuthentication(rw http.ResponseWriter, req *http.R
|
||||
// BuildAuthURL constructs the OIDC provider authorization URL.
|
||||
// It builds the URL with all necessary parameters including client_id, scopes,
|
||||
// PKCE parameters, and provider-specific parameters for Google and Azure.
|
||||
func (h *AuthHandler) BuildAuthURL(redirectURL, state, nonce, codeChallenge string) string {
|
||||
func (h *Handler) BuildAuthURL(redirectURL, state, nonce, codeChallenge string) string {
|
||||
params := url.Values{}
|
||||
params.Set("client_id", h.clientID)
|
||||
params.Set("response_type", "code")
|
||||
@@ -144,49 +153,19 @@ func (h *AuthHandler) BuildAuthURL(redirectURL, state, nonce, codeChallenge stri
|
||||
scopes := make([]string, len(h.scopes))
|
||||
copy(scopes, h.scopes)
|
||||
|
||||
if h.isGoogleProv() {
|
||||
params.Set("access_type", "offline")
|
||||
h.logger.Debugf("Google OIDC provider detected, added access_type=offline for refresh tokens")
|
||||
// Apply discovery-based scope filtering if available
|
||||
if h.scopeFilter != nil && len(h.scopesSupported) > 0 {
|
||||
scopes = h.scopeFilter.FilterSupportedScopes(scopes, h.scopesSupported, h.issuerURL)
|
||||
h.logger.Debugf("AuthHandler.BuildAuthURL: After discovery filtering: %v", scopes)
|
||||
}
|
||||
|
||||
params.Set("prompt", "consent")
|
||||
h.logger.Debugf("Google OIDC provider detected, added prompt=consent to ensure refresh tokens")
|
||||
} else if h.isAzureProv() {
|
||||
params.Set("response_mode", "query")
|
||||
h.logger.Debugf("Azure AD provider detected, added response_mode=query")
|
||||
// Apply provider-specific modifications
|
||||
scopes, params = h.applyProviderSpecificConfig(scopes, params)
|
||||
|
||||
hasOfflineAccess := false
|
||||
|
||||
for _, scope := range scopes {
|
||||
if scope == "offline_access" {
|
||||
hasOfflineAccess = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !h.overrideScopes || (h.overrideScopes && len(h.scopes) == 0) {
|
||||
if !hasOfflineAccess {
|
||||
scopes = append(scopes, "offline_access")
|
||||
h.logger.Debugf("Azure AD provider: Added offline_access scope (overrideScopes: %t, user scopes count: %d)", h.overrideScopes, len(h.scopes))
|
||||
}
|
||||
} else {
|
||||
h.logger.Debugf("Azure AD provider: User is overriding scopes (count: %d), offline_access not automatically added.", len(h.scopes))
|
||||
}
|
||||
} else {
|
||||
if !h.overrideScopes || (h.overrideScopes && len(h.scopes) == 0) {
|
||||
hasOfflineAccess := false
|
||||
for _, scope := range scopes {
|
||||
if scope == "offline_access" {
|
||||
hasOfflineAccess = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasOfflineAccess {
|
||||
scopes = append(scopes, "offline_access")
|
||||
h.logger.Debugf("Standard provider: Added offline_access scope (overrideScopes: %t, user scopes count: %d)", h.overrideScopes, len(h.scopes))
|
||||
}
|
||||
} else {
|
||||
h.logger.Debugf("Standard provider: User is overriding scopes (count: %d), offline_access not automatically added.", len(h.scopes))
|
||||
}
|
||||
// Final filtering pass to remove anything the provider doesn't support
|
||||
if h.scopeFilter != nil && len(h.scopesSupported) > 0 {
|
||||
scopes = h.scopeFilter.FilterSupportedScopes(scopes, h.scopesSupported, h.issuerURL)
|
||||
h.logger.Debugf("AuthHandler.BuildAuthURL: After final filtering: %v", scopes)
|
||||
}
|
||||
|
||||
if len(scopes) > 0 {
|
||||
@@ -198,10 +177,80 @@ func (h *AuthHandler) BuildAuthURL(redirectURL, state, nonce, codeChallenge stri
|
||||
return h.buildURLWithParams(h.authURL, params)
|
||||
}
|
||||
|
||||
// applyProviderSpecificConfig applies provider-specific scope and parameter modifications
|
||||
func (h *Handler) applyProviderSpecificConfig(scopes []string, params url.Values) ([]string, url.Values) {
|
||||
switch {
|
||||
case h.isGoogleProv():
|
||||
return h.applyGoogleConfig(scopes, params)
|
||||
case h.isAzureProv():
|
||||
return h.applyAzureConfig(scopes, params)
|
||||
default:
|
||||
return h.applyStandardProviderConfig(scopes, params)
|
||||
}
|
||||
}
|
||||
|
||||
// applyGoogleConfig applies Google-specific configuration
|
||||
func (h *Handler) applyGoogleConfig(scopes []string, params url.Values) ([]string, url.Values) {
|
||||
// Google: Remove offline_access if present, add access_type=offline
|
||||
filteredScopes := make([]string, 0, len(scopes))
|
||||
for _, scope := range scopes {
|
||||
if scope != "offline_access" {
|
||||
filteredScopes = append(filteredScopes, scope)
|
||||
}
|
||||
}
|
||||
params.Set("access_type", "offline")
|
||||
h.logger.Debugf("Google OIDC provider detected, added access_type=offline")
|
||||
params.Set("prompt", "consent")
|
||||
h.logger.Debugf("Google OIDC provider detected, added prompt=consent to ensure refresh tokens")
|
||||
return filteredScopes, params
|
||||
}
|
||||
|
||||
// applyAzureConfig applies Azure AD-specific configuration
|
||||
func (h *Handler) applyAzureConfig(scopes []string, params url.Values) ([]string, url.Values) {
|
||||
params.Set("response_mode", "query")
|
||||
h.logger.Debugf("Azure AD provider detected, added response_mode=query")
|
||||
|
||||
if h.shouldAddOfflineAccess(scopes) {
|
||||
scopes = append(scopes, "offline_access")
|
||||
h.logger.Debugf("Azure AD provider: Added offline_access scope (overrideScopes: %t, user scopes count: %d)",
|
||||
h.overrideScopes, len(h.scopes))
|
||||
} else {
|
||||
h.logger.Debugf("Azure AD provider: User is overriding scopes (count: %d), offline_access not automatically added.",
|
||||
len(h.scopes))
|
||||
}
|
||||
return scopes, params
|
||||
}
|
||||
|
||||
// applyStandardProviderConfig applies configuration for standard OIDC providers
|
||||
func (h *Handler) applyStandardProviderConfig(scopes []string, params url.Values) ([]string, url.Values) {
|
||||
if h.shouldAddOfflineAccess(scopes) {
|
||||
scopes = append(scopes, "offline_access")
|
||||
h.logger.Debugf("Standard provider: Added offline_access scope (overrideScopes: %t, user scopes count: %d)",
|
||||
h.overrideScopes, len(h.scopes))
|
||||
} else {
|
||||
h.logger.Debugf("Standard provider: User is overriding scopes (count: %d), offline_access not automatically added.",
|
||||
len(h.scopes))
|
||||
}
|
||||
return scopes, params
|
||||
}
|
||||
|
||||
// shouldAddOfflineAccess determines if offline_access scope should be added
|
||||
func (h *Handler) shouldAddOfflineAccess(scopes []string) bool {
|
||||
if h.overrideScopes && len(h.scopes) > 0 {
|
||||
return false
|
||||
}
|
||||
for _, scope := range scopes {
|
||||
if scope == "offline_access" {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// buildURLWithParams constructs a URL by combining a base URL with query parameters.
|
||||
// It handles both relative and absolute URLs, validates URL security,
|
||||
// and properly encodes query parameters.
|
||||
func (h *AuthHandler) buildURLWithParams(baseURL string, params url.Values) string {
|
||||
func (h *Handler) buildURLWithParams(baseURL string, params url.Values) string {
|
||||
if baseURL != "" {
|
||||
if strings.HasPrefix(baseURL, "http://") || strings.HasPrefix(baseURL, "https://") {
|
||||
if err := h.validateURL(baseURL); err != nil {
|
||||
@@ -252,7 +301,7 @@ func (h *AuthHandler) buildURLWithParams(baseURL string, params url.Values) stri
|
||||
|
||||
// validateURL performs security validation on URLs to prevent SSRF attacks.
|
||||
// It checks for allowed schemes, validates hosts, and prevents access to private networks.
|
||||
func (h *AuthHandler) validateURL(urlStr string) error {
|
||||
func (h *Handler) validateURL(urlStr string) error {
|
||||
if urlStr == "" {
|
||||
return fmt.Errorf("empty URL")
|
||||
}
|
||||
@@ -267,7 +316,7 @@ func (h *AuthHandler) validateURL(urlStr string) error {
|
||||
|
||||
// validateParsedURL validates a parsed URL structure for security.
|
||||
// It checks schemes, hosts, and paths to prevent malicious URLs.
|
||||
func (h *AuthHandler) validateParsedURL(u *url.URL) error {
|
||||
func (h *Handler) validateParsedURL(u *url.URL) error {
|
||||
allowedSchemes := map[string]bool{
|
||||
"https": true,
|
||||
"http": true,
|
||||
@@ -298,7 +347,7 @@ func (h *AuthHandler) validateParsedURL(u *url.URL) error {
|
||||
|
||||
// validateHost validates a hostname for security and reachability.
|
||||
// It prevents access to private networks and localhost addresses.
|
||||
func (h *AuthHandler) validateHost(host string) error {
|
||||
func (h *Handler) validateHost(host string) error {
|
||||
if host == "" {
|
||||
return fmt.Errorf("empty host")
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,562 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestAuthHandler_validateURL tests URL validation functionality
|
||||
func TestAuthHandler_validateURL(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
url string
|
||||
wantErr bool
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "Valid HTTPS URL",
|
||||
url: "https://example.com/auth",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Valid HTTP URL",
|
||||
url: "http://example.com/auth",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Empty URL",
|
||||
url: "",
|
||||
wantErr: true,
|
||||
errMsg: "empty URL",
|
||||
},
|
||||
{
|
||||
name: "Invalid URL format",
|
||||
url: "not-a-url",
|
||||
wantErr: true,
|
||||
errMsg: "disallowed URL scheme",
|
||||
},
|
||||
{
|
||||
name: "Disallowed scheme - javascript",
|
||||
url: "javascript:alert('xss')",
|
||||
wantErr: true,
|
||||
errMsg: "disallowed URL scheme",
|
||||
},
|
||||
{
|
||||
name: "Disallowed scheme - data",
|
||||
url: "data:text/html,<script>alert('xss')</script>",
|
||||
wantErr: true,
|
||||
errMsg: "disallowed URL scheme",
|
||||
},
|
||||
{
|
||||
name: "Disallowed scheme - file",
|
||||
url: "file:///etc/passwd",
|
||||
wantErr: true,
|
||||
errMsg: "disallowed URL scheme",
|
||||
},
|
||||
{
|
||||
name: "Disallowed scheme - ftp",
|
||||
url: "ftp://example.com/file",
|
||||
wantErr: true,
|
||||
errMsg: "disallowed URL scheme",
|
||||
},
|
||||
{
|
||||
name: "Missing host",
|
||||
url: "https:///path",
|
||||
wantErr: true,
|
||||
errMsg: "missing host",
|
||||
},
|
||||
{
|
||||
name: "Path traversal attempt",
|
||||
url: "https://example.com/../../../etc/passwd",
|
||||
wantErr: true,
|
||||
errMsg: "path traversal detected",
|
||||
},
|
||||
{
|
||||
name: "Path traversal in middle",
|
||||
url: "https://example.com/path/../sensitive/file",
|
||||
wantErr: true,
|
||||
errMsg: "path traversal detected",
|
||||
},
|
||||
{
|
||||
name: "Localhost attempt",
|
||||
url: "https://localhost/auth",
|
||||
wantErr: true,
|
||||
errMsg: "localhost access not allowed",
|
||||
},
|
||||
{
|
||||
name: "127.0.0.1 attempt",
|
||||
url: "https://127.0.0.1/auth",
|
||||
wantErr: true,
|
||||
errMsg: "localhost access not allowed",
|
||||
},
|
||||
{
|
||||
name: "IPv6 localhost attempt",
|
||||
url: "https://[::1]/auth",
|
||||
wantErr: true,
|
||||
errMsg: "invalid host:port format",
|
||||
},
|
||||
{
|
||||
name: "0.0.0.0 attempt",
|
||||
url: "https://0.0.0.0/auth",
|
||||
wantErr: true,
|
||||
errMsg: "localhost access not allowed",
|
||||
},
|
||||
{
|
||||
name: "Private IP - 192.168.x.x",
|
||||
url: "https://192.168.1.1/auth",
|
||||
wantErr: true,
|
||||
errMsg: "private IP not allowed",
|
||||
},
|
||||
{
|
||||
name: "Private IP - 10.x.x.x",
|
||||
url: "https://10.0.0.1/auth",
|
||||
wantErr: true,
|
||||
errMsg: "private IP not allowed",
|
||||
},
|
||||
{
|
||||
name: "Private IP - 172.16.x.x",
|
||||
url: "https://172.16.0.1/auth",
|
||||
wantErr: true,
|
||||
errMsg: "private IP not allowed",
|
||||
},
|
||||
{
|
||||
name: "Link-local IP",
|
||||
url: "https://169.254.1.1/auth",
|
||||
wantErr: true,
|
||||
errMsg: "link-local IP not allowed",
|
||||
},
|
||||
{
|
||||
name: "Multicast IP",
|
||||
url: "https://224.0.0.1/auth",
|
||||
wantErr: true,
|
||||
errMsg: "multicast IP not allowed",
|
||||
},
|
||||
{
|
||||
name: "Valid public IP",
|
||||
url: "https://8.8.8.8/auth",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Valid domain with port",
|
||||
url: "https://example.com:8443/auth",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "localhost with case variation",
|
||||
url: "https://LOCALHOST/auth",
|
||||
wantErr: true,
|
||||
errMsg: "localhost access not allowed",
|
||||
},
|
||||
{
|
||||
name: "Invalid host:port format",
|
||||
url: "https://example.com:notanumber/auth",
|
||||
wantErr: true,
|
||||
errMsg: "invalid URL format",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := handler.validateURL(tt.url)
|
||||
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Errorf("validateURL() expected error but got none")
|
||||
return
|
||||
}
|
||||
if !strings.Contains(err.Error(), tt.errMsg) {
|
||||
t.Errorf("validateURL() error = %v, expected error containing %v", err, tt.errMsg)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("validateURL() unexpected error = %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthHandler_validateHost tests host validation specifically
|
||||
func TestAuthHandler_validateHost(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
host string
|
||||
wantErr bool
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "Valid hostname",
|
||||
host: "example.com",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Valid hostname with subdomain",
|
||||
host: "api.example.com",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Valid hostname with port",
|
||||
host: "example.com:8080",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Empty host",
|
||||
host: "",
|
||||
wantErr: true,
|
||||
errMsg: "empty host",
|
||||
},
|
||||
{
|
||||
name: "localhost",
|
||||
host: "localhost",
|
||||
wantErr: true,
|
||||
errMsg: "localhost access not allowed",
|
||||
},
|
||||
{
|
||||
name: "LOCALHOST (case insensitive)",
|
||||
host: "LOCALHOST",
|
||||
wantErr: true,
|
||||
errMsg: "localhost access not allowed",
|
||||
},
|
||||
{
|
||||
name: "localhost with port",
|
||||
host: "localhost:8080",
|
||||
wantErr: true,
|
||||
errMsg: "localhost access not allowed",
|
||||
},
|
||||
{
|
||||
name: "127.0.0.1",
|
||||
host: "127.0.0.1",
|
||||
wantErr: true,
|
||||
errMsg: "localhost access not allowed",
|
||||
},
|
||||
{
|
||||
name: "127.0.0.1 with port",
|
||||
host: "127.0.0.1:8080",
|
||||
wantErr: true,
|
||||
errMsg: "localhost access not allowed",
|
||||
},
|
||||
{
|
||||
name: "IPv6 localhost",
|
||||
host: "::1",
|
||||
wantErr: true,
|
||||
errMsg: "invalid host:port format",
|
||||
},
|
||||
{
|
||||
name: "0.0.0.0",
|
||||
host: "0.0.0.0",
|
||||
wantErr: true,
|
||||
errMsg: "localhost access not allowed",
|
||||
},
|
||||
{
|
||||
name: "Private IP 192.168.1.1",
|
||||
host: "192.168.1.1",
|
||||
wantErr: true,
|
||||
errMsg: "private IP not allowed",
|
||||
},
|
||||
{
|
||||
name: "Private IP 10.0.0.1",
|
||||
host: "10.0.0.1",
|
||||
wantErr: true,
|
||||
errMsg: "private IP not allowed",
|
||||
},
|
||||
{
|
||||
name: "Private IP 172.16.0.1",
|
||||
host: "172.16.0.1",
|
||||
wantErr: true,
|
||||
errMsg: "private IP not allowed",
|
||||
},
|
||||
{
|
||||
name: "Public IP 8.8.8.8",
|
||||
host: "8.8.8.8",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Link-local IP",
|
||||
host: "169.254.1.1",
|
||||
wantErr: true,
|
||||
errMsg: "link-local IP not allowed",
|
||||
},
|
||||
{
|
||||
name: "Multicast IP",
|
||||
host: "224.0.0.1",
|
||||
wantErr: true,
|
||||
errMsg: "multicast IP not allowed",
|
||||
},
|
||||
{
|
||||
name: "Invalid host:port format",
|
||||
host: "example.com::",
|
||||
wantErr: true,
|
||||
errMsg: "invalid host:port format",
|
||||
},
|
||||
{
|
||||
name: "Valid international domain",
|
||||
host: "example.org",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Valid ccTLD",
|
||||
host: "example.co.uk",
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := handler.validateHost(tt.host)
|
||||
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Errorf("validateHost() expected error but got none")
|
||||
return
|
||||
}
|
||||
if !strings.Contains(err.Error(), tt.errMsg) {
|
||||
t.Errorf("validateHost() error = %v, expected error containing %v", err, tt.errMsg)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("validateHost() unexpected error = %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthHandler_buildURLWithParams tests URL building with parameters
|
||||
func TestAuthHandler_buildURLWithParams(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
baseURL string
|
||||
params url.Values
|
||||
expected string
|
||||
expectEmpty bool
|
||||
}{
|
||||
{
|
||||
name: "Absolute HTTPS URL",
|
||||
baseURL: "https://provider.com/auth",
|
||||
params: url.Values{
|
||||
"client_id": []string{"test-client"},
|
||||
"response_type": []string{"code"},
|
||||
},
|
||||
expected: "https://provider.com/auth?client_id=test-client&response_type=code",
|
||||
},
|
||||
{
|
||||
name: "Absolute HTTP URL",
|
||||
baseURL: "http://provider.com/auth",
|
||||
params: url.Values{
|
||||
"state": []string{"test-state"},
|
||||
},
|
||||
expected: "http://provider.com/auth?state=test-state",
|
||||
},
|
||||
{
|
||||
name: "Relative URL resolved against issuer",
|
||||
baseURL: "/oauth2/authorize",
|
||||
params: url.Values{
|
||||
"scope": []string{"openid"},
|
||||
},
|
||||
expected: "https://example.com/oauth2/authorize?scope=openid",
|
||||
},
|
||||
{
|
||||
name: "Root relative URL",
|
||||
baseURL: "/auth",
|
||||
params: url.Values{
|
||||
"nonce": []string{"test-nonce"},
|
||||
},
|
||||
expected: "https://example.com/auth?nonce=test-nonce",
|
||||
},
|
||||
{
|
||||
name: "Invalid absolute URL",
|
||||
baseURL: "https://localhost/auth",
|
||||
params: url.Values{},
|
||||
expectEmpty: true, // Should return empty string due to validation failure
|
||||
},
|
||||
{
|
||||
name: "Invalid relative URL when resolved",
|
||||
baseURL: "/auth",
|
||||
params: url.Values{},
|
||||
expected: "", // Should be empty because issuer validation would be tested separately
|
||||
expectEmpty: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := handler.buildURLWithParams(tt.baseURL, tt.params)
|
||||
|
||||
if tt.expectEmpty {
|
||||
if result != "" {
|
||||
t.Errorf("buildURLWithParams() expected empty string, got %v", result)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// For relative URLs, we expect them to be resolved against the issuer URL
|
||||
if !strings.HasPrefix(tt.baseURL, "http") {
|
||||
// Verify it starts with the issuer URL
|
||||
if !strings.HasPrefix(result, handler.issuerURL) {
|
||||
t.Errorf("buildURLWithParams() relative URL not resolved against issuer URL. Got %v", result)
|
||||
}
|
||||
}
|
||||
|
||||
// Parse the result to verify parameters
|
||||
parsedURL, err := url.Parse(result)
|
||||
if err != nil {
|
||||
t.Fatalf("buildURLWithParams() produced invalid URL: %v", err)
|
||||
}
|
||||
|
||||
// Verify all expected parameters are present
|
||||
resultParams := parsedURL.Query()
|
||||
for key, expectedValues := range tt.params {
|
||||
actualValues := resultParams[key]
|
||||
if len(actualValues) != len(expectedValues) {
|
||||
t.Errorf("Parameter %s: expected %d values, got %d", key, len(expectedValues), len(actualValues))
|
||||
continue
|
||||
}
|
||||
for i, expectedValue := range expectedValues {
|
||||
if actualValues[i] != expectedValue {
|
||||
t.Errorf("Parameter %s[%d]: expected %v, got %v", key, i, expectedValue, actualValues[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthHandler_buildURLWithParams_ParameterEncoding tests proper parameter encoding
|
||||
func TestAuthHandler_buildURLWithParams_ParameterEncoding(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
|
||||
|
||||
// Test special characters that need encoding
|
||||
params := url.Values{
|
||||
"redirect_uri": []string{"https://example.com/callback?test=value&other=data"},
|
||||
"state": []string{"state with spaces and & special chars"},
|
||||
"scope": []string{"openid profile email"},
|
||||
"special": []string{"value+with+plus&ersand=equals"},
|
||||
}
|
||||
|
||||
result := handler.buildURLWithParams("https://provider.com/auth", params)
|
||||
|
||||
parsedURL, err := url.Parse(result)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse result URL: %v", err)
|
||||
}
|
||||
|
||||
// Verify parameters are correctly encoded/decoded
|
||||
resultParams := parsedURL.Query()
|
||||
|
||||
expectedParams := map[string]string{
|
||||
"redirect_uri": "https://example.com/callback?test=value&other=data",
|
||||
"state": "state with spaces and & special chars",
|
||||
"scope": "openid profile email",
|
||||
"special": "value+with+plus&ersand=equals",
|
||||
}
|
||||
|
||||
for key, expectedValue := range expectedParams {
|
||||
actualValue := resultParams.Get(key)
|
||||
if actualValue != expectedValue {
|
||||
t.Errorf("Parameter %s: expected %v, got %v", key, expectedValue, actualValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthHandler_validateParsedURL tests validateParsedURL method
|
||||
func TestAuthHandler_validateParsedURL(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
url string
|
||||
wantErr bool
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "Valid HTTPS URL",
|
||||
url: "https://example.com/path",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Valid HTTP URL with warning",
|
||||
url: "http://example.com/path",
|
||||
wantErr: false, // Should not error but should log warning
|
||||
},
|
||||
{
|
||||
name: "Invalid scheme",
|
||||
url: "javascript:alert('xss')",
|
||||
wantErr: true,
|
||||
errMsg: "disallowed URL scheme",
|
||||
},
|
||||
{
|
||||
name: "Missing host",
|
||||
url: "https:///path",
|
||||
wantErr: true,
|
||||
errMsg: "missing host",
|
||||
},
|
||||
{
|
||||
name: "Path traversal",
|
||||
url: "https://example.com/path/../../../etc",
|
||||
wantErr: true,
|
||||
errMsg: "path traversal detected",
|
||||
},
|
||||
{
|
||||
name: "Invalid host (private IP)",
|
||||
url: "https://192.168.1.1/path",
|
||||
wantErr: true,
|
||||
errMsg: "invalid host",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
parsedURL, err := url.Parse(tt.url)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse test URL: %v", err)
|
||||
}
|
||||
|
||||
err = handler.validateParsedURL(parsedURL)
|
||||
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Errorf("validateParsedURL() expected error but got none")
|
||||
return
|
||||
}
|
||||
if !strings.Contains(err.Error(), tt.errMsg) {
|
||||
t.Errorf("validateParsedURL() error = %v, expected error containing %v", err, tt.errMsg)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("validateParsedURL() unexpected error = %v", err)
|
||||
}
|
||||
|
||||
// Check for HTTP warning in debug logs
|
||||
if parsedURL.Scheme == "http" && len(logger.debugMessages) > 0 {
|
||||
found := false
|
||||
for _, msg := range logger.debugMessages {
|
||||
if strings.Contains(msg, "Warning: Using HTTP scheme") {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("Expected HTTP scheme warning in debug logs")
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,428 @@
|
||||
// Package traefikoidc provides OIDC authentication middleware for Traefik.
|
||||
// This file contains tests for Auth0-specific audience validation scenarios.
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestAuth0Scenario1WithCustomAudience tests Auth0 scenario 1:
|
||||
// - Custom audience configured in plugin
|
||||
// - Authorize endpoint called WITH audience parameter
|
||||
// - ID token: aud = client_id
|
||||
// - Access token: aud = [userinfo, custom_audience]
|
||||
// Expected: Both tokens validate correctly
|
||||
func TestAuth0Scenario1WithCustomAudience(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
customAudience := "https://my-api.example.com"
|
||||
ts.tOidc.audience = customAudience
|
||||
|
||||
// Create ID token with aud = client_id (OIDC standard)
|
||||
idToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id", // ID token always has client_id
|
||||
"nonce": "test-nonce-scenario1", // ID tokens have nonce per OIDC spec
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
"email": "test@example.com",
|
||||
"jti": "id-token-jti",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create ID token: %v", err)
|
||||
}
|
||||
|
||||
// Create access token with aud = [userinfo, custom_audience]
|
||||
accessToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": []interface{}{
|
||||
"https://test-issuer.com/userinfo",
|
||||
customAudience, // Custom API audience
|
||||
},
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
"scope": "openid profile email read:data", // Access tokens have scope
|
||||
"jti": "access-token-jti",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create access token: %v", err)
|
||||
}
|
||||
|
||||
// Verify ID token validates against client_id
|
||||
cleanupReplayCache()
|
||||
initReplayCache()
|
||||
err = ts.tOidc.VerifyToken(idToken)
|
||||
if err != nil {
|
||||
t.Errorf("ID token validation failed (should validate against client_id): %v", err)
|
||||
}
|
||||
|
||||
// Verify access token validates against custom audience
|
||||
cleanupReplayCache()
|
||||
initReplayCache()
|
||||
err = ts.tOidc.VerifyToken(accessToken)
|
||||
if err != nil {
|
||||
t.Errorf("Access token validation failed (should validate against custom audience): %v", err)
|
||||
}
|
||||
|
||||
// Verify buildAuthURL includes audience parameter (URL-encoded)
|
||||
authURL := ts.tOidc.buildAuthURL("https://example.com/callback", "state", "nonce", "")
|
||||
if !strings.Contains(authURL, "audience=") {
|
||||
t.Errorf("Auth URL should contain audience parameter when custom audience is configured, got: %s", authURL)
|
||||
}
|
||||
// Verify the audience is properly URL-encoded (contains %3A for :, %2F for /)
|
||||
if !strings.Contains(authURL, "audience=https%3A%2F%2Fmy-api.example.com") {
|
||||
t.Errorf("Auth URL should contain URL-encoded custom audience, got: %s", authURL)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuth0Scenario2DefaultAudience tests Auth0 scenario 2:
|
||||
// - No custom audience configured (defaults to client_id)
|
||||
// - Authorize endpoint called WITHOUT audience parameter
|
||||
// - ID token: aud = client_id
|
||||
// - Access token: aud = [userinfo, default_audience] (no client_id)
|
||||
// Expected: ID token validates, access token falls back to ID token validation
|
||||
func TestAuth0Scenario2DefaultAudience(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
// No custom audience - defaults to client_id
|
||||
ts.tOidc.audience = ts.tOidc.clientID
|
||||
|
||||
// Create ID token with aud = client_id
|
||||
idToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"nonce": "test-nonce-scenario2", // ID tokens have nonce per OIDC spec
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
"email": "test@example.com",
|
||||
"jti": "id-token-jti-2",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create ID token: %v", err)
|
||||
}
|
||||
|
||||
// Create access token with aud = [userinfo, some_default_audience]
|
||||
// This represents Auth0's default audience behavior
|
||||
accessToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": []interface{}{
|
||||
"https://test-issuer.com/userinfo",
|
||||
"https://test-issuer.com/api/v2/", // Default Auth0 Management API
|
||||
},
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
"scope": "openid profile email",
|
||||
"jti": "access-token-jti-2",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create access token: %v", err)
|
||||
}
|
||||
|
||||
// Verify ID token validates
|
||||
cleanupReplayCache()
|
||||
initReplayCache()
|
||||
err = ts.tOidc.VerifyToken(idToken)
|
||||
if err != nil {
|
||||
t.Errorf("ID token validation failed: %v", err)
|
||||
}
|
||||
|
||||
// Access token won't have client_id in aud, so it will fail validation
|
||||
// This is expected for scenario 2 - the session validation relies on ID token
|
||||
cleanupReplayCache()
|
||||
initReplayCache()
|
||||
err = ts.tOidc.VerifyToken(accessToken)
|
||||
if err == nil {
|
||||
t.Logf("Access token validation passed (unexpected but OK if client_id is in aud array)")
|
||||
} else {
|
||||
// Expected failure - access token doesn't have client_id in aud
|
||||
t.Logf("Access token validation failed as expected (aud doesn't contain client_id): %v", err)
|
||||
}
|
||||
|
||||
// Verify buildAuthURL does NOT include audience parameter (since audience == client_id)
|
||||
authURL := ts.tOidc.buildAuthURL("https://example.com/callback", "state", "nonce", "")
|
||||
if strings.Contains(authURL, "audience=") {
|
||||
t.Errorf("Auth URL should NOT contain audience parameter when audience equals client_id, got: %s", authURL)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuth0Scenario3OpaqueAccessToken tests Auth0 scenario 3:
|
||||
// - No custom audience configured
|
||||
// - No default audience in Auth0
|
||||
// - ID token: aud = client_id
|
||||
// - Access token: opaque (not JWT)
|
||||
// Expected: ID token validates, opaque access token is accepted
|
||||
func TestAuth0Scenario3OpaqueAccessToken(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
// Enable opaque tokens for this scenario (Option C requirement)
|
||||
ts.tOidc.allowOpaqueTokens = true
|
||||
|
||||
// No custom audience
|
||||
ts.tOidc.audience = ts.tOidc.clientID
|
||||
|
||||
// Create ID token
|
||||
idToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"nonce": "test-nonce-scenario3", // ID tokens have nonce per OIDC spec
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
"email": "test@example.com",
|
||||
"jti": "id-token-jti-3",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create ID token: %v", err)
|
||||
}
|
||||
|
||||
// Opaque access token (not a JWT - just a random string)
|
||||
opaqueAccessToken := "opaque_access_token_random_string_12345"
|
||||
|
||||
// Verify ID token validates
|
||||
cleanupReplayCache()
|
||||
initReplayCache()
|
||||
err = ts.tOidc.VerifyToken(idToken)
|
||||
if err != nil {
|
||||
t.Errorf("ID token validation failed: %v", err)
|
||||
}
|
||||
|
||||
// Opaque access token should fail JWT validation (expected)
|
||||
err = ts.tOidc.VerifyToken(opaqueAccessToken)
|
||||
if err == nil {
|
||||
t.Error("Opaque access token should fail JWT validation")
|
||||
} else {
|
||||
t.Logf("Opaque access token correctly rejected by JWT validator: %v", err)
|
||||
}
|
||||
|
||||
// Test that validateStandardTokens handles opaque tokens correctly
|
||||
// by falling back to ID token validation
|
||||
req := httptest.NewRequest("GET", "https://example.com/test", nil)
|
||||
|
||||
session, err := ts.tOidc.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
session.SetAuthenticated(true)
|
||||
session.SetAccessToken(opaqueAccessToken)
|
||||
session.SetIDToken(idToken)
|
||||
|
||||
authenticated, needsRefresh, expired := ts.tOidc.validateStandardTokens(session)
|
||||
if !authenticated || needsRefresh || expired {
|
||||
t.Errorf("Session with opaque access token and valid ID token should be authenticated. Got: auth=%v, refresh=%v, expired=%v",
|
||||
authenticated, needsRefresh, expired)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuth0AudienceArrayValidation tests that audience validation
|
||||
// correctly handles array audiences (common in Auth0)
|
||||
func TestAuth0AudienceArrayValidation(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
customAudience := "https://my-api.example.com"
|
||||
ts.tOidc.audience = customAudience
|
||||
|
||||
// Access token with audience as array containing our custom audience
|
||||
accessToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": []interface{}{
|
||||
"https://test-issuer.com/userinfo",
|
||||
customAudience,
|
||||
"https://another-api.example.com",
|
||||
},
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
"scope": "openid profile email read:data write:data",
|
||||
"jti": "array-aud-token-jti",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create access token: %v", err)
|
||||
}
|
||||
|
||||
// Should validate successfully - custom audience is in the array
|
||||
cleanupReplayCache()
|
||||
initReplayCache()
|
||||
err = ts.tOidc.VerifyToken(accessToken)
|
||||
if err != nil {
|
||||
t.Errorf("Access token with audience array should validate when custom audience is present: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuth0MismatchedAudience tests that tokens with wrong audience fail validation
|
||||
func TestAuth0MismatchedAudience(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
customAudience := "https://my-api.example.com"
|
||||
ts.tOidc.audience = customAudience
|
||||
|
||||
// Access token with WRONG audience
|
||||
accessToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": []interface{}{
|
||||
"https://test-issuer.com/userinfo",
|
||||
"https://different-api.example.com", // Wrong audience
|
||||
},
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
"scope": "openid profile email",
|
||||
"jti": "wrong-aud-token-jti",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create access token: %v", err)
|
||||
}
|
||||
|
||||
// Should fail validation - audience doesn't match
|
||||
cleanupReplayCache()
|
||||
initReplayCache()
|
||||
err = ts.tOidc.VerifyToken(accessToken)
|
||||
if err == nil {
|
||||
t.Error("Access token with wrong audience should fail validation")
|
||||
} else if !strings.Contains(err.Error(), "invalid audience") {
|
||||
t.Errorf("Expected 'invalid audience' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuth0Scenario2StrictMode tests strict audience validation mode:
|
||||
// - Scenario 2 (access token with wrong audience) should be REJECTED
|
||||
// - strictAudienceValidation=true prevents fallback to ID token
|
||||
// - This addresses Allan's security concerns about audience bypass
|
||||
func TestAuth0Scenario2StrictMode(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
// Enable strict mode to prevent Scenario 2 bypass (Option C)
|
||||
ts.tOidc.strictAudienceValidation = true
|
||||
|
||||
// Configure custom audience
|
||||
customAudience := "https://my-api.example.com"
|
||||
ts.tOidc.audience = customAudience
|
||||
|
||||
// Create ID token with aud = client_id (valid)
|
||||
idToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"nonce": "test-nonce-strict",
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
"email": "test@example.com",
|
||||
"jti": "id-token-strict-jti",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create ID token: %v", err)
|
||||
}
|
||||
|
||||
// Create access token with WRONG audience (doesn't include custom audience)
|
||||
accessToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": []interface{}{
|
||||
"https://test-issuer.com/userinfo",
|
||||
"https://wrong-api.example.com", // Wrong audience - not our custom audience
|
||||
},
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
"scope": "openid profile email",
|
||||
"jti": "access-token-strict-jti",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create access token: %v", err)
|
||||
}
|
||||
|
||||
// Test session validation with wrong access token and valid ID token
|
||||
req := httptest.NewRequest("GET", "https://example.com/test", nil)
|
||||
session, err := ts.tOidc.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
session.SetAuthenticated(true)
|
||||
session.SetAccessToken(accessToken)
|
||||
session.SetIDToken(idToken)
|
||||
session.SetRefreshToken("test-refresh-token") // Add refresh token so it can attempt refresh
|
||||
|
||||
// In strict mode, this should FAIL (no fallback to ID token)
|
||||
authenticated, needsRefresh, expired := ts.tOidc.validateStandardTokens(session)
|
||||
if authenticated {
|
||||
t.Errorf("Strict mode: Session with wrong access token audience should be rejected, but got authenticated=true")
|
||||
}
|
||||
if !needsRefresh {
|
||||
t.Errorf("Strict mode: Should signal refresh needed after rejection, got needsRefresh=%v", needsRefresh)
|
||||
}
|
||||
if expired {
|
||||
t.Errorf("Strict mode: Should not mark as expired (should try refresh first), got expired=%v", expired)
|
||||
}
|
||||
|
||||
t.Logf("✓ Strict mode correctly rejected Scenario 2 (access token audience mismatch)")
|
||||
}
|
||||
|
||||
// TestIDTokenAlwaysValidatesAgainstClientID verifies that ID tokens
|
||||
// are ALWAYS validated against client_id, regardless of configured audience
|
||||
func TestIDTokenAlwaysValidatesAgainstClientID(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
// Configure a custom audience different from client_id
|
||||
customAudience := "https://my-api.example.com"
|
||||
ts.tOidc.audience = customAudience
|
||||
|
||||
// Create ID token with aud = client_id (per OIDC spec)
|
||||
idToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id", // ID token MUST have client_id
|
||||
"nonce": "test-nonce-123", // ID tokens have nonce for replay protection
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
"email": "test@example.com",
|
||||
"jti": "id-token-client-id-jti",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create ID token: %v", err)
|
||||
}
|
||||
|
||||
// Should validate successfully - ID tokens are checked against client_id
|
||||
cleanupReplayCache()
|
||||
initReplayCache()
|
||||
err = ts.tOidc.VerifyToken(idToken)
|
||||
if err != nil {
|
||||
t.Errorf("ID token should validate against client_id even when custom audience is configured: %v", err)
|
||||
}
|
||||
|
||||
// Create ID token with WRONG audience (should fail)
|
||||
wrongIDToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": customAudience, // WRONG - should be client_id
|
||||
"nonce": "test-nonce-wrong-456", // ID token has nonce, so it will be detected as ID token
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
"email": "test@example.com",
|
||||
"jti": "wrong-id-token-jti",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create wrong ID token: %v", err)
|
||||
}
|
||||
|
||||
// Should fail - ID tokens must have client_id as audience
|
||||
cleanupReplayCache()
|
||||
initReplayCache()
|
||||
err = ts.tOidc.VerifyToken(wrongIDToken)
|
||||
if err == nil {
|
||||
t.Error("ID token with custom audience (not client_id) should fail validation")
|
||||
}
|
||||
}
|
||||
@@ -1,920 +0,0 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/sessions"
|
||||
)
|
||||
|
||||
// generateLargeRealisticToken creates a realistic JWT token with a large payload
|
||||
// that mimics real-world OAuth tokens but with enough data to test chunking
|
||||
func generateLargeRealisticToken() string {
|
||||
// Create a realistic JWT header
|
||||
header := map[string]interface{}{
|
||||
"alg": "RS256",
|
||||
"typ": "JWT",
|
||||
"kid": "test-key-id",
|
||||
}
|
||||
headerJSON, _ := json.Marshal(header)
|
||||
headerB64 := base64.RawURLEncoding.EncodeToString(headerJSON)
|
||||
|
||||
// Create a large but realistic payload with many claims
|
||||
claims := map[string]interface{}{
|
||||
"iss": "https://auth.example.com/",
|
||||
"sub": "auth0|507f1f77bcf86cd799439011",
|
||||
"aud": []string{"https://api.example.com", "https://app.example.com"},
|
||||
"iat": 1516239022,
|
||||
"exp": 1516325422,
|
||||
"azp": "my_client_id",
|
||||
"scope": "openid profile email read:users write:users admin",
|
||||
"gty": "client-credentials",
|
||||
}
|
||||
|
||||
// Add many custom claims to make the token large
|
||||
for i := 0; i < 100; i++ {
|
||||
claimName := fmt.Sprintf("custom_claim_%d", i)
|
||||
claimValue := fmt.Sprintf("This is a test value for claim %d with some additional data to make it larger", i)
|
||||
claims[claimName] = claimValue
|
||||
}
|
||||
|
||||
// Add some array claims with multiple values
|
||||
claims["permissions"] = []string{
|
||||
"read:users", "write:users", "delete:users", "create:users",
|
||||
"read:posts", "write:posts", "delete:posts", "create:posts",
|
||||
"admin:all", "super:admin", "system:manage", "audit:view",
|
||||
}
|
||||
|
||||
claims["groups"] = []string{
|
||||
"administrators", "developers", "qa_team", "devops",
|
||||
"product_managers", "support_team", "security_team",
|
||||
}
|
||||
|
||||
payloadJSON, _ := json.Marshal(claims)
|
||||
payloadB64 := base64.RawURLEncoding.EncodeToString(payloadJSON)
|
||||
|
||||
// Create a mock signature (in real scenario this would be cryptographic)
|
||||
signature := base64.RawURLEncoding.EncodeToString(
|
||||
[]byte("mock_signature_with_some_additional_bytes_for_testing_purposes"))
|
||||
|
||||
return fmt.Sprintf("%s.%s.%s", headerB64, payloadB64, signature)
|
||||
}
|
||||
|
||||
// TestAuth0RedirectLoopFix tests the fixes applied to prevent Auth0 redirect loops
|
||||
// specifically focusing on:
|
||||
// 1. Consistent cookie configuration (Path="/", SameSite=Lax)
|
||||
// 2. CSRF token accessibility during OAuth callbacks
|
||||
// 3. Session cookie persistence across OAuth flow
|
||||
// 4. Redirect loop prevention
|
||||
func TestAuth0RedirectLoopFix(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
encryptionKey := "0123456789abcdef0123456789abcdef0123456789abcdef"
|
||||
|
||||
sm, err := NewSessionManager(encryptionKey, false, "", logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session manager: %v", err)
|
||||
}
|
||||
defer sm.Shutdown()
|
||||
|
||||
t.Run("CookieConfigurationConsistency", func(t *testing.T) {
|
||||
testCookieConfigurationConsistency(t, sm)
|
||||
})
|
||||
|
||||
t.Run("CSRFTokenAccessibility", func(t *testing.T) {
|
||||
testCSRFTokenAccessibility(t, sm)
|
||||
})
|
||||
|
||||
t.Run("SessionPersistenceAcrossOAuth", func(t *testing.T) {
|
||||
testSessionPersistenceAcrossOAuth(t, sm)
|
||||
})
|
||||
|
||||
t.Run("RedirectLoopPrevention", func(t *testing.T) {
|
||||
testRedirectLoopPrevention(t, sm)
|
||||
})
|
||||
|
||||
t.Run("CallbackCSRFValidation", func(t *testing.T) {
|
||||
testCallbackCSRFValidation(t, sm)
|
||||
})
|
||||
|
||||
t.Run("EdgeCases", func(t *testing.T) {
|
||||
testEdgeCases(t, sm)
|
||||
})
|
||||
}
|
||||
|
||||
// testCookieConfigurationConsistency verifies that cookies are configured
|
||||
// consistently with Path="/" and SameSite=Lax regardless of request headers
|
||||
func testCookieConfigurationConsistency(t *testing.T, sm *SessionManager) {
|
||||
tests := []struct {
|
||||
name string
|
||||
headers map[string]string
|
||||
expectPath string
|
||||
expectSame http.SameSite
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "StandardRequest",
|
||||
headers: map[string]string{
|
||||
"Host": "example.com",
|
||||
},
|
||||
expectPath: "/",
|
||||
expectSame: http.SameSiteLaxMode,
|
||||
description: "Standard HTTP request should get consistent cookie config",
|
||||
},
|
||||
{
|
||||
name: "XMLHttpRequest",
|
||||
headers: map[string]string{
|
||||
"Host": "example.com",
|
||||
"X-Requested-With": "XMLHttpRequest",
|
||||
"X-Forwarded-Proto": "https",
|
||||
},
|
||||
expectPath: "/",
|
||||
expectSame: http.SameSiteLaxMode,
|
||||
description: "XMLHttpRequest should still use SameSite=Lax (fix for redirect loop)",
|
||||
},
|
||||
{
|
||||
name: "HTTPSRequest",
|
||||
headers: map[string]string{
|
||||
"Host": "example.com",
|
||||
"X-Forwarded-Proto": "https",
|
||||
},
|
||||
expectPath: "/",
|
||||
expectSame: http.SameSiteLaxMode,
|
||||
description: "HTTPS requests should have consistent cookie config",
|
||||
},
|
||||
{
|
||||
name: "CustomDomainRequest",
|
||||
headers: map[string]string{
|
||||
"Host": "auth.example.com",
|
||||
"X-Forwarded-Host": "auth.example.com",
|
||||
"X-Forwarded-Proto": "https",
|
||||
},
|
||||
expectPath: "/",
|
||||
expectSame: http.SameSiteLaxMode,
|
||||
description: "Custom domain requests should maintain consistent config",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "http://example.com/callback", nil)
|
||||
|
||||
// Set headers
|
||||
for key, value := range tt.headers {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
// Get session and save it to trigger cookie setting
|
||||
session, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
// Set some session data to ensure it gets saved
|
||||
session.SetCSRF("test-csrf-token")
|
||||
session.SetAuthenticated(false)
|
||||
|
||||
err = session.Save(req, rw)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to save session: %v", err)
|
||||
}
|
||||
|
||||
// Verify cookie configuration
|
||||
cookies := rw.Result().Cookies()
|
||||
if len(cookies) == 0 {
|
||||
t.Fatal("No cookies set in response")
|
||||
}
|
||||
|
||||
for _, cookie := range cookies {
|
||||
if strings.HasPrefix(cookie.Name, "_oidc_raczylo") {
|
||||
if cookie.Path != tt.expectPath {
|
||||
t.Errorf("Expected Path=%s, got Path=%s for cookie %s",
|
||||
tt.expectPath, cookie.Path, cookie.Name)
|
||||
}
|
||||
if cookie.SameSite != tt.expectSame {
|
||||
t.Errorf("Expected SameSite=%v, got SameSite=%v for cookie %s",
|
||||
tt.expectSame, cookie.SameSite, cookie.Name)
|
||||
}
|
||||
t.Logf("Cookie %s: Path=%s, SameSite=%v, Secure=%v, HttpOnly=%v",
|
||||
cookie.Name, cookie.Path, cookie.SameSite, cookie.Secure, cookie.HttpOnly)
|
||||
}
|
||||
}
|
||||
|
||||
session.Clear(req, nil)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// testCSRFTokenAccessibility verifies that CSRF tokens remain accessible
|
||||
// during OAuth callbacks regardless of request type
|
||||
func testCSRFTokenAccessibility(t *testing.T, sm *SessionManager) {
|
||||
csrfToken := uuid.New().String()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
headers map[string]string
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "StandardCallback",
|
||||
headers: map[string]string{
|
||||
"Host": "example.com",
|
||||
},
|
||||
description: "Standard OAuth callback should access CSRF token",
|
||||
},
|
||||
{
|
||||
name: "AjaxCallback",
|
||||
headers: map[string]string{
|
||||
"Host": "example.com",
|
||||
"X-Requested-With": "XMLHttpRequest",
|
||||
},
|
||||
description: "AJAX OAuth callback should access CSRF token",
|
||||
},
|
||||
{
|
||||
name: "HTTPSCallback",
|
||||
headers: map[string]string{
|
||||
"Host": "example.com",
|
||||
"X-Forwarded-Proto": "https",
|
||||
},
|
||||
description: "HTTPS OAuth callback should access CSRF token",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Phase 1: Store CSRF token in session (auth initiation)
|
||||
initReq := httptest.NewRequest("GET", "http://example.com/protected", nil)
|
||||
for key, value := range tt.headers {
|
||||
initReq.Header.Set(key, value)
|
||||
}
|
||||
|
||||
initRw := httptest.NewRecorder()
|
||||
|
||||
session, err := sm.GetSession(initReq)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
session.SetCSRF(csrfToken)
|
||||
session.SetNonce("test-nonce")
|
||||
session.SetIncomingPath("/protected")
|
||||
|
||||
err = session.Save(initReq, initRw)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save session: %v", err)
|
||||
}
|
||||
|
||||
// Get cookies from response to simulate browser behavior
|
||||
storedCookies := initRw.Result().Cookies()
|
||||
|
||||
// Phase 2: OAuth callback with same cookies
|
||||
callbackReq := httptest.NewRequest("GET",
|
||||
"http://example.com/callback?state="+csrfToken+"&code=auth_code", nil)
|
||||
|
||||
for key, value := range tt.headers {
|
||||
callbackReq.Header.Set(key, value)
|
||||
}
|
||||
|
||||
// Add cookies to callback request
|
||||
for _, cookie := range storedCookies {
|
||||
callbackReq.AddCookie(cookie)
|
||||
}
|
||||
|
||||
// Get session in callback
|
||||
callbackSession, err := sm.GetSession(callbackReq)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get callback session: %v", err)
|
||||
}
|
||||
defer callbackSession.Clear(callbackReq, nil)
|
||||
|
||||
// Verify CSRF token is accessible
|
||||
retrievedCSRF := callbackSession.GetCSRF()
|
||||
if retrievedCSRF == "" {
|
||||
t.Error("CSRF token not accessible in callback session")
|
||||
}
|
||||
if retrievedCSRF != csrfToken {
|
||||
t.Errorf("CSRF token mismatch: expected %s, got %s", csrfToken, retrievedCSRF)
|
||||
}
|
||||
|
||||
// Verify other session data is accessible
|
||||
if callbackSession.GetNonce() != "test-nonce" {
|
||||
t.Error("Nonce not accessible in callback session")
|
||||
}
|
||||
if callbackSession.GetIncomingPath() != "/protected" {
|
||||
t.Error("Incoming path not accessible in callback session")
|
||||
}
|
||||
|
||||
t.Logf("CSRF token successfully retrieved in %s: %s", tt.name, retrievedCSRF)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// testSessionPersistenceAcrossOAuth verifies that session data persists
|
||||
// correctly across the OAuth flow without being lost due to cookie issues
|
||||
func testSessionPersistenceAcrossOAuth(t *testing.T, sm *SessionManager) {
|
||||
// Simulate complete OAuth flow
|
||||
req := httptest.NewRequest("GET", "http://example.com/protected", nil)
|
||||
req.Header.Set("Host", "example.com")
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
// Phase 1: Initial authentication request
|
||||
session, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get initial session: %v", err)
|
||||
}
|
||||
|
||||
csrfToken := uuid.New().String()
|
||||
nonce := "test-nonce-" + uuid.New().String()
|
||||
|
||||
session.SetCSRF(csrfToken)
|
||||
session.SetNonce(nonce)
|
||||
session.SetIncomingPath("/protected")
|
||||
session.SetCodeVerifier("test-code-verifier")
|
||||
|
||||
err = session.Save(req, rw)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save initial session: %v", err)
|
||||
}
|
||||
|
||||
initialCookies := rw.Result().Cookies()
|
||||
if len(initialCookies) == 0 {
|
||||
t.Fatal("No cookies set in initial response")
|
||||
}
|
||||
|
||||
// Phase 2: OAuth provider redirect (user authenticates)
|
||||
redirectReq := httptest.NewRequest("GET", "https://auth0.example.com/authorize", nil)
|
||||
// Add cookies as browser would
|
||||
for _, cookie := range initialCookies {
|
||||
redirectReq.AddCookie(cookie)
|
||||
}
|
||||
|
||||
// Phase 3: OAuth callback
|
||||
callbackReq := httptest.NewRequest("GET",
|
||||
"http://example.com/callback?state="+csrfToken+"&code=auth_code_12345", nil)
|
||||
callbackReq.Header.Set("Host", "example.com")
|
||||
callbackReq.Header.Set("X-Forwarded-Proto", "https")
|
||||
|
||||
// Add all cookies from initial response
|
||||
for _, cookie := range initialCookies {
|
||||
callbackReq.AddCookie(cookie)
|
||||
}
|
||||
|
||||
callbackRw := httptest.NewRecorder()
|
||||
|
||||
callbackSession, err := sm.GetSession(callbackReq)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get callback session: %v", err)
|
||||
}
|
||||
defer callbackSession.Clear(callbackReq, nil)
|
||||
|
||||
// Verify all session data persisted
|
||||
if callbackSession.GetCSRF() != csrfToken {
|
||||
t.Errorf("CSRF token not persisted: expected %s, got %s",
|
||||
csrfToken, callbackSession.GetCSRF())
|
||||
}
|
||||
if callbackSession.GetNonce() != nonce {
|
||||
t.Errorf("Nonce not persisted: expected %s, got %s",
|
||||
nonce, callbackSession.GetNonce())
|
||||
}
|
||||
if callbackSession.GetIncomingPath() != "/protected" {
|
||||
t.Errorf("Incoming path not persisted: expected /protected, got %s",
|
||||
callbackSession.GetIncomingPath())
|
||||
}
|
||||
if callbackSession.GetCodeVerifier() != "test-code-verifier" {
|
||||
t.Errorf("Code verifier not persisted: expected test-code-verifier, got %s",
|
||||
callbackSession.GetCodeVerifier())
|
||||
}
|
||||
|
||||
// Simulate successful authentication
|
||||
callbackSession.SetAuthenticated(true)
|
||||
callbackSession.SetEmail("user@example.com")
|
||||
callbackSession.SetAccessToken("access_token_12345")
|
||||
callbackSession.SetRefreshToken("refresh_token_12345")
|
||||
callbackSession.SetIDToken("id_token_12345")
|
||||
|
||||
// Clear OAuth-specific data
|
||||
callbackSession.SetCSRF("")
|
||||
callbackSession.SetNonce("")
|
||||
callbackSession.SetCodeVerifier("")
|
||||
callbackSession.ResetRedirectCount()
|
||||
|
||||
err = callbackSession.Save(callbackReq, callbackRw)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to save callback session: %v", err)
|
||||
}
|
||||
|
||||
t.Log("OAuth flow simulation completed successfully - session data persisted")
|
||||
}
|
||||
|
||||
// testRedirectLoopPrevention verifies that the redirect loop prevention
|
||||
// mechanisms work correctly
|
||||
func testRedirectLoopPrevention(t *testing.T, sm *SessionManager) {
|
||||
req := httptest.NewRequest("GET", "http://example.com/protected", nil)
|
||||
req.Header.Set("Host", "example.com")
|
||||
|
||||
session, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
defer session.Clear(req, nil)
|
||||
|
||||
// Test redirect count tracking
|
||||
initialCount := session.GetRedirectCount()
|
||||
if initialCount != 0 {
|
||||
t.Errorf("Initial redirect count should be 0, got %d", initialCount)
|
||||
}
|
||||
|
||||
// Simulate multiple redirect attempts
|
||||
for i := 1; i <= 6; i++ {
|
||||
session.IncrementRedirectCount()
|
||||
count := session.GetRedirectCount()
|
||||
if count != i {
|
||||
t.Errorf("Expected redirect count %d, got %d", i, count)
|
||||
}
|
||||
|
||||
// Test that redirect loop detection kicks in at 5 redirects
|
||||
if i >= 5 {
|
||||
t.Logf("Redirect count at %d - should trigger loop detection", count)
|
||||
}
|
||||
}
|
||||
|
||||
// Test reset functionality
|
||||
session.ResetRedirectCount()
|
||||
if session.GetRedirectCount() != 0 {
|
||||
t.Errorf("Redirect count should be 0 after reset, got %d", session.GetRedirectCount())
|
||||
}
|
||||
|
||||
t.Log("Redirect loop prevention tests passed")
|
||||
}
|
||||
|
||||
// testCallbackCSRFValidation tests CSRF token validation in OAuth callbacks
|
||||
func testCallbackCSRFValidation(t *testing.T, sm *SessionManager) {
|
||||
tests := []struct {
|
||||
name string
|
||||
storedCSRF string
|
||||
callbackState string
|
||||
shouldSucceed bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "ValidCSRF",
|
||||
storedCSRF: "valid-csrf-token-123",
|
||||
callbackState: "valid-csrf-token-123",
|
||||
shouldSucceed: true,
|
||||
description: "Valid CSRF token should pass validation",
|
||||
},
|
||||
{
|
||||
name: "InvalidCSRF",
|
||||
storedCSRF: "valid-csrf-token-123",
|
||||
callbackState: "different-csrf-token-456",
|
||||
shouldSucceed: false,
|
||||
description: "Invalid CSRF token should fail validation",
|
||||
},
|
||||
{
|
||||
name: "EmptyStoredCSRF",
|
||||
storedCSRF: "",
|
||||
callbackState: "some-csrf-token",
|
||||
shouldSucceed: false,
|
||||
description: "Empty stored CSRF should fail validation",
|
||||
},
|
||||
{
|
||||
name: "EmptyCallbackState",
|
||||
storedCSRF: "valid-csrf-token-123",
|
||||
callbackState: "",
|
||||
shouldSucceed: false,
|
||||
description: "Empty callback state should fail validation",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Setup phase - store CSRF token
|
||||
setupReq := httptest.NewRequest("GET", "http://example.com/auth", nil)
|
||||
setupReq.Header.Set("Host", "example.com")
|
||||
|
||||
session, err := sm.GetSession(setupReq)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get setup session: %v", err)
|
||||
}
|
||||
|
||||
if tt.storedCSRF != "" {
|
||||
session.SetCSRF(tt.storedCSRF)
|
||||
}
|
||||
|
||||
setupRw := httptest.NewRecorder()
|
||||
err = session.Save(setupReq, setupRw)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save setup session: %v", err)
|
||||
}
|
||||
|
||||
setupCookies := setupRw.Result().Cookies()
|
||||
|
||||
// Callback phase - validate CSRF
|
||||
callbackURL := "http://example.com/callback"
|
||||
if tt.callbackState != "" {
|
||||
callbackURL += "?state=" + tt.callbackState + "&code=test_code"
|
||||
} else {
|
||||
callbackURL += "?code=test_code"
|
||||
}
|
||||
|
||||
callbackReq := httptest.NewRequest("GET", callbackURL, nil)
|
||||
callbackReq.Header.Set("Host", "example.com")
|
||||
|
||||
// Add cookies
|
||||
for _, cookie := range setupCookies {
|
||||
callbackReq.AddCookie(cookie)
|
||||
}
|
||||
|
||||
callbackSession, err := sm.GetSession(callbackReq)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get callback session: %v", err)
|
||||
}
|
||||
defer callbackSession.Clear(callbackReq, nil)
|
||||
|
||||
// Perform CSRF validation
|
||||
storedCSRF := callbackSession.GetCSRF()
|
||||
stateParam := callbackReq.URL.Query().Get("state")
|
||||
|
||||
csrfValid := (storedCSRF != "" && stateParam != "" && storedCSRF == stateParam)
|
||||
|
||||
if tt.shouldSucceed && !csrfValid {
|
||||
t.Errorf("CSRF validation should have succeeded but failed. Stored: '%s', State: '%s'",
|
||||
storedCSRF, stateParam)
|
||||
}
|
||||
if !tt.shouldSucceed && csrfValid {
|
||||
t.Errorf("CSRF validation should have failed but succeeded. Stored: '%s', State: '%s'",
|
||||
storedCSRF, stateParam)
|
||||
}
|
||||
|
||||
t.Logf("CSRF validation test '%s': stored='%s', state='%s', valid=%v",
|
||||
tt.name, storedCSRF, stateParam, csrfValid)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// testEdgeCases tests various edge cases that could cause redirect loops
|
||||
func testEdgeCases(t *testing.T, sm *SessionManager) {
|
||||
t.Run("MissingHeaders", func(t *testing.T) {
|
||||
// Test with minimal headers
|
||||
req := httptest.NewRequest("GET", "http://localhost/callback", nil)
|
||||
|
||||
session, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session with minimal headers: %v", err)
|
||||
}
|
||||
defer session.Clear(req, nil)
|
||||
|
||||
session.SetCSRF("test-csrf")
|
||||
rw := httptest.NewRecorder()
|
||||
err = session.Save(req, rw)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to save session with minimal headers: %v", err)
|
||||
}
|
||||
|
||||
// Verify cookies still have consistent configuration
|
||||
cookies := rw.Result().Cookies()
|
||||
for _, cookie := range cookies {
|
||||
if strings.HasPrefix(cookie.Name, "_oidc_raczylo") {
|
||||
if cookie.Path != "/" {
|
||||
t.Errorf("Cookie path inconsistent with minimal headers: got %s", cookie.Path)
|
||||
}
|
||||
if cookie.SameSite != http.SameSiteLaxMode {
|
||||
t.Errorf("Cookie SameSite inconsistent with minimal headers: got %v", cookie.SameSite)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DifferentDomains", func(t *testing.T) {
|
||||
domains := []string{"example.com", "auth.example.com", "sub.auth.example.com"}
|
||||
|
||||
for _, domain := range domains {
|
||||
req := httptest.NewRequest("GET", "http://"+domain+"/callback", nil)
|
||||
req.Header.Set("Host", domain)
|
||||
req.Header.Set("X-Forwarded-Host", domain)
|
||||
|
||||
session, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session for domain %s: %v", domain, err)
|
||||
}
|
||||
|
||||
session.SetCSRF("test-csrf-" + domain)
|
||||
rw := httptest.NewRecorder()
|
||||
err = session.Save(req, rw)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to save session for domain %s: %v", domain, err)
|
||||
}
|
||||
|
||||
// Verify consistent cookie configuration across domains
|
||||
cookies := rw.Result().Cookies()
|
||||
for _, cookie := range cookies {
|
||||
if strings.HasPrefix(cookie.Name, "_oidc_raczylo") {
|
||||
if cookie.Path != "/" {
|
||||
t.Errorf("Domain %s: Cookie path inconsistent: got %s", domain, cookie.Path)
|
||||
}
|
||||
if cookie.SameSite != http.SameSiteLaxMode {
|
||||
t.Errorf("Domain %s: Cookie SameSite inconsistent: got %v", domain, cookie.SameSite)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
session.Clear(req, nil)
|
||||
t.Logf("Domain %s: Cookie configuration consistent", domain)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ConcurrentSessions", func(t *testing.T) {
|
||||
// Test that multiple concurrent sessions don't interfere
|
||||
const numSessions = 5
|
||||
sessions := make([]*SessionData, numSessions)
|
||||
|
||||
for i := 0; i < numSessions; i++ {
|
||||
req := httptest.NewRequest("GET", "http://example.com/test", nil)
|
||||
req.Header.Set("Host", "example.com")
|
||||
|
||||
session, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session %d: %v", i, err)
|
||||
}
|
||||
sessions[i] = session
|
||||
|
||||
// Set unique data for each session
|
||||
session.SetCSRF("csrf-" + string(rune('A'+i)))
|
||||
session.SetNonce("nonce-" + string(rune('A'+i)))
|
||||
}
|
||||
|
||||
// Verify each session has its own data
|
||||
for i, session := range sessions {
|
||||
expectedCSRF := "csrf-" + string(rune('A'+i))
|
||||
expectedNonce := "nonce-" + string(rune('A'+i))
|
||||
|
||||
if session.GetCSRF() != expectedCSRF {
|
||||
t.Errorf("Session %d CSRF mismatch: expected %s, got %s",
|
||||
i, expectedCSRF, session.GetCSRF())
|
||||
}
|
||||
if session.GetNonce() != expectedNonce {
|
||||
t.Errorf("Session %d nonce mismatch: expected %s, got %s",
|
||||
i, expectedNonce, session.GetNonce())
|
||||
}
|
||||
|
||||
session.Clear(nil, nil)
|
||||
}
|
||||
|
||||
t.Log("Concurrent sessions test passed")
|
||||
})
|
||||
|
||||
t.Run("LargeCookieHandling", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "http://example.com/test", nil)
|
||||
req.Header.Set("Host", "example.com")
|
||||
|
||||
session, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
defer session.Clear(req, nil)
|
||||
|
||||
// Test with large realistic JWT token that might require chunking
|
||||
largeToken := generateLargeRealisticToken()
|
||||
session.SetAccessToken(largeToken)
|
||||
session.SetCSRF("test-csrf")
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
err = session.Save(req, rw)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to save session with large token: %v", err)
|
||||
}
|
||||
|
||||
// Verify cookies are still consistent even with chunking
|
||||
cookies := rw.Result().Cookies()
|
||||
for _, cookie := range cookies {
|
||||
if strings.HasPrefix(cookie.Name, "_oidc_raczylo") {
|
||||
if cookie.Path != "/" {
|
||||
t.Errorf("Large cookie path inconsistent: got %s", cookie.Path)
|
||||
}
|
||||
if cookie.SameSite != http.SameSiteLaxMode {
|
||||
t.Errorf("Large cookie SameSite inconsistent: got %v", cookie.SameSite)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Verify token can be retrieved correctly
|
||||
if session.GetAccessToken() != largeToken {
|
||||
t.Error("Large access token not retrieved correctly")
|
||||
}
|
||||
|
||||
t.Log("Large cookie handling test passed")
|
||||
})
|
||||
}
|
||||
|
||||
// TestSessionManagerEnhanceSessionSecurity tests the enhanced session security
|
||||
// to ensure SameSite is consistently Lax and not dynamically switched
|
||||
func TestSessionManagerEnhanceSessionSecurity(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
encryptionKey := "0123456789abcdef0123456789abcdef0123456789abcdef"
|
||||
|
||||
sm, err := NewSessionManager(encryptionKey, false, "", logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session manager: %v", err)
|
||||
}
|
||||
defer sm.Shutdown()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
headers map[string]string
|
||||
expectSame http.SameSite
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "StandardRequest",
|
||||
headers: map[string]string{
|
||||
"Host": "example.com",
|
||||
},
|
||||
expectSame: http.SameSiteLaxMode,
|
||||
description: "Standard request should use SameSite=Lax",
|
||||
},
|
||||
{
|
||||
name: "XMLHttpRequestHeader",
|
||||
headers: map[string]string{
|
||||
"Host": "example.com",
|
||||
"X-Requested-With": "XMLHttpRequest",
|
||||
},
|
||||
expectSame: http.SameSiteLaxMode,
|
||||
description: "XMLHttpRequest should still use SameSite=Lax (no dynamic switching)",
|
||||
},
|
||||
{
|
||||
name: "AjaxWithForwardedProto",
|
||||
headers: map[string]string{
|
||||
"Host": "example.com",
|
||||
"X-Requested-With": "XMLHttpRequest",
|
||||
"X-Forwarded-Proto": "https",
|
||||
},
|
||||
expectSame: http.SameSiteLaxMode,
|
||||
description: "AJAX HTTPS request should use SameSite=Lax (no dynamic switching)",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "http://example.com/test", nil)
|
||||
for key, value := range tt.headers {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
|
||||
// Test the EnhanceSessionSecurity method directly
|
||||
options := &sessions.Options{}
|
||||
enhanced := sm.EnhanceSessionSecurity(options, req)
|
||||
|
||||
if enhanced.SameSite != tt.expectSame {
|
||||
t.Errorf("Expected SameSite=%v, got SameSite=%v for %s",
|
||||
tt.expectSame, enhanced.SameSite, tt.description)
|
||||
}
|
||||
|
||||
// Verify Path is always "/"
|
||||
if enhanced.Path != "/" {
|
||||
t.Errorf("Expected Path='/', got Path='%s' for %s",
|
||||
enhanced.Path, tt.description)
|
||||
}
|
||||
|
||||
// Verify HttpOnly is always true
|
||||
if !enhanced.HttpOnly {
|
||||
t.Errorf("Expected HttpOnly=true, got HttpOnly=false for %s", tt.description)
|
||||
}
|
||||
|
||||
t.Logf("%s: SameSite=%v, Path=%s, HttpOnly=%v, Secure=%v",
|
||||
tt.name, enhanced.SameSite, enhanced.Path, enhanced.HttpOnly, enhanced.Secure)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestCallbackHandlerIntegration tests the full callback handler integration
|
||||
// to ensure CSRF tokens work correctly with the fixed cookie configuration
|
||||
func TestCallbackHandlerIntegration(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
encryptionKey := "0123456789abcdef0123456789abcdef0123456789abcdef"
|
||||
|
||||
sm, err := NewSessionManager(encryptionKey, false, "", logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session manager: %v", err)
|
||||
}
|
||||
defer sm.Shutdown()
|
||||
|
||||
// Simulate a complete OAuth flow with various request types
|
||||
scenarios := []struct {
|
||||
name string
|
||||
headers map[string]string
|
||||
}{
|
||||
{
|
||||
name: "StandardBrowser",
|
||||
headers: map[string]string{
|
||||
"Host": "example.com",
|
||||
"User-Agent": "Mozilla/5.0 (Browser)",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "AjaxRequest",
|
||||
headers: map[string]string{
|
||||
"Host": "example.com",
|
||||
"User-Agent": "Mozilla/5.0 (Browser)",
|
||||
"X-Requested-With": "XMLHttpRequest",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "HTTPSProxy",
|
||||
headers: map[string]string{
|
||||
"Host": "example.com",
|
||||
"User-Agent": "Mozilla/5.0 (Browser)",
|
||||
"X-Forwarded-Proto": "https",
|
||||
"X-Forwarded-Host": "example.com",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
t.Run(scenario.name, func(t *testing.T) {
|
||||
// Phase 1: Auth initiation - store CSRF token
|
||||
initReq := httptest.NewRequest("GET", "http://example.com/protected", nil)
|
||||
for key, value := range scenario.headers {
|
||||
initReq.Header.Set(key, value)
|
||||
}
|
||||
|
||||
initRw := httptest.NewRecorder()
|
||||
|
||||
session, err := sm.GetSession(initReq)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get init session: %v", err)
|
||||
}
|
||||
|
||||
csrfToken := uuid.New().String()
|
||||
session.SetCSRF(csrfToken)
|
||||
session.SetNonce("test-nonce")
|
||||
session.SetIncomingPath("/protected")
|
||||
|
||||
err = session.Save(initReq, initRw)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save init session: %v", err)
|
||||
}
|
||||
|
||||
initCookies := initRw.Result().Cookies()
|
||||
|
||||
// Phase 2: OAuth callback - validate CSRF token access
|
||||
callbackReq := httptest.NewRequest("GET",
|
||||
"http://example.com/callback?state="+csrfToken+"&code=test_code", nil)
|
||||
|
||||
for key, value := range scenario.headers {
|
||||
callbackReq.Header.Set(key, value)
|
||||
}
|
||||
|
||||
// Add cookies from init phase
|
||||
for _, cookie := range initCookies {
|
||||
callbackReq.AddCookie(cookie)
|
||||
}
|
||||
|
||||
callbackSession, err := sm.GetSession(callbackReq)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get callback session: %v", err)
|
||||
}
|
||||
defer callbackSession.Clear(callbackReq, nil)
|
||||
|
||||
// This is the critical test - CSRF token must be accessible
|
||||
retrievedCSRF := callbackSession.GetCSRF()
|
||||
if retrievedCSRF == "" {
|
||||
t.Errorf("Scenario %s: CSRF token not accessible in callback", scenario.name)
|
||||
}
|
||||
if retrievedCSRF != csrfToken {
|
||||
t.Errorf("Scenario %s: CSRF token mismatch - expected %s, got %s",
|
||||
scenario.name, csrfToken, retrievedCSRF)
|
||||
}
|
||||
|
||||
// Validate state parameter matches CSRF token
|
||||
stateParam := callbackReq.URL.Query().Get("state")
|
||||
if stateParam != csrfToken {
|
||||
t.Errorf("Scenario %s: State parameter mismatch - expected %s, got %s",
|
||||
scenario.name, csrfToken, stateParam)
|
||||
}
|
||||
|
||||
// Simulate successful CSRF validation
|
||||
if retrievedCSRF != "" && retrievedCSRF == stateParam {
|
||||
t.Logf("Scenario %s: CSRF validation successful", scenario.name)
|
||||
} else {
|
||||
t.Errorf("Scenario %s: CSRF validation failed", scenario.name)
|
||||
}
|
||||
|
||||
// Verify other session data persisted
|
||||
if callbackSession.GetNonce() != "test-nonce" {
|
||||
t.Errorf("Scenario %s: Nonce not persisted", scenario.name)
|
||||
}
|
||||
if callbackSession.GetIncomingPath() != "/protected" {
|
||||
t.Errorf("Scenario %s: Incoming path not persisted", scenario.name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
+336
@@ -0,0 +1,336 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// ============================================================================
|
||||
// AUTHENTICATION FLOW
|
||||
// ============================================================================
|
||||
|
||||
// validateRedirectCount checks if redirect limit is exceeded and handles the error
|
||||
func (t *TraefikOidc) validateRedirectCount(session *SessionData, rw http.ResponseWriter, req *http.Request) error {
|
||||
const maxRedirects = 5
|
||||
redirectCount := session.GetRedirectCount()
|
||||
if redirectCount >= maxRedirects {
|
||||
t.logger.Errorf("Maximum redirect limit (%d) exceeded, possible redirect loop detected", maxRedirects)
|
||||
session.ResetRedirectCount()
|
||||
t.sendErrorResponse(rw, req, "Authentication failed: Too many redirects", http.StatusLoopDetected)
|
||||
return fmt.Errorf("redirect limit exceeded")
|
||||
}
|
||||
|
||||
session.IncrementRedirectCount()
|
||||
return nil
|
||||
}
|
||||
|
||||
// generatePKCEParameters generates PKCE code verifier and challenge if PKCE is enabled
|
||||
func (t *TraefikOidc) generatePKCEParameters() (string, string, error) {
|
||||
if !t.enablePKCE {
|
||||
return "", "", nil
|
||||
}
|
||||
|
||||
codeVerifier, err := generateCodeVerifier()
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to generate code verifier: %w", err)
|
||||
}
|
||||
|
||||
codeChallenge := deriveCodeChallenge(codeVerifier)
|
||||
t.logger.Debugf("PKCE enabled, generated code challenge")
|
||||
|
||||
return codeVerifier, codeChallenge, nil
|
||||
}
|
||||
|
||||
// prepareSessionForAuthentication clears existing session data and sets new authentication state
|
||||
func (t *TraefikOidc) prepareSessionForAuthentication(session *SessionData, csrfToken, nonce, codeVerifier, incomingPath string) {
|
||||
// Clear all existing session data
|
||||
_ = session.SetAuthenticated(false) // Safe to ignore: clearing authentication state on new flow
|
||||
session.SetEmail("")
|
||||
session.SetAccessToken("")
|
||||
session.SetRefreshToken("")
|
||||
session.SetIDToken("")
|
||||
session.SetNonce("")
|
||||
session.SetCodeVerifier("")
|
||||
|
||||
// Set new authentication state
|
||||
session.SetCSRF(csrfToken)
|
||||
session.SetNonce(nonce)
|
||||
if t.enablePKCE && codeVerifier != "" {
|
||||
session.SetCodeVerifier(codeVerifier)
|
||||
}
|
||||
session.SetIncomingPath(incomingPath)
|
||||
t.logger.Debugf("Storing incoming path: %s", incomingPath)
|
||||
}
|
||||
|
||||
// defaultInitiateAuthentication initiates the OIDC authentication flow.
|
||||
// It generates CSRF tokens, nonce, PKCE parameters (if enabled), clears the session,
|
||||
// stores authentication state, and redirects the user to the OIDC provider.
|
||||
// Parameters:
|
||||
// - rw: The HTTP response writer.
|
||||
// - req: The HTTP request initiating authentication.
|
||||
// - session: The session data to prepare for authentication.
|
||||
// - redirectURL: The pre-calculated callback URL (redirect_uri) for this middleware instance.
|
||||
func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
|
||||
t.logger.Debugf("Initiating new OIDC authentication flow for request: %s", req.URL.RequestURI())
|
||||
|
||||
// Check and handle redirect limits
|
||||
if err := t.validateRedirectCount(session, rw, req); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
csrfToken := uuid.NewString()
|
||||
nonce, err := generateNonce()
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to generate nonce: %v", err)
|
||||
http.Error(rw, "Failed to generate nonce", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Generate PKCE parameters if enabled
|
||||
codeVerifier, codeChallenge, err := t.generatePKCEParameters()
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to generate PKCE parameters: %v", err)
|
||||
http.Error(rw, "Failed to generate PKCE parameters", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Clear existing session data and set new authentication state
|
||||
t.prepareSessionForAuthentication(session, csrfToken, nonce, codeVerifier, req.URL.RequestURI())
|
||||
|
||||
session.MarkDirty()
|
||||
|
||||
if err := session.Save(req, rw); err != nil {
|
||||
t.logger.Errorf("Failed to save session before redirecting to provider: %v", err)
|
||||
http.Error(rw, "Failed to save session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
t.logger.Debugf("Session saved before redirect. CSRF: %s, Nonce: %s",
|
||||
csrfToken, nonce)
|
||||
|
||||
authURL := t.buildAuthURL(redirectURL, csrfToken, nonce, codeChallenge)
|
||||
t.logger.Debugf("Redirecting user to OIDC provider: %s", authURL)
|
||||
|
||||
http.Redirect(rw, req, authURL, http.StatusFound)
|
||||
}
|
||||
|
||||
// handleCallback processes the OIDC callback after user authentication.
|
||||
// It validates state/CSRF tokens, exchanges authorization code for tokens,
|
||||
// verifies the received tokens, extracts claims, and establishes the session.
|
||||
// Parameters:
|
||||
// - rw: The HTTP response writer.
|
||||
// - req: The callback request containing authorization code and state.
|
||||
// - redirectURL: The fully qualified callback URL (used in the token exchange request).
|
||||
func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, redirectURL string) {
|
||||
session, err := t.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Session error during callback: %v", err)
|
||||
t.sendErrorResponse(rw, req, "Session error during callback", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer session.returnToPoolSafely()
|
||||
|
||||
t.logger.Debugf("Handling callback, URL: %s", req.URL.String())
|
||||
|
||||
if req.URL.Query().Get("error") != "" {
|
||||
errorDescription := req.URL.Query().Get("error_description")
|
||||
if errorDescription == "" {
|
||||
errorDescription = req.URL.Query().Get("error")
|
||||
}
|
||||
t.logger.Errorf("Authentication error from provider during callback: %s - %s", req.URL.Query().Get("error"), errorDescription)
|
||||
t.sendErrorResponse(rw, req, fmt.Sprintf("Authentication error from provider: %s", errorDescription), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
state := req.URL.Query().Get("state")
|
||||
if state == "" {
|
||||
t.logger.Error("No state in callback")
|
||||
t.sendErrorResponse(rw, req, "State parameter missing in callback", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
csrfToken := session.GetCSRF()
|
||||
if csrfToken == "" {
|
||||
t.logger.Errorf("CSRF token missing in session during callback. Authenticated: %v, Request URL: %s",
|
||||
session.GetAuthenticated(), req.URL.String())
|
||||
|
||||
cookie, err := req.Cookie("_oidc_raczylo_m")
|
||||
if err != nil {
|
||||
t.logger.Errorf("Main session cookie not found in request: %v", err)
|
||||
} else {
|
||||
t.logger.Errorf("Main session cookie exists but CSRF token is empty. Cookie value length: %d", len(cookie.Value))
|
||||
}
|
||||
|
||||
t.sendErrorResponse(rw, req, "CSRF token missing in session", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if state != csrfToken {
|
||||
t.logger.Error("State parameter does not match CSRF token in session during callback")
|
||||
t.sendErrorResponse(rw, req, "Invalid state parameter (CSRF mismatch)", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
code := req.URL.Query().Get("code")
|
||||
if code == "" {
|
||||
t.logger.Error("No code in callback")
|
||||
t.sendErrorResponse(rw, req, "No authorization code received in callback", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
codeVerifier := session.GetCodeVerifier()
|
||||
|
||||
tokenResponse, err := t.tokenExchanger.ExchangeCodeForToken(req.Context(), "authorization_code", code, redirectURL, codeVerifier)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to exchange code for token during callback: %v", err)
|
||||
t.sendErrorResponse(rw, req, "Authentication failed: Could not exchange code for token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if err = t.verifyToken(tokenResponse.IDToken); err != nil {
|
||||
t.logger.Errorf("Failed to verify id_token during callback: %v", err)
|
||||
t.sendErrorResponse(rw, req, "Authentication failed: Could not verify ID token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
claims, err := t.extractClaimsFunc(tokenResponse.IDToken)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to extract claims during callback: %v", err)
|
||||
t.sendErrorResponse(rw, req, "Authentication failed: Could not extract claims from token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
nonceClaim, ok := claims["nonce"].(string)
|
||||
if !ok || nonceClaim == "" {
|
||||
t.logger.Error("Nonce claim missing in id_token during callback")
|
||||
t.sendErrorResponse(rw, req, "Authentication failed: Nonce missing in token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
sessionNonce := session.GetNonce()
|
||||
if sessionNonce == "" {
|
||||
t.logger.Error("Nonce not found in session during callback")
|
||||
t.sendErrorResponse(rw, req, "Authentication failed: Nonce missing in session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if nonceClaim != sessionNonce {
|
||||
t.logger.Error("Nonce claim does not match session nonce during callback")
|
||||
t.sendErrorResponse(rw, req, "Authentication failed: Nonce mismatch", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
email, _ := claims["email"].(string)
|
||||
if email == "" {
|
||||
t.logger.Errorf("Email claim missing or empty in token during callback")
|
||||
t.sendErrorResponse(rw, req, "Authentication failed: Email missing in token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if !t.isAllowedDomain(email) {
|
||||
t.logger.Errorf("Disallowed email domain during callback: %s", email)
|
||||
t.sendErrorResponse(rw, req, "Authentication failed: Email domain not allowed", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
if err := session.SetAuthenticated(true); err != nil {
|
||||
t.logger.Errorf("Failed to set authenticated state and regenerate session ID: %v", err)
|
||||
t.sendErrorResponse(rw, req, "Failed to update session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
session.SetEmail(email)
|
||||
session.SetIDToken(tokenResponse.IDToken)
|
||||
session.SetAccessToken(tokenResponse.AccessToken)
|
||||
session.SetRefreshToken(tokenResponse.RefreshToken)
|
||||
|
||||
session.SetCSRF("")
|
||||
session.SetNonce("")
|
||||
session.SetCodeVerifier("")
|
||||
|
||||
session.ResetRedirectCount()
|
||||
|
||||
redirectPath := "/"
|
||||
if incomingPath := session.GetIncomingPath(); incomingPath != "" && incomingPath != t.redirURLPath {
|
||||
redirectPath = incomingPath
|
||||
}
|
||||
session.SetIncomingPath("")
|
||||
|
||||
if err := session.Save(req, rw); err != nil {
|
||||
t.logger.Errorf("Failed to save session after callback: %v", err)
|
||||
t.sendErrorResponse(rw, req, "Failed to save session after callback", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
t.logger.Debugf("Callback successful, redirecting to %s", redirectPath)
|
||||
http.Redirect(rw, req, redirectPath, http.StatusFound)
|
||||
}
|
||||
|
||||
// handleExpiredToken handles requests with expired or invalid tokens.
|
||||
// It clears the session data and initiates a new authentication flow.
|
||||
// Parameters:
|
||||
// - rw: The HTTP response writer.
|
||||
// - req: The HTTP request with expired token.
|
||||
// - session: The session data to clear.
|
||||
// - redirectURL: The callback URL to be used in the new authentication flow.
|
||||
func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
|
||||
t.logger.Debug("Handling expired token: Clearing session and initiating re-authentication.")
|
||||
_ = session.SetAuthenticated(false) // Safe to ignore: clearing authentication on expired token
|
||||
session.SetIDToken("")
|
||||
session.SetAccessToken("")
|
||||
session.SetRefreshToken("")
|
||||
session.SetEmail("")
|
||||
// Clear CSRF tokens to prevent replay attacks
|
||||
session.SetCSRF("")
|
||||
session.SetNonce("")
|
||||
session.SetCodeVerifier("")
|
||||
// Reset redirect count to prevent loops when handling expired tokens
|
||||
session.ResetRedirectCount()
|
||||
|
||||
if err := session.Save(req, rw); err != nil {
|
||||
t.logger.Errorf("Failed to save cleared session during expired token handling: %v", err)
|
||||
}
|
||||
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
}
|
||||
|
||||
// isUserAuthenticated determines the authentication status and refresh requirements.
|
||||
// It delegates to provider-specific validation methods that handle different token types
|
||||
// and expiration behaviors.
|
||||
// Parameters:
|
||||
// - session: The session data containing authentication tokens.
|
||||
//
|
||||
// Returns:
|
||||
// - authenticated (bool): True if the user has valid tokens.
|
||||
// - needsRefresh (bool): True if tokens are valid but nearing expiration.
|
||||
// - expired (bool): True if the session is unauthenticated, the token is missing,
|
||||
// or the token verification failed for reasons other than nearing/actual expiration.
|
||||
func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, bool) {
|
||||
if t.isAzureProvider() {
|
||||
return t.validateAzureTokens(session)
|
||||
} else if t.isGoogleProvider() {
|
||||
return t.validateGoogleTokens(session)
|
||||
}
|
||||
// Auth0 and other providers can now use standard validation
|
||||
// which handles opaque tokens generically
|
||||
return t.validateStandardTokens(session)
|
||||
}
|
||||
|
||||
// isAjaxRequest determines if this is an AJAX request that should receive 401 instead of redirect
|
||||
func (t *TraefikOidc) isAjaxRequest(req *http.Request) bool {
|
||||
xhr := req.Header.Get("X-Requested-With")
|
||||
contentType := req.Header.Get("Content-Type")
|
||||
accept := req.Header.Get("Accept")
|
||||
|
||||
return xhr == "XMLHttpRequest" ||
|
||||
strings.Contains(contentType, "application/json") ||
|
||||
strings.Contains(accept, "application/json")
|
||||
}
|
||||
|
||||
// isRefreshTokenExpired checks if refresh token is likely expired (older than 6 hours)
|
||||
func (t *TraefikOidc) isRefreshTokenExpired(session *SessionData) bool {
|
||||
// This is a heuristic check - actual implementation would depend on
|
||||
// the specific provider and token metadata
|
||||
return false // Placeholder implementation
|
||||
}
|
||||
@@ -0,0 +1,101 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestGeneratePKCEParameters tests the generatePKCEParameters method
|
||||
func TestGeneratePKCEParameters(t *testing.T) {
|
||||
t.Run("PKCE enabled - successful generation", func(t *testing.T) {
|
||||
// Create a TraefikOidc instance with PKCE enabled
|
||||
plugin := &TraefikOidc{
|
||||
enablePKCE: true,
|
||||
logger: NewLogger("debug"),
|
||||
}
|
||||
|
||||
verifier, challenge, err := plugin.generatePKCEParameters()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, verifier, "code verifier should not be empty when PKCE is enabled")
|
||||
assert.NotEmpty(t, challenge, "code challenge should not be empty when PKCE is enabled")
|
||||
|
||||
// Verify the challenge is derived from the verifier
|
||||
expectedChallenge := deriveCodeChallenge(verifier)
|
||||
assert.Equal(t, expectedChallenge, challenge, "challenge should match derived challenge from verifier")
|
||||
})
|
||||
|
||||
t.Run("PKCE disabled - returns empty strings", func(t *testing.T) {
|
||||
// Create a TraefikOidc instance with PKCE disabled
|
||||
plugin := &TraefikOidc{
|
||||
enablePKCE: false,
|
||||
logger: NewLogger("debug"),
|
||||
}
|
||||
|
||||
verifier, challenge, err := plugin.generatePKCEParameters()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, verifier, "code verifier should be empty when PKCE is disabled")
|
||||
assert.Empty(t, challenge, "code challenge should be empty when PKCE is disabled")
|
||||
})
|
||||
|
||||
t.Run("PKCE enabled - generates different values each time", func(t *testing.T) {
|
||||
plugin := &TraefikOidc{
|
||||
enablePKCE: true,
|
||||
logger: NewLogger("debug"),
|
||||
}
|
||||
|
||||
verifier1, challenge1, err1 := plugin.generatePKCEParameters()
|
||||
require.NoError(t, err1)
|
||||
|
||||
verifier2, challenge2, err2 := plugin.generatePKCEParameters()
|
||||
require.NoError(t, err2)
|
||||
|
||||
assert.NotEqual(t, verifier1, verifier2, "verifiers should be different")
|
||||
assert.NotEqual(t, challenge1, challenge2, "challenges should be different")
|
||||
})
|
||||
|
||||
t.Run("PKCE enabled - verifier and challenge relationship", func(t *testing.T) {
|
||||
plugin := &TraefikOidc{
|
||||
enablePKCE: true,
|
||||
logger: NewLogger("debug"),
|
||||
}
|
||||
|
||||
verifier, challenge, err := plugin.generatePKCEParameters()
|
||||
require.NoError(t, err)
|
||||
|
||||
// The challenge should always be derivable from the verifier
|
||||
recalculatedChallenge := deriveCodeChallenge(verifier)
|
||||
assert.Equal(t, challenge, recalculatedChallenge,
|
||||
"challenge should always match the SHA256 hash of verifier")
|
||||
})
|
||||
|
||||
t.Run("PKCE enabled - verifier meets RFC 7636 requirements", func(t *testing.T) {
|
||||
plugin := &TraefikOidc{
|
||||
enablePKCE: true,
|
||||
logger: NewLogger("debug"),
|
||||
}
|
||||
|
||||
verifier, _, err := plugin.generatePKCEParameters()
|
||||
require.NoError(t, err)
|
||||
|
||||
// RFC 7636 requires verifier to be 43-128 characters
|
||||
assert.GreaterOrEqual(t, len(verifier), 43, "verifier should be at least 43 characters")
|
||||
assert.LessOrEqual(t, len(verifier), 128, "verifier should be at most 128 characters")
|
||||
})
|
||||
|
||||
t.Run("PKCE enabled - challenge meets RFC 7636 requirements", func(t *testing.T) {
|
||||
plugin := &TraefikOidc{
|
||||
enablePKCE: true,
|
||||
logger: NewLogger("debug"),
|
||||
}
|
||||
|
||||
_, challenge, err := plugin.generatePKCEParameters()
|
||||
require.NoError(t, err)
|
||||
|
||||
// SHA256 hash base64 encoded should be 43 characters
|
||||
assert.Equal(t, 43, len(challenge), "S256 challenge should be exactly 43 characters")
|
||||
})
|
||||
}
|
||||
+10
-10
@@ -173,7 +173,7 @@ func (bt *BackgroundTask) run() {
|
||||
|
||||
if bt.logger != nil {
|
||||
if !isTestMode() {
|
||||
bt.logger.Info("Starting background task: %s", bt.name)
|
||||
bt.logger.Debug("Starting background task: %s", bt.name)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -182,7 +182,7 @@ func (bt *BackgroundTask) run() {
|
||||
case <-bt.stopChan:
|
||||
if bt.logger != nil {
|
||||
if !isTestMode() {
|
||||
bt.logger.Info("Stopping background task: %s (before initial execution)", bt.name)
|
||||
bt.logger.Debug("Stopping background task: %s (before initial execution)", bt.name)
|
||||
}
|
||||
}
|
||||
return
|
||||
@@ -201,7 +201,7 @@ func (bt *BackgroundTask) run() {
|
||||
case <-bt.stopChan:
|
||||
if bt.logger != nil {
|
||||
if !isTestMode() {
|
||||
bt.logger.Info("Stopping background task: %s (during periodic execution)", bt.name)
|
||||
bt.logger.Debug("Stopping background task: %s (during periodic execution)", bt.name)
|
||||
}
|
||||
}
|
||||
return
|
||||
@@ -211,7 +211,7 @@ func (bt *BackgroundTask) run() {
|
||||
case <-bt.stopChan:
|
||||
if bt.logger != nil {
|
||||
if !isTestMode() {
|
||||
bt.logger.Info("Stopping background task: %s (direct stop signal)", bt.name)
|
||||
bt.logger.Debug("Stopping background task: %s (direct stop signal)", bt.name)
|
||||
}
|
||||
}
|
||||
return
|
||||
@@ -315,7 +315,7 @@ func (cb *TaskCircuitBreaker) CanCreateTask(taskName string) error {
|
||||
if time.Now().Unix()-lastFailure > int64(cb.timeout.Seconds()) {
|
||||
atomic.StoreInt32(&cb.state, int32(CircuitBreakerHalfOpen))
|
||||
if cb.logger != nil {
|
||||
cb.logger.Info("Circuit breaker transitioning to half-open for task: %s", taskName)
|
||||
cb.logger.Debug("Circuit breaker transitioning to half-open for task: %s", taskName)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -467,7 +467,7 @@ func (tr *TaskRegistry) RegisterTask(name string, task *BackgroundTask) error {
|
||||
tr.cb.OnTaskSuccess(name)
|
||||
|
||||
if tr.logger != nil {
|
||||
tr.logger.Info("Registered background task: %s", name)
|
||||
tr.logger.Debug("Registered background task: %s", name)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -483,7 +483,7 @@ func (tr *TaskRegistry) UnregisterTask(name string) {
|
||||
delete(tr.tasks, name)
|
||||
|
||||
if tr.logger != nil {
|
||||
tr.logger.Info("Unregistered background task: %s", name)
|
||||
tr.logger.Debug("Unregistered background task: %s", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -513,7 +513,7 @@ func (tr *TaskRegistry) StopAllTasks() {
|
||||
for name, task := range tasksCopy {
|
||||
task.Stop()
|
||||
if tr.logger != nil {
|
||||
tr.logger.Info("Stopped background task during shutdown: %s", name)
|
||||
tr.logger.Debug("Stopped background task during shutdown: %s", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -538,7 +538,7 @@ func (tr *TaskRegistry) CreateSingletonTask(name string, interval time.Duration,
|
||||
|
||||
// Start the task if not already running
|
||||
if !rm.IsTaskRunning(name) {
|
||||
rm.StartBackgroundTask(name)
|
||||
_ = rm.StartBackgroundTask(name) // Safe to ignore: task registration succeeded, start is best-effort
|
||||
}
|
||||
|
||||
// Get the task from resource manager's internal registry
|
||||
@@ -641,7 +641,7 @@ func (mm *TaskMemoryMonitor) Start(interval time.Duration) error {
|
||||
mm.started = true
|
||||
|
||||
if mm.logger != nil && !isTestMode() {
|
||||
mm.logger.Info("Started global task memory monitoring with %v interval", interval)
|
||||
mm.logger.Debug("Started global task memory monitoring with %v interval", interval)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -0,0 +1,224 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// globalRegistryMutex protects only the global registry operations
|
||||
var globalRegistryMutex sync.Mutex
|
||||
|
||||
// TestTaskCircuitBreakerOnTaskFailure tests the OnTaskFailure method
|
||||
func TestTaskCircuitBreakerOnTaskFailure(t *testing.T) {
|
||||
logger := NewLogger("debug") // Create a real logger
|
||||
cb := NewTaskCircuitBreaker(3, time.Minute, logger)
|
||||
|
||||
// Test failure doesn't trigger open state before threshold
|
||||
cb.OnTaskFailure("test-task", errors.New("test error"))
|
||||
if err := cb.CanCreateTask("test-task"); err != nil {
|
||||
t.Error("Circuit breaker should allow task creation after 1 failure (threshold: 3)")
|
||||
}
|
||||
|
||||
// Test failure count reaches threshold and opens circuit
|
||||
cb.OnTaskFailure("test-task", errors.New("test error 2"))
|
||||
cb.OnTaskFailure("test-task", errors.New("test error 3"))
|
||||
|
||||
if err := cb.CanCreateTask("test-task"); err == nil {
|
||||
t.Error("Circuit breaker should prevent task creation after reaching failure threshold")
|
||||
}
|
||||
}
|
||||
|
||||
// TestResetGlobalTaskRegistry tests the reset functionality
|
||||
func TestResetGlobalTaskRegistry(t *testing.T) {
|
||||
globalRegistryMutex.Lock()
|
||||
defer globalRegistryMutex.Unlock()
|
||||
|
||||
// Get the global registry first
|
||||
registry := GetGlobalTaskRegistry()
|
||||
|
||||
// Create and register a dummy task
|
||||
logger := NewLogger("debug")
|
||||
task := NewBackgroundTask("test-task", time.Second, func() {
|
||||
// Do nothing
|
||||
}, logger)
|
||||
|
||||
registry.RegisterTask("test-task", task)
|
||||
|
||||
// Verify task is registered
|
||||
if registry.GetTaskCount() == 0 {
|
||||
t.Error("Expected task to be registered")
|
||||
}
|
||||
|
||||
// Reset the registry
|
||||
ResetGlobalTaskRegistry()
|
||||
|
||||
// Get registry again and verify it's empty
|
||||
newRegistry := GetGlobalTaskRegistry()
|
||||
if newRegistry.GetTaskCount() != 0 {
|
||||
t.Error("Expected registry to be empty after reset")
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetTask tests the GetTask method
|
||||
func TestGetTask(t *testing.T) {
|
||||
globalRegistryMutex.Lock()
|
||||
defer globalRegistryMutex.Unlock()
|
||||
|
||||
// Reset registry to ensure clean state
|
||||
ResetGlobalTaskRegistry()
|
||||
registry := GetGlobalTaskRegistry()
|
||||
|
||||
// Test getting non-existent task
|
||||
task, exists := registry.GetTask("non-existent")
|
||||
if task != nil || exists {
|
||||
t.Error("Expected nil and false for non-existent task")
|
||||
}
|
||||
|
||||
// Create and register a task
|
||||
logger := NewLogger("debug")
|
||||
newTask := NewBackgroundTask("test-task", time.Second, func() {
|
||||
// Do nothing
|
||||
}, logger)
|
||||
|
||||
registry.RegisterTask("test-task", newTask)
|
||||
|
||||
// Test getting existing task
|
||||
retrievedTask, exists := registry.GetTask("test-task")
|
||||
if retrievedTask == nil || !exists {
|
||||
t.Error("Expected to retrieve registered task")
|
||||
return
|
||||
}
|
||||
|
||||
if retrievedTask.name != "test-task" {
|
||||
t.Errorf("Expected task name 'test-task', got '%s'", retrievedTask.name)
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewTaskMemoryMonitor tests the NewTaskMemoryMonitor function
|
||||
func TestNewTaskMemoryMonitor(t *testing.T) {
|
||||
// No mutex needed - this doesn't modify global state
|
||||
logger := NewLogger("debug")
|
||||
registry := GetGlobalTaskRegistry()
|
||||
monitor := NewTaskMemoryMonitor(logger, registry)
|
||||
|
||||
if monitor == nil {
|
||||
t.Error("Expected NewTaskMemoryMonitor to return non-nil monitor")
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetCurrentStats tests the GetCurrentStats method
|
||||
func TestGetCurrentStats(t *testing.T) {
|
||||
// Don't hold mutex during background task execution to avoid deadlocks
|
||||
logger := NewLogger("debug")
|
||||
registry := GetGlobalTaskRegistry()
|
||||
monitor := NewTaskMemoryMonitor(logger, registry)
|
||||
|
||||
// Start the monitor and let it collect at least one statistic
|
||||
err := monitor.Start(50 * time.Millisecond)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to start monitor: %v", err)
|
||||
}
|
||||
|
||||
// Ensure monitor is stopped even if test fails
|
||||
defer func() {
|
||||
monitor.Stop()
|
||||
// Give extra time for cleanup
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}()
|
||||
|
||||
// Wait a bit for the monitor to collect stats
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
stats, err := monitor.GetCurrentStats()
|
||||
if err != nil {
|
||||
// If no stats are available yet, that's acceptable for this test
|
||||
t.Logf("No memory statistics available yet: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// TaskMemoryStats is a struct, not a pointer, so it can't be nil
|
||||
if stats.Timestamp.IsZero() {
|
||||
t.Error("Expected GetCurrentStats to return valid timestamp")
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetStatsHistory tests the GetStatsHistory method
|
||||
func TestGetStatsHistory(t *testing.T) {
|
||||
// No mutex needed - this just creates a monitor and checks its initial state
|
||||
logger := NewLogger("debug")
|
||||
registry := GetGlobalTaskRegistry()
|
||||
monitor := NewTaskMemoryMonitor(logger, registry)
|
||||
|
||||
history := monitor.GetStatsHistory()
|
||||
if history == nil {
|
||||
t.Error("Expected GetStatsHistory to return non-nil history")
|
||||
}
|
||||
|
||||
// A fresh monitor should have empty history
|
||||
if len(history) != 0 {
|
||||
t.Logf("History length: %d (may be non-empty due to shared global state)", len(history))
|
||||
}
|
||||
}
|
||||
|
||||
// TestForceGC tests the ForceGC method
|
||||
func TestForceGC(t *testing.T) {
|
||||
// No mutex needed - this doesn't modify global state
|
||||
logger := NewLogger("debug")
|
||||
registry := GetGlobalTaskRegistry()
|
||||
monitor := NewTaskMemoryMonitor(logger, registry)
|
||||
|
||||
// This should not panic and should work
|
||||
monitor.ForceGC()
|
||||
// No specific verification needed, just ensuring it doesn't crash
|
||||
}
|
||||
|
||||
// TestShutdownAllTasks tests the ShutdownAllTasks function
|
||||
func TestShutdownAllTasks(t *testing.T) {
|
||||
// Use a unique task name prefix to avoid conflicts with other tests
|
||||
taskPrefix := "shutdown-test-"
|
||||
|
||||
// Create a temporary clean registry state
|
||||
func() {
|
||||
globalRegistryMutex.Lock()
|
||||
defer globalRegistryMutex.Unlock()
|
||||
ResetGlobalTaskRegistry()
|
||||
}()
|
||||
|
||||
registry := GetGlobalTaskRegistry()
|
||||
logger := NewLogger("debug")
|
||||
|
||||
// Create some test tasks with unique names
|
||||
task1 := NewBackgroundTask(taskPrefix+"task1", time.Millisecond, func() {
|
||||
time.Sleep(100 * time.Millisecond) // Simulate work
|
||||
}, logger)
|
||||
|
||||
task2 := NewBackgroundTask(taskPrefix+"task2", time.Millisecond, func() {
|
||||
time.Sleep(100 * time.Millisecond) // Simulate work
|
||||
}, logger)
|
||||
|
||||
// Register tasks under mutex protection
|
||||
func() {
|
||||
globalRegistryMutex.Lock()
|
||||
defer globalRegistryMutex.Unlock()
|
||||
registry.RegisterTask(taskPrefix+"task1", task1)
|
||||
registry.RegisterTask(taskPrefix+"task2", task2)
|
||||
}()
|
||||
|
||||
// Start the tasks (outside mutex to avoid deadlock)
|
||||
task1.Start()
|
||||
task2.Start()
|
||||
|
||||
// Give tasks time to start
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Shutdown all tasks
|
||||
ShutdownAllTasks()
|
||||
|
||||
// Give shutdown time to complete
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// Note: We can't reliably verify task count due to other tests
|
||||
// Just ensure shutdown doesn't panic
|
||||
}
|
||||
+3
-2
@@ -58,12 +58,13 @@ func TestAzureOIDCRegression(t *testing.T) {
|
||||
tokenURL: "https://login.microsoftonline.com/tenant-id/oauth2/v2.0/token",
|
||||
jwksURL: "https://login.microsoftonline.com/tenant-id/discovery/v2.0/keys",
|
||||
clientID: "test-client-id",
|
||||
audience: "test-client-id",
|
||||
clientSecret: "test-client-secret",
|
||||
scopes: []string{"openid", "profile", "email"},
|
||||
refreshGracePeriod: 60 * time.Second,
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 100), // Add rate limiter
|
||||
logger: mockLogger,
|
||||
httpClient: createDefaultHTTPClient(), // Add HTTP client
|
||||
httpClient: CreateDefaultHTTPClient(), // Add HTTP client
|
||||
jwkCache: &JWKCache{}, // Add JWK cache
|
||||
tokenCache: tokenCache,
|
||||
tokenBlacklist: tokenBlacklist,
|
||||
@@ -78,7 +79,7 @@ func TestAzureOIDCRegression(t *testing.T) {
|
||||
tOidc := &mockTraefikOidc{TraefikOidc: baseOidc}
|
||||
|
||||
// Initialize session manager
|
||||
sessionManager, _ := NewSessionManager("test-encryption-key-32-bytes-long", false, "", mockLogger)
|
||||
sessionManager, _ := NewSessionManager("test-encryption-key-32-bytes-long", false, "", "", 0, mockLogger)
|
||||
tOidc.sessionManager = sessionManager
|
||||
|
||||
// Mock the JWT verification to avoid JWKS lookup issues
|
||||
|
||||
@@ -0,0 +1,536 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// TestMemoryMonitorComprehensive tests memory monitor edge cases
|
||||
func TestMemoryMonitorComprehensive(t *testing.T) {
|
||||
t.Run("TriggerGC calls runtime GC", func(t *testing.T) {
|
||||
thresholds := DefaultMemoryAlertThresholds()
|
||||
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
|
||||
|
||||
// Should not panic
|
||||
assert.NotPanics(t, func() {
|
||||
monitor.TriggerGC()
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("GetMemoryPressure returns pressure level", func(t *testing.T) {
|
||||
thresholds := DefaultMemoryAlertThresholds()
|
||||
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
|
||||
|
||||
// Initially should return None (no stats yet)
|
||||
pressure := monitor.GetMemoryPressure()
|
||||
assert.Equal(t, MemoryPressureNone, pressure)
|
||||
|
||||
// Collect stats to populate lastStats
|
||||
monitor.GetCurrentStats()
|
||||
|
||||
// Now should return a valid pressure level
|
||||
pressure = monitor.GetMemoryPressure()
|
||||
assert.NotNil(t, pressure)
|
||||
})
|
||||
|
||||
t.Run("StartMonitoring can be called", func(t *testing.T) {
|
||||
ResetGlobalMemoryMonitor()
|
||||
ResetGlobalTaskRegistry()
|
||||
defer ResetGlobalMemoryMonitor()
|
||||
defer ResetGlobalTaskRegistry()
|
||||
|
||||
thresholds := DefaultMemoryAlertThresholds()
|
||||
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
|
||||
|
||||
// Start monitoring should not panic
|
||||
assert.NotPanics(t, func() {
|
||||
ctx := context.Background()
|
||||
monitor.StartMonitoring(ctx, 100*time.Millisecond)
|
||||
time.Sleep(GetTestDuration(50 * time.Millisecond))
|
||||
})
|
||||
|
||||
// Clean up
|
||||
monitor.StopMonitoring()
|
||||
})
|
||||
|
||||
t.Run("StopMonitoring can be called safely", func(t *testing.T) {
|
||||
ResetGlobalMemoryMonitor()
|
||||
defer ResetGlobalMemoryMonitor()
|
||||
|
||||
thresholds := DefaultMemoryAlertThresholds()
|
||||
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
|
||||
|
||||
// StopMonitoring should not panic even if not started
|
||||
assert.NotPanics(t, func() {
|
||||
monitor.StopMonitoring()
|
||||
})
|
||||
|
||||
// Can be called multiple times safely
|
||||
assert.NotPanics(t, func() {
|
||||
monitor.StopMonitoring()
|
||||
monitor.StopMonitoring()
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("ResetGlobalMemoryMonitor resets singleton", func(t *testing.T) {
|
||||
ResetGlobalMemoryMonitor()
|
||||
defer ResetGlobalMemoryMonitor()
|
||||
|
||||
// Get initial instance
|
||||
GetGlobalMemoryMonitor()
|
||||
|
||||
// Reset
|
||||
ResetGlobalMemoryMonitor()
|
||||
|
||||
// Should be able to get a new instance
|
||||
monitor := GetGlobalMemoryMonitor()
|
||||
assert.NotNil(t, monitor)
|
||||
|
||||
// Clean up
|
||||
monitor.StopMonitoring()
|
||||
ResetGlobalMemoryMonitor()
|
||||
})
|
||||
|
||||
t.Run("String method returns pressure name", func(t *testing.T) {
|
||||
pressures := []struct {
|
||||
level MemoryPressureLevel
|
||||
name string
|
||||
}{
|
||||
{MemoryPressureNone, "None"},
|
||||
{MemoryPressureLow, "Low"},
|
||||
{MemoryPressureModerate, "Moderate"},
|
||||
{MemoryPressureHigh, "High"},
|
||||
{MemoryPressureCritical, "Critical"},
|
||||
{MemoryPressureLevel(999), "Unknown"},
|
||||
}
|
||||
|
||||
for _, p := range pressures {
|
||||
assert.Equal(t, p.name, p.level.String(), "pressure level %d should return %s", p.level, p.name)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetCurrentStats collects statistics", func(t *testing.T) {
|
||||
thresholds := DefaultMemoryAlertThresholds()
|
||||
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
|
||||
|
||||
stats := monitor.GetCurrentStats()
|
||||
assert.NotNil(t, stats)
|
||||
assert.Greater(t, stats.HeapAllocBytes, uint64(0))
|
||||
assert.Greater(t, stats.NumGoroutines, 0)
|
||||
assert.NotZero(t, stats.Timestamp)
|
||||
})
|
||||
}
|
||||
|
||||
// TestBackgroundTaskRegistry tests background task registry edge cases
|
||||
func TestBackgroundTaskRegistry(t *testing.T) {
|
||||
t.Run("GetGlobalTaskRegistry returns singleton", func(t *testing.T) {
|
||||
registry1 := GetGlobalTaskRegistry()
|
||||
registry2 := GetGlobalTaskRegistry()
|
||||
|
||||
assert.Equal(t, registry1, registry2, "should return same instance")
|
||||
})
|
||||
|
||||
t.Run("RegisterTask adds task to registry", func(t *testing.T) {
|
||||
ResetGlobalTaskRegistry()
|
||||
registry := GetGlobalTaskRegistry()
|
||||
defer ResetGlobalTaskRegistry()
|
||||
|
||||
taskName := "test-register-task"
|
||||
task := NewBackgroundTask(
|
||||
taskName,
|
||||
100*time.Millisecond,
|
||||
func() {},
|
||||
newNoOpLogger(),
|
||||
)
|
||||
|
||||
err := registry.RegisterTask(taskName, task)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify task was registered
|
||||
_, exists := registry.GetTask(taskName)
|
||||
assert.True(t, exists, "task should be registered")
|
||||
|
||||
// Clean up
|
||||
task.Stop()
|
||||
})
|
||||
|
||||
t.Run("CreateSingletonTask is idempotent", func(t *testing.T) {
|
||||
ResetGlobalTaskRegistry()
|
||||
registry := GetGlobalTaskRegistry()
|
||||
defer ResetGlobalTaskRegistry()
|
||||
|
||||
taskName := "test-singleton-idempotent"
|
||||
callCount := 0
|
||||
var mu sync.Mutex
|
||||
|
||||
taskFunc := func() {
|
||||
mu.Lock()
|
||||
callCount++
|
||||
mu.Unlock()
|
||||
}
|
||||
|
||||
// First creation should succeed
|
||||
task1, err1 := registry.CreateSingletonTask(
|
||||
taskName,
|
||||
100*time.Millisecond,
|
||||
taskFunc,
|
||||
newNoOpLogger(),
|
||||
nil,
|
||||
)
|
||||
|
||||
assert.NoError(t, err1)
|
||||
assert.NotNil(t, task1)
|
||||
|
||||
// Second creation should also succeed (idempotent)
|
||||
// Returns same task without error
|
||||
task2, err2 := registry.CreateSingletonTask(
|
||||
taskName,
|
||||
100*time.Millisecond,
|
||||
taskFunc,
|
||||
newNoOpLogger(),
|
||||
nil,
|
||||
)
|
||||
|
||||
assert.NoError(t, err2, "CreateSingletonTask should be idempotent")
|
||||
assert.NotNil(t, task2)
|
||||
|
||||
// Clean up
|
||||
if task1 != nil {
|
||||
task1.Stop()
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetTaskCount returns active task count", func(t *testing.T) {
|
||||
ResetGlobalTaskRegistry()
|
||||
registry := GetGlobalTaskRegistry()
|
||||
defer ResetGlobalTaskRegistry()
|
||||
|
||||
// Initially should be 0 or small number
|
||||
initialCount := registry.GetTaskCount()
|
||||
|
||||
// Create a task
|
||||
task := NewBackgroundTask(
|
||||
"count-test-task",
|
||||
100*time.Millisecond,
|
||||
func() {},
|
||||
newNoOpLogger(),
|
||||
)
|
||||
|
||||
err := registry.RegisterTask("count-test-task", task)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Count should increase
|
||||
newCount := registry.GetTaskCount()
|
||||
assert.Equal(t, initialCount+1, newCount)
|
||||
|
||||
// Clean up
|
||||
task.Stop()
|
||||
})
|
||||
|
||||
t.Run("StopAllTasks stops all tasks", func(t *testing.T) {
|
||||
ResetGlobalTaskRegistry()
|
||||
registry := GetGlobalTaskRegistry()
|
||||
defer ResetGlobalTaskRegistry()
|
||||
|
||||
// Create multiple tasks
|
||||
for i := 0; i < 3; i++ {
|
||||
taskName := "multi-task-" + string(rune(i+'0'))
|
||||
task := NewBackgroundTask(
|
||||
taskName,
|
||||
100*time.Millisecond,
|
||||
func() {},
|
||||
newNoOpLogger(),
|
||||
)
|
||||
registry.RegisterTask(taskName, task)
|
||||
}
|
||||
|
||||
// Verify tasks were created
|
||||
assert.GreaterOrEqual(t, registry.GetTaskCount(), 3)
|
||||
|
||||
// Stop all tasks
|
||||
registry.StopAllTasks()
|
||||
|
||||
// Verify all tasks are removed
|
||||
taskCount := registry.GetTaskCount()
|
||||
assert.Equal(t, 0, taskCount, "all tasks should be stopped")
|
||||
})
|
||||
|
||||
t.Run("ResetGlobalTaskRegistry clears registry", func(t *testing.T) {
|
||||
ResetGlobalTaskRegistry()
|
||||
registry := GetGlobalTaskRegistry()
|
||||
|
||||
// Create a task
|
||||
task := NewBackgroundTask(
|
||||
"reset-test-task",
|
||||
100*time.Millisecond,
|
||||
func() {},
|
||||
newNoOpLogger(),
|
||||
)
|
||||
registry.RegisterTask("reset-test-task", task)
|
||||
|
||||
// Reset
|
||||
ResetGlobalTaskRegistry()
|
||||
|
||||
// Get new registry
|
||||
newRegistry := GetGlobalTaskRegistry()
|
||||
assert.Equal(t, 0, newRegistry.GetTaskCount(), "new registry should be empty")
|
||||
})
|
||||
}
|
||||
|
||||
// TestBackgroundTaskLifecycle tests background task lifecycle
|
||||
func TestBackgroundTaskLifecycle(t *testing.T) {
|
||||
t.Run("Start begins task execution", func(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping background task test in short mode")
|
||||
}
|
||||
|
||||
ResetGlobalTaskRegistry()
|
||||
defer ResetGlobalTaskRegistry()
|
||||
|
||||
executed := false
|
||||
var mu sync.Mutex
|
||||
|
||||
task := NewBackgroundTask(
|
||||
"lifecycle-test",
|
||||
50*time.Millisecond,
|
||||
func() {
|
||||
mu.Lock()
|
||||
executed = true
|
||||
mu.Unlock()
|
||||
},
|
||||
newNoOpLogger(),
|
||||
)
|
||||
|
||||
// Start task
|
||||
task.Start()
|
||||
|
||||
// Wait for execution
|
||||
time.Sleep(GetTestDuration(100 * time.Millisecond))
|
||||
|
||||
// Stop task
|
||||
task.Stop()
|
||||
|
||||
// Verify it executed
|
||||
mu.Lock()
|
||||
wasExecuted := executed
|
||||
mu.Unlock()
|
||||
|
||||
assert.True(t, wasExecuted, "task should have executed")
|
||||
})
|
||||
|
||||
t.Run("Stop halts task execution", func(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping background task test in short mode")
|
||||
}
|
||||
|
||||
ResetGlobalTaskRegistry()
|
||||
defer ResetGlobalTaskRegistry()
|
||||
|
||||
execCount := 0
|
||||
var mu sync.Mutex
|
||||
|
||||
task := NewBackgroundTask(
|
||||
"stop-test",
|
||||
30*time.Millisecond,
|
||||
func() {
|
||||
mu.Lock()
|
||||
execCount++
|
||||
mu.Unlock()
|
||||
},
|
||||
newNoOpLogger(),
|
||||
)
|
||||
|
||||
// Start task
|
||||
task.Start()
|
||||
|
||||
// Let it run a few times
|
||||
time.Sleep(GetTestDuration(100 * time.Millisecond))
|
||||
|
||||
// Stop task
|
||||
task.Stop()
|
||||
|
||||
// Record count
|
||||
mu.Lock()
|
||||
countAfterStop := execCount
|
||||
mu.Unlock()
|
||||
|
||||
// Wait more
|
||||
time.Sleep(GetTestDuration(100 * time.Millisecond))
|
||||
|
||||
// Count should not increase
|
||||
mu.Lock()
|
||||
finalCount := execCount
|
||||
mu.Unlock()
|
||||
|
||||
assert.Equal(t, countAfterStop, finalCount, "task should not execute after stop")
|
||||
})
|
||||
|
||||
t.Run("Multiple Start calls are safe", func(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping background task test in short mode")
|
||||
}
|
||||
|
||||
ResetGlobalTaskRegistry()
|
||||
defer ResetGlobalTaskRegistry()
|
||||
|
||||
execCount := 0
|
||||
var mu sync.Mutex
|
||||
|
||||
task := NewBackgroundTask(
|
||||
"multi-start-test",
|
||||
100*time.Millisecond,
|
||||
func() {
|
||||
mu.Lock()
|
||||
execCount++
|
||||
mu.Unlock()
|
||||
},
|
||||
newNoOpLogger(),
|
||||
)
|
||||
|
||||
// Multiple starts should be safe
|
||||
task.Start()
|
||||
task.Start()
|
||||
task.Start()
|
||||
|
||||
// Wait a bit
|
||||
time.Sleep(GetTestDuration(50 * time.Millisecond))
|
||||
|
||||
// Stop task
|
||||
task.Stop()
|
||||
|
||||
// Should have executed, but only one goroutine
|
||||
mu.Lock()
|
||||
count := execCount
|
||||
mu.Unlock()
|
||||
|
||||
assert.GreaterOrEqual(t, count, 0, "task should have executed at least once")
|
||||
})
|
||||
|
||||
t.Run("Multiple Stop calls are safe", func(t *testing.T) {
|
||||
ResetGlobalTaskRegistry()
|
||||
defer ResetGlobalTaskRegistry()
|
||||
|
||||
task := NewBackgroundTask(
|
||||
"multi-stop-test",
|
||||
100*time.Millisecond,
|
||||
func() {},
|
||||
newNoOpLogger(),
|
||||
)
|
||||
|
||||
// Start and stop
|
||||
task.Start()
|
||||
time.Sleep(GetTestDuration(20 * time.Millisecond))
|
||||
|
||||
// Multiple stops should be safe
|
||||
assert.NotPanics(t, func() {
|
||||
task.Stop()
|
||||
task.Stop()
|
||||
task.Stop()
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// TestMemoryMonitorIntegration tests memory monitor integration
|
||||
func TestMemoryMonitorIntegration(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping memory monitor integration test in short mode")
|
||||
}
|
||||
|
||||
t.Run("monitoring updates stats", func(t *testing.T) {
|
||||
ResetGlobalMemoryMonitor()
|
||||
ResetGlobalTaskRegistry()
|
||||
defer ResetGlobalMemoryMonitor()
|
||||
defer ResetGlobalTaskRegistry()
|
||||
|
||||
thresholds := DefaultMemoryAlertThresholds()
|
||||
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
|
||||
defer monitor.StopMonitoring()
|
||||
|
||||
// Start monitoring
|
||||
ctx := context.Background()
|
||||
monitor.StartMonitoring(ctx, 50*time.Millisecond)
|
||||
|
||||
// Wait for at least one check
|
||||
time.Sleep(GetTestDuration(150 * time.Millisecond))
|
||||
|
||||
// Get pressure (should be a valid pressure level)
|
||||
pressure := monitor.GetMemoryPressure()
|
||||
assert.Contains(t, []MemoryPressureLevel{
|
||||
MemoryPressureNone,
|
||||
MemoryPressureLow,
|
||||
MemoryPressureModerate,
|
||||
MemoryPressureHigh,
|
||||
MemoryPressureCritical,
|
||||
}, pressure, "pressure should be a valid level")
|
||||
|
||||
// Stop monitoring
|
||||
monitor.StopMonitoring()
|
||||
})
|
||||
|
||||
t.Run("global memory monitor singleton", func(t *testing.T) {
|
||||
ResetGlobalMemoryMonitor()
|
||||
defer ResetGlobalMemoryMonitor()
|
||||
|
||||
monitor1 := GetGlobalMemoryMonitor()
|
||||
monitor2 := GetGlobalMemoryMonitor()
|
||||
|
||||
assert.Equal(t, monitor1, monitor2, "should return same instance")
|
||||
})
|
||||
}
|
||||
|
||||
// TestMemoryStatsCollection tests memory statistics collection
|
||||
func TestMemoryStatsCollection(t *testing.T) {
|
||||
t.Run("GetCurrentStats returns valid data", func(t *testing.T) {
|
||||
thresholds := DefaultMemoryAlertThresholds()
|
||||
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
|
||||
|
||||
stats := monitor.GetCurrentStats()
|
||||
|
||||
assert.NotNil(t, stats)
|
||||
assert.Greater(t, stats.HeapAllocBytes, uint64(0))
|
||||
assert.Greater(t, stats.HeapSysBytes, uint64(0))
|
||||
assert.Greater(t, stats.NumGoroutines, 0)
|
||||
assert.False(t, stats.Timestamp.IsZero())
|
||||
})
|
||||
|
||||
t.Run("Stats include memory pressure", func(t *testing.T) {
|
||||
thresholds := DefaultMemoryAlertThresholds()
|
||||
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
|
||||
|
||||
stats := monitor.GetCurrentStats()
|
||||
|
||||
// Should calculate and include pressure level
|
||||
assert.NotNil(t, stats.MemoryPressure)
|
||||
assert.Contains(t, []MemoryPressureLevel{
|
||||
MemoryPressureNone,
|
||||
MemoryPressureLow,
|
||||
MemoryPressureModerate,
|
||||
MemoryPressureHigh,
|
||||
MemoryPressureCritical,
|
||||
}, stats.MemoryPressure)
|
||||
})
|
||||
|
||||
t.Run("TriggerGC reduces memory", func(t *testing.T) {
|
||||
thresholds := DefaultMemoryAlertThresholds()
|
||||
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
|
||||
|
||||
// Allocate some memory
|
||||
_ = make([]byte, 1024*1024) // 1MB
|
||||
|
||||
// Get stats before GC
|
||||
beforeStats := monitor.GetCurrentStats()
|
||||
|
||||
// Trigger GC
|
||||
monitor.TriggerGC()
|
||||
|
||||
// Get stats after GC
|
||||
afterStats := monitor.GetCurrentStats()
|
||||
|
||||
// After GC should have different stats
|
||||
assert.NotEqual(t, beforeStats.LastGCTime, afterStats.LastGCTime)
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,369 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestNewBoundedCache tests creation of bounded cache
|
||||
func TestNewBoundedCache(t *testing.T) {
|
||||
maxSize := 500
|
||||
cache := NewBoundedCache(maxSize)
|
||||
|
||||
if cache == nil {
|
||||
t.Fatal("Expected cache to be created, got nil")
|
||||
}
|
||||
|
||||
// Verify we can use basic operations
|
||||
cache.Set("test-key", "test-value", time.Hour)
|
||||
value, found := cache.Get("test-key")
|
||||
if !found {
|
||||
t.Error("Expected key to be found in cache")
|
||||
}
|
||||
if value != "test-value" {
|
||||
t.Errorf("Expected 'test-value', got %v", value)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDefaultUnifiedCacheConfig tests default configuration
|
||||
func TestDefaultUnifiedCacheConfig(t *testing.T) {
|
||||
config := DefaultUnifiedCacheConfig()
|
||||
|
||||
if config.Type != CacheTypeGeneral {
|
||||
t.Errorf("Expected CacheTypeGeneral, got %v", config.Type)
|
||||
}
|
||||
|
||||
if config.MaxSize != 500 {
|
||||
t.Errorf("Expected MaxSize 500, got %d", config.MaxSize)
|
||||
}
|
||||
|
||||
if config.MaxMemoryBytes != 64*1024*1024 {
|
||||
t.Errorf("Expected MaxMemoryBytes 64MB, got %d", config.MaxMemoryBytes)
|
||||
}
|
||||
|
||||
if config.CleanupInterval != 2*time.Minute {
|
||||
t.Errorf("Expected CleanupInterval 2 minutes, got %v", config.CleanupInterval)
|
||||
}
|
||||
|
||||
if config.Logger == nil {
|
||||
t.Error("Expected Logger to be set")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewUnifiedCache tests unified cache creation
|
||||
func TestNewUnifiedCache(t *testing.T) {
|
||||
config := DefaultUnifiedCacheConfig()
|
||||
cache := NewUnifiedCache(config)
|
||||
|
||||
if cache == nil {
|
||||
t.Fatal("Expected cache to be created, got nil")
|
||||
}
|
||||
|
||||
if cache.UniversalCache == nil {
|
||||
t.Error("Expected UniversalCache to be set")
|
||||
}
|
||||
|
||||
// Test basic operations
|
||||
cache.Set("test-key", "test-value", time.Hour)
|
||||
value, found := cache.Get("test-key")
|
||||
if !found {
|
||||
t.Error("Expected key to be found in cache")
|
||||
}
|
||||
if value != "test-value" {
|
||||
t.Errorf("Expected 'test-value', got %v", value)
|
||||
}
|
||||
}
|
||||
|
||||
// TestUnifiedCache_SetMaxSize tests SetMaxSize method
|
||||
func TestUnifiedCache_SetMaxSize(t *testing.T) {
|
||||
config := DefaultUnifiedCacheConfig()
|
||||
cache := NewUnifiedCache(config)
|
||||
|
||||
// Test setting max size
|
||||
newSize := 1000
|
||||
cache.SetMaxSize(newSize)
|
||||
|
||||
// We can't easily verify the size was set without exposing internal fields,
|
||||
// but we can ensure the method doesn't panic
|
||||
}
|
||||
|
||||
// TestNewCacheAdapter tests cache adapter creation
|
||||
func TestNewCacheAdapter(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cache interface{}
|
||||
expectNil bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "UniversalCache",
|
||||
cache: NewUniversalCache(DefaultUnifiedCacheConfig()),
|
||||
expectNil: false,
|
||||
description: "Should create adapter for UniversalCache",
|
||||
},
|
||||
{
|
||||
name: "UnifiedCache",
|
||||
cache: NewUnifiedCache(DefaultUnifiedCacheConfig()),
|
||||
expectNil: false,
|
||||
description: "Should create adapter for UnifiedCache",
|
||||
},
|
||||
{
|
||||
name: "Invalid cache type",
|
||||
cache: "not-a-cache",
|
||||
expectNil: true,
|
||||
description: "Should return nil for invalid cache type",
|
||||
},
|
||||
{
|
||||
name: "Nil cache",
|
||||
cache: nil,
|
||||
expectNil: true,
|
||||
description: "Should return nil for nil cache",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
adapter := NewCacheAdapter(tt.cache)
|
||||
|
||||
if tt.expectNil {
|
||||
if adapter != nil {
|
||||
t.Errorf("Expected nil adapter, got %v", adapter)
|
||||
}
|
||||
} else {
|
||||
if adapter == nil {
|
||||
t.Error("Expected non-nil adapter")
|
||||
}
|
||||
// Test basic operations
|
||||
adapter.Set("test", "value", time.Hour)
|
||||
value, found := adapter.Get("test")
|
||||
if !found {
|
||||
t.Error("Expected key to be found")
|
||||
}
|
||||
if value != "value" {
|
||||
t.Errorf("Expected 'value', got %v", value)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewOptimizedCache tests optimized cache creation
|
||||
func TestNewOptimizedCache(t *testing.T) {
|
||||
cache := NewOptimizedCache()
|
||||
|
||||
if cache == nil {
|
||||
t.Fatal("Expected cache to be created, got nil")
|
||||
}
|
||||
|
||||
// Verify it works with basic operations
|
||||
cache.Set("test-key", "test-value", time.Hour)
|
||||
value, found := cache.Get("test-key")
|
||||
if !found {
|
||||
t.Error("Expected key to be found in cache")
|
||||
}
|
||||
if value != "test-value" {
|
||||
t.Errorf("Expected 'test-value', got %v", value)
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewLRUStrategy tests LRU strategy creation
|
||||
func TestNewLRUStrategy(t *testing.T) {
|
||||
maxSize := 100
|
||||
strategy := NewLRUStrategy(maxSize)
|
||||
|
||||
if strategy == nil {
|
||||
t.Fatal("Expected strategy to be created, got nil")
|
||||
}
|
||||
|
||||
lruStrategy, ok := strategy.(*LRUStrategy)
|
||||
if !ok {
|
||||
t.Fatal("Expected LRUStrategy type")
|
||||
}
|
||||
|
||||
if lruStrategy.maxSize != maxSize {
|
||||
t.Errorf("Expected maxSize %d, got %d", maxSize, lruStrategy.maxSize)
|
||||
}
|
||||
|
||||
if lruStrategy.order == nil {
|
||||
t.Error("Expected order list to be initialized")
|
||||
}
|
||||
|
||||
if lruStrategy.elements == nil {
|
||||
t.Error("Expected elements map to be initialized")
|
||||
}
|
||||
}
|
||||
|
||||
// TestLRUStrategy_Name tests strategy name
|
||||
func TestLRUStrategy_Name(t *testing.T) {
|
||||
strategy := NewLRUStrategy(100)
|
||||
|
||||
name := strategy.Name()
|
||||
if name != "LRU" {
|
||||
t.Errorf("Expected 'LRU', got %s", name)
|
||||
}
|
||||
}
|
||||
|
||||
// TestLRUStrategy_ShouldEvict tests eviction logic
|
||||
func TestLRUStrategy_ShouldEvict(t *testing.T) {
|
||||
strategy := NewLRUStrategy(100)
|
||||
|
||||
// LRU strategy always returns false for ShouldEvict
|
||||
result := strategy.ShouldEvict("test-item", time.Now())
|
||||
if result != false {
|
||||
t.Error("Expected ShouldEvict to return false")
|
||||
}
|
||||
}
|
||||
|
||||
// TestLRUStrategy_OnAccess tests access callback
|
||||
func TestLRUStrategy_OnAccess(t *testing.T) {
|
||||
strategy := NewLRUStrategy(100)
|
||||
|
||||
// OnAccess should not panic
|
||||
strategy.OnAccess("test-key", "test-value")
|
||||
}
|
||||
|
||||
// TestLRUStrategy_OnRemove tests removal callback
|
||||
func TestLRUStrategy_OnRemove(t *testing.T) {
|
||||
strategy := NewLRUStrategy(100)
|
||||
|
||||
// OnRemove should not panic
|
||||
strategy.OnRemove("test-key")
|
||||
}
|
||||
|
||||
// TestLRUStrategy_EstimateSize tests size estimation
|
||||
func TestLRUStrategy_EstimateSize(t *testing.T) {
|
||||
strategy := NewLRUStrategy(100)
|
||||
|
||||
size := strategy.EstimateSize("test-item")
|
||||
if size != 64 {
|
||||
t.Errorf("Expected size 64, got %d", size)
|
||||
}
|
||||
}
|
||||
|
||||
// TestLRUStrategy_GetEvictionCandidate tests eviction candidate retrieval
|
||||
func TestLRUStrategy_GetEvictionCandidate(t *testing.T) {
|
||||
strategy := NewLRUStrategy(100)
|
||||
|
||||
key, found := strategy.GetEvictionCandidate()
|
||||
if found {
|
||||
t.Error("Expected no eviction candidate to be found")
|
||||
}
|
||||
if key != "" {
|
||||
t.Errorf("Expected empty key, got %s", key)
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewOptimizedCacheWithConfig tests optimized cache with custom config
|
||||
func TestNewOptimizedCacheWithConfig(t *testing.T) {
|
||||
config := UniversalCacheConfig{
|
||||
Type: CacheTypeGeneral,
|
||||
MaxSize: 1000,
|
||||
MaxMemoryBytes: 128 * 1024 * 1024,
|
||||
EnableMetrics: true,
|
||||
Logger: GetSingletonNoOpLogger(),
|
||||
}
|
||||
|
||||
cache := NewOptimizedCacheWithConfig(config)
|
||||
|
||||
if cache == nil {
|
||||
t.Fatal("Expected cache to be created, got nil")
|
||||
}
|
||||
|
||||
// Verify it works with basic operations
|
||||
cache.Set("test-key", "test-value", time.Hour)
|
||||
value, found := cache.Get("test-key")
|
||||
if !found {
|
||||
t.Error("Expected key to be found in cache")
|
||||
}
|
||||
if value != "test-value" {
|
||||
t.Errorf("Expected 'test-value', got %v", value)
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewFixedMetadataCache tests fixed metadata cache creation
|
||||
func TestNewFixedMetadataCache(t *testing.T) {
|
||||
cache := NewFixedMetadataCache()
|
||||
|
||||
if cache == nil {
|
||||
t.Fatal("Expected cache to be created, got nil")
|
||||
}
|
||||
|
||||
// Verify it works with proper metadata operations
|
||||
metadata := &ProviderMetadata{
|
||||
Issuer: "https://example.com",
|
||||
AuthURL: "https://example.com/auth",
|
||||
TokenURL: "https://example.com/token",
|
||||
JWKSURL: "https://example.com/jwks",
|
||||
}
|
||||
|
||||
err := cache.Set("test-provider", metadata, time.Hour)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error setting metadata: %v", err)
|
||||
}
|
||||
|
||||
// Test that the cache was created (basic verification)
|
||||
// Note: We can't easily test Get without more complex setup
|
||||
}
|
||||
|
||||
// TestNewDoublyLinkedList tests doubly linked list creation
|
||||
func TestNewDoublyLinkedList(t *testing.T) {
|
||||
list := NewDoublyLinkedList()
|
||||
|
||||
if list == nil {
|
||||
t.Fatal("Expected list to be created, got nil")
|
||||
}
|
||||
|
||||
// Test it's a proper list structure
|
||||
if list.Len() != 0 {
|
||||
t.Error("Expected empty list initially")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDoublyLinkedList_PopFront tests front element removal
|
||||
func TestDoublyLinkedList_PopFront(t *testing.T) {
|
||||
list := NewDoublyLinkedList()
|
||||
|
||||
// Test popping from empty list
|
||||
element := list.PopFront()
|
||||
if element != nil {
|
||||
t.Error("Expected nil when popping from empty list")
|
||||
}
|
||||
|
||||
// Add an element and test popping
|
||||
added := list.PushBack("test-value")
|
||||
if added == nil {
|
||||
t.Fatal("Expected element to be added")
|
||||
}
|
||||
|
||||
popped := list.PopFront()
|
||||
if popped == nil {
|
||||
t.Error("Expected element to be popped")
|
||||
}
|
||||
|
||||
if list.Len() != 0 {
|
||||
t.Error("Expected list to be empty after popping")
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark tests for performance
|
||||
func BenchmarkNewBoundedCache(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
NewBoundedCache(1000)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkNewOptimizedCache(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
NewOptimizedCache()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkLRUStrategy_EstimateSize(b *testing.B) {
|
||||
strategy := NewLRUStrategy(1000)
|
||||
item := "test-item"
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
strategy.EstimateSize(item)
|
||||
}
|
||||
}
|
||||
+46
-3
@@ -21,10 +21,37 @@ var (
|
||||
)
|
||||
|
||||
// GetGlobalCacheManager returns a singleton CacheManager instance
|
||||
// Deprecated: Use GetGlobalCacheManagerWithConfig instead
|
||||
func GetGlobalCacheManager(wg *sync.WaitGroup) *CacheManager {
|
||||
return GetGlobalCacheManagerWithConfig(wg, nil)
|
||||
}
|
||||
|
||||
// GetGlobalCacheManagerWithConfig returns a singleton CacheManager instance with optional Redis configuration
|
||||
func GetGlobalCacheManagerWithConfig(wg *sync.WaitGroup, config *Config) *CacheManager {
|
||||
cacheManagerInitOnce.Do(func() {
|
||||
var redisConfig *RedisConfig
|
||||
var logger *Logger
|
||||
|
||||
if config != nil {
|
||||
logger = NewLogger(config.LogLevel)
|
||||
|
||||
// Initialize Redis config if not present
|
||||
if config.Redis == nil {
|
||||
config.Redis = &RedisConfig{}
|
||||
}
|
||||
|
||||
// Apply environment variable fallbacks for fields not set in config
|
||||
// This allows env vars to be used as optional overrides
|
||||
config.Redis.ApplyEnvFallbacks()
|
||||
|
||||
// Apply defaults after env fallbacks
|
||||
config.Redis.ApplyDefaults()
|
||||
|
||||
redisConfig = config.Redis
|
||||
}
|
||||
|
||||
globalCacheManagerInstance = &CacheManager{
|
||||
manager: GetUniversalCacheManager(nil),
|
||||
manager: GetUniversalCacheManagerWithConfig(logger, redisConfig),
|
||||
}
|
||||
})
|
||||
return globalCacheManagerInstance
|
||||
@@ -61,6 +88,22 @@ func (cm *CacheManager) GetSharedJWKCache() JWKCacheInterface {
|
||||
return &JWKCache{cache: cm.manager.GetJWKCache()}
|
||||
}
|
||||
|
||||
// GetSharedIntrospectionCache returns the shared token introspection cache
|
||||
// for caching OAuth 2.0 Token Introspection (RFC 7662) results
|
||||
func (cm *CacheManager) GetSharedIntrospectionCache() CacheInterface {
|
||||
cm.mu.RLock()
|
||||
defer cm.mu.RUnlock()
|
||||
return &CacheInterfaceWrapper{cache: cm.manager.GetIntrospectionCache()}
|
||||
}
|
||||
|
||||
// GetSharedTokenTypeCache returns the shared token type cache
|
||||
// for caching token type detection results to improve performance
|
||||
func (cm *CacheManager) GetSharedTokenTypeCache() CacheInterface {
|
||||
cm.mu.RLock()
|
||||
defer cm.mu.RUnlock()
|
||||
return &CacheInterfaceWrapper{cache: cm.manager.GetTokenTypeCache()}
|
||||
}
|
||||
|
||||
// Close gracefully shuts down all cache components
|
||||
func (cm *CacheManager) Close() error {
|
||||
cm.mu.Lock()
|
||||
@@ -83,7 +126,7 @@ type CacheInterfaceWrapper struct {
|
||||
|
||||
// Set stores a value
|
||||
func (c *CacheInterfaceWrapper) Set(key string, value interface{}, ttl time.Duration) {
|
||||
c.cache.Set(key, value, ttl)
|
||||
_ = c.cache.Set(key, value, ttl) // Safe to ignore: cache set failures are non-critical
|
||||
}
|
||||
|
||||
// Get retrieves a value
|
||||
@@ -110,7 +153,7 @@ func (c *CacheInterfaceWrapper) Cleanup() {
|
||||
func (c *CacheInterfaceWrapper) Close() {
|
||||
// Close the underlying cache to stop goroutines
|
||||
if c.cache != nil {
|
||||
c.cache.Close()
|
||||
_ = c.cache.Close() // Safe to ignore: closing cache is best-effort during shutdown
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,314 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Helper function to ensure we have a working cache manager for tests
|
||||
func getTestCacheManager(t *testing.T) *CacheManager {
|
||||
cm := GetGlobalCacheManager(&sync.WaitGroup{})
|
||||
if cm == nil {
|
||||
t.Fatal("Failed to get cache manager")
|
||||
}
|
||||
if cm.manager == nil {
|
||||
t.Fatal("Cache manager has nil internal manager")
|
||||
}
|
||||
return cm
|
||||
}
|
||||
|
||||
// TestCacheManager_Close tests cache manager close functionality
|
||||
func TestCacheManager_Close(t *testing.T) {
|
||||
// Get a fresh cache manager
|
||||
wg := &sync.WaitGroup{}
|
||||
cm := GetGlobalCacheManager(wg)
|
||||
|
||||
if cm == nil {
|
||||
t.Fatal("Expected cache manager to be created")
|
||||
}
|
||||
|
||||
// Test closing the cache manager
|
||||
err := cm.Close()
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error closing cache manager: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCleanupGlobalCacheManager tests global cleanup
|
||||
func TestCleanupGlobalCacheManager(t *testing.T) {
|
||||
// Test cleanup when no instance exists (should not error)
|
||||
originalInstance := globalCacheManagerInstance
|
||||
globalCacheManagerInstance = nil
|
||||
err := CleanupGlobalCacheManager()
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error during cleanup of nil instance: %v", err)
|
||||
}
|
||||
|
||||
// Restore original instance
|
||||
globalCacheManagerInstance = originalInstance
|
||||
}
|
||||
|
||||
// TestCacheInterfaceWrapper_Delete tests delete functionality
|
||||
func TestCacheInterfaceWrapper_Delete(t *testing.T) {
|
||||
cm := getTestCacheManager(t)
|
||||
cache := cm.GetSharedTokenBlacklist()
|
||||
|
||||
// Add an item
|
||||
cache.Set("test-key", "test-value", time.Hour)
|
||||
|
||||
// Verify it exists
|
||||
value, found := cache.Get("test-key")
|
||||
if !found {
|
||||
t.Fatal("Expected key to be found after setting")
|
||||
}
|
||||
if value != "test-value" {
|
||||
t.Errorf("Expected 'test-value', got %v", value)
|
||||
}
|
||||
|
||||
// Delete it
|
||||
cache.Delete("test-key")
|
||||
|
||||
// Verify it's gone
|
||||
_, found = cache.Get("test-key")
|
||||
if found {
|
||||
t.Error("Expected key to be deleted")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCacheInterfaceWrapper_Size tests size functionality
|
||||
func TestCacheInterfaceWrapper_Size(t *testing.T) {
|
||||
cm := getTestCacheManager(t)
|
||||
cache := cm.GetSharedTokenBlacklist()
|
||||
|
||||
// Clear cache first
|
||||
cache.Clear()
|
||||
|
||||
// Check initial size
|
||||
initialSize := cache.Size()
|
||||
if initialSize != 0 {
|
||||
t.Errorf("Expected initial size 0, got %d", initialSize)
|
||||
}
|
||||
|
||||
// Add some items
|
||||
cache.Set("key1", "value1", time.Hour)
|
||||
cache.Set("key2", "value2", time.Hour)
|
||||
|
||||
// Check size increased
|
||||
newSize := cache.Size()
|
||||
if newSize != 2 {
|
||||
t.Errorf("Expected size 2, got %d", newSize)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCacheInterfaceWrapper_Clear tests clear functionality
|
||||
func TestCacheInterfaceWrapper_Clear(t *testing.T) {
|
||||
cm := getTestCacheManager(t)
|
||||
cache := cm.GetSharedTokenBlacklist()
|
||||
|
||||
// Add some items
|
||||
cache.Set("key1", "value1", time.Hour)
|
||||
cache.Set("key2", "value2", time.Hour)
|
||||
|
||||
// Verify items exist
|
||||
size := cache.Size()
|
||||
if size != 2 {
|
||||
t.Errorf("Expected 2 items before clear, got %d", size)
|
||||
}
|
||||
|
||||
// Clear all
|
||||
cache.Clear()
|
||||
|
||||
// Verify cache is empty
|
||||
size = cache.Size()
|
||||
if size != 0 {
|
||||
t.Errorf("Expected 0 items after clear, got %d", size)
|
||||
}
|
||||
|
||||
// Verify specific items are gone
|
||||
_, found := cache.Get("key1")
|
||||
if found {
|
||||
t.Error("Expected key1 to be cleared")
|
||||
}
|
||||
|
||||
_, found = cache.Get("key2")
|
||||
if found {
|
||||
t.Error("Expected key2 to be cleared")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCacheInterfaceWrapper_Close tests wrapper close functionality
|
||||
func TestCacheInterfaceWrapper_Close(t *testing.T) {
|
||||
cm := getTestCacheManager(t)
|
||||
cache := cm.GetSharedTokenBlacklist()
|
||||
|
||||
// Test close - should not panic
|
||||
wrapper, ok := cache.(*CacheInterfaceWrapper)
|
||||
if !ok {
|
||||
t.Fatal("Expected CacheInterfaceWrapper")
|
||||
}
|
||||
|
||||
wrapper.Close() // Should not panic
|
||||
|
||||
// Test close with nil cache
|
||||
nilWrapper := &CacheInterfaceWrapper{cache: nil}
|
||||
nilWrapper.Close() // Should not panic
|
||||
}
|
||||
|
||||
// TestCacheInterfaceWrapper_GetStats tests stats functionality
|
||||
func TestCacheInterfaceWrapper_GetStats(t *testing.T) {
|
||||
cm := getTestCacheManager(t)
|
||||
cache := cm.GetSharedTokenBlacklist()
|
||||
|
||||
wrapper, ok := cache.(*CacheInterfaceWrapper)
|
||||
if !ok {
|
||||
t.Fatal("Expected CacheInterfaceWrapper")
|
||||
}
|
||||
|
||||
// Get stats
|
||||
stats := wrapper.GetStats()
|
||||
if stats == nil {
|
||||
t.Error("Expected non-nil stats")
|
||||
}
|
||||
|
||||
// Stats should be accessible (len() never returns negative values)
|
||||
// Just verify it's accessible by checking it's not nil (already done above)
|
||||
}
|
||||
|
||||
// TestCacheInterfaceWrapper_Cleanup tests cleanup functionality
|
||||
func TestCacheInterfaceWrapper_Cleanup(t *testing.T) {
|
||||
cm := getTestCacheManager(t)
|
||||
cache := cm.GetSharedTokenBlacklist()
|
||||
|
||||
// Add an item that will expire quickly
|
||||
cache.Set("expire-key", "expire-value", time.Millisecond)
|
||||
|
||||
// Wait for expiration
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Trigger cleanup
|
||||
cache.Cleanup()
|
||||
|
||||
// Item should be cleaned up
|
||||
_, found := cache.Get("expire-key")
|
||||
if found {
|
||||
t.Error("Expected expired key to be cleaned up")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCacheInterfaceWrapper_SetMaxSize tests max size setting
|
||||
func TestCacheInterfaceWrapper_SetMaxSize(t *testing.T) {
|
||||
cm := getTestCacheManager(t)
|
||||
cache := cm.GetSharedTokenBlacklist()
|
||||
|
||||
// Test setting max size (should not panic)
|
||||
cache.SetMaxSize(1000)
|
||||
|
||||
// We can't easily verify the size was set without exposing internals,
|
||||
// but we can ensure the method doesn't panic
|
||||
}
|
||||
|
||||
// TestGetSharedCaches tests getting shared cache instances
|
||||
func TestGetSharedCaches(t *testing.T) {
|
||||
cm := getTestCacheManager(t)
|
||||
|
||||
// Test getting shared token blacklist
|
||||
blacklist := cm.GetSharedTokenBlacklist()
|
||||
if blacklist == nil {
|
||||
t.Error("Expected non-nil token blacklist")
|
||||
}
|
||||
|
||||
// Test getting shared token cache
|
||||
tokenCache := cm.GetSharedTokenCache()
|
||||
if tokenCache == nil {
|
||||
t.Error("Expected non-nil token cache")
|
||||
}
|
||||
|
||||
// Test getting shared metadata cache
|
||||
metadataCache := cm.GetSharedMetadataCache()
|
||||
if metadataCache == nil {
|
||||
t.Error("Expected non-nil metadata cache")
|
||||
}
|
||||
|
||||
// Test getting shared JWK cache
|
||||
jwkCache := cm.GetSharedJWKCache()
|
||||
if jwkCache == nil {
|
||||
t.Error("Expected non-nil JWK cache")
|
||||
}
|
||||
}
|
||||
|
||||
// TestConcurrentCacheAccess tests thread safety
|
||||
func TestConcurrentCacheAccess(t *testing.T) {
|
||||
cm := getTestCacheManager(t)
|
||||
cache := cm.GetSharedTokenBlacklist()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
goroutines := 10
|
||||
iterations := 10
|
||||
|
||||
// Concurrent operations
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < iterations; j++ {
|
||||
key := fmt.Sprintf("key-%d-%d", id, j)
|
||||
value := fmt.Sprintf("value-%d-%d", id, j)
|
||||
|
||||
cache.Set(key, value, time.Hour)
|
||||
|
||||
retrieved, found := cache.Get(key)
|
||||
if found && retrieved != value {
|
||||
t.Errorf("Concurrent access failed: expected %s, got %v", value, retrieved)
|
||||
}
|
||||
|
||||
cache.Delete(key)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// Benchmark tests for performance
|
||||
func BenchmarkCacheInterfaceWrapper_Set(b *testing.B) {
|
||||
t := &testing.T{}
|
||||
cm := getTestCacheManager(t)
|
||||
cache := cm.GetSharedTokenBlacklist()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
cache.Set("benchmark-key", "benchmark-value", time.Hour)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkCacheInterfaceWrapper_Get(b *testing.B) {
|
||||
t := &testing.T{}
|
||||
cm := getTestCacheManager(t)
|
||||
cache := cm.GetSharedTokenBlacklist()
|
||||
|
||||
// Pre-populate cache
|
||||
cache.Set("benchmark-key", "benchmark-value", time.Hour)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
cache.Get("benchmark-key")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkCacheInterfaceWrapper_Delete(b *testing.B) {
|
||||
t := &testing.T{}
|
||||
cm := getTestCacheManager(t)
|
||||
cache := cm.GetSharedTokenBlacklist()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
b.StopTimer()
|
||||
key := fmt.Sprintf("benchmark-key-%d", i)
|
||||
cache.Set(key, "value", time.Hour)
|
||||
b.StartTimer()
|
||||
|
||||
cache.Delete(key)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,258 @@
|
||||
// Package config provides backward compatibility for legacy configuration
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/compat"
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/features"
|
||||
)
|
||||
|
||||
// LegacyAdapter provides backward compatibility for old Config struct
|
||||
type LegacyAdapter struct {
|
||||
unified *UnifiedConfig
|
||||
adapter *compat.ConfigAdapter
|
||||
}
|
||||
|
||||
// NewLegacyAdapter creates a new legacy adapter from unified config
|
||||
func NewLegacyAdapter(unified *UnifiedConfig) *LegacyAdapter {
|
||||
adapter := compat.NewConfigAdapter(unified)
|
||||
|
||||
// Register getters for commonly used fields
|
||||
adapter.RegisterGetter("ProviderURL", func() interface{} {
|
||||
return unified.Provider.IssuerURL
|
||||
})
|
||||
adapter.RegisterGetter("ClientID", func() interface{} {
|
||||
return unified.Provider.ClientID
|
||||
})
|
||||
adapter.RegisterGetter("ClientSecret", func() interface{} {
|
||||
return unified.Provider.ClientSecret
|
||||
})
|
||||
adapter.RegisterGetter("CallbackURL", func() interface{} {
|
||||
return unified.Provider.RedirectURL
|
||||
})
|
||||
adapter.RegisterGetter("LogoutURL", func() interface{} {
|
||||
return unified.Provider.LogoutURL
|
||||
})
|
||||
adapter.RegisterGetter("PostLogoutRedirectURI", func() interface{} {
|
||||
return unified.Provider.PostLogoutRedirectURI
|
||||
})
|
||||
adapter.RegisterGetter("SessionEncryptionKey", func() interface{} {
|
||||
return unified.Session.EncryptionKey
|
||||
})
|
||||
adapter.RegisterGetter("ForceHTTPS", func() interface{} {
|
||||
return unified.Security.ForceHTTPS
|
||||
})
|
||||
adapter.RegisterGetter("LogLevel", func() interface{} {
|
||||
return unified.Logging.Level
|
||||
})
|
||||
adapter.RegisterGetter("Scopes", func() interface{} {
|
||||
return unified.Provider.Scopes
|
||||
})
|
||||
adapter.RegisterGetter("OverrideScopes", func() interface{} {
|
||||
return unified.Provider.OverrideScopes
|
||||
})
|
||||
adapter.RegisterGetter("AllowedUsers", func() interface{} {
|
||||
return unified.Security.AllowedUsers
|
||||
})
|
||||
adapter.RegisterGetter("AllowedUserDomains", func() interface{} {
|
||||
return unified.Security.AllowedUserDomains
|
||||
})
|
||||
adapter.RegisterGetter("AllowedRolesAndGroups", func() interface{} {
|
||||
return unified.Security.AllowedRolesAndGroups
|
||||
})
|
||||
adapter.RegisterGetter("ExcludedURLs", func() interface{} {
|
||||
return unified.Security.ExcludedURLs
|
||||
})
|
||||
adapter.RegisterGetter("EnablePKCE", func() interface{} {
|
||||
return unified.Security.EnablePKCE
|
||||
})
|
||||
adapter.RegisterGetter("RateLimit", func() interface{} {
|
||||
return unified.RateLimit.RequestsPerSecond
|
||||
})
|
||||
adapter.RegisterGetter("RefreshGracePeriodSeconds", func() interface{} {
|
||||
return int(unified.Token.RefreshGracePeriod.Seconds())
|
||||
})
|
||||
adapter.RegisterGetter("CookieDomain", func() interface{} {
|
||||
return unified.Session.Domain
|
||||
})
|
||||
adapter.RegisterGetter("SecurityHeaders", func() interface{} {
|
||||
return unified.Security.Headers
|
||||
})
|
||||
|
||||
return &LegacyAdapter{
|
||||
unified: unified,
|
||||
adapter: adapter,
|
||||
}
|
||||
}
|
||||
|
||||
// ToOldConfig converts unified config to old Config struct format
|
||||
func (la *LegacyAdapter) ToOldConfig() *Config {
|
||||
// Use feature flags to determine behavior
|
||||
if !features.IsUnifiedConfigEnabled() {
|
||||
// Return existing Config if unified config not enabled
|
||||
return CreateConfig()
|
||||
}
|
||||
|
||||
cfg := &Config{
|
||||
ProviderURL: la.unified.Provider.IssuerURL,
|
||||
ClientID: la.unified.Provider.ClientID,
|
||||
ClientSecret: la.unified.Provider.ClientSecret,
|
||||
CallbackURL: la.unified.Provider.RedirectURL,
|
||||
LogoutURL: la.unified.Provider.LogoutURL,
|
||||
PostLogoutRedirectURI: la.unified.Provider.PostLogoutRedirectURI,
|
||||
SessionEncryptionKey: la.unified.Session.EncryptionKey,
|
||||
ForceHTTPS: la.unified.Security.ForceHTTPS,
|
||||
LogLevel: la.unified.Logging.Level,
|
||||
Scopes: la.unified.Provider.Scopes,
|
||||
OverrideScopes: la.unified.Provider.OverrideScopes,
|
||||
AllowedUsers: la.unified.Security.AllowedUsers,
|
||||
AllowedUserDomains: la.unified.Security.AllowedUserDomains,
|
||||
AllowedRolesAndGroups: la.unified.Security.AllowedRolesAndGroups,
|
||||
ExcludedURLs: la.unified.Security.ExcludedURLs,
|
||||
EnablePKCE: la.unified.Security.EnablePKCE,
|
||||
RateLimit: la.unified.RateLimit.RequestsPerSecond,
|
||||
RefreshGracePeriodSeconds: int(la.unified.Token.RefreshGracePeriod.Seconds()),
|
||||
Headers: la.convertHeaders(),
|
||||
CookieDomain: la.unified.Session.Domain,
|
||||
SecurityHeaders: la.unified.Security.Headers,
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
// convertHeaders converts unified header config to old format
|
||||
func (la *LegacyAdapter) convertHeaders() []HeaderConfig {
|
||||
headers := make([]HeaderConfig, 0)
|
||||
|
||||
for name, value := range la.unified.Middleware.CustomHeaders {
|
||||
headers = append(headers, HeaderConfig{
|
||||
Name: name,
|
||||
Value: value,
|
||||
})
|
||||
}
|
||||
|
||||
return headers
|
||||
}
|
||||
|
||||
// FromOldConfig creates unified config from old Config struct
|
||||
func FromOldConfig(old *Config) *UnifiedConfig {
|
||||
unified := NewUnifiedConfig()
|
||||
|
||||
// Map provider settings
|
||||
unified.Provider.IssuerURL = old.ProviderURL
|
||||
unified.Provider.ClientID = old.ClientID
|
||||
unified.Provider.ClientSecret = old.ClientSecret
|
||||
unified.Provider.RedirectURL = old.CallbackURL
|
||||
unified.Provider.LogoutURL = old.LogoutURL
|
||||
unified.Provider.PostLogoutRedirectURI = old.PostLogoutRedirectURI
|
||||
unified.Provider.Scopes = old.Scopes
|
||||
unified.Provider.OverrideScopes = old.OverrideScopes
|
||||
|
||||
// Map session settings
|
||||
unified.Session.EncryptionKey = old.SessionEncryptionKey
|
||||
unified.Session.Domain = old.CookieDomain
|
||||
|
||||
// Map security settings
|
||||
unified.Security.ForceHTTPS = old.ForceHTTPS
|
||||
unified.Security.EnablePKCE = old.EnablePKCE
|
||||
unified.Security.AllowedUsers = old.AllowedUsers
|
||||
unified.Security.AllowedUserDomains = old.AllowedUserDomains
|
||||
unified.Security.AllowedRolesAndGroups = old.AllowedRolesAndGroups
|
||||
unified.Security.ExcludedURLs = old.ExcludedURLs
|
||||
unified.Security.Headers = old.SecurityHeaders
|
||||
|
||||
// Map rate limiting
|
||||
unified.RateLimit.RequestsPerSecond = old.RateLimit
|
||||
unified.RateLimit.Enabled = old.RateLimit > 0
|
||||
|
||||
// Map token settings
|
||||
unified.Token.RefreshGracePeriod = timeSecondsToDuration(old.RefreshGracePeriodSeconds)
|
||||
|
||||
// Map logging
|
||||
unified.Logging.Level = old.LogLevel
|
||||
|
||||
// Map custom headers
|
||||
if len(old.Headers) > 0 {
|
||||
unified.Middleware.CustomHeaders = make(map[string]string)
|
||||
for _, header := range old.Headers {
|
||||
unified.Middleware.CustomHeaders[header.Name] = header.Value
|
||||
}
|
||||
}
|
||||
|
||||
// Store original config in legacy field for reference
|
||||
unified.Legacy["original"] = old
|
||||
|
||||
return unified
|
||||
}
|
||||
|
||||
// timeSecondsToDuration converts seconds to time.Duration
|
||||
func timeSecondsToDuration(seconds int) time.Duration {
|
||||
return time.Duration(seconds) * time.Second
|
||||
}
|
||||
|
||||
// GetConfigInterface returns appropriate config based on feature flag
|
||||
func GetConfigInterface() interface{} {
|
||||
if features.IsUnifiedConfigEnabled() {
|
||||
return NewUnifiedConfig()
|
||||
}
|
||||
return CreateConfig()
|
||||
}
|
||||
|
||||
// ValidateConfig validates config based on feature flag
|
||||
func ValidateConfig(cfg interface{}) error {
|
||||
if features.IsUnifiedConfigEnabled() {
|
||||
if unified, ok := cfg.(*UnifiedConfig); ok {
|
||||
return unified.Validate()
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to old validation if available
|
||||
if old, ok := cfg.(*Config); ok {
|
||||
return old.Validate()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Add Validate method to old Config for compatibility
|
||||
func (c *Config) Validate() error {
|
||||
var errors ValidationErrors
|
||||
|
||||
// Basic validation for old config
|
||||
if c.ProviderURL == "" {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "ProviderURL",
|
||||
Message: "provider URL is required",
|
||||
})
|
||||
}
|
||||
|
||||
if c.ClientID == "" {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "ClientID",
|
||||
Message: "client ID is required",
|
||||
})
|
||||
}
|
||||
|
||||
if c.ClientSecret == "" && !c.EnablePKCE {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "ClientSecret",
|
||||
Message: "client secret is required (or enable PKCE)",
|
||||
})
|
||||
}
|
||||
|
||||
if c.SessionEncryptionKey != "" && len(c.SessionEncryptionKey) < minEncryptionKeyLength {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "SessionEncryptionKey",
|
||||
Message: fmt.Sprintf("encryption key must be at least %d characters", minEncryptionKeyLength),
|
||||
Value: len(c.SessionEncryptionKey),
|
||||
})
|
||||
}
|
||||
|
||||
if len(errors) > 0 {
|
||||
return errors
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,363 @@
|
||||
//go:build !yaegi
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/features"
|
||||
)
|
||||
|
||||
// NewLegacyAdapter Tests
|
||||
func TestNewLegacyAdapter(t *testing.T) {
|
||||
unified := NewUnifiedConfig()
|
||||
unified.Provider.IssuerURL = "https://provider.example.com"
|
||||
unified.Provider.ClientID = "test-client"
|
||||
unified.Provider.ClientSecret = "test-secret"
|
||||
|
||||
adapter := NewLegacyAdapter(unified)
|
||||
|
||||
if adapter == nil {
|
||||
t.Fatal("Expected NewLegacyAdapter to return non-nil")
|
||||
}
|
||||
|
||||
if adapter.unified != unified {
|
||||
t.Error("Expected adapter to reference the unified config")
|
||||
}
|
||||
|
||||
if adapter.adapter == nil {
|
||||
t.Error("Expected internal adapter to be initialized")
|
||||
}
|
||||
}
|
||||
|
||||
// ToOldConfig Tests
|
||||
func TestLegacyAdapter_ToOldConfig(t *testing.T) {
|
||||
unified := NewUnifiedConfig()
|
||||
unified.Provider.IssuerURL = "https://issuer.example.com"
|
||||
unified.Provider.ClientID = "client-123"
|
||||
unified.Provider.ClientSecret = "secret-456"
|
||||
unified.Provider.RedirectURL = "https://app.example.com/callback"
|
||||
unified.Provider.LogoutURL = "/logout"
|
||||
unified.Provider.PostLogoutRedirectURI = "https://app.example.com"
|
||||
unified.Provider.Scopes = []string{"openid", "profile"}
|
||||
unified.Provider.OverrideScopes = true
|
||||
unified.Session.EncryptionKey = "test-encryption-key-32-chars!!"
|
||||
unified.Session.Domain = "example.com"
|
||||
unified.Security.ForceHTTPS = true
|
||||
unified.Security.EnablePKCE = true
|
||||
unified.Security.AllowedUsers = []string{"user@example.com"}
|
||||
unified.Security.AllowedUserDomains = []string{"example.com"}
|
||||
unified.Security.AllowedRolesAndGroups = []string{"admin"}
|
||||
unified.Security.ExcludedURLs = []string{"/health"}
|
||||
unified.RateLimit.RequestsPerSecond = 100
|
||||
unified.Logging.Level = "debug"
|
||||
unified.Middleware.CustomHeaders = map[string]string{
|
||||
"X-Header-1": "value1",
|
||||
"X-Header-2": "value2",
|
||||
}
|
||||
|
||||
adapter := NewLegacyAdapter(unified)
|
||||
oldConfig := adapter.ToOldConfig()
|
||||
|
||||
if oldConfig == nil {
|
||||
t.Fatal("Expected ToOldConfig to return non-nil")
|
||||
}
|
||||
|
||||
// ToOldConfig behavior depends on feature flag
|
||||
if !features.IsUnifiedConfigEnabled() {
|
||||
// When feature is disabled, returns default config
|
||||
if oldConfig.ProviderURL == "" {
|
||||
t.Log("Feature flag disabled - ToOldConfig returns default config")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// When feature is enabled, verify all fields were correctly mapped
|
||||
if oldConfig.ProviderURL != unified.Provider.IssuerURL {
|
||||
t.Errorf("Expected ProviderURL '%s', got '%s'", unified.Provider.IssuerURL, oldConfig.ProviderURL)
|
||||
}
|
||||
|
||||
if oldConfig.ClientID != unified.Provider.ClientID {
|
||||
t.Errorf("Expected ClientID '%s', got '%s'", unified.Provider.ClientID, oldConfig.ClientID)
|
||||
}
|
||||
|
||||
if oldConfig.ClientSecret != unified.Provider.ClientSecret {
|
||||
t.Errorf("Expected ClientSecret '%s', got '%s'", unified.Provider.ClientSecret, oldConfig.ClientSecret)
|
||||
}
|
||||
|
||||
if oldConfig.CallbackURL != unified.Provider.RedirectURL {
|
||||
t.Error("Expected CallbackURL to match RedirectURL")
|
||||
}
|
||||
|
||||
if oldConfig.LogoutURL != unified.Provider.LogoutURL {
|
||||
t.Error("Expected LogoutURL to match")
|
||||
}
|
||||
|
||||
if oldConfig.ForceHTTPS != unified.Security.ForceHTTPS {
|
||||
t.Error("Expected ForceHTTPS to match")
|
||||
}
|
||||
|
||||
if oldConfig.EnablePKCE != unified.Security.EnablePKCE {
|
||||
t.Error("Expected EnablePKCE to match")
|
||||
}
|
||||
|
||||
if oldConfig.RateLimit != unified.RateLimit.RequestsPerSecond {
|
||||
t.Errorf("Expected RateLimit %d, got %d", unified.RateLimit.RequestsPerSecond, oldConfig.RateLimit)
|
||||
}
|
||||
|
||||
if len(oldConfig.Headers) != 2 {
|
||||
t.Errorf("Expected 2 headers, got %d", len(oldConfig.Headers))
|
||||
}
|
||||
}
|
||||
|
||||
// convertHeaders Tests
|
||||
func TestLegacyAdapter_convertHeaders(t *testing.T) {
|
||||
unified := NewUnifiedConfig()
|
||||
unified.Middleware.CustomHeaders = map[string]string{
|
||||
"X-Custom-Header-1": "value1",
|
||||
"X-Custom-Header-2": "value2",
|
||||
"X-Custom-Header-3": "value3",
|
||||
}
|
||||
|
||||
adapter := NewLegacyAdapter(unified)
|
||||
headers := adapter.convertHeaders()
|
||||
|
||||
if len(headers) != 3 {
|
||||
t.Errorf("Expected 3 headers, got %d", len(headers))
|
||||
}
|
||||
|
||||
// Check that headers were converted
|
||||
headerMap := make(map[string]string)
|
||||
for _, h := range headers {
|
||||
headerMap[h.Name] = h.Value
|
||||
}
|
||||
|
||||
if headerMap["X-Custom-Header-1"] != "value1" {
|
||||
t.Error("Expected X-Custom-Header-1 to have value 'value1'")
|
||||
}
|
||||
|
||||
if headerMap["X-Custom-Header-2"] != "value2" {
|
||||
t.Error("Expected X-Custom-Header-2 to have value 'value2'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLegacyAdapter_convertHeaders_Empty(t *testing.T) {
|
||||
unified := NewUnifiedConfig()
|
||||
// No custom headers
|
||||
|
||||
adapter := NewLegacyAdapter(unified)
|
||||
headers := adapter.convertHeaders()
|
||||
|
||||
if len(headers) != 0 {
|
||||
t.Errorf("Expected 0 headers, got %d", len(headers))
|
||||
}
|
||||
}
|
||||
|
||||
// GetConfigInterface Tests
|
||||
func TestGetConfigInterface(t *testing.T) {
|
||||
cfg := GetConfigInterface()
|
||||
|
||||
if cfg == nil {
|
||||
t.Fatal("Expected GetConfigInterface to return non-nil")
|
||||
}
|
||||
|
||||
// Should return either UnifiedConfig or Config depending on feature flag
|
||||
_, isUnified := cfg.(*UnifiedConfig)
|
||||
_, isOld := cfg.(*Config)
|
||||
|
||||
if !isUnified && !isOld {
|
||||
t.Error("Expected either *UnifiedConfig or *Config")
|
||||
}
|
||||
|
||||
// Verify consistency with feature flag
|
||||
if features.IsUnifiedConfigEnabled() {
|
||||
if !isUnified {
|
||||
t.Error("Expected *UnifiedConfig when unified config is enabled")
|
||||
}
|
||||
} else {
|
||||
if !isOld {
|
||||
t.Error("Expected *Config when unified config is disabled")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateConfig Tests
|
||||
func TestValidateConfig_UnifiedConfig(t *testing.T) {
|
||||
unified := NewUnifiedConfig()
|
||||
unified.Provider.IssuerURL = "https://provider.example.com"
|
||||
unified.Provider.ClientID = "client-id"
|
||||
unified.Provider.ClientSecret = "client-secret"
|
||||
unified.Session.EncryptionKey = "encryption-key-32-characters!!"
|
||||
|
||||
err := ValidateConfig(unified)
|
||||
// Should succeed regardless of feature flag since we're passing the right type
|
||||
if err != nil {
|
||||
t.Errorf("Expected valid unified config to pass validation, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateConfig_OldConfig(t *testing.T) {
|
||||
old := CreateConfig()
|
||||
old.ProviderURL = "https://provider.example.com"
|
||||
old.ClientID = "client-id"
|
||||
old.ClientSecret = "client-secret"
|
||||
old.SessionEncryptionKey = "encryption-key-32-characters!!"
|
||||
|
||||
err := ValidateConfig(old)
|
||||
if err != nil {
|
||||
t.Errorf("Expected valid old config to pass validation, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateConfig_InvalidType(t *testing.T) {
|
||||
// Pass something that's not a config
|
||||
err := ValidateConfig("not a config")
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil for unknown type, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Config.Validate Tests
|
||||
func TestConfig_Validate_Valid(t *testing.T) {
|
||||
cfg := CreateConfig()
|
||||
cfg.ProviderURL = "https://provider.example.com"
|
||||
cfg.ClientID = "client-id"
|
||||
cfg.ClientSecret = "client-secret"
|
||||
cfg.SessionEncryptionKey = "encryption-key-32-characters!!"
|
||||
|
||||
err := cfg.Validate()
|
||||
if err != nil {
|
||||
t.Errorf("Expected valid config to pass, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_Validate_MissingProviderURL(t *testing.T) {
|
||||
cfg := CreateConfig()
|
||||
cfg.ClientID = "client-id"
|
||||
cfg.ClientSecret = "client-secret"
|
||||
|
||||
err := cfg.Validate()
|
||||
if err == nil {
|
||||
t.Error("Expected error for missing ProviderURL")
|
||||
}
|
||||
|
||||
// Check if it's a ValidationErrors type
|
||||
if verrs, ok := err.(ValidationErrors); ok {
|
||||
found := false
|
||||
for _, verr := range verrs {
|
||||
if verr.Field == "ProviderURL" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("Expected ProviderURL validation error")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_Validate_MissingClientID(t *testing.T) {
|
||||
cfg := CreateConfig()
|
||||
cfg.ProviderURL = "https://provider.example.com"
|
||||
cfg.ClientSecret = "client-secret"
|
||||
|
||||
err := cfg.Validate()
|
||||
if err == nil {
|
||||
t.Error("Expected error for missing ClientID")
|
||||
}
|
||||
|
||||
if verrs, ok := err.(ValidationErrors); ok {
|
||||
found := false
|
||||
for _, verr := range verrs {
|
||||
if verr.Field == "ClientID" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("Expected ClientID validation error")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_Validate_MissingClientSecret_NoPKCE(t *testing.T) {
|
||||
cfg := CreateConfig()
|
||||
cfg.ProviderURL = "https://provider.example.com"
|
||||
cfg.ClientID = "client-id"
|
||||
cfg.EnablePKCE = false
|
||||
|
||||
err := cfg.Validate()
|
||||
if err == nil {
|
||||
t.Error("Expected error for missing ClientSecret without PKCE")
|
||||
}
|
||||
|
||||
if verrs, ok := err.(ValidationErrors); ok {
|
||||
found := false
|
||||
for _, verr := range verrs {
|
||||
if verr.Field == "ClientSecret" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("Expected ClientSecret validation error")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_Validate_MissingClientSecret_WithPKCE(t *testing.T) {
|
||||
cfg := CreateConfig()
|
||||
cfg.ProviderURL = "https://provider.example.com"
|
||||
cfg.ClientID = "client-id"
|
||||
cfg.EnablePKCE = true // PKCE enabled, so ClientSecret not required
|
||||
|
||||
err := cfg.Validate()
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error with PKCE enabled and no ClientSecret, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_Validate_ShortEncryptionKey(t *testing.T) {
|
||||
cfg := CreateConfig()
|
||||
cfg.ProviderURL = "https://provider.example.com"
|
||||
cfg.ClientID = "client-id"
|
||||
cfg.ClientSecret = "client-secret"
|
||||
cfg.SessionEncryptionKey = "short" // Too short
|
||||
|
||||
err := cfg.Validate()
|
||||
if err == nil {
|
||||
t.Error("Expected error for short encryption key")
|
||||
}
|
||||
|
||||
if verrs, ok := err.(ValidationErrors); ok {
|
||||
found := false
|
||||
for _, verr := range verrs {
|
||||
if verr.Field == "SessionEncryptionKey" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("Expected SessionEncryptionKey validation error")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_Validate_MultipleErrors(t *testing.T) {
|
||||
cfg := CreateConfig()
|
||||
// Missing ProviderURL, ClientID, and ClientSecret
|
||||
|
||||
err := cfg.Validate()
|
||||
if err == nil {
|
||||
t.Fatal("Expected validation errors")
|
||||
}
|
||||
|
||||
verrs, ok := err.(ValidationErrors)
|
||||
if !ok {
|
||||
t.Fatal("Expected ValidationErrors type")
|
||||
}
|
||||
|
||||
if len(verrs) < 2 {
|
||||
t.Errorf("Expected at least 2 validation errors, got %d", len(verrs))
|
||||
}
|
||||
}
|
||||
+854
-983
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,276 @@
|
||||
// Package config provides default values and initialization for unified configuration
|
||||
package config
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// NewUnifiedConfig creates a new unified configuration with sensible defaults
|
||||
func NewUnifiedConfig() *UnifiedConfig {
|
||||
return &UnifiedConfig{
|
||||
Provider: DefaultProviderConfig(),
|
||||
Session: DefaultSessionConfig(),
|
||||
Token: DefaultTokenConfig(),
|
||||
Redis: *DefaultRedisConfig(), // Using existing DefaultRedisConfig
|
||||
Security: DefaultSecurityConfig(),
|
||||
Middleware: DefaultMiddlewareConfig(),
|
||||
Cache: DefaultCacheConfig(),
|
||||
RateLimit: DefaultRateLimitConfig(),
|
||||
Logging: DefaultLoggingConfig(),
|
||||
Metrics: DefaultMetricsConfig(),
|
||||
Health: DefaultHealthConfig(),
|
||||
Transport: DefaultTransportConfig(),
|
||||
Pool: DefaultPoolConfig(),
|
||||
Circuit: DefaultCircuitConfig(),
|
||||
Legacy: make(map[string]interface{}),
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultProviderConfig returns default provider configuration
|
||||
func DefaultProviderConfig() ProviderConfig {
|
||||
return ProviderConfig{
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
OverrideScopes: false,
|
||||
CustomClaims: make(map[string]string),
|
||||
JWKCachePeriod: 24 * time.Hour,
|
||||
MetadataCacheTTL: 24 * time.Hour,
|
||||
Discovery: true,
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultSessionConfig returns default session configuration
|
||||
func DefaultSessionConfig() SessionConfig {
|
||||
return SessionConfig{
|
||||
Name: "oidc_session",
|
||||
MaxAge: 86400, // 24 hours
|
||||
ChunkSize: 4000, // Safe size for cookies
|
||||
MaxChunks: 5,
|
||||
Path: "/",
|
||||
Secure: true,
|
||||
HttpOnly: true,
|
||||
SameSite: "Lax",
|
||||
StorageType: "cookie",
|
||||
CleanupInterval: 1 * time.Hour,
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultTokenConfig returns default token configuration
|
||||
func DefaultTokenConfig() TokenConfig {
|
||||
return TokenConfig{
|
||||
AccessTokenTTL: 1 * time.Hour,
|
||||
RefreshTokenTTL: 24 * time.Hour,
|
||||
RefreshGracePeriod: 60 * time.Second,
|
||||
ValidationMode: "jwt",
|
||||
CacheEnabled: true,
|
||||
CacheTTL: 5 * time.Minute,
|
||||
CacheNegativeTTL: 30 * time.Second,
|
||||
ValidateSignature: true,
|
||||
ValidateExpiry: true,
|
||||
ValidateAudience: true,
|
||||
ValidateIssuer: true,
|
||||
RequiredClaims: []string{"sub", "iat", "exp"},
|
||||
ClockSkew: 5 * time.Minute,
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultSecurityConfig returns default security configuration
|
||||
func DefaultSecurityConfig() SecurityConfig {
|
||||
return SecurityConfig{
|
||||
ForceHTTPS: true,
|
||||
EnablePKCE: true,
|
||||
AllowedUsers: []string{},
|
||||
AllowedUserDomains: []string{},
|
||||
AllowedRolesAndGroups: []string{},
|
||||
ExcludedURLs: []string{
|
||||
"/favicon.ico",
|
||||
"/robots.txt",
|
||||
"/health",
|
||||
"/.well-known/",
|
||||
"/metrics",
|
||||
"/ping",
|
||||
"/static/",
|
||||
"/assets/",
|
||||
"/js/",
|
||||
"/css/",
|
||||
"/images/",
|
||||
"/fonts/",
|
||||
},
|
||||
Headers: createDefaultSecurityConfig(),
|
||||
CSRFProtection: true,
|
||||
CSRFTokenName: "csrf_token",
|
||||
CSRFTokenTTL: 1 * time.Hour,
|
||||
MaxLoginAttempts: 5,
|
||||
LockoutDuration: 15 * time.Minute,
|
||||
RequireMFA: false,
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultMiddlewareConfig returns default middleware configuration
|
||||
func DefaultMiddlewareConfig() MiddlewareConfig {
|
||||
return MiddlewareConfig{
|
||||
Priority: 1000,
|
||||
SkipPaths: []string{},
|
||||
RequirePaths: []string{},
|
||||
PassthroughMode: false,
|
||||
MaxRequestSize: 10 * 1024 * 1024, // 10MB
|
||||
RequestTimeout: 30 * time.Second,
|
||||
IdleTimeout: 90 * time.Second,
|
||||
CustomHeaders: make(map[string]string),
|
||||
RemoveHeaders: []string{},
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultCacheConfig returns default cache configuration
|
||||
func DefaultCacheConfig() CacheConfig {
|
||||
return CacheConfig{
|
||||
Enabled: true,
|
||||
Type: "memory",
|
||||
DefaultTTL: 5 * time.Minute,
|
||||
MaxEntries: 10000,
|
||||
MaxEntrySize: 1024 * 1024, // 1MB
|
||||
EvictionPolicy: "lru",
|
||||
CleanupInterval: 10 * time.Minute,
|
||||
Namespace: "traefikoidc",
|
||||
Compression: false,
|
||||
Serialization: "json",
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultRateLimitConfig returns default rate limiting configuration
|
||||
func DefaultRateLimitConfig() RateLimitConfig {
|
||||
return RateLimitConfig{
|
||||
Enabled: false,
|
||||
RequestsPerSecond: 10,
|
||||
Burst: 20,
|
||||
StorageType: "memory",
|
||||
WindowDuration: 1 * time.Minute,
|
||||
KeyType: "ip",
|
||||
CustomKeyFunc: "",
|
||||
WhitelistIPs: []string{},
|
||||
WhitelistUsers: []string{},
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultLoggingConfig returns default logging configuration
|
||||
func DefaultLoggingConfig() LoggingConfig {
|
||||
return LoggingConfig{
|
||||
Level: "info",
|
||||
Format: "json",
|
||||
Output: "stdout",
|
||||
FilePath: "",
|
||||
FilterSensitive: true,
|
||||
MaskFields: []string{
|
||||
"password",
|
||||
"secret",
|
||||
"token",
|
||||
"key",
|
||||
"authorization",
|
||||
"cookie",
|
||||
},
|
||||
BufferSize: 8192,
|
||||
FlushInterval: 5 * time.Second,
|
||||
AuditEnabled: false,
|
||||
AuditEvents: []string{
|
||||
"login",
|
||||
"logout",
|
||||
"token_refresh",
|
||||
"auth_failure",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultMetricsConfig returns default metrics configuration
|
||||
func DefaultMetricsConfig() MetricsConfig {
|
||||
return MetricsConfig{
|
||||
Enabled: false,
|
||||
Provider: "prometheus",
|
||||
Endpoint: "/metrics",
|
||||
Namespace: "traefikoidc",
|
||||
Subsystem: "middleware",
|
||||
CollectInterval: 10 * time.Second,
|
||||
Histograms: true,
|
||||
Labels: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultHealthConfig returns default health check configuration
|
||||
func DefaultHealthConfig() HealthConfig {
|
||||
return HealthConfig{
|
||||
Enabled: true,
|
||||
Path: "/health",
|
||||
CheckInterval: 30 * time.Second,
|
||||
Timeout: 5 * time.Second,
|
||||
CheckProvider: true,
|
||||
CheckRedis: true,
|
||||
CheckCache: true,
|
||||
MaxLatency: 1 * time.Second,
|
||||
MinMemory: 100 * 1024 * 1024, // 100MB
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultTransportConfig returns default HTTP transport configuration
|
||||
func DefaultTransportConfig() TransportConfig {
|
||||
return TransportConfig{
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
MaxConnsPerHost: 0, // No limit
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
ResponseHeaderTimeout: 10 * time.Second,
|
||||
DisableKeepAlives: false,
|
||||
DisableCompression: false,
|
||||
TLSInsecureSkipVerify: false,
|
||||
TLSMinVersion: "TLS1.2",
|
||||
TLSCipherSuites: []string{},
|
||||
ProxyURL: "",
|
||||
NoProxy: []string{},
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultPoolConfig returns default connection pool configuration
|
||||
func DefaultPoolConfig() PoolConfig {
|
||||
return PoolConfig{
|
||||
Enabled: true,
|
||||
Size: 10,
|
||||
MinSize: 2,
|
||||
MaxSize: 50,
|
||||
MaxAge: 30 * time.Minute,
|
||||
IdleTimeout: 5 * time.Minute,
|
||||
WaitTimeout: 5 * time.Second,
|
||||
HealthCheckInterval: 30 * time.Second,
|
||||
MaxRetries: 3,
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultCircuitConfig returns default circuit breaker configuration
|
||||
func DefaultCircuitConfig() CircuitConfig {
|
||||
return CircuitConfig{
|
||||
Enabled: true,
|
||||
MaxRequests: 100,
|
||||
Interval: 10 * time.Second,
|
||||
Timeout: 60 * time.Second,
|
||||
ConsecutiveFailures: 5,
|
||||
FailureRatio: 0.5,
|
||||
OnOpen: "reject",
|
||||
OnHalfOpen: "passthrough",
|
||||
MetricsEnabled: true,
|
||||
LogStateChanges: true,
|
||||
}
|
||||
}
|
||||
|
||||
// MergeWithDefaults merges a partial configuration with defaults
|
||||
func MergeWithDefaults(partial *UnifiedConfig) *UnifiedConfig {
|
||||
if partial == nil {
|
||||
return NewUnifiedConfig()
|
||||
}
|
||||
|
||||
// Ensure Legacy field is initialized
|
||||
if partial.Legacy == nil {
|
||||
partial.Legacy = make(map[string]interface{})
|
||||
}
|
||||
|
||||
// TODO: Implement deep merge logic with defaults
|
||||
// For now, just return the partial config
|
||||
return partial
|
||||
}
|
||||
@@ -0,0 +1,396 @@
|
||||
// Package config provides configuration loading and merging logic
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/features"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// ConfigLoader handles loading configuration from various sources
|
||||
type ConfigLoader struct {
|
||||
migrator *ConfigMigrator
|
||||
envPrefix string
|
||||
configPaths []string
|
||||
}
|
||||
|
||||
// NewConfigLoader creates a new configuration loader
|
||||
func NewConfigLoader() *ConfigLoader {
|
||||
return &ConfigLoader{
|
||||
migrator: NewConfigMigrator(),
|
||||
envPrefix: "TRAEFIKOIDC_",
|
||||
configPaths: getDefaultConfigPaths(),
|
||||
}
|
||||
}
|
||||
|
||||
// getDefaultConfigPaths returns default configuration file paths to check
|
||||
func getDefaultConfigPaths() []string {
|
||||
return []string{
|
||||
"traefik-oidc.yaml",
|
||||
"traefik-oidc.yml",
|
||||
"traefik-oidc.json",
|
||||
"config.yaml",
|
||||
"config.yml",
|
||||
"config.json",
|
||||
"/etc/traefik-oidc/config.yaml",
|
||||
"/etc/traefik-oidc/config.json",
|
||||
}
|
||||
}
|
||||
|
||||
// Load loads configuration from all available sources
|
||||
func (l *ConfigLoader) Load() (*UnifiedConfig, error) {
|
||||
// Start with defaults
|
||||
config := NewUnifiedConfig()
|
||||
|
||||
// Try to load from file
|
||||
if fileConfig, err := l.LoadFromFile(); err == nil && fileConfig != nil {
|
||||
config = l.mergeConfigs(config, fileConfig)
|
||||
}
|
||||
|
||||
// Load from environment variables
|
||||
l.LoadFromEnv(config)
|
||||
|
||||
// Validate the final configuration
|
||||
if err := config.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("configuration validation failed: %w", err)
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// LoadFromFile loads configuration from a file
|
||||
func (l *ConfigLoader) LoadFromFile(paths ...string) (*UnifiedConfig, error) {
|
||||
// Use provided paths or default paths
|
||||
searchPaths := paths
|
||||
if len(searchPaths) == 0 {
|
||||
searchPaths = l.configPaths
|
||||
}
|
||||
|
||||
// Check for config file in environment variable
|
||||
if envPath := os.Getenv(l.envPrefix + "CONFIG_FILE"); envPath != "" {
|
||||
searchPaths = append([]string{envPath}, searchPaths...)
|
||||
}
|
||||
|
||||
// Try each path
|
||||
for _, path := range searchPaths {
|
||||
if _, err := os.Stat(path); err == nil {
|
||||
return l.loadFile(path)
|
||||
}
|
||||
}
|
||||
|
||||
// No config file found, not an error (use defaults)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// loadFile loads a specific configuration file
|
||||
func (l *ConfigLoader) loadFile(path string) (*UnifiedConfig, error) {
|
||||
// Clean and validate path to prevent traversal attacks
|
||||
cleanPath := filepath.Clean(path)
|
||||
|
||||
// Check for path traversal attempts
|
||||
if strings.Contains(cleanPath, "..") {
|
||||
return nil, fmt.Errorf("invalid config path: potential path traversal detected in %s", path)
|
||||
}
|
||||
|
||||
// Ensure the path is within expected directories (current dir or subdirs)
|
||||
absPath, err := filepath.Abs(cleanPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to resolve absolute path for %s: %w", path, err)
|
||||
}
|
||||
|
||||
// Read the file with validated path
|
||||
data, err := os.ReadFile(absPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read config file %s: %w", absPath, err)
|
||||
}
|
||||
|
||||
// Check if unified config is enabled
|
||||
if features.IsUnifiedConfigEnabled() {
|
||||
// Use migrator to handle any version
|
||||
config, warnings, err := l.migrator.Migrate(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to migrate config from %s: %w", path, err)
|
||||
}
|
||||
|
||||
// Log warnings
|
||||
for _, warning := range warnings {
|
||||
// In production, use proper logging
|
||||
fmt.Printf("Config Warning (%s): %s\n", path, warning)
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// Legacy path: load old config and convert
|
||||
ext := strings.ToLower(filepath.Ext(path))
|
||||
var oldConfig Config
|
||||
|
||||
switch ext {
|
||||
case ".json":
|
||||
if err := json.Unmarshal(data, &oldConfig); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse JSON config: %w", err)
|
||||
}
|
||||
case ".yaml", ".yml":
|
||||
if err := yaml.Unmarshal(data, &oldConfig); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse YAML config: %w", err)
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported config file extension: %s", ext)
|
||||
}
|
||||
|
||||
return FromOldConfig(&oldConfig), nil
|
||||
}
|
||||
|
||||
// LoadFromEnv loads configuration from environment variables
|
||||
func (l *ConfigLoader) LoadFromEnv(config *UnifiedConfig) {
|
||||
// Provider configuration
|
||||
l.loadEnvString(&config.Provider.IssuerURL, "PROVIDER_ISSUER_URL", "PROVIDER_URL")
|
||||
l.loadEnvString(&config.Provider.ClientID, "PROVIDER_CLIENT_ID", "CLIENT_ID")
|
||||
l.loadEnvString(&config.Provider.ClientSecret, "PROVIDER_CLIENT_SECRET", "CLIENT_SECRET")
|
||||
l.loadEnvString(&config.Provider.RedirectURL, "PROVIDER_REDIRECT_URL", "CALLBACK_URL")
|
||||
l.loadEnvString(&config.Provider.LogoutURL, "PROVIDER_LOGOUT_URL", "LOGOUT_URL")
|
||||
l.loadEnvString(&config.Provider.PostLogoutRedirectURI, "PROVIDER_POST_LOGOUT_URI", "POST_LOGOUT_REDIRECT_URI")
|
||||
l.loadEnvStringSlice(&config.Provider.Scopes, "PROVIDER_SCOPES", "SCOPES")
|
||||
l.loadEnvBool(&config.Provider.OverrideScopes, "PROVIDER_OVERRIDE_SCOPES", "OVERRIDE_SCOPES")
|
||||
|
||||
// Session configuration
|
||||
l.loadEnvString(&config.Session.Name, "SESSION_NAME")
|
||||
l.loadEnvInt(&config.Session.MaxAge, "SESSION_MAX_AGE")
|
||||
l.loadEnvString(&config.Session.Secret, "SESSION_SECRET")
|
||||
l.loadEnvString(&config.Session.EncryptionKey, "SESSION_ENCRYPTION_KEY")
|
||||
l.loadEnvString(&config.Session.Domain, "SESSION_DOMAIN", "COOKIE_DOMAIN")
|
||||
l.loadEnvBool(&config.Session.Secure, "SESSION_SECURE")
|
||||
l.loadEnvBool(&config.Session.HttpOnly, "SESSION_HTTP_ONLY")
|
||||
l.loadEnvString(&config.Session.SameSite, "SESSION_SAME_SITE")
|
||||
|
||||
// Security configuration
|
||||
l.loadEnvBool(&config.Security.ForceHTTPS, "SECURITY_FORCE_HTTPS", "FORCE_HTTPS")
|
||||
l.loadEnvBool(&config.Security.EnablePKCE, "SECURITY_ENABLE_PKCE", "ENABLE_PKCE")
|
||||
l.loadEnvStringSlice(&config.Security.AllowedUsers, "SECURITY_ALLOWED_USERS", "ALLOWED_USERS")
|
||||
l.loadEnvStringSlice(&config.Security.AllowedUserDomains, "SECURITY_ALLOWED_DOMAINS", "ALLOWED_USER_DOMAINS")
|
||||
l.loadEnvStringSlice(&config.Security.AllowedRolesAndGroups, "SECURITY_ALLOWED_ROLES", "ALLOWED_ROLES_AND_GROUPS")
|
||||
l.loadEnvStringSlice(&config.Security.ExcludedURLs, "SECURITY_EXCLUDED_URLS", "EXCLUDED_URLS")
|
||||
|
||||
// Cache configuration
|
||||
l.loadEnvBool(&config.Cache.Enabled, "CACHE_ENABLED")
|
||||
l.loadEnvString(&config.Cache.Type, "CACHE_TYPE")
|
||||
l.loadEnvInt(&config.Cache.MaxEntries, "CACHE_MAX_ENTRIES")
|
||||
// MaxEntrySize is int64, skip for now
|
||||
|
||||
// Rate limiting
|
||||
l.loadEnvBool(&config.RateLimit.Enabled, "RATELIMIT_ENABLED")
|
||||
l.loadEnvInt(&config.RateLimit.RequestsPerSecond, "RATELIMIT_RPS", "RATE_LIMIT")
|
||||
l.loadEnvInt(&config.RateLimit.Burst, "RATELIMIT_BURST")
|
||||
|
||||
// Logging
|
||||
l.loadEnvString(&config.Logging.Level, "LOGGING_LEVEL", "LOG_LEVEL")
|
||||
l.loadEnvString(&config.Logging.Format, "LOGGING_FORMAT")
|
||||
l.loadEnvString(&config.Logging.Output, "LOGGING_OUTPUT")
|
||||
|
||||
// Redis configuration (already handled by its own LoadFromEnv)
|
||||
config.Redis.LoadFromEnv()
|
||||
|
||||
// Feature flags
|
||||
features.GetManager().LoadFromEnv()
|
||||
}
|
||||
|
||||
// Helper methods for environment variable loading
|
||||
|
||||
func (l *ConfigLoader) loadEnvString(target *string, keys ...string) {
|
||||
for _, key := range keys {
|
||||
if value := os.Getenv(l.envPrefix + key); value != "" {
|
||||
*target = value
|
||||
return
|
||||
}
|
||||
// Try without prefix
|
||||
if value := os.Getenv(key); value != "" {
|
||||
*target = value
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *ConfigLoader) loadEnvBool(target *bool, keys ...string) {
|
||||
for _, key := range keys {
|
||||
if value := os.Getenv(l.envPrefix + key); value != "" {
|
||||
*target = strings.ToLower(value) == "true" || value == "1"
|
||||
return
|
||||
}
|
||||
// Try without prefix
|
||||
if value := os.Getenv(key); value != "" {
|
||||
*target = strings.ToLower(value) == "true" || value == "1"
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *ConfigLoader) loadEnvInt(target *int, keys ...string) {
|
||||
for _, key := range keys {
|
||||
if value := os.Getenv(l.envPrefix + key); value != "" {
|
||||
var i int
|
||||
if _, err := fmt.Sscanf(value, "%d", &i); err == nil {
|
||||
*target = i
|
||||
return
|
||||
}
|
||||
}
|
||||
// Try without prefix
|
||||
if value := os.Getenv(key); value != "" {
|
||||
var i int
|
||||
if _, err := fmt.Sscanf(value, "%d", &i); err == nil {
|
||||
*target = i
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *ConfigLoader) loadEnvStringSlice(target *[]string, keys ...string) {
|
||||
for _, key := range keys {
|
||||
if value := os.Getenv(l.envPrefix + key); value != "" {
|
||||
*target = splitAndTrim(value)
|
||||
return
|
||||
}
|
||||
// Try without prefix
|
||||
if value := os.Getenv(key); value != "" {
|
||||
*target = splitAndTrim(value)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func splitAndTrim(s string) []string {
|
||||
parts := strings.Split(s, ",")
|
||||
result := make([]string, 0, len(parts))
|
||||
for _, part := range parts {
|
||||
if trimmed := strings.TrimSpace(part); trimmed != "" {
|
||||
result = append(result, trimmed)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// mergeConfigs merges two configurations, with source overriding target
|
||||
func (l *ConfigLoader) mergeConfigs(target, source *UnifiedConfig) *UnifiedConfig {
|
||||
if source == nil {
|
||||
return target
|
||||
}
|
||||
if target == nil {
|
||||
return source
|
||||
}
|
||||
|
||||
// Use reflection for deep merge
|
||||
l.mergeStructs(reflect.ValueOf(target).Elem(), reflect.ValueOf(source).Elem())
|
||||
|
||||
return target
|
||||
}
|
||||
|
||||
// mergeStructs recursively merges two structs
|
||||
func (l *ConfigLoader) mergeStructs(target, source reflect.Value) {
|
||||
for i := 0; i < source.NumField(); i++ {
|
||||
sourceField := source.Field(i)
|
||||
targetField := target.Field(i)
|
||||
|
||||
// Skip if source field is zero value
|
||||
if isZeroValue(sourceField) {
|
||||
continue
|
||||
}
|
||||
|
||||
switch sourceField.Kind() {
|
||||
case reflect.Struct:
|
||||
// Recursively merge structs
|
||||
l.mergeStructs(targetField, sourceField)
|
||||
case reflect.Slice:
|
||||
// Replace slice if source has values
|
||||
if sourceField.Len() > 0 {
|
||||
targetField.Set(sourceField)
|
||||
}
|
||||
case reflect.Map:
|
||||
// Merge maps
|
||||
if !sourceField.IsNil() {
|
||||
if targetField.IsNil() {
|
||||
targetField.Set(reflect.MakeMap(sourceField.Type()))
|
||||
}
|
||||
for _, key := range sourceField.MapKeys() {
|
||||
targetField.SetMapIndex(key, sourceField.MapIndex(key))
|
||||
}
|
||||
}
|
||||
default:
|
||||
// Replace value
|
||||
targetField.Set(sourceField)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// isZeroValue checks if a reflect.Value is a zero value
|
||||
func isZeroValue(v reflect.Value) bool {
|
||||
switch v.Kind() {
|
||||
case reflect.Ptr, reflect.Interface:
|
||||
return v.IsNil()
|
||||
case reflect.Slice, reflect.Map:
|
||||
return v.IsNil() || v.Len() == 0
|
||||
case reflect.Struct:
|
||||
// Check if all fields are zero
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
if !isZeroValue(v.Field(i)) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
default:
|
||||
zero := reflect.Zero(v.Type())
|
||||
return reflect.DeepEqual(v.Interface(), zero.Interface())
|
||||
}
|
||||
}
|
||||
|
||||
// SaveToFile saves the configuration to a file
|
||||
func (l *ConfigLoader) SaveToFile(config *UnifiedConfig, path string) error {
|
||||
// Clean and validate path to prevent traversal attacks
|
||||
cleanPath := filepath.Clean(path)
|
||||
|
||||
// Check for path traversal attempts
|
||||
if strings.Contains(cleanPath, "..") {
|
||||
return fmt.Errorf("invalid config path: potential path traversal detected in %s", path)
|
||||
}
|
||||
|
||||
// Ensure the path is within expected directories
|
||||
absPath, err := filepath.Abs(cleanPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to resolve absolute path for %s: %w", path, err)
|
||||
}
|
||||
|
||||
ext := strings.ToLower(filepath.Ext(absPath))
|
||||
|
||||
var data []byte
|
||||
|
||||
switch ext {
|
||||
case ".json":
|
||||
data, err = json.MarshalIndent(config, "", " ")
|
||||
case ".yaml", ".yml":
|
||||
data, err = yaml.Marshal(config)
|
||||
default:
|
||||
return fmt.Errorf("unsupported file extension: %s", ext)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal config: %w", err)
|
||||
}
|
||||
|
||||
// Create directory if it doesn't exist with secure permissions
|
||||
dir := filepath.Dir(absPath)
|
||||
if err := os.MkdirAll(dir, 0700); err != nil {
|
||||
return fmt.Errorf("failed to create directory %s: %w", dir, err)
|
||||
}
|
||||
|
||||
// Write file with secure permissions (owner read/write only)
|
||||
if err := os.WriteFile(absPath, data, 0600); err != nil {
|
||||
return fmt.Errorf("failed to write config file %s: %w", absPath, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,832 @@
|
||||
//go:build !yaegi
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestConfigLoader tests the config loader functionality
|
||||
func TestConfigLoader(t *testing.T) {
|
||||
loader := NewConfigLoader()
|
||||
|
||||
if loader == nil {
|
||||
t.Fatal("NewConfigLoader should not return nil")
|
||||
}
|
||||
|
||||
if loader.migrator == nil {
|
||||
t.Error("ConfigLoader should have a migrator")
|
||||
}
|
||||
|
||||
if loader.envPrefix != "TRAEFIKOIDC_" {
|
||||
t.Errorf("Expected envPrefix to be 'TRAEFIKOIDC_', got %s", loader.envPrefix)
|
||||
}
|
||||
|
||||
if len(loader.configPaths) == 0 {
|
||||
t.Error("ConfigLoader should have default config paths")
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoadFromEnv tests loading configuration from environment variables
|
||||
func TestLoadFromEnv(t *testing.T) {
|
||||
// Set up test environment variables
|
||||
testEnvVars := map[string]string{
|
||||
"TRAEFIKOIDC_PROVIDER_ISSUER_URL": "https://test.example.com",
|
||||
"TRAEFIKOIDC_PROVIDER_CLIENT_ID": "test-client-id",
|
||||
"TRAEFIKOIDC_PROVIDER_CLIENT_SECRET": "test-secret",
|
||||
"TRAEFIKOIDC_SESSION_ENCRYPTION_KEY": "32-character-encryption-key-12345",
|
||||
"TRAEFIKOIDC_SESSION_CHUNKED": "true",
|
||||
"TRAEFIKOIDC_REDIS_ENABLED": "true",
|
||||
"TRAEFIKOIDC_REDIS_ADDR": "redis.example.com:6379",
|
||||
"TRAEFIKOIDC_SECURITY_FORCE_HTTPS": "true",
|
||||
"TRAEFIKOIDC_CACHE_ENABLED": "true",
|
||||
"TRAEFIKOIDC_CACHE_TYPE": "redis",
|
||||
"TRAEFIKOIDC_RATELIMIT_ENABLED": "true",
|
||||
"TRAEFIKOIDC_RATELIMIT_RPS": "100",
|
||||
}
|
||||
|
||||
// Set environment variables
|
||||
for key, value := range testEnvVars {
|
||||
os.Setenv(key, value)
|
||||
defer os.Unsetenv(key)
|
||||
}
|
||||
|
||||
loader := NewConfigLoader()
|
||||
config := &UnifiedConfig{}
|
||||
loader.LoadFromEnv(config)
|
||||
|
||||
// Verify values were loaded
|
||||
if config.Provider.IssuerURL != "https://test.example.com" {
|
||||
t.Errorf("Expected IssuerURL to be 'https://test.example.com', got %s", config.Provider.IssuerURL)
|
||||
}
|
||||
if config.Provider.ClientID != "test-client-id" {
|
||||
t.Errorf("Expected ClientID to be 'test-client-id', got %s", config.Provider.ClientID)
|
||||
}
|
||||
if config.Provider.ClientSecret != "test-secret" {
|
||||
t.Errorf("Expected ClientSecret to be 'test-secret', got %s", config.Provider.ClientSecret)
|
||||
}
|
||||
if config.Session.EncryptionKey != "32-character-encryption-key-12345" {
|
||||
t.Errorf("Expected EncryptionKey to be set, got %s", config.Session.EncryptionKey)
|
||||
}
|
||||
if !config.Security.ForceHTTPS {
|
||||
t.Error("Expected ForceHTTPS to be true")
|
||||
}
|
||||
if !config.Cache.Enabled {
|
||||
t.Error("Expected Cache to be enabled")
|
||||
}
|
||||
if config.Cache.Type != "redis" {
|
||||
t.Errorf("Expected Cache.Type to be 'redis', got %s", config.Cache.Type)
|
||||
}
|
||||
if !config.RateLimit.Enabled {
|
||||
t.Error("Expected RateLimit to be enabled")
|
||||
}
|
||||
if config.RateLimit.RequestsPerSecond != 100 {
|
||||
t.Errorf("Expected RequestsPerSecond to be 100, got %d", config.RateLimit.RequestsPerSecond)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSaveToFile tests saving configuration to files
|
||||
func TestSaveToFile(t *testing.T) {
|
||||
// Create a temporary directory for test files
|
||||
tmpDir, err := os.MkdirTemp("", "config-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
loader := NewConfigLoader()
|
||||
config := &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://auth.example.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "secret",
|
||||
},
|
||||
Session: SessionConfig{
|
||||
EncryptionKey: "32-character-encryption-key-12345",
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
filename string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "save as JSON",
|
||||
filename: "config.json",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "save as YAML",
|
||||
filename: "config.yaml",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "save as YML",
|
||||
filename: "config.yml",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "unsupported extension",
|
||||
filename: "config.txt",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "path traversal attempt",
|
||||
filename: "../../../etc/config.json",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
filePath := filepath.Join(tmpDir, tt.filename)
|
||||
err := loader.SaveToFile(config, filePath)
|
||||
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Error("Expected error but got none")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify file was created with correct permissions
|
||||
info, err := os.Stat(filePath)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to stat saved file: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Check file permissions (should be 0600)
|
||||
mode := info.Mode().Perm()
|
||||
if mode != 0600 {
|
||||
t.Errorf("Expected file permissions 0600, got %o", mode)
|
||||
}
|
||||
|
||||
// Verify content can be read back
|
||||
data, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to read saved file: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify secrets are redacted
|
||||
content := string(data)
|
||||
if strings.Contains(content, "secret") && !strings.Contains(content, "[REDACTED]") {
|
||||
t.Error("Secrets should be redacted in saved file")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoadFile tests loading configuration from files
|
||||
func TestLoadFile(t *testing.T) {
|
||||
// Create a temporary directory for test files
|
||||
tmpDir, err := os.MkdirTemp("", "config-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
// Test data - using old config format since unified config is not enabled by default
|
||||
jsonConfig := `{
|
||||
"providerURL": "https://auth.example.com",
|
||||
"clientID": "test-client",
|
||||
"clientSecret": "secret",
|
||||
"sessionEncryptionKey": "32-character-encryption-key-12345"
|
||||
}`
|
||||
|
||||
yamlConfig := `
|
||||
providerurl: https://auth.example.com
|
||||
clientid: test-client
|
||||
clientsecret: secret
|
||||
sessionencryptionkey: 32-character-encryption-key-12345
|
||||
`
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
filename string
|
||||
content string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "load JSON config",
|
||||
filename: "config.json",
|
||||
content: jsonConfig,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "load YAML config",
|
||||
filename: "config.yaml",
|
||||
content: yamlConfig,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "path traversal attempt",
|
||||
filename: "../../../etc/passwd",
|
||||
content: "",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "non-existent file",
|
||||
filename: "does-not-exist.json",
|
||||
content: "",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
loader := NewConfigLoader()
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var filePath string
|
||||
if tt.content != "" {
|
||||
filePath = filepath.Join(tmpDir, tt.filename)
|
||||
err := os.WriteFile(filePath, []byte(tt.content), 0600)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write test file: %v", err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
filePath = tt.filename
|
||||
}
|
||||
|
||||
config, err := loader.loadFile(filePath)
|
||||
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Error("Expected error but got none")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) && !strings.Contains(err.Error(), "no such file") {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Verify loaded config
|
||||
if config == nil {
|
||||
t.Error("Expected config to be loaded")
|
||||
return
|
||||
}
|
||||
|
||||
if config.Provider.IssuerURL != "https://auth.example.com" {
|
||||
t.Errorf("Expected IssuerURL to be 'https://auth.example.com', got %s", config.Provider.IssuerURL)
|
||||
}
|
||||
if config.Provider.ClientID != "test-client" {
|
||||
t.Errorf("Expected ClientID to be 'test-client', got %s", config.Provider.ClientID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ====================================================================================
|
||||
// Tests for untested functions (0% coverage)
|
||||
// ====================================================================================
|
||||
|
||||
// TestConfigLoader_Load tests the full Load pipeline
|
||||
func TestConfigLoader_Load(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "config-load-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
// Create a test config file
|
||||
configPath := filepath.Join(tmpDir, "traefik-oidc.json")
|
||||
configData := `{
|
||||
"providerURL": "https://auth.example.com",
|
||||
"clientID": "test-client",
|
||||
"clientSecret": "test-secret",
|
||||
"sessionEncryptionKey": "32-character-encryption-key-12345"
|
||||
}`
|
||||
err = os.WriteFile(configPath, []byte(configData), 0600)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write test config file: %v", err)
|
||||
}
|
||||
|
||||
// Change to temp directory so loader can find the config
|
||||
oldDir, _ := os.Getwd()
|
||||
os.Chdir(tmpDir)
|
||||
defer os.Chdir(oldDir)
|
||||
|
||||
// Set some environment variables to test merging
|
||||
os.Setenv("TRAEFIKOIDC_SECURITY_FORCE_HTTPS", "true")
|
||||
defer os.Unsetenv("TRAEFIKOIDC_SECURITY_FORCE_HTTPS")
|
||||
|
||||
loader := NewConfigLoader()
|
||||
config, err := loader.Load()
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Load() failed: %v", err)
|
||||
}
|
||||
|
||||
if config == nil {
|
||||
t.Fatal("Load() returned nil config")
|
||||
}
|
||||
|
||||
// Verify file was loaded
|
||||
if config.Provider.IssuerURL != "https://auth.example.com" {
|
||||
t.Errorf("Expected IssuerURL from file, got %s", config.Provider.IssuerURL)
|
||||
}
|
||||
|
||||
// Verify env vars were loaded
|
||||
if !config.Security.ForceHTTPS {
|
||||
t.Error("Expected ForceHTTPS from env var to be true")
|
||||
}
|
||||
}
|
||||
|
||||
// TestConfigLoader_LoadFromFile tests the LoadFromFile function
|
||||
func TestConfigLoader_LoadFromFile(t *testing.T) {
|
||||
t.Run("NoConfigFile", func(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "config-nofile-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
oldDir, _ := os.Getwd()
|
||||
os.Chdir(tmpDir)
|
||||
defer os.Chdir(oldDir)
|
||||
|
||||
loader := NewConfigLoader()
|
||||
config, err := loader.LoadFromFile()
|
||||
|
||||
// Should not error when no config file found
|
||||
if err != nil {
|
||||
t.Errorf("LoadFromFile() should not error when no file found: %v", err)
|
||||
}
|
||||
|
||||
// Should return nil config
|
||||
if config != nil {
|
||||
t.Error("LoadFromFile() should return nil config when no file found")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("LoadFromEnvPath", func(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "config-envpath-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
// Create config file
|
||||
configPath := filepath.Join(tmpDir, "custom-config.json")
|
||||
configData := `{
|
||||
"providerURL": "https://custom.example.com",
|
||||
"clientID": "custom-client"
|
||||
}`
|
||||
err = os.WriteFile(configPath, []byte(configData), 0600)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write test config: %v", err)
|
||||
}
|
||||
|
||||
// Set env variable pointing to config
|
||||
os.Setenv("TRAEFIKOIDC_CONFIG_FILE", configPath)
|
||||
defer os.Unsetenv("TRAEFIKOIDC_CONFIG_FILE")
|
||||
|
||||
loader := NewConfigLoader()
|
||||
config, err := loader.LoadFromFile()
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("LoadFromFile() failed: %v", err)
|
||||
}
|
||||
|
||||
if config == nil {
|
||||
t.Fatal("LoadFromFile() returned nil config")
|
||||
}
|
||||
|
||||
if config.Provider.IssuerURL != "https://custom.example.com" {
|
||||
t.Errorf("Expected IssuerURL 'https://custom.example.com', got %s", config.Provider.IssuerURL)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("LoadWithProvidedPaths", func(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "config-provided-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
// Create config file
|
||||
configPath := filepath.Join(tmpDir, "specific.json")
|
||||
configData := `{
|
||||
"providerURL": "https://specific.example.com",
|
||||
"clientID": "specific-client"
|
||||
}`
|
||||
err = os.WriteFile(configPath, []byte(configData), 0600)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write test config: %v", err)
|
||||
}
|
||||
|
||||
loader := NewConfigLoader()
|
||||
config, err := loader.LoadFromFile(configPath)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("LoadFromFile() with path failed: %v", err)
|
||||
}
|
||||
|
||||
if config == nil {
|
||||
t.Fatal("LoadFromFile() returned nil config")
|
||||
}
|
||||
|
||||
if config.Provider.IssuerURL != "https://specific.example.com" {
|
||||
t.Errorf("Expected IssuerURL 'https://specific.example.com', got %s", config.Provider.IssuerURL)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestSplitAndTrim tests the splitAndTrim helper function
|
||||
func TestSplitAndTrim(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
name: "Simple comma-separated",
|
||||
input: "a,b,c",
|
||||
expected: []string{"a", "b", "c"},
|
||||
},
|
||||
{
|
||||
name: "With spaces",
|
||||
input: "a, b , c",
|
||||
expected: []string{"a", "b", "c"},
|
||||
},
|
||||
{
|
||||
name: "Empty strings filtered out",
|
||||
input: "a,,b, ,c",
|
||||
expected: []string{"a", "b", "c"},
|
||||
},
|
||||
{
|
||||
name: "Leading and trailing spaces",
|
||||
input: " a , b , c ",
|
||||
expected: []string{"a", "b", "c"},
|
||||
},
|
||||
{
|
||||
name: "Single value",
|
||||
input: "single",
|
||||
expected: []string{"single"},
|
||||
},
|
||||
{
|
||||
name: "Empty string",
|
||||
input: "",
|
||||
expected: []string{},
|
||||
},
|
||||
{
|
||||
name: "Only commas and spaces",
|
||||
input: " , , , ",
|
||||
expected: []string{},
|
||||
},
|
||||
{
|
||||
name: "Complex real-world example",
|
||||
input: "openid, profile, email, groups",
|
||||
expected: []string{"openid", "profile", "email", "groups"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := splitAndTrim(tt.input)
|
||||
|
||||
if len(result) != len(tt.expected) {
|
||||
t.Errorf("Expected %d items, got %d: %v", len(tt.expected), len(result), result)
|
||||
return
|
||||
}
|
||||
|
||||
for i, expected := range tt.expected {
|
||||
if result[i] != expected {
|
||||
t.Errorf("At index %d: expected %q, got %q", i, expected, result[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestConfigLoader_MergeConfigs tests the mergeConfigs function
|
||||
func TestConfigLoader_MergeConfigs(t *testing.T) {
|
||||
loader := NewConfigLoader()
|
||||
|
||||
t.Run("MergeNilSource", func(t *testing.T) {
|
||||
target := &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://target.example.com",
|
||||
},
|
||||
}
|
||||
|
||||
result := loader.mergeConfigs(target, nil)
|
||||
|
||||
if result != target {
|
||||
t.Error("mergeConfigs should return target when source is nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("MergeNilTarget", func(t *testing.T) {
|
||||
source := &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://source.example.com",
|
||||
},
|
||||
}
|
||||
|
||||
result := loader.mergeConfigs(nil, source)
|
||||
|
||||
if result != source {
|
||||
t.Error("mergeConfigs should return source when target is nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("MergeSimpleFields", func(t *testing.T) {
|
||||
target := &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://target.example.com",
|
||||
ClientID: "",
|
||||
},
|
||||
}
|
||||
|
||||
source := &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://source.example.com",
|
||||
ClientID: "source-client",
|
||||
},
|
||||
}
|
||||
|
||||
result := loader.mergeConfigs(target, source)
|
||||
|
||||
if result.Provider.IssuerURL != "https://source.example.com" {
|
||||
t.Errorf("Expected IssuerURL to be overridden, got %s", result.Provider.IssuerURL)
|
||||
}
|
||||
|
||||
if result.Provider.ClientID != "source-client" {
|
||||
t.Errorf("Expected ClientID to be set, got %s", result.Provider.ClientID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("MergeSlices", func(t *testing.T) {
|
||||
target := &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
Scopes: []string{"openid", "profile"},
|
||||
},
|
||||
}
|
||||
|
||||
source := &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
Scopes: []string{"email", "groups"},
|
||||
},
|
||||
}
|
||||
|
||||
result := loader.mergeConfigs(target, source)
|
||||
|
||||
// Source slice should replace target slice
|
||||
if len(result.Provider.Scopes) != 2 {
|
||||
t.Errorf("Expected 2 scopes, got %d", len(result.Provider.Scopes))
|
||||
}
|
||||
|
||||
if result.Provider.Scopes[0] != "email" {
|
||||
t.Errorf("Expected first scope 'email', got %s", result.Provider.Scopes[0])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("MergeMaps", func(t *testing.T) {
|
||||
target := &UnifiedConfig{
|
||||
Middleware: MiddlewareConfig{
|
||||
CustomHeaders: map[string]string{
|
||||
"X-Target-Header": "target-value",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
source := &UnifiedConfig{
|
||||
Middleware: MiddlewareConfig{
|
||||
CustomHeaders: map[string]string{
|
||||
"X-Source-Header": "source-value",
|
||||
"X-Target-Header": "overridden-value",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := loader.mergeConfigs(target, source)
|
||||
|
||||
if len(result.Middleware.CustomHeaders) != 2 {
|
||||
t.Errorf("Expected 2 headers, got %d", len(result.Middleware.CustomHeaders))
|
||||
}
|
||||
|
||||
if result.Middleware.CustomHeaders["X-Target-Header"] != "overridden-value" {
|
||||
t.Errorf("Expected X-Target-Header to be overridden")
|
||||
}
|
||||
|
||||
if result.Middleware.CustomHeaders["X-Source-Header"] != "source-value" {
|
||||
t.Errorf("Expected X-Source-Header to be added")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestConfigLoader_MergeStructs tests the mergeStructs function indirectly
|
||||
func TestConfigLoader_MergeStructs(t *testing.T) {
|
||||
loader := NewConfigLoader()
|
||||
|
||||
t.Run("NestedStructMerge", func(t *testing.T) {
|
||||
target := &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://target.example.com",
|
||||
ClientID: "target-client",
|
||||
},
|
||||
Session: SessionConfig{
|
||||
Name: "target-session",
|
||||
MaxAge: 3600,
|
||||
},
|
||||
}
|
||||
|
||||
source := &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
ClientID: "source-client",
|
||||
ClientSecret: "source-secret",
|
||||
},
|
||||
Session: SessionConfig{
|
||||
MaxAge: 7200,
|
||||
},
|
||||
}
|
||||
|
||||
result := loader.mergeConfigs(target, source)
|
||||
|
||||
// Provider.IssuerURL should remain (zero value in source)
|
||||
if result.Provider.IssuerURL != "https://target.example.com" {
|
||||
t.Errorf("Expected IssuerURL to remain, got %s", result.Provider.IssuerURL)
|
||||
}
|
||||
|
||||
// Provider.ClientID should be overridden
|
||||
if result.Provider.ClientID != "source-client" {
|
||||
t.Errorf("Expected ClientID to be overridden, got %s", result.Provider.ClientID)
|
||||
}
|
||||
|
||||
// Provider.ClientSecret should be added
|
||||
if result.Provider.ClientSecret != "source-secret" {
|
||||
t.Errorf("Expected ClientSecret to be added, got %s", result.Provider.ClientSecret)
|
||||
}
|
||||
|
||||
// Session.Name should remain (zero value in source)
|
||||
if result.Session.Name != "target-session" {
|
||||
t.Errorf("Expected Session.Name to remain, got %s", result.Session.Name)
|
||||
}
|
||||
|
||||
// Session.MaxAge should be overridden
|
||||
if result.Session.MaxAge != 7200 {
|
||||
t.Errorf("Expected Session.MaxAge to be overridden, got %d", result.Session.MaxAge)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestIsZeroValue tests the isZeroValue helper function
|
||||
func TestIsZeroValue(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
value interface{}
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "Zero string",
|
||||
value: "",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Non-zero string",
|
||||
value: "hello",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Zero int",
|
||||
value: 0,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Non-zero int",
|
||||
value: 42,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Zero bool",
|
||||
value: false,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Non-zero bool",
|
||||
value: true,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Nil pointer",
|
||||
value: (*string)(nil),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Non-nil pointer",
|
||||
value: stringPtr("test"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Nil slice",
|
||||
value: ([]string)(nil),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Empty slice",
|
||||
value: []string{},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Non-empty slice",
|
||||
value: []string{"a"},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Nil map",
|
||||
value: (map[string]string)(nil),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Empty map",
|
||||
value: map[string]string{},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Non-empty map",
|
||||
value: map[string]string{"key": "value"},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
v := reflect.ValueOf(tt.value)
|
||||
result := isZeroValue(v)
|
||||
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected isZeroValue to be %v, got %v", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsZeroValue_Struct tests isZeroValue with struct types
|
||||
func TestIsZeroValue_Struct(t *testing.T) {
|
||||
type TestStruct struct {
|
||||
Field1 string
|
||||
Field2 int
|
||||
}
|
||||
|
||||
t.Run("Zero struct", func(t *testing.T) {
|
||||
s := TestStruct{}
|
||||
v := reflect.ValueOf(s)
|
||||
result := isZeroValue(v)
|
||||
|
||||
if !result {
|
||||
t.Error("Expected zero struct to return true")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Non-zero struct - Field1 set", func(t *testing.T) {
|
||||
s := TestStruct{Field1: "test"}
|
||||
v := reflect.ValueOf(s)
|
||||
result := isZeroValue(v)
|
||||
|
||||
if result {
|
||||
t.Error("Expected non-zero struct to return false")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Non-zero struct - Field2 set", func(t *testing.T) {
|
||||
s := TestStruct{Field2: 42}
|
||||
v := reflect.ValueOf(s)
|
||||
result := isZeroValue(v)
|
||||
|
||||
if result {
|
||||
t.Error("Expected non-zero struct to return false")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Non-zero struct - Both fields set", func(t *testing.T) {
|
||||
s := TestStruct{Field1: "test", Field2: 42}
|
||||
v := reflect.ValueOf(s)
|
||||
result := isZeroValue(v)
|
||||
|
||||
if result {
|
||||
t.Error("Expected non-zero struct to return false")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Helper function for pointer tests
|
||||
func stringPtr(s string) *string {
|
||||
return &s
|
||||
}
|
||||
@@ -0,0 +1,169 @@
|
||||
// Package config provides unified configuration management for the OIDC middleware
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
// REDACTED is the placeholder value for sensitive information
|
||||
const REDACTED = "[REDACTED]"
|
||||
|
||||
// MarshalJSON implements custom JSON marshalling to redact sensitive fields
|
||||
func (c UnifiedConfig) MarshalJSON() ([]byte, error) {
|
||||
// Create an alias to avoid recursion
|
||||
type Alias UnifiedConfig
|
||||
|
||||
// Create a copy with redacted sensitive fields
|
||||
copy := (Alias)(c)
|
||||
|
||||
// Redact provider secrets
|
||||
if copy.Provider.ClientSecret != "" {
|
||||
copy.Provider.ClientSecret = REDACTED
|
||||
}
|
||||
|
||||
// Redact session secrets
|
||||
if copy.Session.Secret != "" {
|
||||
copy.Session.Secret = REDACTED
|
||||
}
|
||||
if copy.Session.EncryptionKey != "" {
|
||||
copy.Session.EncryptionKey = REDACTED
|
||||
}
|
||||
if copy.Session.SigningKey != "" {
|
||||
copy.Session.SigningKey = REDACTED
|
||||
}
|
||||
|
||||
// Redact Redis passwords
|
||||
if copy.Redis.Password != "" {
|
||||
copy.Redis.Password = REDACTED
|
||||
}
|
||||
if copy.Redis.SentinelPassword != "" {
|
||||
copy.Redis.SentinelPassword = REDACTED
|
||||
}
|
||||
|
||||
return json.Marshal(copy)
|
||||
}
|
||||
|
||||
// MarshalJSON for ProviderConfig to redact sensitive fields
|
||||
func (p ProviderConfig) MarshalJSON() ([]byte, error) {
|
||||
type Alias ProviderConfig
|
||||
copy := (Alias)(p)
|
||||
|
||||
if copy.ClientSecret != "" {
|
||||
copy.ClientSecret = REDACTED
|
||||
}
|
||||
|
||||
return json.Marshal(copy)
|
||||
}
|
||||
|
||||
// MarshalJSON for SessionConfig to redact sensitive fields
|
||||
func (s SessionConfig) MarshalJSON() ([]byte, error) {
|
||||
type Alias SessionConfig
|
||||
copy := (Alias)(s)
|
||||
|
||||
if copy.Secret != "" {
|
||||
copy.Secret = REDACTED
|
||||
}
|
||||
if copy.EncryptionKey != "" {
|
||||
copy.EncryptionKey = REDACTED
|
||||
}
|
||||
if copy.SigningKey != "" {
|
||||
copy.SigningKey = REDACTED
|
||||
}
|
||||
|
||||
return json.Marshal(copy)
|
||||
}
|
||||
|
||||
// MarshalJSON for RedisConfig to redact sensitive fields
|
||||
func (r RedisConfig) MarshalJSON() ([]byte, error) {
|
||||
type Alias RedisConfig
|
||||
copy := (Alias)(r)
|
||||
|
||||
if copy.Password != "" {
|
||||
copy.Password = REDACTED
|
||||
}
|
||||
if copy.SentinelPassword != "" {
|
||||
copy.SentinelPassword = REDACTED
|
||||
}
|
||||
|
||||
return json.Marshal(copy)
|
||||
}
|
||||
|
||||
// MarshalYAML implements custom YAML marshalling to redact sensitive fields
|
||||
func (c UnifiedConfig) MarshalYAML() (interface{}, error) {
|
||||
// Create an alias to avoid recursion
|
||||
type Alias UnifiedConfig
|
||||
|
||||
// Create a copy with redacted sensitive fields
|
||||
copy := (Alias)(c)
|
||||
|
||||
// Redact provider secrets
|
||||
if copy.Provider.ClientSecret != "" {
|
||||
copy.Provider.ClientSecret = REDACTED
|
||||
}
|
||||
|
||||
// Redact session secrets
|
||||
if copy.Session.Secret != "" {
|
||||
copy.Session.Secret = REDACTED
|
||||
}
|
||||
if copy.Session.EncryptionKey != "" {
|
||||
copy.Session.EncryptionKey = REDACTED
|
||||
}
|
||||
if copy.Session.SigningKey != "" {
|
||||
copy.Session.SigningKey = REDACTED
|
||||
}
|
||||
|
||||
// Redact Redis passwords
|
||||
if copy.Redis.Password != "" {
|
||||
copy.Redis.Password = REDACTED
|
||||
}
|
||||
if copy.Redis.SentinelPassword != "" {
|
||||
copy.Redis.SentinelPassword = REDACTED
|
||||
}
|
||||
|
||||
return copy, nil
|
||||
}
|
||||
|
||||
// MarshalYAML for ProviderConfig to redact sensitive fields
|
||||
func (p ProviderConfig) MarshalYAML() (interface{}, error) {
|
||||
type Alias ProviderConfig
|
||||
copy := (Alias)(p)
|
||||
|
||||
if copy.ClientSecret != "" {
|
||||
copy.ClientSecret = REDACTED
|
||||
}
|
||||
|
||||
return copy, nil
|
||||
}
|
||||
|
||||
// MarshalYAML for SessionConfig to redact sensitive fields
|
||||
func (s SessionConfig) MarshalYAML() (interface{}, error) {
|
||||
type Alias SessionConfig
|
||||
copy := (Alias)(s)
|
||||
|
||||
if copy.Secret != "" {
|
||||
copy.Secret = REDACTED
|
||||
}
|
||||
if copy.EncryptionKey != "" {
|
||||
copy.EncryptionKey = REDACTED
|
||||
}
|
||||
if copy.SigningKey != "" {
|
||||
copy.SigningKey = REDACTED
|
||||
}
|
||||
|
||||
return copy, nil
|
||||
}
|
||||
|
||||
// MarshalYAML for RedisConfig to redact sensitive fields
|
||||
func (r RedisConfig) MarshalYAML() (interface{}, error) {
|
||||
type Alias RedisConfig
|
||||
copy := (Alias)(r)
|
||||
|
||||
if copy.Password != "" {
|
||||
copy.Password = REDACTED
|
||||
}
|
||||
if copy.SentinelPassword != "" {
|
||||
copy.SentinelPassword = REDACTED
|
||||
}
|
||||
|
||||
return copy, nil
|
||||
}
|
||||
@@ -0,0 +1,407 @@
|
||||
// Package config provides configuration migration from old to new format
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/compat"
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/features"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// ConfigVersion represents the version of a configuration format
|
||||
type ConfigVersion string
|
||||
|
||||
const (
|
||||
// VersionLegacy represents the original config format
|
||||
VersionLegacy ConfigVersion = "legacy"
|
||||
|
||||
// VersionUnified represents the new unified config format
|
||||
VersionUnified ConfigVersion = "unified"
|
||||
|
||||
// CurrentVersion is the current config version
|
||||
CurrentVersion ConfigVersion = VersionUnified
|
||||
)
|
||||
|
||||
// ConfigMigrator handles migration between config versions
|
||||
type ConfigMigrator struct {
|
||||
compatLayer *compat.CompatibilityLayer
|
||||
migrations map[ConfigVersion]MigrationFunc
|
||||
}
|
||||
|
||||
// MigrationFunc defines a function that migrates configuration
|
||||
type MigrationFunc func(data map[string]interface{}) (*UnifiedConfig, error)
|
||||
|
||||
// NewConfigMigrator creates a new configuration migrator
|
||||
func NewConfigMigrator() *ConfigMigrator {
|
||||
m := &ConfigMigrator{
|
||||
compatLayer: compat.GetLayer(),
|
||||
migrations: make(map[ConfigVersion]MigrationFunc),
|
||||
}
|
||||
|
||||
// Register migration functions
|
||||
m.migrations[VersionLegacy] = m.migrateLegacyToUnified
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// DetectVersion detects the version of a configuration
|
||||
func (m *ConfigMigrator) DetectVersion(data []byte) ConfigVersion {
|
||||
var testMap map[string]interface{}
|
||||
|
||||
// Try JSON first
|
||||
if err := json.Unmarshal(data, &testMap); err != nil {
|
||||
// Try YAML
|
||||
if err := yaml.Unmarshal(data, &testMap); err != nil {
|
||||
return VersionLegacy // Default to legacy if can't parse
|
||||
}
|
||||
}
|
||||
|
||||
// Check for unified config markers
|
||||
if _, hasProvider := testMap["provider"]; hasProvider {
|
||||
if _, hasSession := testMap["session"]; hasSession {
|
||||
return VersionUnified
|
||||
}
|
||||
}
|
||||
|
||||
// Check for legacy config markers
|
||||
if _, hasProviderURL := testMap["providerUrl"]; hasProviderURL {
|
||||
return VersionLegacy
|
||||
}
|
||||
if _, hasProviderURL := testMap["ProviderURL"]; hasProviderURL {
|
||||
return VersionLegacy
|
||||
}
|
||||
|
||||
return VersionLegacy
|
||||
}
|
||||
|
||||
// Migrate migrates configuration data to the current version
|
||||
func (m *ConfigMigrator) Migrate(data []byte) (*UnifiedConfig, []string, error) {
|
||||
warnings := []string{}
|
||||
|
||||
// Detect version
|
||||
version := m.DetectVersion(data)
|
||||
|
||||
// If already current version, just unmarshal
|
||||
if version == CurrentVersion {
|
||||
var config UnifiedConfig
|
||||
if err := json.Unmarshal(data, &config); err != nil {
|
||||
// Try YAML
|
||||
if err := yaml.Unmarshal(data, &config); err != nil {
|
||||
return nil, warnings, fmt.Errorf("failed to unmarshal unified config: %w", err)
|
||||
}
|
||||
}
|
||||
return &config, warnings, nil
|
||||
}
|
||||
|
||||
// Parse to generic map
|
||||
var configMap map[string]interface{}
|
||||
if err := json.Unmarshal(data, &configMap); err != nil {
|
||||
// Try YAML
|
||||
if err := yaml.Unmarshal(data, &configMap); err != nil {
|
||||
return nil, warnings, fmt.Errorf("failed to unmarshal config: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Apply migration
|
||||
migrationFunc, exists := m.migrations[version]
|
||||
if !exists {
|
||||
return nil, warnings, fmt.Errorf("no migration path from version %s", version)
|
||||
}
|
||||
|
||||
config, err := migrationFunc(configMap)
|
||||
if err != nil {
|
||||
return nil, warnings, fmt.Errorf("migration failed: %w", err)
|
||||
}
|
||||
|
||||
// Collect any deprecation warnings
|
||||
for key := range configMap {
|
||||
if warning, deprecated := m.compatLayer.CheckDeprecation(key); deprecated {
|
||||
warnings = append(warnings, warning)
|
||||
}
|
||||
}
|
||||
|
||||
return config, warnings, nil
|
||||
}
|
||||
|
||||
// migrateLegacyToUnified migrates legacy config to unified format
|
||||
func (m *ConfigMigrator) migrateLegacyToUnified(data map[string]interface{}) (*UnifiedConfig, error) {
|
||||
config := NewUnifiedConfig()
|
||||
|
||||
// Use compatibility layer for field mapping
|
||||
migratedMap, warnings := m.compatLayer.MigrateMap(data)
|
||||
|
||||
// Log warnings
|
||||
for _, warning := range warnings {
|
||||
// In production, these would be logged
|
||||
_ = warning
|
||||
}
|
||||
|
||||
// Map provider configuration
|
||||
if provider, ok := getNestedMap(migratedMap, "Provider"); ok {
|
||||
_ = mapToStruct(provider, &config.Provider)
|
||||
} else {
|
||||
// Direct field mapping for legacy format
|
||||
config.Provider.IssuerURL = getStringValue(data, "providerUrl", "ProviderURL")
|
||||
config.Provider.ClientID = getStringValue(data, "clientId", "ClientID")
|
||||
config.Provider.ClientSecret = getStringValue(data, "clientSecret", "ClientSecret")
|
||||
config.Provider.RedirectURL = getStringValue(data, "callbackUrl", "CallbackURL")
|
||||
config.Provider.LogoutURL = getStringValue(data, "logoutUrl", "LogoutURL")
|
||||
config.Provider.PostLogoutRedirectURI = getStringValue(data, "postLogoutRedirectUri", "PostLogoutRedirectURI")
|
||||
|
||||
if scopes := getArrayValue(data, "scopes", "Scopes"); scopes != nil {
|
||||
config.Provider.Scopes = scopes
|
||||
}
|
||||
config.Provider.OverrideScopes = getBoolValue(data, "overrideScopes", "OverrideScopes")
|
||||
}
|
||||
|
||||
// Map session configuration
|
||||
if session, ok := getNestedMap(migratedMap, "Session"); ok {
|
||||
_ = mapToStruct(session, &config.Session)
|
||||
} else {
|
||||
config.Session.EncryptionKey = getStringValue(data, "sessionEncryptionKey", "SessionEncryptionKey")
|
||||
config.Session.Domain = getStringValue(data, "cookieDomain", "CookieDomain")
|
||||
}
|
||||
|
||||
// Map security configuration
|
||||
if security, ok := getNestedMap(migratedMap, "Security"); ok {
|
||||
_ = mapToStruct(security, &config.Security)
|
||||
} else {
|
||||
config.Security.ForceHTTPS = getBoolValue(data, "forceHttps", "ForceHTTPS")
|
||||
config.Security.EnablePKCE = getBoolValue(data, "enablePkce", "EnablePKCE")
|
||||
|
||||
if users := getArrayValue(data, "allowedUsers", "AllowedUsers"); users != nil {
|
||||
config.Security.AllowedUsers = users
|
||||
}
|
||||
if domains := getArrayValue(data, "allowedUserDomains", "AllowedUserDomains"); domains != nil {
|
||||
config.Security.AllowedUserDomains = domains
|
||||
}
|
||||
if roles := getArrayValue(data, "allowedRolesAndGroups", "AllowedRolesAndGroups"); roles != nil {
|
||||
config.Security.AllowedRolesAndGroups = roles
|
||||
}
|
||||
if excluded := getArrayValue(data, "excludedUrls", "ExcludedURLs"); excluded != nil {
|
||||
config.Security.ExcludedURLs = excluded
|
||||
}
|
||||
|
||||
// Handle security headers
|
||||
if headers := data["securityHeaders"]; headers != nil {
|
||||
// Security headers might be in old format
|
||||
_ = mapToStruct(headers, &config.Security.Headers)
|
||||
}
|
||||
}
|
||||
|
||||
// Map rate limiting
|
||||
if rateLimit := getIntValue(data, "rateLimit", "RateLimit"); rateLimit > 0 {
|
||||
config.RateLimit.Enabled = true
|
||||
config.RateLimit.RequestsPerSecond = rateLimit
|
||||
config.RateLimit.Burst = rateLimit * 2 // Default burst to 2x rate
|
||||
}
|
||||
|
||||
// Map token configuration
|
||||
if refreshGrace := getIntValue(data, "refreshGracePeriodSeconds", "RefreshGracePeriodSeconds"); refreshGrace > 0 {
|
||||
config.Token.RefreshGracePeriod = time.Duration(refreshGrace) * time.Second
|
||||
}
|
||||
|
||||
// Map logging
|
||||
config.Logging.Level = strings.ToLower(getStringValue(data, "logLevel", "LogLevel"))
|
||||
if config.Logging.Level == "" {
|
||||
config.Logging.Level = "info"
|
||||
}
|
||||
|
||||
// Map custom headers
|
||||
if headers := data["headers"]; headers != nil {
|
||||
if headerList, ok := headers.([]interface{}); ok {
|
||||
config.Middleware.CustomHeaders = make(map[string]string)
|
||||
for _, h := range headerList {
|
||||
if headerMap, ok := h.(map[string]interface{}); ok {
|
||||
name := getStringFromInterface(headerMap["name"])
|
||||
value := getStringFromInterface(headerMap["value"])
|
||||
if name != "" {
|
||||
config.Middleware.CustomHeaders[name] = value
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Store original data for reference
|
||||
config.Legacy = data
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// MigrateFile migrates a configuration file
|
||||
func (m *ConfigMigrator) MigrateFile(filePath string) (*UnifiedConfig, error) {
|
||||
// Clean and validate path to prevent traversal attacks
|
||||
cleanPath := filepath.Clean(filePath)
|
||||
|
||||
// Check for path traversal attempts
|
||||
if strings.Contains(cleanPath, "..") {
|
||||
return nil, fmt.Errorf("invalid config path: potential path traversal detected in %s", filePath)
|
||||
}
|
||||
|
||||
// Ensure the path is within expected directories
|
||||
absPath, err := filepath.Abs(cleanPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to resolve absolute path for %s: %w", filePath, err)
|
||||
}
|
||||
|
||||
// Read the file with validated path
|
||||
data, err := os.ReadFile(absPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read config file: %w", err)
|
||||
}
|
||||
|
||||
config, warnings, err := m.Migrate(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Log warnings
|
||||
for _, warning := range warnings {
|
||||
fmt.Printf("Migration Warning: %s\n", warning)
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// AutoMigrate automatically migrates config based on feature flags
|
||||
func AutoMigrate(data interface{}) (*UnifiedConfig, error) {
|
||||
if !features.IsUnifiedConfigEnabled() {
|
||||
// Feature not enabled, return nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
migrator := NewConfigMigrator()
|
||||
|
||||
// Handle different input types
|
||||
switch v := data.(type) {
|
||||
case []byte:
|
||||
config, _, err := migrator.Migrate(v)
|
||||
return config, err
|
||||
case string:
|
||||
config, _, err := migrator.Migrate([]byte(v))
|
||||
return config, err
|
||||
case *Config:
|
||||
// Convert old config to unified
|
||||
return FromOldConfig(v), nil
|
||||
case *UnifiedConfig:
|
||||
// Already unified
|
||||
return v, nil
|
||||
case map[string]interface{}:
|
||||
// Convert map to JSON then migrate
|
||||
jsonData, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
config, _, err := migrator.Migrate(jsonData)
|
||||
return config, err
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported config type: %T", v)
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
func getNestedMap(m map[string]interface{}, key string) (map[string]interface{}, bool) {
|
||||
if val, exists := m[key]; exists {
|
||||
if mapped, ok := val.(map[string]interface{}); ok {
|
||||
return mapped, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func getStringValue(m map[string]interface{}, keys ...string) string {
|
||||
for _, key := range keys {
|
||||
if val, exists := m[key]; exists {
|
||||
return getStringFromInterface(val)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func getStringFromInterface(val interface{}) string {
|
||||
if val == nil {
|
||||
return ""
|
||||
}
|
||||
switch v := val.(type) {
|
||||
case string:
|
||||
return v
|
||||
case []byte:
|
||||
return string(v)
|
||||
default:
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
}
|
||||
|
||||
func getBoolValue(m map[string]interface{}, keys ...string) bool {
|
||||
for _, key := range keys {
|
||||
if val, exists := m[key]; exists {
|
||||
if b, ok := val.(bool); ok {
|
||||
return b
|
||||
}
|
||||
// Try string conversion
|
||||
if s, ok := val.(string); ok {
|
||||
return strings.ToLower(s) == "true"
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func getIntValue(m map[string]interface{}, keys ...string) int {
|
||||
for _, key := range keys {
|
||||
if val, exists := m[key]; exists {
|
||||
switch v := val.(type) {
|
||||
case int:
|
||||
return v
|
||||
case int64:
|
||||
return int(v)
|
||||
case float64:
|
||||
return int(v)
|
||||
case string:
|
||||
// Try to parse
|
||||
var i int
|
||||
if _, err := fmt.Sscanf(v, "%d", &i); err != nil {
|
||||
// If parsing fails, return default
|
||||
return 0
|
||||
}
|
||||
return i
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func getArrayValue(m map[string]interface{}, keys ...string) []string {
|
||||
for _, key := range keys {
|
||||
if val, exists := m[key]; exists {
|
||||
if arr, ok := val.([]interface{}); ok {
|
||||
result := make([]string, 0, len(arr))
|
||||
for _, item := range arr {
|
||||
result = append(result, getStringFromInterface(item))
|
||||
}
|
||||
return result
|
||||
}
|
||||
if strArr, ok := val.([]string); ok {
|
||||
return strArr
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func mapToStruct(m interface{}, target interface{}) error {
|
||||
// Simple mapping using JSON as intermediate
|
||||
data, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return json.Unmarshal(data, target)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,297 @@
|
||||
// Package config provides configuration structures for the Traefik OIDC plugin.
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// RedisMode represents the Redis deployment mode
|
||||
type RedisMode string
|
||||
|
||||
const (
|
||||
// RedisModeStandalone represents a single Redis instance
|
||||
RedisModeStandalone RedisMode = "standalone"
|
||||
|
||||
// RedisModeCluster represents Redis cluster mode
|
||||
RedisModeCluster RedisMode = "cluster"
|
||||
|
||||
// RedisModeSentinel represents Redis sentinel mode
|
||||
RedisModeSentinel RedisMode = "sentinel"
|
||||
)
|
||||
|
||||
// RedisConfig holds Redis cache backend configuration
|
||||
type RedisConfig struct {
|
||||
// Enabled indicates if Redis backend should be used
|
||||
Enabled bool `json:"enabled,omitempty" yaml:"enabled,omitempty"`
|
||||
|
||||
// Mode specifies the Redis deployment mode
|
||||
Mode RedisMode `json:"mode,omitempty" yaml:"mode,omitempty"`
|
||||
|
||||
// === Standalone Configuration ===
|
||||
// Addr is the Redis server address (host:port)
|
||||
Addr string `json:"addr,omitempty" yaml:"addr,omitempty"`
|
||||
|
||||
// Password for Redis authentication
|
||||
Password string `json:"password,omitempty" yaml:"password,omitempty"`
|
||||
|
||||
// DB is the database number (0-15)
|
||||
DB int `json:"db,omitempty" yaml:"db,omitempty"`
|
||||
|
||||
// === Cluster Configuration ===
|
||||
// ClusterAddrs is the list of cluster node addresses
|
||||
ClusterAddrs []string `json:"clusterAddrs,omitempty" yaml:"clusterAddrs,omitempty"`
|
||||
|
||||
// === Sentinel Configuration ===
|
||||
// MasterName is the name of the master instance
|
||||
MasterName string `json:"masterName,omitempty" yaml:"masterName,omitempty"`
|
||||
|
||||
// SentinelAddrs is the list of sentinel addresses
|
||||
SentinelAddrs []string `json:"sentinelAddrs,omitempty" yaml:"sentinelAddrs,omitempty"`
|
||||
|
||||
// SentinelPassword is the password for sentinel authentication
|
||||
SentinelPassword string `json:"sentinelPassword,omitempty" yaml:"sentinelPassword,omitempty"`
|
||||
|
||||
// === Connection Pool Settings ===
|
||||
// PoolSize is the maximum number of socket connections
|
||||
PoolSize int `json:"poolSize,omitempty" yaml:"poolSize,omitempty"`
|
||||
|
||||
// MinIdleConns is the minimum number of idle connections
|
||||
MinIdleConns int `json:"minIdleConns,omitempty" yaml:"minIdleConns,omitempty"`
|
||||
|
||||
// MaxRetries is the maximum number of retries before giving up
|
||||
MaxRetries int `json:"maxRetries,omitempty" yaml:"maxRetries,omitempty"`
|
||||
|
||||
// === Timeouts ===
|
||||
// DialTimeout is the timeout for establishing new connections
|
||||
DialTimeout time.Duration `json:"dialTimeout,omitempty" yaml:"dialTimeout,omitempty"`
|
||||
|
||||
// ReadTimeout is the timeout for socket reads
|
||||
ReadTimeout time.Duration `json:"readTimeout,omitempty" yaml:"readTimeout,omitempty"`
|
||||
|
||||
// WriteTimeout is the timeout for socket writes
|
||||
WriteTimeout time.Duration `json:"writeTimeout,omitempty" yaml:"writeTimeout,omitempty"`
|
||||
|
||||
// PoolTimeout is the timeout for connection pool
|
||||
PoolTimeout time.Duration `json:"poolTimeout,omitempty" yaml:"poolTimeout,omitempty"`
|
||||
|
||||
// ConnMaxIdleTime is the maximum amount of time a connection may be idle
|
||||
ConnMaxIdleTime time.Duration `json:"connMaxIdleTime,omitempty" yaml:"connMaxIdleTime,omitempty"`
|
||||
|
||||
// ConnMaxLifetime is the maximum lifetime of a connection
|
||||
ConnMaxLifetime time.Duration `json:"connMaxLifetime,omitempty" yaml:"connMaxLifetime,omitempty"`
|
||||
|
||||
// === Key Management ===
|
||||
// KeyPrefix is the prefix for all Redis keys
|
||||
KeyPrefix string `json:"keyPrefix,omitempty" yaml:"keyPrefix,omitempty"`
|
||||
|
||||
// === TLS Configuration ===
|
||||
// TLSEnabled enables TLS for Redis connections
|
||||
TLSEnabled bool `json:"tlsEnabled,omitempty" yaml:"tlsEnabled,omitempty"`
|
||||
|
||||
// TLSInsecureSkipVerify skips TLS certificate verification
|
||||
TLSInsecureSkipVerify bool `json:"tlsInsecureSkipVerify,omitempty" yaml:"tlsInsecureSkipVerify,omitempty"`
|
||||
|
||||
// === Resilience Settings ===
|
||||
// EnableCircuitBreaker enables circuit breaker for Redis operations
|
||||
EnableCircuitBreaker bool `json:"enableCircuitBreaker,omitempty" yaml:"enableCircuitBreaker,omitempty"`
|
||||
|
||||
// CircuitBreakerMaxFailures is the number of failures before opening circuit
|
||||
CircuitBreakerMaxFailures int `json:"circuitBreakerMaxFailures,omitempty" yaml:"circuitBreakerMaxFailures,omitempty"`
|
||||
|
||||
// CircuitBreakerTimeout is how long the circuit stays open
|
||||
CircuitBreakerTimeout time.Duration `json:"circuitBreakerTimeout,omitempty" yaml:"circuitBreakerTimeout,omitempty"`
|
||||
|
||||
// EnableHealthCheck enables periodic health checks
|
||||
EnableHealthCheck bool `json:"enableHealthCheck,omitempty" yaml:"enableHealthCheck,omitempty"`
|
||||
|
||||
// HealthCheckInterval is how often to check Redis health
|
||||
HealthCheckInterval time.Duration `json:"healthCheckInterval,omitempty" yaml:"healthCheckInterval,omitempty"`
|
||||
}
|
||||
|
||||
// DefaultRedisConfig returns default Redis configuration
|
||||
func DefaultRedisConfig() *RedisConfig {
|
||||
return &RedisConfig{
|
||||
Enabled: false,
|
||||
Mode: RedisModeStandalone,
|
||||
Addr: "localhost:6379",
|
||||
DB: 0,
|
||||
PoolSize: 10,
|
||||
MinIdleConns: 2,
|
||||
MaxRetries: 3,
|
||||
DialTimeout: 5 * time.Second,
|
||||
ReadTimeout: 3 * time.Second,
|
||||
WriteTimeout: 3 * time.Second,
|
||||
PoolTimeout: 4 * time.Second,
|
||||
ConnMaxIdleTime: 5 * time.Minute,
|
||||
ConnMaxLifetime: 30 * time.Minute,
|
||||
KeyPrefix: "traefikoidc:",
|
||||
TLSEnabled: false,
|
||||
TLSInsecureSkipVerify: false,
|
||||
EnableCircuitBreaker: true,
|
||||
CircuitBreakerMaxFailures: 5,
|
||||
CircuitBreakerTimeout: 30 * time.Second,
|
||||
EnableHealthCheck: true,
|
||||
HealthCheckInterval: 30 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// LoadFromEnv loads Redis configuration from environment variables
|
||||
func (c *RedisConfig) LoadFromEnv() {
|
||||
// Enable Redis if environment variable is set
|
||||
if enabled := os.Getenv("REDIS_ENABLED"); enabled != "" {
|
||||
c.Enabled = strings.ToLower(enabled) == "true"
|
||||
}
|
||||
|
||||
// Mode
|
||||
if mode := os.Getenv("REDIS_MODE"); mode != "" {
|
||||
c.Mode = RedisMode(strings.ToLower(mode))
|
||||
}
|
||||
|
||||
// Standalone configuration
|
||||
if addr := os.Getenv("REDIS_ADDR"); addr != "" {
|
||||
c.Addr = addr
|
||||
}
|
||||
if password := os.Getenv("REDIS_PASSWORD"); password != "" {
|
||||
c.Password = password
|
||||
}
|
||||
if db := os.Getenv("REDIS_DB"); db != "" {
|
||||
if dbNum, err := strconv.Atoi(db); err == nil {
|
||||
c.DB = dbNum
|
||||
}
|
||||
}
|
||||
|
||||
// Cluster configuration
|
||||
if clusterAddrs := os.Getenv("REDIS_CLUSTER_ADDRS"); clusterAddrs != "" {
|
||||
c.ClusterAddrs = strings.Split(clusterAddrs, ",")
|
||||
for i := range c.ClusterAddrs {
|
||||
c.ClusterAddrs[i] = strings.TrimSpace(c.ClusterAddrs[i])
|
||||
}
|
||||
}
|
||||
|
||||
// Sentinel configuration
|
||||
if masterName := os.Getenv("REDIS_MASTER_NAME"); masterName != "" {
|
||||
c.MasterName = masterName
|
||||
}
|
||||
if sentinelAddrs := os.Getenv("REDIS_SENTINEL_ADDRS"); sentinelAddrs != "" {
|
||||
c.SentinelAddrs = strings.Split(sentinelAddrs, ",")
|
||||
for i := range c.SentinelAddrs {
|
||||
c.SentinelAddrs[i] = strings.TrimSpace(c.SentinelAddrs[i])
|
||||
}
|
||||
}
|
||||
if sentinelPassword := os.Getenv("REDIS_SENTINEL_PASSWORD"); sentinelPassword != "" {
|
||||
c.SentinelPassword = sentinelPassword
|
||||
}
|
||||
|
||||
// Connection pool settings
|
||||
if poolSize := os.Getenv("REDIS_POOL_SIZE"); poolSize != "" {
|
||||
if size, err := strconv.Atoi(poolSize); err == nil {
|
||||
c.PoolSize = size
|
||||
}
|
||||
}
|
||||
if minIdleConns := os.Getenv("REDIS_MIN_IDLE_CONNS"); minIdleConns != "" {
|
||||
if conns, err := strconv.Atoi(minIdleConns); err == nil {
|
||||
c.MinIdleConns = conns
|
||||
}
|
||||
}
|
||||
if maxRetries := os.Getenv("REDIS_MAX_RETRIES"); maxRetries != "" {
|
||||
if retries, err := strconv.Atoi(maxRetries); err == nil {
|
||||
c.MaxRetries = retries
|
||||
}
|
||||
}
|
||||
|
||||
// Timeouts
|
||||
if dialTimeout := os.Getenv("REDIS_DIAL_TIMEOUT"); dialTimeout != "" {
|
||||
if timeout, err := time.ParseDuration(dialTimeout); err == nil {
|
||||
c.DialTimeout = timeout
|
||||
}
|
||||
}
|
||||
if readTimeout := os.Getenv("REDIS_READ_TIMEOUT"); readTimeout != "" {
|
||||
if timeout, err := time.ParseDuration(readTimeout); err == nil {
|
||||
c.ReadTimeout = timeout
|
||||
}
|
||||
}
|
||||
if writeTimeout := os.Getenv("REDIS_WRITE_TIMEOUT"); writeTimeout != "" {
|
||||
if timeout, err := time.ParseDuration(writeTimeout); err == nil {
|
||||
c.WriteTimeout = timeout
|
||||
}
|
||||
}
|
||||
|
||||
// Key prefix
|
||||
if keyPrefix := os.Getenv("REDIS_KEY_PREFIX"); keyPrefix != "" {
|
||||
c.KeyPrefix = keyPrefix
|
||||
}
|
||||
|
||||
// TLS settings
|
||||
if tlsEnabled := os.Getenv("REDIS_TLS_ENABLED"); tlsEnabled != "" {
|
||||
c.TLSEnabled = strings.ToLower(tlsEnabled) == "true"
|
||||
}
|
||||
if tlsInsecure := os.Getenv("REDIS_TLS_INSECURE_SKIP_VERIFY"); tlsInsecure != "" {
|
||||
c.TLSInsecureSkipVerify = strings.ToLower(tlsInsecure) == "true"
|
||||
}
|
||||
|
||||
// Resilience settings
|
||||
if enableCB := os.Getenv("REDIS_ENABLE_CIRCUIT_BREAKER"); enableCB != "" {
|
||||
c.EnableCircuitBreaker = strings.ToLower(enableCB) == "true"
|
||||
}
|
||||
if cbMaxFailures := os.Getenv("REDIS_CIRCUIT_BREAKER_MAX_FAILURES"); cbMaxFailures != "" {
|
||||
if failures, err := strconv.Atoi(cbMaxFailures); err == nil {
|
||||
c.CircuitBreakerMaxFailures = failures
|
||||
}
|
||||
}
|
||||
if cbTimeout := os.Getenv("REDIS_CIRCUIT_BREAKER_TIMEOUT"); cbTimeout != "" {
|
||||
if timeout, err := time.ParseDuration(cbTimeout); err == nil {
|
||||
c.CircuitBreakerTimeout = timeout
|
||||
}
|
||||
}
|
||||
if enableHC := os.Getenv("REDIS_ENABLE_HEALTH_CHECK"); enableHC != "" {
|
||||
c.EnableHealthCheck = strings.ToLower(enableHC) == "true"
|
||||
}
|
||||
if hcInterval := os.Getenv("REDIS_HEALTH_CHECK_INTERVAL"); hcInterval != "" {
|
||||
if interval, err := time.ParseDuration(hcInterval); err == nil {
|
||||
c.HealthCheckInterval = interval
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Validate checks if the configuration is valid
|
||||
func (c *RedisConfig) Validate() error {
|
||||
if !c.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch c.Mode {
|
||||
case RedisModeStandalone:
|
||||
if c.Addr == "" {
|
||||
return &ConfigError{Field: "addr", Message: "Redis address is required for standalone mode"}
|
||||
}
|
||||
case RedisModeCluster:
|
||||
if len(c.ClusterAddrs) == 0 {
|
||||
return &ConfigError{Field: "clusterAddrs", Message: "At least one cluster address is required"}
|
||||
}
|
||||
case RedisModeSentinel:
|
||||
if c.MasterName == "" {
|
||||
return &ConfigError{Field: "masterName", Message: "Master name is required for sentinel mode"}
|
||||
}
|
||||
if len(c.SentinelAddrs) == 0 {
|
||||
return &ConfigError{Field: "sentinelAddrs", Message: "At least one sentinel address is required"}
|
||||
}
|
||||
default:
|
||||
return &ConfigError{Field: "mode", Message: "Invalid Redis mode"}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ConfigError represents a configuration validation error
|
||||
type ConfigError struct {
|
||||
Field string
|
||||
Message string
|
||||
}
|
||||
|
||||
// Error implements the error interface
|
||||
func (e *ConfigError) Error() string {
|
||||
return "redis config error: " + e.Field + ": " + e.Message
|
||||
}
|
||||
+421
-121
@@ -2,12 +2,10 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -49,27 +47,111 @@ type Logger interface {
|
||||
|
||||
// Config represents the configuration for the OIDC middleware
|
||||
type Config struct {
|
||||
ProviderURL string `json:"providerUrl"`
|
||||
ClientID string `json:"clientId"`
|
||||
ClientSecret string `json:"clientSecret"`
|
||||
CallbackURL string `json:"callbackUrl"`
|
||||
LogoutURL string `json:"logoutUrl"`
|
||||
PostLogoutRedirectURI string `json:"postLogoutRedirectUri"`
|
||||
SessionEncryptionKey string `json:"sessionEncryptionKey"`
|
||||
ForceHTTPS bool `json:"forceHttps"`
|
||||
LogLevel string `json:"logLevel"`
|
||||
Scopes []string `json:"scopes"`
|
||||
OverrideScopes bool `json:"overrideScopes"`
|
||||
AllowedUsers []string `json:"allowedUsers"`
|
||||
AllowedUserDomains []string `json:"allowedUserDomains"`
|
||||
AllowedRolesAndGroups []string `json:"allowedRolesAndGroups"`
|
||||
ExcludedURLs []string `json:"excludedUrls"`
|
||||
EnablePKCE bool `json:"enablePkce"`
|
||||
RateLimit int `json:"rateLimit"`
|
||||
RefreshGracePeriodSeconds int `json:"refreshGracePeriodSeconds"`
|
||||
Headers []HeaderConfig `json:"headers"`
|
||||
HTTPClient *http.Client `json:"-"`
|
||||
CookieDomain string `json:"cookieDomain"`
|
||||
ProviderURL string `json:"providerUrl"`
|
||||
ClientID string `json:"clientId"`
|
||||
ClientSecret string `json:"clientSecret"`
|
||||
CallbackURL string `json:"callbackUrl"`
|
||||
LogoutURL string `json:"logoutUrl"`
|
||||
PostLogoutRedirectURI string `json:"postLogoutRedirectUri"`
|
||||
SessionEncryptionKey string `json:"sessionEncryptionKey"`
|
||||
ForceHTTPS bool `json:"forceHttps"`
|
||||
LogLevel string `json:"logLevel"`
|
||||
Scopes []string `json:"scopes"`
|
||||
OverrideScopes bool `json:"overrideScopes"`
|
||||
AllowedUsers []string `json:"allowedUsers"`
|
||||
AllowedUserDomains []string `json:"allowedUserDomains"`
|
||||
AllowedRolesAndGroups []string `json:"allowedRolesAndGroups"`
|
||||
ExcludedURLs []string `json:"excludedUrls"`
|
||||
EnablePKCE bool `json:"enablePkce"`
|
||||
RateLimit int `json:"rateLimit"`
|
||||
RefreshGracePeriodSeconds int `json:"refreshGracePeriodSeconds"`
|
||||
Headers []HeaderConfig `json:"headers"`
|
||||
HTTPClient *http.Client `json:"-"`
|
||||
CookieDomain string `json:"cookieDomain"`
|
||||
SecurityHeaders *SecurityHeadersConfig `json:"securityHeaders,omitempty"`
|
||||
|
||||
// Dynamic Client Registration (RFC 7591) configuration
|
||||
DynamicClientRegistration *DynamicClientRegistrationConfig `json:"dynamicClientRegistration,omitempty"`
|
||||
}
|
||||
|
||||
// DynamicClientRegistrationConfig configures OIDC Dynamic Client Registration (RFC 7591)
|
||||
type DynamicClientRegistrationConfig struct {
|
||||
// Enabled enables automatic client registration with the OIDC provider
|
||||
Enabled bool `json:"enabled"`
|
||||
|
||||
// InitialAccessToken is an optional bearer token for protected registration endpoints
|
||||
// Some providers require this token to authorize new client registrations
|
||||
InitialAccessToken string `json:"initialAccessToken,omitempty"`
|
||||
|
||||
// RegistrationEndpoint overrides the endpoint discovered from provider metadata
|
||||
// If empty, uses the registration_endpoint from .well-known/openid-configuration
|
||||
RegistrationEndpoint string `json:"registrationEndpoint,omitempty"`
|
||||
|
||||
// ClientMetadata contains the client metadata to register
|
||||
ClientMetadata *ClientRegistrationMetadata `json:"clientMetadata,omitempty"`
|
||||
|
||||
// PersistCredentials determines whether to save registered credentials to a file
|
||||
// This allows reusing the same client_id/client_secret across restarts
|
||||
PersistCredentials bool `json:"persistCredentials"`
|
||||
|
||||
// CredentialsFile is the path to store/load registered client credentials
|
||||
// Defaults to "/tmp/oidc-client-credentials.json" if not specified
|
||||
CredentialsFile string `json:"credentialsFile,omitempty"`
|
||||
}
|
||||
|
||||
// ClientRegistrationMetadata contains client metadata for dynamic registration (RFC 7591)
|
||||
type ClientRegistrationMetadata struct {
|
||||
// RedirectURIs is REQUIRED - array of redirect URIs for authorization
|
||||
RedirectURIs []string `json:"redirect_uris"`
|
||||
|
||||
// ResponseTypes specifies OAuth 2.0 response types (default: ["code"])
|
||||
ResponseTypes []string `json:"response_types,omitempty"`
|
||||
|
||||
// GrantTypes specifies OAuth 2.0 grant types (default: ["authorization_code"])
|
||||
GrantTypes []string `json:"grant_types,omitempty"`
|
||||
|
||||
// ApplicationType is either "web" (default) or "native"
|
||||
ApplicationType string `json:"application_type,omitempty"`
|
||||
|
||||
// Contacts is an array of email addresses for responsible parties
|
||||
Contacts []string `json:"contacts,omitempty"`
|
||||
|
||||
// ClientName is a human-readable name for the client
|
||||
ClientName string `json:"client_name,omitempty"`
|
||||
|
||||
// LogoURI is a URL pointing to a logo for the client
|
||||
LogoURI string `json:"logo_uri,omitempty"`
|
||||
|
||||
// ClientURI is a URL of the home page of the client
|
||||
ClientURI string `json:"client_uri,omitempty"`
|
||||
|
||||
// PolicyURI is a URL pointing to the client's privacy policy
|
||||
PolicyURI string `json:"policy_uri,omitempty"`
|
||||
|
||||
// TOSURI is a URL pointing to the client's terms of service
|
||||
TOSURI string `json:"tos_uri,omitempty"`
|
||||
|
||||
// JWKSURI is a URL for the client's JSON Web Key Set
|
||||
JWKSURI string `json:"jwks_uri,omitempty"`
|
||||
|
||||
// SubjectType is "pairwise" or "public" (provider-specific)
|
||||
SubjectType string `json:"subject_type,omitempty"`
|
||||
|
||||
// TokenEndpointAuthMethod specifies how the client authenticates at token endpoint
|
||||
// Values: "client_secret_basic", "client_secret_post", "client_secret_jwt", "private_key_jwt", "none"
|
||||
TokenEndpointAuthMethod string `json:"token_endpoint_auth_method,omitempty"`
|
||||
|
||||
// DefaultMaxAge is the default maximum authentication age in seconds
|
||||
DefaultMaxAge int `json:"default_max_age,omitempty"`
|
||||
|
||||
// RequireAuthTime specifies whether auth_time claim is required in ID token
|
||||
RequireAuthTime bool `json:"require_auth_time,omitempty"`
|
||||
|
||||
// DefaultACRValues specifies default ACR values
|
||||
DefaultACRValues []string `json:"default_acr_values,omitempty"`
|
||||
|
||||
// Scope is a space-separated list of scopes (alternative to config.Scopes)
|
||||
Scope string `json:"scope,omitempty"`
|
||||
}
|
||||
|
||||
// HeaderConfig represents header template configuration
|
||||
@@ -78,6 +160,59 @@ type HeaderConfig struct {
|
||||
Value string `json:"value"`
|
||||
}
|
||||
|
||||
// SecurityHeadersConfig configures security headers for the plugin
|
||||
type SecurityHeadersConfig struct {
|
||||
// Enable security headers (default: true)
|
||||
Enabled bool `json:"enabled"`
|
||||
|
||||
// Security profile: "default", "strict", "development", "api", or "custom"
|
||||
Profile string `json:"profile"`
|
||||
|
||||
// Content Security Policy
|
||||
ContentSecurityPolicy string `json:"contentSecurityPolicy,omitempty"`
|
||||
|
||||
// HSTS settings
|
||||
StrictTransportSecurity bool `json:"strictTransportSecurity"`
|
||||
StrictTransportSecurityMaxAge int `json:"strictTransportSecurityMaxAge"` // seconds
|
||||
StrictTransportSecuritySubdomains bool `json:"strictTransportSecuritySubdomains"`
|
||||
StrictTransportSecurityPreload bool `json:"strictTransportSecurityPreload"`
|
||||
|
||||
// Frame options: "DENY", "SAMEORIGIN", or "ALLOW-FROM uri"
|
||||
FrameOptions string `json:"frameOptions,omitempty"`
|
||||
|
||||
// Content type options (default: "nosniff")
|
||||
ContentTypeOptions string `json:"contentTypeOptions,omitempty"`
|
||||
|
||||
// XSS protection (default: "1; mode=block")
|
||||
XSSProtection string `json:"xssProtection,omitempty"`
|
||||
|
||||
// Referrer policy
|
||||
ReferrerPolicy string `json:"referrerPolicy,omitempty"`
|
||||
|
||||
// Permissions policy
|
||||
PermissionsPolicy string `json:"permissionsPolicy,omitempty"`
|
||||
|
||||
// Cross-origin settings
|
||||
CrossOriginEmbedderPolicy string `json:"crossOriginEmbedderPolicy,omitempty"`
|
||||
CrossOriginOpenerPolicy string `json:"crossOriginOpenerPolicy,omitempty"`
|
||||
CrossOriginResourcePolicy string `json:"crossOriginResourcePolicy,omitempty"`
|
||||
|
||||
// CORS settings
|
||||
CORSEnabled bool `json:"corsEnabled"`
|
||||
CORSAllowedOrigins []string `json:"corsAllowedOrigins,omitempty"`
|
||||
CORSAllowedMethods []string `json:"corsAllowedMethods,omitempty"`
|
||||
CORSAllowedHeaders []string `json:"corsAllowedHeaders,omitempty"`
|
||||
CORSAllowCredentials bool `json:"corsAllowCredentials"`
|
||||
CORSMaxAge int `json:"corsMaxAge"` // seconds
|
||||
|
||||
// Custom headers (in addition to standard security headers)
|
||||
CustomHeaders map[string]string `json:"customHeaders,omitempty"`
|
||||
|
||||
// Security features
|
||||
DisableServerHeader bool `json:"disableServerHeader"`
|
||||
DisablePoweredByHeader bool `json:"disablePoweredByHeader"`
|
||||
}
|
||||
|
||||
// NewSettings creates a new Settings instance
|
||||
func NewSettings(logger Logger) *Settings {
|
||||
return &Settings{
|
||||
@@ -95,117 +230,282 @@ func CreateConfig() *Config {
|
||||
RefreshGracePeriodSeconds: 60,
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
Headers: []HeaderConfig{},
|
||||
SecurityHeaders: createDefaultSecurityConfig(),
|
||||
}
|
||||
}
|
||||
|
||||
// InitializeTraefikOidc would initialize and configure a new TraefikOidc instance
|
||||
// This functionality has been moved to the main New function in main.go
|
||||
// This function is kept for compatibility but should not be used
|
||||
func (s *Settings) InitializeTraefikOidc(ctx context.Context, next http.Handler, config *Config, name string) (interface{}, error) {
|
||||
return nil, fmt.Errorf("InitializeTraefikOidc is deprecated - use New function from main package instead")
|
||||
// createDefaultSecurityConfig creates a default security headers configuration
|
||||
func createDefaultSecurityConfig() *SecurityHeadersConfig {
|
||||
return &SecurityHeadersConfig{
|
||||
Enabled: true,
|
||||
Profile: "default",
|
||||
|
||||
// Default security headers
|
||||
StrictTransportSecurity: true,
|
||||
StrictTransportSecurityMaxAge: 31536000, // 1 year
|
||||
StrictTransportSecuritySubdomains: true,
|
||||
StrictTransportSecurityPreload: true,
|
||||
|
||||
FrameOptions: "DENY",
|
||||
ContentTypeOptions: "nosniff",
|
||||
XSSProtection: "1; mode=block",
|
||||
ReferrerPolicy: "strict-origin-when-cross-origin",
|
||||
|
||||
// CORS disabled by default
|
||||
CORSEnabled: false,
|
||||
CORSAllowedMethods: []string{"GET", "POST", "OPTIONS"},
|
||||
CORSAllowedHeaders: []string{"Authorization", "Content-Type"},
|
||||
CORSAllowCredentials: false,
|
||||
CORSMaxAge: 86400, // 24 hours
|
||||
|
||||
// Security features
|
||||
DisableServerHeader: true,
|
||||
DisablePoweredByHeader: true,
|
||||
}
|
||||
}
|
||||
|
||||
//lint:ignore U1000 Kept for backward compatibility
|
||||
func (s *Settings) setupHeaderTemplates(t interface{}, config *Config, logger Logger) error {
|
||||
logger.Debug("setupHeaderTemplates is deprecated")
|
||||
return nil
|
||||
// ToInternalSecurityConfig converts plugin SecurityHeadersConfig to internal security config
|
||||
func (c *SecurityHeadersConfig) ToInternalSecurityConfig() interface{} {
|
||||
if c == nil || !c.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create the internal security config structure
|
||||
config := map[string]interface{}{
|
||||
"DevelopmentMode": false,
|
||||
}
|
||||
|
||||
// Apply profile-based defaults
|
||||
switch strings.ToLower(c.Profile) {
|
||||
case "strict":
|
||||
applyStrictProfile(config)
|
||||
case "development":
|
||||
applyDevelopmentProfile(config)
|
||||
case "api":
|
||||
applyAPIProfile(config)
|
||||
case "custom":
|
||||
// No defaults, use only what's explicitly configured
|
||||
default: // "default"
|
||||
applyDefaultProfile(config)
|
||||
}
|
||||
|
||||
// Override with explicit configuration
|
||||
if c.ContentSecurityPolicy != "" {
|
||||
config["ContentSecurityPolicy"] = c.ContentSecurityPolicy
|
||||
}
|
||||
|
||||
// HSTS configuration
|
||||
if c.StrictTransportSecurity {
|
||||
config["StrictTransportSecurityMaxAge"] = c.StrictTransportSecurityMaxAge
|
||||
config["StrictTransportSecuritySubdomains"] = c.StrictTransportSecuritySubdomains
|
||||
config["StrictTransportSecurityPreload"] = c.StrictTransportSecurityPreload
|
||||
}
|
||||
|
||||
// Frame options
|
||||
if c.FrameOptions != "" {
|
||||
config["FrameOptions"] = c.FrameOptions
|
||||
}
|
||||
|
||||
// Content type and XSS protection
|
||||
if c.ContentTypeOptions != "" {
|
||||
config["ContentTypeOptions"] = c.ContentTypeOptions
|
||||
}
|
||||
if c.XSSProtection != "" {
|
||||
config["XSSProtection"] = c.XSSProtection
|
||||
}
|
||||
|
||||
// Referrer and permissions policies
|
||||
if c.ReferrerPolicy != "" {
|
||||
config["ReferrerPolicy"] = c.ReferrerPolicy
|
||||
}
|
||||
if c.PermissionsPolicy != "" {
|
||||
config["PermissionsPolicy"] = c.PermissionsPolicy
|
||||
}
|
||||
|
||||
// Cross-origin policies
|
||||
if c.CrossOriginEmbedderPolicy != "" {
|
||||
config["CrossOriginEmbedderPolicy"] = c.CrossOriginEmbedderPolicy
|
||||
}
|
||||
if c.CrossOriginOpenerPolicy != "" {
|
||||
config["CrossOriginOpenerPolicy"] = c.CrossOriginOpenerPolicy
|
||||
}
|
||||
if c.CrossOriginResourcePolicy != "" {
|
||||
config["CrossOriginResourcePolicy"] = c.CrossOriginResourcePolicy
|
||||
}
|
||||
|
||||
// CORS configuration
|
||||
config["CORSEnabled"] = c.CORSEnabled
|
||||
if len(c.CORSAllowedOrigins) > 0 {
|
||||
config["CORSAllowedOrigins"] = c.CORSAllowedOrigins
|
||||
}
|
||||
if len(c.CORSAllowedMethods) > 0 {
|
||||
config["CORSAllowedMethods"] = c.CORSAllowedMethods
|
||||
}
|
||||
if len(c.CORSAllowedHeaders) > 0 {
|
||||
config["CORSAllowedHeaders"] = c.CORSAllowedHeaders
|
||||
}
|
||||
config["CORSAllowCredentials"] = c.CORSAllowCredentials
|
||||
if c.CORSMaxAge > 0 {
|
||||
config["CORSMaxAge"] = c.CORSMaxAge
|
||||
}
|
||||
|
||||
// Custom headers
|
||||
if len(c.CustomHeaders) > 0 {
|
||||
config["CustomHeaders"] = c.CustomHeaders
|
||||
}
|
||||
|
||||
// Security features
|
||||
config["DisableServerHeader"] = c.DisableServerHeader
|
||||
config["DisablePoweredByHeader"] = c.DisablePoweredByHeader
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
//lint:ignore U1000 May be needed for future background service management
|
||||
func (s *Settings) startBackgroundServices(ctx context.Context, logger Logger) {
|
||||
startReplayCacheCleanup(ctx, logger)
|
||||
|
||||
// Start memory monitoring for leak detection and performance insights
|
||||
memoryMonitor := GetGlobalMemoryMonitor()
|
||||
memoryMonitor.StartMonitoring(ctx, 60*time.Second) // Monitor every minute
|
||||
logger.Debug("Started global memory monitoring")
|
||||
// applyDefaultProfile applies default security settings
|
||||
func applyDefaultProfile(config map[string]interface{}) {
|
||||
config["ContentSecurityPolicy"] = "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'; img-src 'self' data: https:; font-src 'self' data:; connect-src 'self'; frame-ancestors 'none';"
|
||||
config["FrameOptions"] = "DENY"
|
||||
config["ContentTypeOptions"] = "nosniff"
|
||||
config["XSSProtection"] = "1; mode=block"
|
||||
config["ReferrerPolicy"] = "strict-origin-when-cross-origin"
|
||||
config["PermissionsPolicy"] = "geolocation=(), microphone=(), camera=(), payment=(), usb=()"
|
||||
config["CrossOriginEmbedderPolicy"] = "require-corp"
|
||||
config["CrossOriginOpenerPolicy"] = "same-origin"
|
||||
config["CrossOriginResourcePolicy"] = "same-origin"
|
||||
}
|
||||
|
||||
// Utility functions
|
||||
// applyStrictProfile applies strict security settings
|
||||
func applyStrictProfile(config map[string]interface{}) {
|
||||
config["ContentSecurityPolicy"] = "default-src 'none'; script-src 'self'; style-src 'self'; img-src 'self'; font-src 'self'; connect-src 'self'; frame-ancestors 'none'; base-uri 'self'; form-action 'self';"
|
||||
config["FrameOptions"] = "DENY"
|
||||
config["ContentTypeOptions"] = "nosniff"
|
||||
config["XSSProtection"] = "1; mode=block"
|
||||
config["ReferrerPolicy"] = "strict-origin-when-cross-origin"
|
||||
config["PermissionsPolicy"] = "geolocation=(), microphone=(), camera=(), payment=(), usb=(), magnetometer=(), gyroscope=(), speaker=()"
|
||||
config["CrossOriginEmbedderPolicy"] = "require-corp"
|
||||
config["CrossOriginOpenerPolicy"] = "same-origin"
|
||||
config["CrossOriginResourcePolicy"] = "same-site"
|
||||
}
|
||||
|
||||
//lint:ignore U1000 May be needed for future scope processing
|
||||
func deduplicateScopes(scopes []string) []string {
|
||||
seen := make(map[string]bool)
|
||||
result := []string{}
|
||||
for _, scope := range scopes {
|
||||
if !seen[scope] {
|
||||
seen[scope] = true
|
||||
result = append(result, scope)
|
||||
// applyDevelopmentProfile applies development-friendly settings
|
||||
func applyDevelopmentProfile(config map[string]interface{}) {
|
||||
config["ContentSecurityPolicy"] = "default-src 'self' 'unsafe-inline' 'unsafe-eval'; img-src 'self' data: https: http:; connect-src 'self' ws: wss:;"
|
||||
config["FrameOptions"] = "SAMEORIGIN"
|
||||
config["ContentTypeOptions"] = "nosniff"
|
||||
config["XSSProtection"] = "1; mode=block"
|
||||
config["ReferrerPolicy"] = "strict-origin-when-cross-origin"
|
||||
config["CrossOriginOpenerPolicy"] = "unsafe-none"
|
||||
config["CrossOriginResourcePolicy"] = "cross-origin"
|
||||
config["DevelopmentMode"] = true
|
||||
}
|
||||
|
||||
// applyAPIProfile applies API-friendly settings
|
||||
func applyAPIProfile(config map[string]interface{}) {
|
||||
config["ContentSecurityPolicy"] = "default-src 'none'; frame-ancestors 'none';"
|
||||
config["FrameOptions"] = "DENY"
|
||||
config["ContentTypeOptions"] = "nosniff"
|
||||
config["XSSProtection"] = "1; mode=block"
|
||||
config["ReferrerPolicy"] = "strict-origin-when-cross-origin"
|
||||
config["CrossOriginResourcePolicy"] = "cross-origin"
|
||||
}
|
||||
|
||||
// GetSecurityHeadersApplier returns a function that applies security headers
|
||||
func (c *Config) GetSecurityHeadersApplier() func(http.ResponseWriter, *http.Request) {
|
||||
if c.SecurityHeaders == nil || !c.SecurityHeaders.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
// This would need to import the internal security package
|
||||
// For now, return a basic implementation
|
||||
return func(rw http.ResponseWriter, req *http.Request) {
|
||||
headers := rw.Header()
|
||||
|
||||
// Apply basic security headers based on configuration
|
||||
if c.SecurityHeaders.FrameOptions != "" {
|
||||
headers.Set("X-Frame-Options", c.SecurityHeaders.FrameOptions)
|
||||
}
|
||||
if c.SecurityHeaders.ContentTypeOptions != "" {
|
||||
headers.Set("X-Content-Type-Options", c.SecurityHeaders.ContentTypeOptions)
|
||||
}
|
||||
if c.SecurityHeaders.XSSProtection != "" {
|
||||
headers.Set("X-XSS-Protection", c.SecurityHeaders.XSSProtection)
|
||||
}
|
||||
if c.SecurityHeaders.ReferrerPolicy != "" {
|
||||
headers.Set("Referrer-Policy", c.SecurityHeaders.ReferrerPolicy)
|
||||
}
|
||||
if c.SecurityHeaders.ContentSecurityPolicy != "" {
|
||||
headers.Set("Content-Security-Policy", c.SecurityHeaders.ContentSecurityPolicy)
|
||||
}
|
||||
|
||||
// HSTS for HTTPS
|
||||
if (req.TLS != nil || req.Header.Get("X-Forwarded-Proto") == "https") && c.SecurityHeaders.StrictTransportSecurity {
|
||||
hstsValue := fmt.Sprintf("max-age=%d", c.SecurityHeaders.StrictTransportSecurityMaxAge)
|
||||
if c.SecurityHeaders.StrictTransportSecuritySubdomains {
|
||||
hstsValue += "; includeSubDomains"
|
||||
}
|
||||
if c.SecurityHeaders.StrictTransportSecurityPreload {
|
||||
hstsValue += "; preload"
|
||||
}
|
||||
headers.Set("Strict-Transport-Security", hstsValue)
|
||||
}
|
||||
|
||||
// CORS headers
|
||||
if c.SecurityHeaders.CORSEnabled {
|
||||
origin := req.Header.Get("Origin")
|
||||
if origin != "" && isOriginAllowed(origin, c.SecurityHeaders.CORSAllowedOrigins) {
|
||||
headers.Set("Access-Control-Allow-Origin", origin)
|
||||
}
|
||||
|
||||
if len(c.SecurityHeaders.CORSAllowedMethods) > 0 {
|
||||
headers.Set("Access-Control-Allow-Methods", strings.Join(c.SecurityHeaders.CORSAllowedMethods, ", "))
|
||||
}
|
||||
if len(c.SecurityHeaders.CORSAllowedHeaders) > 0 {
|
||||
headers.Set("Access-Control-Allow-Headers", strings.Join(c.SecurityHeaders.CORSAllowedHeaders, ", "))
|
||||
}
|
||||
if c.SecurityHeaders.CORSAllowCredentials {
|
||||
headers.Set("Access-Control-Allow-Credentials", "true")
|
||||
}
|
||||
if c.SecurityHeaders.CORSMaxAge > 0 {
|
||||
headers.Set("Access-Control-Max-Age", strconv.Itoa(c.SecurityHeaders.CORSMaxAge))
|
||||
}
|
||||
}
|
||||
|
||||
// Custom headers
|
||||
for name, value := range c.SecurityHeaders.CustomHeaders {
|
||||
headers.Set(name, value)
|
||||
}
|
||||
|
||||
// Remove server headers
|
||||
if c.SecurityHeaders.DisableServerHeader {
|
||||
headers.Del("Server")
|
||||
}
|
||||
if c.SecurityHeaders.DisablePoweredByHeader {
|
||||
headers.Del("X-Powered-By")
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
//lint:ignore U1000 May be needed for future scope merging operations
|
||||
func mergeScopes(defaultScopes, userScopes []string) []string {
|
||||
result := make([]string, len(defaultScopes))
|
||||
copy(result, defaultScopes)
|
||||
return append(result, userScopes...)
|
||||
}
|
||||
|
||||
//lint:ignore U1000 May be needed for future utility operations
|
||||
func createStringMap(items []string) map[string]struct{} {
|
||||
result := make(map[string]struct{})
|
||||
for _, item := range items {
|
||||
result[item] = struct{}{}
|
||||
// isOriginAllowed checks if an origin is in the allowed list
|
||||
func isOriginAllowed(origin string, allowedOrigins []string) bool {
|
||||
for _, allowed := range allowedOrigins {
|
||||
if origin == allowed || allowed == "*" {
|
||||
return true
|
||||
}
|
||||
// Simple wildcard matching for subdomains
|
||||
if strings.Contains(allowed, "*") {
|
||||
if strings.HasPrefix(allowed, "https://*.") {
|
||||
domain := strings.TrimPrefix(allowed, "https://*.")
|
||||
if strings.HasSuffix(origin, "."+domain) || origin == "https://"+domain {
|
||||
return true
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(allowed, "http://*.") {
|
||||
domain := strings.TrimPrefix(allowed, "http://*.")
|
||||
if strings.HasSuffix(origin, "."+domain) || origin == "http://"+domain {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
//lint:ignore U1000 May be needed for future case-insensitive operations
|
||||
func createCaseInsensitiveStringMap(items []string) map[string]struct{} {
|
||||
result := make(map[string]struct{})
|
||||
for _, item := range items {
|
||||
result[strings.ToLower(item)] = struct{}{}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
//lint:ignore U1000 May be needed for future test environment detection
|
||||
func isTestMode() bool {
|
||||
// This function should be implemented based on environment detection logic
|
||||
return false
|
||||
}
|
||||
|
||||
// External dependencies that need to be provided
|
||||
// TraefikOidc struct is defined in types.go
|
||||
|
||||
// These functions need to be provided by external packages
|
||||
func NewLogger(level string) Logger { return nil }
|
||||
func CreateDefaultHTTPClient() *http.Client { return nil }
|
||||
func CreateTokenHTTPClient() *http.Client { return nil }
|
||||
func GetGlobalCacheManager(*sync.WaitGroup) CacheManager { return nil }
|
||||
func NewSessionManager(string, bool, string, Logger) (SessionManager, error) { return nil, nil }
|
||||
func NewErrorRecoveryManager(Logger) ErrorRecoveryManager { return nil }
|
||||
|
||||
//lint:ignore U1000 May be needed for future token claim extraction
|
||||
func extractClaims(string) (map[string]interface{}, error) { return nil, nil }
|
||||
|
||||
//lint:ignore U1000 May be needed for future replay attack prevention
|
||||
func startReplayCacheCleanup(context.Context, Logger) {}
|
||||
func GetGlobalMemoryMonitor() MemoryMonitor { return nil }
|
||||
|
||||
// Interfaces for external dependencies
|
||||
type CacheManager interface {
|
||||
GetSharedTokenBlacklist() CacheInterface
|
||||
GetSharedTokenCache() *TokenCache
|
||||
GetSharedMetadataCache() *MetadataCache
|
||||
GetSharedJWKCache() JWKCacheInterface
|
||||
Close() error
|
||||
}
|
||||
type SessionManager interface{}
|
||||
type ErrorRecoveryManager interface{}
|
||||
type MemoryMonitor interface {
|
||||
StartMonitoring(ctx context.Context, interval time.Duration)
|
||||
}
|
||||
type CacheInterface interface {
|
||||
Set(key string, value interface{}, ttl time.Duration)
|
||||
Get(key string) (interface{}, bool)
|
||||
Delete(key string)
|
||||
SetMaxSize(size int)
|
||||
Cleanup()
|
||||
Close()
|
||||
}
|
||||
type TokenCache struct{}
|
||||
type MetadataCache struct{}
|
||||
type JWKCacheInterface interface{}
|
||||
|
||||
@@ -0,0 +1,287 @@
|
||||
// Package config provides unified configuration management for the OIDC middleware
|
||||
package config
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// UnifiedConfig is the master configuration structure consolidating all config aspects
|
||||
// This replaces 45 duplicate config structs across the codebase
|
||||
type UnifiedConfig struct {
|
||||
// Core Configuration
|
||||
Provider ProviderConfig `json:"provider" yaml:"provider"`
|
||||
Session SessionConfig `json:"session" yaml:"session"`
|
||||
Token TokenConfig `json:"token" yaml:"token"`
|
||||
Redis RedisConfig `json:"redis" yaml:"redis"`
|
||||
Security SecurityConfig `json:"security" yaml:"security"`
|
||||
|
||||
// Middleware Configuration
|
||||
Middleware MiddlewareConfig `json:"middleware" yaml:"middleware"`
|
||||
Cache CacheConfig `json:"cache" yaml:"cache"`
|
||||
RateLimit RateLimitConfig `json:"rateLimit" yaml:"rateLimit"`
|
||||
|
||||
// Operational Configuration
|
||||
Logging LoggingConfig `json:"logging" yaml:"logging"`
|
||||
Metrics MetricsConfig `json:"metrics" yaml:"metrics"`
|
||||
Health HealthConfig `json:"health" yaml:"health"`
|
||||
|
||||
// Advanced Configuration
|
||||
Transport TransportConfig `json:"transport" yaml:"transport"`
|
||||
Pool PoolConfig `json:"pool" yaml:"pool"`
|
||||
Circuit CircuitConfig `json:"circuit" yaml:"circuit"`
|
||||
|
||||
// Compatibility field for migration
|
||||
Legacy map[string]interface{} `json:"-" yaml:"-"`
|
||||
}
|
||||
|
||||
// ProviderConfig contains OIDC provider settings
|
||||
type ProviderConfig struct {
|
||||
IssuerURL string `json:"issuerURL" yaml:"issuerURL"`
|
||||
ClientID string `json:"clientID" yaml:"clientID"`
|
||||
ClientSecret string `json:"clientSecret" yaml:"clientSecret"`
|
||||
RedirectURL string `json:"redirectURL" yaml:"redirectURL"`
|
||||
LogoutURL string `json:"logoutURL" yaml:"logoutURL"`
|
||||
PostLogoutRedirectURI string `json:"postLogoutRedirectURI" yaml:"postLogoutRedirectURI"`
|
||||
Scopes []string `json:"scopes" yaml:"scopes"`
|
||||
OverrideScopes bool `json:"overrideScopes" yaml:"overrideScopes"`
|
||||
CustomClaims map[string]string `json:"customClaims" yaml:"customClaims"`
|
||||
JWKCachePeriod time.Duration `json:"jwkCachePeriod" yaml:"jwkCachePeriod"`
|
||||
MetadataCacheTTL time.Duration `json:"metadataCacheTTL" yaml:"metadataCacheTTL"`
|
||||
Discovery bool `json:"discovery" yaml:"discovery"`
|
||||
|
||||
// Provider-specific endpoints
|
||||
AuthorizationEndpoint string `json:"authorizationEndpoint,omitempty" yaml:"authorizationEndpoint,omitempty"`
|
||||
TokenEndpoint string `json:"tokenEndpoint,omitempty" yaml:"tokenEndpoint,omitempty"`
|
||||
UserInfoEndpoint string `json:"userInfoEndpoint,omitempty" yaml:"userInfoEndpoint,omitempty"`
|
||||
JWKSEndpoint string `json:"jwksEndpoint,omitempty" yaml:"jwksEndpoint,omitempty"`
|
||||
IntrospectEndpoint string `json:"introspectEndpoint,omitempty" yaml:"introspectEndpoint,omitempty"`
|
||||
RevocationEndpoint string `json:"revocationEndpoint,omitempty" yaml:"revocationEndpoint,omitempty"`
|
||||
}
|
||||
|
||||
// SessionConfig contains session management settings
|
||||
type SessionConfig struct {
|
||||
Name string `json:"name" yaml:"name"`
|
||||
MaxAge int `json:"maxAge" yaml:"maxAge"`
|
||||
Secret string `json:"secret" yaml:"secret"`
|
||||
EncryptionKey string `json:"encryptionKey" yaml:"encryptionKey"`
|
||||
SigningKey string `json:"signingKey" yaml:"signingKey"`
|
||||
ChunkSize int `json:"chunkSize" yaml:"chunkSize"`
|
||||
MaxChunks int `json:"maxChunks" yaml:"maxChunks"`
|
||||
|
||||
// Cookie settings
|
||||
Domain string `json:"domain" yaml:"domain"`
|
||||
Path string `json:"path" yaml:"path"`
|
||||
Secure bool `json:"secure" yaml:"secure"`
|
||||
HttpOnly bool `json:"httpOnly" yaml:"httpOnly"`
|
||||
SameSite string `json:"sameSite" yaml:"sameSite"`
|
||||
CookiePrefix string `json:"cookiePrefix" yaml:"cookiePrefix"` // Prefix for cookie names (e.g., "_oidc_myapp_")
|
||||
|
||||
// Storage settings
|
||||
StorageType string `json:"storageType" yaml:"storageType"` // "memory", "redis", "cookie"
|
||||
CleanupInterval time.Duration `json:"cleanupInterval" yaml:"cleanupInterval"`
|
||||
}
|
||||
|
||||
// TokenConfig contains token handling settings
|
||||
type TokenConfig struct {
|
||||
AccessTokenTTL time.Duration `json:"accessTokenTTL" yaml:"accessTokenTTL"`
|
||||
RefreshTokenTTL time.Duration `json:"refreshTokenTTL" yaml:"refreshTokenTTL"`
|
||||
RefreshGracePeriod time.Duration `json:"refreshGracePeriod" yaml:"refreshGracePeriod"`
|
||||
ValidationMode string `json:"validationMode" yaml:"validationMode"` // "jwt", "introspect", "hybrid"
|
||||
IntrospectURL string `json:"introspectURL" yaml:"introspectURL"`
|
||||
|
||||
// Token caching
|
||||
CacheEnabled bool `json:"cacheEnabled" yaml:"cacheEnabled"`
|
||||
CacheTTL time.Duration `json:"cacheTTL" yaml:"cacheTTL"`
|
||||
CacheNegativeTTL time.Duration `json:"cacheNegativeTTL" yaml:"cacheNegativeTTL"`
|
||||
|
||||
// Token validation
|
||||
ValidateSignature bool `json:"validateSignature" yaml:"validateSignature"`
|
||||
ValidateExpiry bool `json:"validateExpiry" yaml:"validateExpiry"`
|
||||
ValidateAudience bool `json:"validateAudience" yaml:"validateAudience"`
|
||||
ValidateIssuer bool `json:"validateIssuer" yaml:"validateIssuer"`
|
||||
RequiredClaims []string `json:"requiredClaims" yaml:"requiredClaims"`
|
||||
ClockSkew time.Duration `json:"clockSkew" yaml:"clockSkew"`
|
||||
}
|
||||
|
||||
// SecurityConfig contains security-related settings
|
||||
type SecurityConfig struct {
|
||||
ForceHTTPS bool `json:"forceHTTPS" yaml:"forceHTTPS"`
|
||||
EnablePKCE bool `json:"enablePKCE" yaml:"enablePKCE"`
|
||||
AllowedUsers []string `json:"allowedUsers" yaml:"allowedUsers"`
|
||||
AllowedUserDomains []string `json:"allowedUserDomains" yaml:"allowedUserDomains"`
|
||||
AllowedRolesAndGroups []string `json:"allowedRolesAndGroups" yaml:"allowedRolesAndGroups"`
|
||||
ExcludedURLs []string `json:"excludedURLs" yaml:"excludedURLs"`
|
||||
Headers *SecurityHeadersConfig `json:"headers" yaml:"headers"`
|
||||
|
||||
// CSRF protection
|
||||
CSRFProtection bool `json:"csrfProtection" yaml:"csrfProtection"`
|
||||
CSRFTokenName string `json:"csrfTokenName" yaml:"csrfTokenName"`
|
||||
CSRFTokenTTL time.Duration `json:"csrfTokenTTL" yaml:"csrfTokenTTL"`
|
||||
|
||||
// Additional security
|
||||
MaxLoginAttempts int `json:"maxLoginAttempts" yaml:"maxLoginAttempts"`
|
||||
LockoutDuration time.Duration `json:"lockoutDuration" yaml:"lockoutDuration"`
|
||||
RequireMFA bool `json:"requireMFA" yaml:"requireMFA"`
|
||||
}
|
||||
|
||||
// MiddlewareConfig contains middleware-specific settings
|
||||
type MiddlewareConfig struct {
|
||||
Priority int `json:"priority" yaml:"priority"`
|
||||
SkipPaths []string `json:"skipPaths" yaml:"skipPaths"`
|
||||
RequirePaths []string `json:"requirePaths" yaml:"requirePaths"`
|
||||
PassthroughMode bool `json:"passthroughMode" yaml:"passthroughMode"`
|
||||
|
||||
// Request handling
|
||||
MaxRequestSize int64 `json:"maxRequestSize" yaml:"maxRequestSize"`
|
||||
RequestTimeout time.Duration `json:"requestTimeout" yaml:"requestTimeout"`
|
||||
IdleTimeout time.Duration `json:"idleTimeout" yaml:"idleTimeout"`
|
||||
|
||||
// Response handling
|
||||
CustomHeaders map[string]string `json:"customHeaders" yaml:"customHeaders"`
|
||||
RemoveHeaders []string `json:"removeHeaders" yaml:"removeHeaders"`
|
||||
}
|
||||
|
||||
// CacheConfig contains cache configuration
|
||||
type CacheConfig struct {
|
||||
Enabled bool `json:"enabled" yaml:"enabled"`
|
||||
Type string `json:"type" yaml:"type"` // "memory", "redis", "hybrid"
|
||||
DefaultTTL time.Duration `json:"defaultTTL" yaml:"defaultTTL"`
|
||||
MaxEntries int `json:"maxEntries" yaml:"maxEntries"`
|
||||
MaxEntrySize int64 `json:"maxEntrySize" yaml:"maxEntrySize"`
|
||||
EvictionPolicy string `json:"evictionPolicy" yaml:"evictionPolicy"` // "lru", "lfu", "fifo"
|
||||
|
||||
// Memory cache settings
|
||||
CleanupInterval time.Duration `json:"cleanupInterval" yaml:"cleanupInterval"`
|
||||
|
||||
// Distributed cache settings
|
||||
Namespace string `json:"namespace" yaml:"namespace"`
|
||||
Compression bool `json:"compression" yaml:"compression"`
|
||||
Serialization string `json:"serialization" yaml:"serialization"` // "json", "msgpack", "protobuf"
|
||||
}
|
||||
|
||||
// RateLimitConfig contains rate limiting configuration
|
||||
type RateLimitConfig struct {
|
||||
Enabled bool `json:"enabled" yaml:"enabled"`
|
||||
RequestsPerSecond int `json:"requestsPerSecond" yaml:"requestsPerSecond"`
|
||||
Burst int `json:"burst" yaml:"burst"`
|
||||
|
||||
// Rate limit storage
|
||||
StorageType string `json:"storageType" yaml:"storageType"` // "memory", "redis"
|
||||
WindowDuration time.Duration `json:"windowDuration" yaml:"windowDuration"`
|
||||
|
||||
// Rate limit keys
|
||||
KeyType string `json:"keyType" yaml:"keyType"` // "ip", "user", "token", "custom"
|
||||
CustomKeyFunc string `json:"customKeyFunc" yaml:"customKeyFunc"`
|
||||
|
||||
// Whitelisting
|
||||
WhitelistIPs []string `json:"whitelistIPs" yaml:"whitelistIPs"`
|
||||
WhitelistUsers []string `json:"whitelistUsers" yaml:"whitelistUsers"`
|
||||
}
|
||||
|
||||
// LoggingConfig contains logging configuration
|
||||
type LoggingConfig struct {
|
||||
Level string `json:"level" yaml:"level"` // "debug", "info", "warn", "error"
|
||||
Format string `json:"format" yaml:"format"` // "json", "text", "structured"
|
||||
Output string `json:"output" yaml:"output"` // "stdout", "stderr", "file"
|
||||
FilePath string `json:"filePath" yaml:"filePath"`
|
||||
|
||||
// Log filtering
|
||||
FilterSensitive bool `json:"filterSensitive" yaml:"filterSensitive"`
|
||||
MaskFields []string `json:"maskFields" yaml:"maskFields"`
|
||||
|
||||
// Performance
|
||||
BufferSize int `json:"bufferSize" yaml:"bufferSize"`
|
||||
FlushInterval time.Duration `json:"flushInterval" yaml:"flushInterval"`
|
||||
|
||||
// Audit logging
|
||||
AuditEnabled bool `json:"auditEnabled" yaml:"auditEnabled"`
|
||||
AuditEvents []string `json:"auditEvents" yaml:"auditEvents"`
|
||||
}
|
||||
|
||||
// MetricsConfig contains metrics collection configuration
|
||||
type MetricsConfig struct {
|
||||
Enabled bool `json:"enabled" yaml:"enabled"`
|
||||
Provider string `json:"provider" yaml:"provider"` // "prometheus", "statsd", "otlp"
|
||||
Endpoint string `json:"endpoint" yaml:"endpoint"`
|
||||
Namespace string `json:"namespace" yaml:"namespace"`
|
||||
Subsystem string `json:"subsystem" yaml:"subsystem"`
|
||||
|
||||
// Collection settings
|
||||
CollectInterval time.Duration `json:"collectInterval" yaml:"collectInterval"`
|
||||
Histograms bool `json:"histograms" yaml:"histograms"`
|
||||
|
||||
// Custom labels
|
||||
Labels map[string]string `json:"labels" yaml:"labels"`
|
||||
}
|
||||
|
||||
// HealthConfig contains health check configuration
|
||||
type HealthConfig struct {
|
||||
Enabled bool `json:"enabled" yaml:"enabled"`
|
||||
Path string `json:"path" yaml:"path"`
|
||||
CheckInterval time.Duration `json:"checkInterval" yaml:"checkInterval"`
|
||||
Timeout time.Duration `json:"timeout" yaml:"timeout"`
|
||||
|
||||
// Checks to perform
|
||||
CheckProvider bool `json:"checkProvider" yaml:"checkProvider"`
|
||||
CheckRedis bool `json:"checkRedis" yaml:"checkRedis"`
|
||||
CheckCache bool `json:"checkCache" yaml:"checkCache"`
|
||||
|
||||
// Thresholds
|
||||
MaxLatency time.Duration `json:"maxLatency" yaml:"maxLatency"`
|
||||
MinMemory int64 `json:"minMemory" yaml:"minMemory"`
|
||||
}
|
||||
|
||||
// TransportConfig contains HTTP transport configuration
|
||||
type TransportConfig struct {
|
||||
MaxIdleConns int `json:"maxIdleConns" yaml:"maxIdleConns"`
|
||||
MaxIdleConnsPerHost int `json:"maxIdleConnsPerHost" yaml:"maxIdleConnsPerHost"`
|
||||
MaxConnsPerHost int `json:"maxConnsPerHost" yaml:"maxConnsPerHost"`
|
||||
IdleConnTimeout time.Duration `json:"idleConnTimeout" yaml:"idleConnTimeout"`
|
||||
TLSHandshakeTimeout time.Duration `json:"tlsHandshakeTimeout" yaml:"tlsHandshakeTimeout"`
|
||||
ExpectContinueTimeout time.Duration `json:"expectContinueTimeout" yaml:"expectContinueTimeout"`
|
||||
ResponseHeaderTimeout time.Duration `json:"responseHeaderTimeout" yaml:"responseHeaderTimeout"`
|
||||
DisableKeepAlives bool `json:"disableKeepAlives" yaml:"disableKeepAlives"`
|
||||
DisableCompression bool `json:"disableCompression" yaml:"disableCompression"`
|
||||
|
||||
// TLS configuration
|
||||
TLSInsecureSkipVerify bool `json:"tlsInsecureSkipVerify" yaml:"tlsInsecureSkipVerify"`
|
||||
TLSMinVersion string `json:"tlsMinVersion" yaml:"tlsMinVersion"`
|
||||
TLSCipherSuites []string `json:"tlsCipherSuites" yaml:"tlsCipherSuites"`
|
||||
|
||||
// Proxy settings
|
||||
ProxyURL string `json:"proxyURL" yaml:"proxyURL"`
|
||||
NoProxy []string `json:"noProxy" yaml:"noProxy"`
|
||||
}
|
||||
|
||||
// PoolConfig contains connection pool configuration
|
||||
type PoolConfig struct {
|
||||
Enabled bool `json:"enabled" yaml:"enabled"`
|
||||
Size int `json:"size" yaml:"size"`
|
||||
MinSize int `json:"minSize" yaml:"minSize"`
|
||||
MaxSize int `json:"maxSize" yaml:"maxSize"`
|
||||
MaxAge time.Duration `json:"maxAge" yaml:"maxAge"`
|
||||
IdleTimeout time.Duration `json:"idleTimeout" yaml:"idleTimeout"`
|
||||
WaitTimeout time.Duration `json:"waitTimeout" yaml:"waitTimeout"`
|
||||
|
||||
// Health checking
|
||||
HealthCheckInterval time.Duration `json:"healthCheckInterval" yaml:"healthCheckInterval"`
|
||||
MaxRetries int `json:"maxRetries" yaml:"maxRetries"`
|
||||
}
|
||||
|
||||
// CircuitConfig contains circuit breaker configuration
|
||||
type CircuitConfig struct {
|
||||
Enabled bool `json:"enabled" yaml:"enabled"`
|
||||
MaxRequests uint32 `json:"maxRequests" yaml:"maxRequests"`
|
||||
Interval time.Duration `json:"interval" yaml:"interval"`
|
||||
Timeout time.Duration `json:"timeout" yaml:"timeout"`
|
||||
ConsecutiveFailures uint32 `json:"consecutiveFailures" yaml:"consecutiveFailures"`
|
||||
FailureRatio float64 `json:"failureRatio" yaml:"failureRatio"`
|
||||
|
||||
// Circuit states
|
||||
OnOpen string `json:"onOpen" yaml:"onOpen"` // "reject", "fallback", "passthrough"
|
||||
OnHalfOpen string `json:"onHalfOpen" yaml:"onHalfOpen"`
|
||||
|
||||
// Monitoring
|
||||
MetricsEnabled bool `json:"metricsEnabled" yaml:"metricsEnabled"`
|
||||
LogStateChanges bool `json:"logStateChanges" yaml:"logStateChanges"`
|
||||
}
|
||||
@@ -0,0 +1,263 @@
|
||||
//go:build !yaegi
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// TestUnifiedConfigJSONMarshalling tests JSON marshalling with secret redaction
|
||||
func TestUnifiedConfigJSONMarshalling(t *testing.T) {
|
||||
config := &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://auth.example.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "super-secret-value",
|
||||
},
|
||||
Session: SessionConfig{
|
||||
Secret: "session-secret",
|
||||
EncryptionKey: "32-character-encryption-key-here",
|
||||
SigningKey: "signing-key-secret",
|
||||
},
|
||||
Redis: RedisConfig{
|
||||
Password: "redis-password",
|
||||
SentinelPassword: "sentinel-password",
|
||||
},
|
||||
}
|
||||
|
||||
// Marshal to JSON
|
||||
jsonBytes, err := json.Marshal(config)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal config to JSON: %v", err)
|
||||
}
|
||||
|
||||
jsonStr := string(jsonBytes)
|
||||
|
||||
// Verify secrets are redacted
|
||||
if !contains(jsonStr, `"clientSecret":"[REDACTED]"`) {
|
||||
t.Error("ClientSecret should be redacted in JSON output")
|
||||
}
|
||||
if !contains(jsonStr, `"secret":"[REDACTED]"`) {
|
||||
t.Error("Session.Secret should be redacted in JSON output")
|
||||
}
|
||||
if !contains(jsonStr, `"encryptionKey":"[REDACTED]"`) {
|
||||
t.Error("Session.EncryptionKey should be redacted in JSON output")
|
||||
}
|
||||
if !contains(jsonStr, `"signingKey":"[REDACTED]"`) {
|
||||
t.Error("Session.SigningKey should be redacted in JSON output")
|
||||
}
|
||||
if !contains(jsonStr, `"password":"[REDACTED]"`) {
|
||||
t.Error("Redis.Password should be redacted in JSON output")
|
||||
}
|
||||
if !contains(jsonStr, `"sentinelPassword":"[REDACTED]"`) {
|
||||
t.Error("Redis.SentinelPassword should be redacted in JSON output")
|
||||
}
|
||||
|
||||
// Verify non-secret fields are preserved
|
||||
if !contains(jsonStr, `"issuerURL":"https://auth.example.com"`) {
|
||||
t.Error("IssuerURL should be preserved in JSON output")
|
||||
}
|
||||
if !contains(jsonStr, `"clientID":"test-client"`) {
|
||||
t.Error("ClientID should be preserved in JSON output")
|
||||
}
|
||||
}
|
||||
|
||||
// TestUnifiedConfigYAMLMarshalling tests YAML marshalling with secret redaction
|
||||
func TestUnifiedConfigYAMLMarshalling(t *testing.T) {
|
||||
config := &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://auth.example.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "super-secret-value",
|
||||
},
|
||||
Session: SessionConfig{
|
||||
Secret: "session-secret",
|
||||
EncryptionKey: "32-character-encryption-key-here",
|
||||
SigningKey: "signing-key-secret",
|
||||
},
|
||||
Redis: RedisConfig{
|
||||
Password: "redis-password",
|
||||
SentinelPassword: "sentinel-password",
|
||||
},
|
||||
}
|
||||
|
||||
// Marshal to YAML
|
||||
yamlBytes, err := yaml.Marshal(config)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal config to YAML: %v", err)
|
||||
}
|
||||
|
||||
yamlStr := string(yamlBytes)
|
||||
|
||||
// Verify secrets are redacted
|
||||
if !contains(yamlStr, "clientSecret: '[REDACTED]'") {
|
||||
t.Error("ClientSecret should be redacted in YAML output")
|
||||
}
|
||||
if !contains(yamlStr, "secret: '[REDACTED]'") {
|
||||
t.Error("Session.Secret should be redacted in YAML output")
|
||||
}
|
||||
if !contains(yamlStr, "encryptionKey: '[REDACTED]'") {
|
||||
t.Error("Session.EncryptionKey should be redacted in YAML output")
|
||||
}
|
||||
if !contains(yamlStr, "signingKey: '[REDACTED]'") {
|
||||
t.Error("Session.SigningKey should be redacted in YAML output")
|
||||
}
|
||||
if !contains(yamlStr, "password: '[REDACTED]'") {
|
||||
t.Error("Redis.Password should be redacted in YAML output")
|
||||
}
|
||||
if !contains(yamlStr, "sentinelPassword: '[REDACTED]'") {
|
||||
t.Error("Redis.SentinelPassword should be redacted in YAML output")
|
||||
}
|
||||
|
||||
// Verify non-secret fields are preserved
|
||||
if !contains(yamlStr, "issuerURL: https://auth.example.com") {
|
||||
t.Error("IssuerURL should be preserved in YAML output")
|
||||
}
|
||||
if !contains(yamlStr, "clientID: test-client") {
|
||||
t.Error("ClientID should be preserved in YAML output")
|
||||
}
|
||||
}
|
||||
|
||||
// TestProviderConfigMarshalling tests individual struct marshalling
|
||||
func TestProviderConfigMarshalling(t *testing.T) {
|
||||
provider := ProviderConfig{
|
||||
IssuerURL: "https://auth.example.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "super-secret-value",
|
||||
}
|
||||
|
||||
// Test JSON marshalling
|
||||
jsonBytes, err := json.Marshal(provider)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal ProviderConfig to JSON: %v", err)
|
||||
}
|
||||
|
||||
jsonStr := string(jsonBytes)
|
||||
if !contains(jsonStr, `"clientSecret":"[REDACTED]"`) {
|
||||
t.Error("ClientSecret should be redacted in JSON output")
|
||||
}
|
||||
if !contains(jsonStr, `"clientID":"test-client"`) {
|
||||
t.Error("ClientID should be preserved in JSON output")
|
||||
}
|
||||
|
||||
// Test YAML marshalling
|
||||
yamlBytes, err := yaml.Marshal(provider)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal ProviderConfig to YAML: %v", err)
|
||||
}
|
||||
|
||||
yamlStr := string(yamlBytes)
|
||||
if !contains(yamlStr, "clientSecret: '[REDACTED]'") {
|
||||
t.Error("ClientSecret should be redacted in YAML output")
|
||||
}
|
||||
if !contains(yamlStr, "clientID: test-client") {
|
||||
t.Error("ClientID should be preserved in YAML output")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSessionConfigMarshalling tests session config marshalling
|
||||
func TestSessionConfigMarshalling(t *testing.T) {
|
||||
session := SessionConfig{
|
||||
Name: "session-cookie",
|
||||
Secret: "session-secret",
|
||||
EncryptionKey: "32-character-encryption-key-here",
|
||||
SigningKey: "signing-key-secret",
|
||||
Domain: "example.com",
|
||||
Secure: true,
|
||||
}
|
||||
|
||||
// Test JSON marshalling
|
||||
jsonBytes, err := json.Marshal(session)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal SessionConfig to JSON: %v", err)
|
||||
}
|
||||
|
||||
jsonStr := string(jsonBytes)
|
||||
if !contains(jsonStr, `"secret":"[REDACTED]"`) {
|
||||
t.Error("Secret should be redacted in JSON output")
|
||||
}
|
||||
if !contains(jsonStr, `"encryptionKey":"[REDACTED]"`) {
|
||||
t.Error("EncryptionKey should be redacted in JSON output")
|
||||
}
|
||||
if !contains(jsonStr, `"signingKey":"[REDACTED]"`) {
|
||||
t.Error("SigningKey should be redacted in JSON output")
|
||||
}
|
||||
if !contains(jsonStr, `"name":"session-cookie"`) {
|
||||
t.Error("Name should be preserved in JSON output")
|
||||
}
|
||||
if !contains(jsonStr, `"domain":"example.com"`) {
|
||||
t.Error("Domain should be preserved in JSON output")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRedisConfigMarshalling tests Redis config marshalling
|
||||
func TestRedisConfigMarshalling(t *testing.T) {
|
||||
redis := RedisConfig{
|
||||
Enabled: true,
|
||||
Mode: RedisModeCluster,
|
||||
Password: "redis-password",
|
||||
SentinelPassword: "sentinel-password",
|
||||
Addr: "localhost:6379",
|
||||
DB: 1,
|
||||
}
|
||||
|
||||
// Test JSON marshalling
|
||||
jsonBytes, err := json.Marshal(redis)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal RedisConfig to JSON: %v", err)
|
||||
}
|
||||
|
||||
jsonStr := string(jsonBytes)
|
||||
if !contains(jsonStr, `"password":"[REDACTED]"`) {
|
||||
t.Error("Password should be redacted in JSON output")
|
||||
}
|
||||
if !contains(jsonStr, `"sentinelPassword":"[REDACTED]"`) {
|
||||
t.Error("SentinelPassword should be redacted in JSON output")
|
||||
}
|
||||
if !contains(jsonStr, `"addr":"localhost:6379"`) {
|
||||
t.Error("Addr should be preserved in JSON output")
|
||||
}
|
||||
if !contains(jsonStr, `"db":1`) {
|
||||
t.Error("DB should be preserved in JSON output")
|
||||
}
|
||||
}
|
||||
|
||||
// TestEmptySecretsNotRedacted tests that empty secrets are not shown as redacted
|
||||
func TestEmptySecretsNotRedacted(t *testing.T) {
|
||||
config := &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://auth.example.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "", // Empty secret
|
||||
},
|
||||
Session: SessionConfig{
|
||||
Secret: "", // Empty secret
|
||||
EncryptionKey: "", // Empty secret
|
||||
},
|
||||
Redis: RedisConfig{
|
||||
Password: "", // Empty secret
|
||||
},
|
||||
}
|
||||
|
||||
// Marshal to JSON
|
||||
jsonBytes, err := json.Marshal(config)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal config to JSON: %v", err)
|
||||
}
|
||||
|
||||
jsonStr := string(jsonBytes)
|
||||
|
||||
// Verify empty secrets are not shown as redacted
|
||||
if contains(jsonStr, "[REDACTED]") {
|
||||
t.Error("Empty secrets should not be shown as [REDACTED]")
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to check if string contains substring
|
||||
func contains(s, substr string) bool {
|
||||
return strings.Contains(s, substr)
|
||||
}
|
||||
@@ -0,0 +1,652 @@
|
||||
// Package config provides validation for unified configuration
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ValidationError represents a configuration validation error
|
||||
type ValidationError struct {
|
||||
Field string
|
||||
Message string
|
||||
Value interface{}
|
||||
}
|
||||
|
||||
// Error implements the error interface
|
||||
func (e *ValidationError) Error() string {
|
||||
if e.Value != nil {
|
||||
return fmt.Sprintf("config validation error: %s: %s (value: %v)", e.Field, e.Message, e.Value)
|
||||
}
|
||||
return fmt.Sprintf("config validation error: %s: %s", e.Field, e.Message)
|
||||
}
|
||||
|
||||
// ValidationErrors represents multiple validation errors
|
||||
type ValidationErrors []ValidationError
|
||||
|
||||
// Error implements the error interface
|
||||
func (e ValidationErrors) Error() string {
|
||||
if len(e) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var messages []string
|
||||
for _, err := range e {
|
||||
messages = append(messages, err.Error())
|
||||
}
|
||||
return strings.Join(messages, "; ")
|
||||
}
|
||||
|
||||
// Validate performs comprehensive validation on the unified configuration
|
||||
func (c *UnifiedConfig) Validate() error {
|
||||
var errors ValidationErrors
|
||||
|
||||
// Validate Provider configuration
|
||||
if err := c.validateProvider(); err != nil {
|
||||
errors = append(errors, err...)
|
||||
}
|
||||
|
||||
// Validate Session configuration
|
||||
if err := c.validateSession(); err != nil {
|
||||
errors = append(errors, err...)
|
||||
}
|
||||
|
||||
// Validate Token configuration
|
||||
if err := c.validateToken(); err != nil {
|
||||
errors = append(errors, err...)
|
||||
}
|
||||
|
||||
// Validate Redis configuration (uses existing validation)
|
||||
if err := c.Redis.Validate(); err != nil {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Redis",
|
||||
Message: err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
// Validate Security configuration
|
||||
if err := c.validateSecurity(); err != nil {
|
||||
errors = append(errors, err...)
|
||||
}
|
||||
|
||||
// Validate Middleware configuration
|
||||
if err := c.validateMiddleware(); err != nil {
|
||||
errors = append(errors, err...)
|
||||
}
|
||||
|
||||
// Validate Cache configuration
|
||||
if err := c.validateCache(); err != nil {
|
||||
errors = append(errors, err...)
|
||||
}
|
||||
|
||||
// Validate RateLimit configuration
|
||||
if err := c.validateRateLimit(); err != nil {
|
||||
errors = append(errors, err...)
|
||||
}
|
||||
|
||||
// Validate Logging configuration
|
||||
if err := c.validateLogging(); err != nil {
|
||||
errors = append(errors, err...)
|
||||
}
|
||||
|
||||
// Validate Metrics configuration
|
||||
if err := c.validateMetrics(); err != nil {
|
||||
errors = append(errors, err...)
|
||||
}
|
||||
|
||||
// Validate Transport configuration
|
||||
if err := c.validateTransport(); err != nil {
|
||||
errors = append(errors, err...)
|
||||
}
|
||||
|
||||
// Validate Circuit configuration
|
||||
if err := c.validateCircuit(); err != nil {
|
||||
errors = append(errors, err...)
|
||||
}
|
||||
|
||||
if len(errors) > 0 {
|
||||
return errors
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateProvider validates provider configuration
|
||||
func (c *UnifiedConfig) validateProvider() ValidationErrors {
|
||||
var errors ValidationErrors
|
||||
|
||||
// IssuerURL is required and must be a valid URL
|
||||
if c.Provider.IssuerURL == "" {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Provider.IssuerURL",
|
||||
Message: "issuer URL is required",
|
||||
})
|
||||
} else if _, err := url.Parse(c.Provider.IssuerURL); err != nil {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Provider.IssuerURL",
|
||||
Message: "invalid issuer URL",
|
||||
Value: c.Provider.IssuerURL,
|
||||
})
|
||||
}
|
||||
|
||||
// ClientID is required
|
||||
if c.Provider.ClientID == "" {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Provider.ClientID",
|
||||
Message: "client ID is required",
|
||||
})
|
||||
}
|
||||
|
||||
// ClientSecret is required (except for public clients with PKCE)
|
||||
if c.Provider.ClientSecret == "" && !c.Security.EnablePKCE {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Provider.ClientSecret",
|
||||
Message: "client secret is required (or enable PKCE for public clients)",
|
||||
})
|
||||
}
|
||||
|
||||
// RedirectURL must be valid if provided
|
||||
if c.Provider.RedirectURL != "" {
|
||||
if _, err := url.Parse(c.Provider.RedirectURL); err != nil {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Provider.RedirectURL",
|
||||
Message: "invalid redirect URL",
|
||||
Value: c.Provider.RedirectURL,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Scopes must include 'openid' for OIDC
|
||||
hasOpenID := false
|
||||
for _, scope := range c.Provider.Scopes {
|
||||
if scope == "openid" {
|
||||
hasOpenID = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasOpenID && !c.Provider.OverrideScopes {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Provider.Scopes",
|
||||
Message: "scopes must include 'openid' for OIDC",
|
||||
Value: c.Provider.Scopes,
|
||||
})
|
||||
}
|
||||
|
||||
// JWK cache period must be positive
|
||||
if c.Provider.JWKCachePeriod < 0 {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Provider.JWKCachePeriod",
|
||||
Message: "JWK cache period must be positive",
|
||||
Value: c.Provider.JWKCachePeriod,
|
||||
})
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
|
||||
// validateSession validates session configuration
|
||||
func (c *UnifiedConfig) validateSession() ValidationErrors {
|
||||
var errors ValidationErrors
|
||||
|
||||
// Session name must not be empty
|
||||
if c.Session.Name == "" {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Session.Name",
|
||||
Message: "session name is required",
|
||||
})
|
||||
}
|
||||
|
||||
// Session secret or encryption key is required
|
||||
if c.Session.Secret == "" && c.Session.EncryptionKey == "" {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Session",
|
||||
Message: "either session secret or encryption key is required",
|
||||
})
|
||||
}
|
||||
|
||||
// Encryption key must be at least 32 bytes for security
|
||||
if c.Session.EncryptionKey != "" && len(c.Session.EncryptionKey) < 32 {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Session.EncryptionKey",
|
||||
Message: "encryption key must be at least 32 characters for proper security",
|
||||
Value: len(c.Session.EncryptionKey),
|
||||
})
|
||||
}
|
||||
|
||||
// ChunkSize must be reasonable (between 1KB and 10KB)
|
||||
if c.Session.ChunkSize < 1000 || c.Session.ChunkSize > 10000 {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Session.ChunkSize",
|
||||
Message: "chunk size must be between 1000 and 10000 bytes",
|
||||
Value: c.Session.ChunkSize,
|
||||
})
|
||||
}
|
||||
|
||||
// MaxChunks must be reasonable (between 1 and 100)
|
||||
if c.Session.MaxChunks < 1 || c.Session.MaxChunks > 100 {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Session.MaxChunks",
|
||||
Message: "max chunks must be between 1 and 100",
|
||||
Value: c.Session.MaxChunks,
|
||||
})
|
||||
}
|
||||
|
||||
// SameSite must be valid
|
||||
validSameSite := map[string]bool{
|
||||
"": true,
|
||||
"Lax": true,
|
||||
"Strict": true,
|
||||
"None": true,
|
||||
}
|
||||
if !validSameSite[c.Session.SameSite] {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Session.SameSite",
|
||||
Message: "invalid SameSite value (must be Lax, Strict, or None)",
|
||||
Value: c.Session.SameSite,
|
||||
})
|
||||
}
|
||||
|
||||
// StorageType must be valid
|
||||
validStorage := map[string]bool{
|
||||
"memory": true,
|
||||
"redis": true,
|
||||
"cookie": true,
|
||||
}
|
||||
if !validStorage[c.Session.StorageType] {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Session.StorageType",
|
||||
Message: "invalid storage type (must be memory, redis, or cookie)",
|
||||
Value: c.Session.StorageType,
|
||||
})
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
|
||||
// validateToken validates token configuration
|
||||
func (c *UnifiedConfig) validateToken() ValidationErrors {
|
||||
var errors ValidationErrors
|
||||
|
||||
// Token TTLs must be positive
|
||||
if c.Token.AccessTokenTTL <= 0 {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Token.AccessTokenTTL",
|
||||
Message: "access token TTL must be positive",
|
||||
Value: c.Token.AccessTokenTTL,
|
||||
})
|
||||
}
|
||||
|
||||
if c.Token.RefreshTokenTTL <= 0 {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Token.RefreshTokenTTL",
|
||||
Message: "refresh token TTL must be positive",
|
||||
Value: c.Token.RefreshTokenTTL,
|
||||
})
|
||||
}
|
||||
|
||||
// Validation mode must be valid
|
||||
validModes := map[string]bool{
|
||||
"jwt": true,
|
||||
"introspect": true,
|
||||
"hybrid": true,
|
||||
}
|
||||
if !validModes[c.Token.ValidationMode] {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Token.ValidationMode",
|
||||
Message: "invalid validation mode (must be jwt, introspect, or hybrid)",
|
||||
Value: c.Token.ValidationMode,
|
||||
})
|
||||
}
|
||||
|
||||
// Introspect URL required for introspect or hybrid mode
|
||||
if (c.Token.ValidationMode == "introspect" || c.Token.ValidationMode == "hybrid") && c.Token.IntrospectURL == "" {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Token.IntrospectURL",
|
||||
Message: "introspect URL is required for introspect or hybrid validation mode",
|
||||
})
|
||||
}
|
||||
|
||||
// Clock skew must be reasonable (0 to 10 minutes)
|
||||
if c.Token.ClockSkew < 0 || c.Token.ClockSkew > 10*time.Minute {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Token.ClockSkew",
|
||||
Message: "clock skew must be between 0 and 10 minutes",
|
||||
Value: c.Token.ClockSkew,
|
||||
})
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
|
||||
// validateSecurity validates security configuration
|
||||
func (c *UnifiedConfig) validateSecurity() ValidationErrors {
|
||||
var errors ValidationErrors
|
||||
|
||||
// Validate allowed user domains are valid domains
|
||||
domainRegex := regexp.MustCompile(`^([a-zA-Z0-9-]+\.)*[a-zA-Z0-9-]+\.[a-zA-Z]{2,}$`)
|
||||
for _, domain := range c.Security.AllowedUserDomains {
|
||||
if !domainRegex.MatchString(domain) {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Security.AllowedUserDomains",
|
||||
Message: "invalid domain format",
|
||||
Value: domain,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Max login attempts must be reasonable
|
||||
if c.Security.MaxLoginAttempts < 0 || c.Security.MaxLoginAttempts > 100 {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Security.MaxLoginAttempts",
|
||||
Message: "max login attempts must be between 0 and 100",
|
||||
Value: c.Security.MaxLoginAttempts,
|
||||
})
|
||||
}
|
||||
|
||||
// Lockout duration must be reasonable
|
||||
if c.Security.LockoutDuration < 0 || c.Security.LockoutDuration > 24*time.Hour {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Security.LockoutDuration",
|
||||
Message: "lockout duration must be between 0 and 24 hours",
|
||||
Value: c.Security.LockoutDuration,
|
||||
})
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
|
||||
// validateMiddleware validates middleware configuration
|
||||
func (c *UnifiedConfig) validateMiddleware() ValidationErrors {
|
||||
var errors ValidationErrors
|
||||
|
||||
// Max request size must be reasonable (1KB to 100MB)
|
||||
if c.Middleware.MaxRequestSize < 1024 || c.Middleware.MaxRequestSize > 100*1024*1024 {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Middleware.MaxRequestSize",
|
||||
Message: "max request size must be between 1KB and 100MB",
|
||||
Value: c.Middleware.MaxRequestSize,
|
||||
})
|
||||
}
|
||||
|
||||
// Request timeout must be reasonable
|
||||
if c.Middleware.RequestTimeout < time.Second || c.Middleware.RequestTimeout > 5*time.Minute {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Middleware.RequestTimeout",
|
||||
Message: "request timeout must be between 1 second and 5 minutes",
|
||||
Value: c.Middleware.RequestTimeout,
|
||||
})
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
|
||||
// validateCache validates cache configuration
|
||||
func (c *UnifiedConfig) validateCache() ValidationErrors {
|
||||
var errors ValidationErrors
|
||||
|
||||
if !c.Cache.Enabled {
|
||||
return errors
|
||||
}
|
||||
|
||||
// Cache type must be valid
|
||||
validTypes := map[string]bool{
|
||||
"memory": true,
|
||||
"redis": true,
|
||||
"hybrid": true,
|
||||
}
|
||||
if !validTypes[c.Cache.Type] {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Cache.Type",
|
||||
Message: "invalid cache type (must be memory, redis, or hybrid)",
|
||||
Value: c.Cache.Type,
|
||||
})
|
||||
}
|
||||
|
||||
// Max entries must be reasonable
|
||||
if c.Cache.MaxEntries < 10 || c.Cache.MaxEntries > 1000000 {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Cache.MaxEntries",
|
||||
Message: "max entries must be between 10 and 1000000",
|
||||
Value: c.Cache.MaxEntries,
|
||||
})
|
||||
}
|
||||
|
||||
// Eviction policy must be valid
|
||||
validEviction := map[string]bool{
|
||||
"lru": true,
|
||||
"lfu": true,
|
||||
"fifo": true,
|
||||
}
|
||||
if !validEviction[c.Cache.EvictionPolicy] {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Cache.EvictionPolicy",
|
||||
Message: "invalid eviction policy (must be lru, lfu, or fifo)",
|
||||
Value: c.Cache.EvictionPolicy,
|
||||
})
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
|
||||
// validateRateLimit validates rate limiting configuration
|
||||
func (c *UnifiedConfig) validateRateLimit() ValidationErrors {
|
||||
var errors ValidationErrors
|
||||
|
||||
if !c.RateLimit.Enabled {
|
||||
return errors
|
||||
}
|
||||
|
||||
// Requests per second must be reasonable
|
||||
if c.RateLimit.RequestsPerSecond < 1 || c.RateLimit.RequestsPerSecond > 10000 {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "RateLimit.RequestsPerSecond",
|
||||
Message: "requests per second must be between 1 and 10000",
|
||||
Value: c.RateLimit.RequestsPerSecond,
|
||||
})
|
||||
}
|
||||
|
||||
// Burst must be at least as large as requests per second
|
||||
if c.RateLimit.Burst < c.RateLimit.RequestsPerSecond {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "RateLimit.Burst",
|
||||
Message: "burst must be at least as large as requests per second",
|
||||
Value: c.RateLimit.Burst,
|
||||
})
|
||||
}
|
||||
|
||||
// Key type must be valid
|
||||
validKeyTypes := map[string]bool{
|
||||
"ip": true,
|
||||
"user": true,
|
||||
"token": true,
|
||||
"custom": true,
|
||||
}
|
||||
if !validKeyTypes[c.RateLimit.KeyType] {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "RateLimit.KeyType",
|
||||
Message: "invalid key type (must be ip, user, token, or custom)",
|
||||
Value: c.RateLimit.KeyType,
|
||||
})
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
|
||||
// validateLogging validates logging configuration
|
||||
func (c *UnifiedConfig) validateLogging() ValidationErrors {
|
||||
var errors ValidationErrors
|
||||
|
||||
// Log level must be valid
|
||||
validLevels := map[string]bool{
|
||||
"debug": true,
|
||||
"info": true,
|
||||
"warn": true,
|
||||
"error": true,
|
||||
}
|
||||
if !validLevels[c.Logging.Level] {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Logging.Level",
|
||||
Message: "invalid log level (must be debug, info, warn, or error)",
|
||||
Value: c.Logging.Level,
|
||||
})
|
||||
}
|
||||
|
||||
// Format must be valid
|
||||
validFormats := map[string]bool{
|
||||
"json": true,
|
||||
"text": true,
|
||||
"structured": true,
|
||||
}
|
||||
if !validFormats[c.Logging.Format] {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Logging.Format",
|
||||
Message: "invalid log format (must be json, text, or structured)",
|
||||
Value: c.Logging.Format,
|
||||
})
|
||||
}
|
||||
|
||||
// Output must be valid
|
||||
validOutputs := map[string]bool{
|
||||
"stdout": true,
|
||||
"stderr": true,
|
||||
"file": true,
|
||||
}
|
||||
if !validOutputs[c.Logging.Output] {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Logging.Output",
|
||||
Message: "invalid log output (must be stdout, stderr, or file)",
|
||||
Value: c.Logging.Output,
|
||||
})
|
||||
}
|
||||
|
||||
// File path required if output is file
|
||||
if c.Logging.Output == "file" && c.Logging.FilePath == "" {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Logging.FilePath",
|
||||
Message: "file path is required when output is 'file'",
|
||||
})
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
|
||||
// validateMetrics validates metrics configuration
|
||||
func (c *UnifiedConfig) validateMetrics() ValidationErrors {
|
||||
var errors ValidationErrors
|
||||
|
||||
if !c.Metrics.Enabled {
|
||||
return errors
|
||||
}
|
||||
|
||||
// Provider must be valid
|
||||
validProviders := map[string]bool{
|
||||
"prometheus": true,
|
||||
"statsd": true,
|
||||
"otlp": true,
|
||||
}
|
||||
if !validProviders[c.Metrics.Provider] {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Metrics.Provider",
|
||||
Message: "invalid metrics provider (must be prometheus, statsd, or otlp)",
|
||||
Value: c.Metrics.Provider,
|
||||
})
|
||||
}
|
||||
|
||||
// Endpoint required for some providers
|
||||
if (c.Metrics.Provider == "statsd" || c.Metrics.Provider == "otlp") && c.Metrics.Endpoint == "" {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Metrics.Endpoint",
|
||||
Message: fmt.Sprintf("endpoint is required for %s provider", c.Metrics.Provider),
|
||||
})
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
|
||||
// validateTransport validates transport configuration
|
||||
func (c *UnifiedConfig) validateTransport() ValidationErrors {
|
||||
var errors ValidationErrors
|
||||
|
||||
// Max connections must be reasonable
|
||||
if c.Transport.MaxIdleConns < 0 || c.Transport.MaxIdleConns > 10000 {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Transport.MaxIdleConns",
|
||||
Message: "max idle connections must be between 0 and 10000",
|
||||
Value: c.Transport.MaxIdleConns,
|
||||
})
|
||||
}
|
||||
|
||||
// TLS min version must be valid
|
||||
validTLSVersions := map[string]bool{
|
||||
"TLS1.0": true,
|
||||
"TLS1.1": true,
|
||||
"TLS1.2": true,
|
||||
"TLS1.3": true,
|
||||
}
|
||||
if c.Transport.TLSMinVersion != "" && !validTLSVersions[c.Transport.TLSMinVersion] {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Transport.TLSMinVersion",
|
||||
Message: "invalid TLS min version (must be TLS1.0, TLS1.1, TLS1.2, or TLS1.3)",
|
||||
Value: c.Transport.TLSMinVersion,
|
||||
})
|
||||
}
|
||||
|
||||
// Proxy URL must be valid if provided
|
||||
if c.Transport.ProxyURL != "" {
|
||||
if _, err := url.Parse(c.Transport.ProxyURL); err != nil {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Transport.ProxyURL",
|
||||
Message: "invalid proxy URL",
|
||||
Value: c.Transport.ProxyURL,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
|
||||
// validateCircuit validates circuit breaker configuration
|
||||
func (c *UnifiedConfig) validateCircuit() ValidationErrors {
|
||||
var errors ValidationErrors
|
||||
|
||||
if !c.Circuit.Enabled {
|
||||
return errors
|
||||
}
|
||||
|
||||
// Consecutive failures must be reasonable
|
||||
if c.Circuit.ConsecutiveFailures < 1 || c.Circuit.ConsecutiveFailures > 100 {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Circuit.ConsecutiveFailures",
|
||||
Message: "consecutive failures must be between 1 and 100",
|
||||
Value: c.Circuit.ConsecutiveFailures,
|
||||
})
|
||||
}
|
||||
|
||||
// Failure ratio must be between 0 and 1
|
||||
if c.Circuit.FailureRatio < 0 || c.Circuit.FailureRatio > 1 {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Circuit.FailureRatio",
|
||||
Message: "failure ratio must be between 0 and 1",
|
||||
Value: c.Circuit.FailureRatio,
|
||||
})
|
||||
}
|
||||
|
||||
// OnOpen action must be valid
|
||||
validActions := map[string]bool{
|
||||
"reject": true,
|
||||
"fallback": true,
|
||||
"passthrough": true,
|
||||
}
|
||||
if !validActions[c.Circuit.OnOpen] {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Circuit.OnOpen",
|
||||
Message: "invalid OnOpen action (must be reject, fallback, or passthrough)",
|
||||
Value: c.Circuit.OnOpen,
|
||||
})
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
@@ -0,0 +1,588 @@
|
||||
//go:build !yaegi
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestValidateUnifiedConfig tests the validation of UnifiedConfig
|
||||
func TestValidateUnifiedConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config *UnifiedConfig
|
||||
expectError bool
|
||||
errorField string
|
||||
}{
|
||||
{
|
||||
name: "valid config with minimum requirements",
|
||||
config: &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://auth.example.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "secret",
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
},
|
||||
Session: SessionConfig{
|
||||
Name: "oidc_session",
|
||||
EncryptionKey: "this-is-a-32-character-key-12345",
|
||||
ChunkSize: 4000,
|
||||
MaxChunks: 5,
|
||||
StorageType: "cookie",
|
||||
},
|
||||
Token: TokenConfig{
|
||||
AccessTokenTTL: time.Hour,
|
||||
RefreshTokenTTL: 24 * time.Hour,
|
||||
ValidationMode: "jwt",
|
||||
},
|
||||
Middleware: MiddlewareConfig{
|
||||
MaxRequestSize: 10 * 1024 * 1024,
|
||||
RequestTimeout: 30 * time.Second,
|
||||
},
|
||||
Logging: LoggingConfig{
|
||||
Level: "info",
|
||||
Format: "json",
|
||||
Output: "stdout",
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "missing provider URL",
|
||||
config: &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "secret",
|
||||
},
|
||||
Session: SessionConfig{
|
||||
EncryptionKey: "this-is-a-32-character-key-12345",
|
||||
},
|
||||
},
|
||||
expectError: true,
|
||||
errorField: "Provider.IssuerURL",
|
||||
},
|
||||
{
|
||||
name: "missing client ID",
|
||||
config: &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://auth.example.com",
|
||||
ClientSecret: "secret",
|
||||
},
|
||||
Session: SessionConfig{
|
||||
EncryptionKey: "this-is-a-32-character-key-12345",
|
||||
},
|
||||
},
|
||||
expectError: true,
|
||||
errorField: "Provider.ClientID",
|
||||
},
|
||||
{
|
||||
name: "encryption key too short",
|
||||
config: &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://auth.example.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "secret",
|
||||
},
|
||||
Session: SessionConfig{
|
||||
EncryptionKey: "too-short",
|
||||
},
|
||||
},
|
||||
expectError: true,
|
||||
errorField: "Session.EncryptionKey",
|
||||
},
|
||||
{
|
||||
name: "invalid chunk size",
|
||||
config: &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://auth.example.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "secret",
|
||||
},
|
||||
Session: SessionConfig{
|
||||
EncryptionKey: "this-is-a-32-character-key-12345",
|
||||
ChunkSize: 500, // Too small
|
||||
},
|
||||
},
|
||||
expectError: true,
|
||||
errorField: "Session.ChunkSize",
|
||||
},
|
||||
{
|
||||
name: "invalid max chunks",
|
||||
config: &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://auth.example.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "secret",
|
||||
},
|
||||
Session: SessionConfig{
|
||||
EncryptionKey: "this-is-a-32-character-key-12345",
|
||||
ChunkSize: 4000,
|
||||
MaxChunks: 0, // Too small
|
||||
},
|
||||
},
|
||||
expectError: true,
|
||||
errorField: "Session.MaxChunks",
|
||||
},
|
||||
{
|
||||
name: "invalid TLS min version",
|
||||
config: &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://auth.example.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "secret",
|
||||
},
|
||||
Session: SessionConfig{
|
||||
EncryptionKey: "this-is-a-32-character-key-12345",
|
||||
},
|
||||
Transport: TransportConfig{
|
||||
TLSMinVersion: "1.0", // Too old
|
||||
},
|
||||
},
|
||||
expectError: true,
|
||||
errorField: "Transport.TLSMinVersion",
|
||||
},
|
||||
{
|
||||
name: "invalid circuit breaker failure ratio",
|
||||
config: &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://auth.example.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "secret",
|
||||
},
|
||||
Session: SessionConfig{
|
||||
EncryptionKey: "this-is-a-32-character-key-12345",
|
||||
},
|
||||
Circuit: CircuitConfig{
|
||||
Enabled: true,
|
||||
FailureRatio: 1.5, // Too high
|
||||
},
|
||||
},
|
||||
expectError: true,
|
||||
errorField: "Circuit.FailureRatio",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.config.Validate()
|
||||
|
||||
if tt.expectError {
|
||||
if err == nil {
|
||||
t.Errorf("Expected validation error for field %s, but got none", tt.errorField)
|
||||
} else if validationErrs, ok := err.(ValidationErrors); ok {
|
||||
found := false
|
||||
for _, e := range validationErrs {
|
||||
if e.Field == tt.errorField {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("Expected validation error for field %s, but got errors for: %v",
|
||||
tt.errorField, validationErrs)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("Expected no validation error, but got: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidationErrorMessage tests validation error formatting
|
||||
func TestValidationErrorMessage(t *testing.T) {
|
||||
errs := ValidationErrors{
|
||||
{
|
||||
Field: "Provider.IssuerURL",
|
||||
Message: "is required",
|
||||
Value: nil,
|
||||
},
|
||||
{
|
||||
Field: "Session.EncryptionKey",
|
||||
Message: "must be at least 32 characters",
|
||||
Value: 16,
|
||||
},
|
||||
}
|
||||
|
||||
errMsg := errs.Error()
|
||||
|
||||
if !strings.Contains(errMsg, "Provider.IssuerURL") {
|
||||
t.Error("Error message should contain field name Provider.IssuerURL")
|
||||
}
|
||||
if !strings.Contains(errMsg, "is required") {
|
||||
t.Error("Error message should contain 'is required'")
|
||||
}
|
||||
if !strings.Contains(errMsg, "Session.EncryptionKey") {
|
||||
t.Error("Error message should contain field name Session.EncryptionKey")
|
||||
}
|
||||
if !strings.Contains(errMsg, "must be at least 32 characters") {
|
||||
t.Error("Error message should contain 'must be at least 32 characters'")
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateRedisConfig tests Redis configuration validation
|
||||
func TestValidateRedisConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config *RedisConfig
|
||||
expectError bool
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "valid standalone config",
|
||||
config: &RedisConfig{
|
||||
Enabled: true,
|
||||
Mode: RedisModeStandalone,
|
||||
Addr: "localhost:6379",
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "missing address for standalone",
|
||||
config: &RedisConfig{
|
||||
Enabled: true,
|
||||
Mode: RedisModeStandalone,
|
||||
Addr: "",
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "Redis address is required",
|
||||
},
|
||||
{
|
||||
name: "valid cluster config",
|
||||
config: &RedisConfig{
|
||||
Enabled: true,
|
||||
Mode: RedisModeCluster,
|
||||
ClusterAddrs: []string{"localhost:7000", "localhost:7001"},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "missing cluster addresses",
|
||||
config: &RedisConfig{
|
||||
Enabled: true,
|
||||
Mode: RedisModeCluster,
|
||||
ClusterAddrs: []string{},
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "cluster address is required",
|
||||
},
|
||||
{
|
||||
name: "valid sentinel config",
|
||||
config: &RedisConfig{
|
||||
Enabled: true,
|
||||
Mode: RedisModeSentinel,
|
||||
MasterName: "mymaster",
|
||||
SentinelAddrs: []string{"localhost:26379"},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "missing master name for sentinel",
|
||||
config: &RedisConfig{
|
||||
Enabled: true,
|
||||
Mode: RedisModeSentinel,
|
||||
MasterName: "",
|
||||
SentinelAddrs: []string{"localhost:26379"},
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "Master name is required",
|
||||
},
|
||||
{
|
||||
name: "missing sentinel addresses",
|
||||
config: &RedisConfig{
|
||||
Enabled: true,
|
||||
Mode: RedisModeSentinel,
|
||||
MasterName: "mymaster",
|
||||
SentinelAddrs: []string{},
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "sentinel address is required",
|
||||
},
|
||||
{
|
||||
name: "disabled redis needs no validation",
|
||||
config: &RedisConfig{
|
||||
Enabled: false,
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid redis mode",
|
||||
config: &RedisConfig{
|
||||
Enabled: true,
|
||||
Mode: "invalid-mode",
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "Invalid Redis mode",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.config.Validate()
|
||||
|
||||
if tt.expectError {
|
||||
if err == nil {
|
||||
t.Errorf("Expected validation error containing '%s', but got none", tt.errorMsg)
|
||||
} else if !strings.Contains(err.Error(), tt.errorMsg) {
|
||||
t.Errorf("Expected error message to contain '%s', but got: %v", tt.errorMsg, err)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("Expected no validation error, but got: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// validateRateLimit Tests
|
||||
// ============================================================================
|
||||
|
||||
func TestValidateRateLimit_Disabled(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.RateLimit.Enabled = false
|
||||
|
||||
errors := config.validateRateLimit()
|
||||
|
||||
assert.Empty(t, errors, "Should have no errors when rate limiting is disabled")
|
||||
}
|
||||
|
||||
func TestValidateRateLimit_ValidConfig(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.RateLimit.Enabled = true
|
||||
config.RateLimit.RequestsPerSecond = 100
|
||||
config.RateLimit.Burst = 200
|
||||
config.RateLimit.KeyType = "ip"
|
||||
|
||||
errors := config.validateRateLimit()
|
||||
|
||||
assert.Empty(t, errors, "Should have no errors for valid rate limit config")
|
||||
}
|
||||
|
||||
func TestValidateRateLimit_RequestsPerSecondTooLow(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.RateLimit.Enabled = true
|
||||
config.RateLimit.RequestsPerSecond = 0
|
||||
config.RateLimit.Burst = 100
|
||||
config.RateLimit.KeyType = "ip"
|
||||
|
||||
errors := config.validateRateLimit()
|
||||
|
||||
require.Len(t, errors, 1)
|
||||
assert.Equal(t, "RateLimit.RequestsPerSecond", errors[0].Field)
|
||||
assert.Contains(t, errors[0].Message, "between 1 and 10000")
|
||||
}
|
||||
|
||||
func TestValidateRateLimit_RequestsPerSecondTooHigh(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.RateLimit.Enabled = true
|
||||
config.RateLimit.RequestsPerSecond = 15000
|
||||
config.RateLimit.Burst = 20000
|
||||
config.RateLimit.KeyType = "ip"
|
||||
|
||||
errors := config.validateRateLimit()
|
||||
|
||||
require.Len(t, errors, 1)
|
||||
assert.Equal(t, "RateLimit.RequestsPerSecond", errors[0].Field)
|
||||
assert.Contains(t, errors[0].Message, "between 1 and 10000")
|
||||
}
|
||||
|
||||
func TestValidateRateLimit_BurstTooSmall(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.RateLimit.Enabled = true
|
||||
config.RateLimit.RequestsPerSecond = 100
|
||||
config.RateLimit.Burst = 50 // Less than RequestsPerSecond
|
||||
config.RateLimit.KeyType = "ip"
|
||||
|
||||
errors := config.validateRateLimit()
|
||||
|
||||
require.Len(t, errors, 1)
|
||||
assert.Equal(t, "RateLimit.Burst", errors[0].Field)
|
||||
assert.Contains(t, errors[0].Message, "at least as large as requests per second")
|
||||
}
|
||||
|
||||
func TestValidateRateLimit_InvalidKeyType(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
keyType string
|
||||
}{
|
||||
{"empty key type", ""},
|
||||
{"invalid key type", "invalid"},
|
||||
{"random string", "foobar"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.RateLimit.Enabled = true
|
||||
config.RateLimit.RequestsPerSecond = 100
|
||||
config.RateLimit.Burst = 200
|
||||
config.RateLimit.KeyType = tt.keyType
|
||||
|
||||
errors := config.validateRateLimit()
|
||||
|
||||
require.Len(t, errors, 1)
|
||||
assert.Equal(t, "RateLimit.KeyType", errors[0].Field)
|
||||
assert.Contains(t, errors[0].Message, "invalid key type")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateRateLimit_ValidKeyTypes(t *testing.T) {
|
||||
validKeyTypes := []string{"ip", "user", "token", "custom"}
|
||||
|
||||
for _, keyType := range validKeyTypes {
|
||||
t.Run(keyType, func(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.RateLimit.Enabled = true
|
||||
config.RateLimit.RequestsPerSecond = 100
|
||||
config.RateLimit.Burst = 200
|
||||
config.RateLimit.KeyType = keyType
|
||||
|
||||
errors := config.validateRateLimit()
|
||||
|
||||
assert.Empty(t, errors, "Should have no errors for valid key type: %s", keyType)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateRateLimit_MultipleErrors(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.RateLimit.Enabled = true
|
||||
config.RateLimit.RequestsPerSecond = 0 // Too low
|
||||
config.RateLimit.Burst = 50 // Will pass (0 < 50)
|
||||
config.RateLimit.KeyType = "invalid" // Invalid
|
||||
|
||||
errors := config.validateRateLimit()
|
||||
|
||||
// Should have 2 errors (rps and keyType)
|
||||
assert.Len(t, errors, 2)
|
||||
|
||||
// Check each error is present
|
||||
fields := make(map[string]bool)
|
||||
for _, err := range errors {
|
||||
fields[err.Field] = true
|
||||
}
|
||||
assert.True(t, fields["RateLimit.RequestsPerSecond"])
|
||||
assert.True(t, fields["RateLimit.KeyType"])
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// validateMetrics Tests
|
||||
// ============================================================================
|
||||
|
||||
func TestValidateMetrics_Disabled(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.Metrics.Enabled = false
|
||||
|
||||
errors := config.validateMetrics()
|
||||
|
||||
assert.Empty(t, errors, "Should have no errors when metrics are disabled")
|
||||
}
|
||||
|
||||
func TestValidateMetrics_ValidPrometheus(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.Metrics.Enabled = true
|
||||
config.Metrics.Provider = "prometheus"
|
||||
config.Metrics.Endpoint = "" // Prometheus doesn't require endpoint
|
||||
|
||||
errors := config.validateMetrics()
|
||||
|
||||
assert.Empty(t, errors, "Should have no errors for valid prometheus config")
|
||||
}
|
||||
|
||||
func TestValidateMetrics_ValidStatsd(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.Metrics.Enabled = true
|
||||
config.Metrics.Provider = "statsd"
|
||||
config.Metrics.Endpoint = "localhost:8125"
|
||||
|
||||
errors := config.validateMetrics()
|
||||
|
||||
assert.Empty(t, errors, "Should have no errors for valid statsd config")
|
||||
}
|
||||
|
||||
func TestValidateMetrics_ValidOTLP(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.Metrics.Enabled = true
|
||||
config.Metrics.Provider = "otlp"
|
||||
config.Metrics.Endpoint = "localhost:4317"
|
||||
|
||||
errors := config.validateMetrics()
|
||||
|
||||
assert.Empty(t, errors, "Should have no errors for valid otlp config")
|
||||
}
|
||||
|
||||
func TestValidateMetrics_InvalidProvider(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
provider string
|
||||
}{
|
||||
{"empty provider", ""},
|
||||
{"invalid provider", "invalid"},
|
||||
{"datadog", "datadog"},
|
||||
{"influx", "influx"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.Metrics.Enabled = true
|
||||
config.Metrics.Provider = tt.provider
|
||||
config.Metrics.Endpoint = "localhost:8080"
|
||||
|
||||
errors := config.validateMetrics()
|
||||
|
||||
require.Len(t, errors, 1)
|
||||
assert.Equal(t, "Metrics.Provider", errors[0].Field)
|
||||
assert.Contains(t, errors[0].Message, "invalid metrics provider")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateMetrics_StatsdMissingEndpoint(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.Metrics.Enabled = true
|
||||
config.Metrics.Provider = "statsd"
|
||||
config.Metrics.Endpoint = "" // Missing required endpoint
|
||||
|
||||
errors := config.validateMetrics()
|
||||
|
||||
require.Len(t, errors, 1)
|
||||
assert.Equal(t, "Metrics.Endpoint", errors[0].Field)
|
||||
assert.Contains(t, errors[0].Message, "endpoint is required for statsd provider")
|
||||
}
|
||||
|
||||
func TestValidateMetrics_OTLPMissingEndpoint(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.Metrics.Enabled = true
|
||||
config.Metrics.Provider = "otlp"
|
||||
config.Metrics.Endpoint = "" // Missing required endpoint
|
||||
|
||||
errors := config.validateMetrics()
|
||||
|
||||
require.Len(t, errors, 1)
|
||||
assert.Equal(t, "Metrics.Endpoint", errors[0].Field)
|
||||
assert.Contains(t, errors[0].Message, "endpoint is required for otlp provider")
|
||||
}
|
||||
|
||||
func TestValidateMetrics_MultipleErrors(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.Metrics.Enabled = true
|
||||
config.Metrics.Provider = "invalid" // Invalid provider
|
||||
config.Metrics.Endpoint = "" // Would be missing if provider was statsd/otlp
|
||||
|
||||
errors := config.validateMetrics()
|
||||
|
||||
// Should have at least 1 error for invalid provider
|
||||
assert.NotEmpty(t, errors)
|
||||
assert.Equal(t, "Metrics.Provider", errors[0].Field)
|
||||
}
|
||||
@@ -0,0 +1,116 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
// REDACTED is the placeholder value for sensitive information
|
||||
const REDACTED = "[REDACTED]"
|
||||
|
||||
// MarshalJSON implements custom JSON marshalling to redact sensitive fields
|
||||
// Rewritten without type aliases for yaegi compatibility
|
||||
func (c Config) MarshalJSON() ([]byte, error) {
|
||||
// Build a map manually to avoid type alias issues with yaegi
|
||||
result := make(map[string]interface{})
|
||||
|
||||
// Copy public fields
|
||||
result["providerURL"] = c.ProviderURL
|
||||
result["clientID"] = c.ClientID
|
||||
result["callbackURL"] = c.CallbackURL
|
||||
result["logoutURL"] = c.LogoutURL
|
||||
result["postLogoutRedirectURI"] = c.PostLogoutRedirectURI
|
||||
result["scopes"] = c.Scopes
|
||||
result["forceHTTPS"] = c.ForceHTTPS
|
||||
result["logLevel"] = c.LogLevel
|
||||
result["rateLimit"] = c.RateLimit
|
||||
result["excludedURLs"] = c.ExcludedURLs
|
||||
result["allowedUserDomains"] = c.AllowedUserDomains
|
||||
result["allowedUsers"] = c.AllowedUsers
|
||||
result["allowedRolesAndGroups"] = c.AllowedRolesAndGroups
|
||||
|
||||
// Redact sensitive fields
|
||||
result["clientSecret"] = REDACTED
|
||||
result["sessionEncryptionKey"] = REDACTED
|
||||
|
||||
// Handle Redis config
|
||||
if c.Redis != nil {
|
||||
redisMap := make(map[string]interface{})
|
||||
redisMap["enabled"] = c.Redis.Enabled
|
||||
redisMap["address"] = c.Redis.Address
|
||||
redisMap["password"] = REDACTED
|
||||
redisMap["db"] = c.Redis.DB
|
||||
redisMap["poolSize"] = c.Redis.PoolSize
|
||||
redisMap["cacheMode"] = c.Redis.CacheMode
|
||||
result["redis"] = redisMap
|
||||
}
|
||||
|
||||
return json.Marshal(result)
|
||||
}
|
||||
|
||||
// MarshalYAML implements custom YAML marshalling to redact sensitive fields
|
||||
// Rewritten without type aliases for yaegi compatibility
|
||||
func (c Config) MarshalYAML() (interface{}, error) {
|
||||
// Build a map manually to avoid type alias issues with yaegi
|
||||
result := make(map[string]interface{})
|
||||
|
||||
// Copy public fields
|
||||
result["providerURL"] = c.ProviderURL
|
||||
result["clientID"] = c.ClientID
|
||||
result["callbackURL"] = c.CallbackURL
|
||||
result["logoutURL"] = c.LogoutURL
|
||||
result["postLogoutRedirectURI"] = c.PostLogoutRedirectURI
|
||||
result["scopes"] = c.Scopes
|
||||
result["forceHTTPS"] = c.ForceHTTPS
|
||||
result["logLevel"] = c.LogLevel
|
||||
result["rateLimit"] = c.RateLimit
|
||||
result["excludedURLs"] = c.ExcludedURLs
|
||||
result["allowedUserDomains"] = c.AllowedUserDomains
|
||||
result["allowedUsers"] = c.AllowedUsers
|
||||
result["allowedRolesAndGroups"] = c.AllowedRolesAndGroups
|
||||
|
||||
// Redact sensitive fields
|
||||
result["clientSecret"] = REDACTED
|
||||
result["sessionEncryptionKey"] = REDACTED
|
||||
|
||||
// Handle Redis config
|
||||
if c.Redis != nil {
|
||||
redisMap := make(map[string]interface{})
|
||||
redisMap["enabled"] = c.Redis.Enabled
|
||||
redisMap["address"] = c.Redis.Address
|
||||
redisMap["password"] = REDACTED
|
||||
redisMap["db"] = c.Redis.DB
|
||||
redisMap["poolSize"] = c.Redis.PoolSize
|
||||
redisMap["cacheMode"] = c.Redis.CacheMode
|
||||
result["redis"] = redisMap
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// MarshalJSON for RedisConfig to redact sensitive fields
|
||||
// Rewritten without type aliases for yaegi compatibility
|
||||
func (r RedisConfig) MarshalJSON() ([]byte, error) {
|
||||
result := make(map[string]interface{})
|
||||
result["enabled"] = r.Enabled
|
||||
result["address"] = r.Address
|
||||
result["password"] = REDACTED
|
||||
result["db"] = r.DB
|
||||
result["poolSize"] = r.PoolSize
|
||||
result["cacheMode"] = r.CacheMode
|
||||
|
||||
return json.Marshal(result)
|
||||
}
|
||||
|
||||
// MarshalYAML for RedisConfig to redact sensitive fields
|
||||
// Rewritten without type aliases for yaegi compatibility
|
||||
func (r RedisConfig) MarshalYAML() (interface{}, error) {
|
||||
result := make(map[string]interface{})
|
||||
result["enabled"] = r.Enabled
|
||||
result["address"] = r.Address
|
||||
result["password"] = REDACTED
|
||||
result["db"] = r.DB
|
||||
result["poolSize"] = r.PoolSize
|
||||
result["cacheMode"] = r.CacheMode
|
||||
|
||||
return result, nil
|
||||
}
|
||||
@@ -18,7 +18,7 @@ func TestCSRFTokenSessionManagement(t *testing.T) {
|
||||
// Test that CSRF tokens persist through the authentication flow
|
||||
t.Run("CSRF_Token_Persists_After_Selective_Clear", func(t *testing.T) {
|
||||
// Create a session manager
|
||||
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", NewLogger("debug"))
|
||||
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", "", 0, NewLogger("debug"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create initial request
|
||||
@@ -90,7 +90,7 @@ func TestCSRFTokenSessionManagement(t *testing.T) {
|
||||
|
||||
// Test that marking session as dirty forces save
|
||||
t.Run("Mark_Dirty_Forces_Session_Save", func(t *testing.T) {
|
||||
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", NewLogger("debug"))
|
||||
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", "", 0, NewLogger("debug"))
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com/test", nil)
|
||||
@@ -126,7 +126,7 @@ func TestCSRFTokenSessionManagement(t *testing.T) {
|
||||
|
||||
// Test Azure-specific session handling
|
||||
t.Run("Azure_Session_Cookie_Configuration", func(t *testing.T) {
|
||||
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", NewLogger("debug"))
|
||||
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", "", 0, NewLogger("debug"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Simulate Azure callback scenario
|
||||
@@ -158,7 +158,7 @@ func TestCSRFTokenSessionManagement(t *testing.T) {
|
||||
|
||||
// Test session continuity through auth flow
|
||||
t.Run("Session_Continuity_Through_Auth_Flow", func(t *testing.T) {
|
||||
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", NewLogger("debug"))
|
||||
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", "", 0, NewLogger("debug"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Step 1: Initial request
|
||||
@@ -199,7 +199,7 @@ func TestCSRFTokenSessionManagement(t *testing.T) {
|
||||
|
||||
// Test large token handling doesn't affect CSRF
|
||||
t.Run("Large_Tokens_Dont_Affect_CSRF", func(t *testing.T) {
|
||||
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", NewLogger("debug"))
|
||||
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", "", 0, NewLogger("debug"))
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com/test", nil)
|
||||
@@ -262,7 +262,7 @@ func TestAuthFlowWithoutExternalDependencies(t *testing.T) {
|
||||
|
||||
// We can't fully initialize TraefikOidc without network access,
|
||||
// but we can test the session management directly
|
||||
sessionManager, err := NewSessionManager(plugin.SessionEncryptionKey, plugin.ForceHTTPS, "", NewLogger(plugin.LogLevel))
|
||||
sessionManager, err := NewSessionManager(plugin.SessionEncryptionKey, plugin.ForceHTTPS, "", "", 0, NewLogger(plugin.LogLevel))
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("Session_Created_On_Protected_Request", func(t *testing.T) {
|
||||
@@ -291,7 +291,7 @@ func TestAuthFlowWithoutExternalDependencies(t *testing.T) {
|
||||
// TestRegressionLoginLoop specifically tests the fix for issue #53
|
||||
func TestRegressionLoginLoop(t *testing.T) {
|
||||
// This test verifies that the specific changes made to fix the login loop work correctly
|
||||
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", NewLogger("debug"))
|
||||
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", "", 0, NewLogger("debug"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Simulate the exact flow that was causing the login loop
|
||||
@@ -392,7 +392,7 @@ func TestRegressionLoginLoop(t *testing.T) {
|
||||
|
||||
// TestCSRFValidationTiming tests timing-sensitive CSRF validation scenarios
|
||||
func TestCSRFValidationTiming(t *testing.T) {
|
||||
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", NewLogger("debug"))
|
||||
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", "", 0, NewLogger("debug"))
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("Rapid_Redirect_Maintains_CSRF", func(t *testing.T) {
|
||||
|
||||
@@ -0,0 +1,364 @@
|
||||
//go:build !yaegi
|
||||
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestCustomClaimNames_DefaultBehavior tests backward compatibility with default claim names
|
||||
func TestCustomClaimNames_DefaultBehavior(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
// Explicitly set defaults to test backward compatibility
|
||||
ts.tOidc.roleClaimName = "roles"
|
||||
ts.tOidc.groupClaimName = "groups"
|
||||
|
||||
// Test that when no custom claim names are configured, it uses defaults "roles" and "groups"
|
||||
claims := map[string]interface{}{
|
||||
"groups": []interface{}{"admin", "users"},
|
||||
"roles": []interface{}{"editor", "viewer"},
|
||||
}
|
||||
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test token: %v", err)
|
||||
}
|
||||
|
||||
groups, roles, err := ts.tOidc.extractGroupsAndRoles(token)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if !stringSliceEqual(groups, []string{"admin", "users"}) {
|
||||
t.Errorf("Expected groups [admin users], got %v", groups)
|
||||
}
|
||||
|
||||
if !stringSliceEqual(roles, []string{"editor", "viewer"}) {
|
||||
t.Errorf("Expected roles [editor viewer], got %v", roles)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCustomClaimNames_Auth0Namespaced tests Auth0-style namespaced claims
|
||||
func TestCustomClaimNames_Auth0Namespaced(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
// Configure custom claim names for Auth0
|
||||
ts.tOidc.roleClaimName = "https://myapp.com/roles"
|
||||
ts.tOidc.groupClaimName = "https://myapp.com/groups"
|
||||
|
||||
// Create token with Auth0-style namespaced claims
|
||||
claims := map[string]interface{}{
|
||||
"https://myapp.com/groups": []interface{}{"admin", "users"},
|
||||
"https://myapp.com/roles": []interface{}{"editor", "viewer"},
|
||||
}
|
||||
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test token: %v", err)
|
||||
}
|
||||
|
||||
groups, roles, err := ts.tOidc.extractGroupsAndRoles(token)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if !stringSliceEqual(groups, []string{"admin", "users"}) {
|
||||
t.Errorf("Expected groups [admin users], got %v", groups)
|
||||
}
|
||||
|
||||
if !stringSliceEqual(roles, []string{"editor", "viewer"}) {
|
||||
t.Errorf("Expected roles [editor viewer], got %v", roles)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCustomClaimNames_CustomSimpleNames tests custom simple claim names
|
||||
func TestCustomClaimNames_CustomSimpleNames(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
// Configure custom simple claim names
|
||||
ts.tOidc.roleClaimName = "user_roles"
|
||||
ts.tOidc.groupClaimName = "user_groups"
|
||||
|
||||
// Create token with custom claim names
|
||||
claims := map[string]interface{}{
|
||||
"user_groups": []interface{}{"engineering", "product"},
|
||||
"user_roles": []interface{}{"developer", "manager"},
|
||||
}
|
||||
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test token: %v", err)
|
||||
}
|
||||
|
||||
groups, roles, err := ts.tOidc.extractGroupsAndRoles(token)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if !stringSliceEqual(groups, []string{"engineering", "product"}) {
|
||||
t.Errorf("Expected groups [engineering product], got %v", groups)
|
||||
}
|
||||
|
||||
if !stringSliceEqual(roles, []string{"developer", "manager"}) {
|
||||
t.Errorf("Expected roles [developer manager], got %v", roles)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCustomClaimNames_MissingClaims tests behavior when custom claims are missing
|
||||
func TestCustomClaimNames_MissingClaims(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
// Configure custom claim names
|
||||
ts.tOidc.roleClaimName = "custom_roles"
|
||||
ts.tOidc.groupClaimName = "custom_groups"
|
||||
|
||||
// Create token WITHOUT the custom claims
|
||||
claims := map[string]interface{}{
|
||||
"sub": "user123",
|
||||
"email": "user@example.com",
|
||||
}
|
||||
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test token: %v", err)
|
||||
}
|
||||
|
||||
groups, roles, err := ts.tOidc.extractGroupsAndRoles(token)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Should return empty slices, not error
|
||||
if len(groups) != 0 {
|
||||
t.Errorf("Expected empty groups, got %v", groups)
|
||||
}
|
||||
|
||||
if len(roles) != 0 {
|
||||
t.Errorf("Expected empty roles, got %v", roles)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCustomClaimNames_MalformedClaims tests error handling for malformed claims
|
||||
func TestCustomClaimNames_MalformedRoleClaim(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
// Configure custom claim names
|
||||
ts.tOidc.roleClaimName = "custom_roles"
|
||||
|
||||
// Create token with malformed role claim (not an array)
|
||||
claims := map[string]interface{}{
|
||||
"custom_roles": "this-should-be-an-array",
|
||||
}
|
||||
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test token: %v", err)
|
||||
}
|
||||
|
||||
_, _, err = ts.tOidc.extractGroupsAndRoles(token)
|
||||
if err == nil {
|
||||
t.Error("Expected error for malformed role claim, got nil")
|
||||
}
|
||||
|
||||
// Check error message contains the custom claim name
|
||||
expectedError := "custom_roles claim is not an array"
|
||||
if err.Error() != expectedError {
|
||||
t.Errorf("Expected error '%s', got '%s'", expectedError, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// TestCustomClaimNames_MalformedGroupClaim tests error handling for malformed group claims
|
||||
func TestCustomClaimNames_MalformedGroupClaim(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
// Configure custom claim names
|
||||
ts.tOidc.groupClaimName = "custom_groups"
|
||||
|
||||
// Create token with malformed group claim (not an array)
|
||||
claims := map[string]interface{}{
|
||||
"custom_groups": 12345, // Not an array
|
||||
}
|
||||
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test token: %v", err)
|
||||
}
|
||||
|
||||
_, _, err = ts.tOidc.extractGroupsAndRoles(token)
|
||||
if err == nil {
|
||||
t.Error("Expected error for malformed group claim, got nil")
|
||||
}
|
||||
|
||||
// Check error message contains the custom claim name
|
||||
expectedError := "custom_groups claim is not an array"
|
||||
if err.Error() != expectedError {
|
||||
t.Errorf("Expected error '%s', got '%s'", expectedError, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// TestCustomClaimNames_PartialConfiguration tests when only one claim name is customized
|
||||
func TestCustomClaimNames_OnlyRoleCustomized(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
// Configure only role claim name (group uses default)
|
||||
ts.tOidc.roleClaimName = "https://myapp.com/roles"
|
||||
ts.tOidc.groupClaimName = "groups" // default
|
||||
|
||||
// Create token with mixed claim names
|
||||
claims := map[string]interface{}{
|
||||
"groups": []interface{}{"admin"},
|
||||
"https://myapp.com/roles": []interface{}{"editor"},
|
||||
}
|
||||
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test token: %v", err)
|
||||
}
|
||||
|
||||
groups, roles, err := ts.tOidc.extractGroupsAndRoles(token)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if !stringSliceEqual(groups, []string{"admin"}) {
|
||||
t.Errorf("Expected groups [admin], got %v", groups)
|
||||
}
|
||||
|
||||
if !stringSliceEqual(roles, []string{"editor"}) {
|
||||
t.Errorf("Expected roles [editor], got %v", roles)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCustomClaimNames_OnlyGroupCustomized tests when only group claim name is customized
|
||||
func TestCustomClaimNames_OnlyGroupCustomized(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
// Configure only group claim name (role uses default)
|
||||
ts.tOidc.roleClaimName = "roles" // default
|
||||
ts.tOidc.groupClaimName = "https://myapp.com/groups"
|
||||
|
||||
// Create token with mixed claim names
|
||||
claims := map[string]interface{}{
|
||||
"roles": []interface{}{"viewer"},
|
||||
"https://myapp.com/groups": []interface{}{"users"},
|
||||
}
|
||||
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test token: %v", err)
|
||||
}
|
||||
|
||||
groups, roles, err := ts.tOidc.extractGroupsAndRoles(token)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if !stringSliceEqual(groups, []string{"users"}) {
|
||||
t.Errorf("Expected groups [users], got %v", groups)
|
||||
}
|
||||
|
||||
if !stringSliceEqual(roles, []string{"viewer"}) {
|
||||
t.Errorf("Expected roles [viewer], got %v", roles)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCustomClaimNames_EmptyArrays tests extraction with empty claim arrays
|
||||
func TestCustomClaimNames_EmptyArrays(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
// Configure custom claim names
|
||||
ts.tOidc.roleClaimName = "https://myapp.com/roles"
|
||||
ts.tOidc.groupClaimName = "https://myapp.com/groups"
|
||||
|
||||
// Create token with empty arrays
|
||||
claims := map[string]interface{}{
|
||||
"https://myapp.com/groups": []interface{}{},
|
||||
"https://myapp.com/roles": []interface{}{},
|
||||
}
|
||||
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test token: %v", err)
|
||||
}
|
||||
|
||||
groups, roles, err := ts.tOidc.extractGroupsAndRoles(token)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(groups) != 0 {
|
||||
t.Errorf("Expected empty groups, got %v", groups)
|
||||
}
|
||||
|
||||
if len(roles) != 0 {
|
||||
t.Errorf("Expected empty roles, got %v", roles)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCustomClaimNames_NonStringElements tests handling of non-string elements in claim arrays
|
||||
func TestCustomClaimNames_NonStringInRoleArray(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
// Configure custom claim names
|
||||
ts.tOidc.roleClaimName = "custom_roles"
|
||||
|
||||
// Create token with mixed-type array (should skip non-string elements)
|
||||
claims := map[string]interface{}{
|
||||
"custom_roles": []interface{}{"role1", 12345, "role2", true},
|
||||
}
|
||||
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test token: %v", err)
|
||||
}
|
||||
|
||||
_, roles, err := ts.tOidc.extractGroupsAndRoles(token)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Should only extract string elements
|
||||
if !stringSliceEqual(roles, []string{"role1", "role2"}) {
|
||||
t.Errorf("Expected roles [role1 role2], got %v", roles)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCustomClaimNames_NonStringInGroupArray tests handling of non-string elements in group arrays
|
||||
func TestCustomClaimNames_NonStringInGroupArray(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
// Configure custom claim names
|
||||
ts.tOidc.groupClaimName = "custom_groups"
|
||||
|
||||
// Create token with mixed-type array (should skip non-string elements)
|
||||
claims := map[string]interface{}{
|
||||
"custom_groups": []interface{}{"group1", nil, "group2", 3.14},
|
||||
}
|
||||
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test token: %v", err)
|
||||
}
|
||||
|
||||
groups, _, err := ts.tOidc.extractGroupsAndRoles(token)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Should only extract string elements
|
||||
if !stringSliceEqual(groups, []string{"group1", "group2"}) {
|
||||
t.Errorf("Expected groups [group1 group2], got %v", groups)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,424 @@
|
||||
# Auth0 Audience Validation Guide
|
||||
|
||||
## Overview
|
||||
|
||||
This guide explains how to configure audience validation for Auth0 and other OIDC providers that support custom API audiences. It covers three common Auth0 scenarios and how to configure the middleware for maximum security.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
1. [Understanding Audiences](#understanding-audiences)
|
||||
2. [The Three Auth0 Scenarios](#the-three-auth0-scenarios)
|
||||
3. [Configuration Options](#configuration-options)
|
||||
4. [Security Recommendations](#security-recommendations)
|
||||
5. [Troubleshooting](#troubleshooting)
|
||||
|
||||
---
|
||||
|
||||
## Understanding Audiences
|
||||
|
||||
### What is an Audience?
|
||||
|
||||
The **audience** (`aud`) claim in a JWT identifies the intended recipient of the token. Per OAuth 2.0 and OIDC specifications:
|
||||
|
||||
- **ID Tokens**: MUST have `aud = client_id` (per OIDC Core 1.0 spec)
|
||||
- **Access Tokens**: Can have custom audiences (e.g., API identifiers)
|
||||
|
||||
### Why Does This Matter?
|
||||
|
||||
Proper audience validation prevents **token confusion attacks** where a token intended for one API is used to access another API.
|
||||
|
||||
---
|
||||
|
||||
## The Three Auth0 Scenarios
|
||||
|
||||
### Scenario 1: Custom API Audience ✅ **RECOMMENDED**
|
||||
|
||||
**Configuration:**
|
||||
```yaml
|
||||
audience: "https://my-api.example.com" # Your API identifier from Auth0
|
||||
```
|
||||
|
||||
**What Happens:**
|
||||
1. Authorization request includes `audience` parameter
|
||||
2. Auth0 issues:
|
||||
- **ID Token**: `aud = client_id`
|
||||
- **Access Token**: `aud = ["https://issuer/userinfo", "https://my-api.example.com"]`
|
||||
3. Middleware validates:
|
||||
- ID tokens against `client_id`
|
||||
- Access tokens against custom audience
|
||||
|
||||
**Result:** ✅ Fully secure, OIDC compliant
|
||||
|
||||
---
|
||||
|
||||
### Scenario 2: Default Audience (No Custom API) ⚠️ **USE WITH CAUTION**
|
||||
|
||||
**Configuration:**
|
||||
```yaml
|
||||
# audience not specified (defaults to client_id)
|
||||
```
|
||||
|
||||
**What Happens:**
|
||||
1. Authorization request WITHOUT `audience` parameter
|
||||
2. Auth0 issues:
|
||||
- **ID Token**: `aud = client_id`
|
||||
- **Access Token**: `aud = ["https://issuer/userinfo", "default_api"]` (no `client_id`)
|
||||
3. Access token validation fails (audience mismatch)
|
||||
4. Middleware falls back to ID token validation
|
||||
|
||||
**Security Warning:**
|
||||
```
|
||||
⚠️⚠️⚠️ SECURITY WARNING: Falling back to ID token validation despite access token audience mismatch!
|
||||
⚠️ This could allow tokens intended for different APIs to grant access
|
||||
⚠️ Set strictAudienceValidation=true to enforce proper audience validation
|
||||
⚠️ See: https://github.com/lukaszraczylo/traefikoidc/issues/74
|
||||
```
|
||||
|
||||
**Recommended Fix:**
|
||||
```yaml
|
||||
strictAudienceValidation: true # Reject sessions with audience mismatch
|
||||
```
|
||||
|
||||
**Result:**
|
||||
- Default: ⚠️ Works but logs security warnings
|
||||
- With strict mode: ✅ Secure (rejects mismatched tokens)
|
||||
|
||||
---
|
||||
|
||||
### Scenario 3: Opaque Access Tokens ✅ **SUPPORTED**
|
||||
|
||||
**Configuration:**
|
||||
```yaml
|
||||
allowOpaqueTokens: true # Enable opaque token support
|
||||
requireTokenIntrospection: true # Require introspection (recommended)
|
||||
```
|
||||
|
||||
**What Happens:**
|
||||
1. Auth0 issues opaque (non-JWT) access token
|
||||
2. Middleware detects opaque token (not 3 parts separated by dots)
|
||||
3. Uses OAuth 2.0 Token Introspection (RFC 7662) to validate
|
||||
4. Falls back to ID token if introspection unavailable (unless `requireTokenIntrospection=true`)
|
||||
|
||||
**Requirements:**
|
||||
- Provider must support `introspection_endpoint` in OIDC discovery
|
||||
- Client must have introspection permissions
|
||||
|
||||
**Result:** ✅ Secure with introspection, ⚠️ risky without
|
||||
|
||||
---
|
||||
|
||||
## Configuration Options
|
||||
|
||||
### Audience Settings
|
||||
|
||||
| Option | Type | Default | Description |
|
||||
|--------|------|---------|-------------|
|
||||
| `audience` | string | `client_id` | Expected audience for access tokens |
|
||||
|
||||
**Example:**
|
||||
```yaml
|
||||
# .traefik.yml
|
||||
http:
|
||||
middlewares:
|
||||
oidc-auth:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
audience: "https://my-api.example.com"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Security Mode Settings
|
||||
|
||||
#### `strictAudienceValidation`
|
||||
|
||||
**Type:** boolean
|
||||
**Default:** `false`
|
||||
**Recommended:** `true` for production
|
||||
|
||||
**What it does:**
|
||||
- When `true`: Rejects sessions if access token audience doesn't match (prevents Scenario 2)
|
||||
- When `false`: Logs warnings but allows fallback to ID token (backward compatible)
|
||||
|
||||
**Example:**
|
||||
```yaml
|
||||
strictAudienceValidation: true
|
||||
```
|
||||
|
||||
**When to use:**
|
||||
- ✅ Always use in production environments
|
||||
- ✅ When you have custom API audiences configured in Auth0
|
||||
- ⚠️ May break existing deployments relying on Scenario 2 behavior
|
||||
|
||||
---
|
||||
|
||||
#### `allowOpaqueTokens`
|
||||
|
||||
**Type:** boolean
|
||||
**Default:** `false`
|
||||
|
||||
**What it does:**
|
||||
- When `true`: Accepts opaque (non-JWT) access tokens
|
||||
- When `false`: Only accepts JWT access tokens
|
||||
|
||||
**Example:**
|
||||
```yaml
|
||||
allowOpaqueTokens: true
|
||||
```
|
||||
|
||||
**When to use:**
|
||||
- ✅ When Auth0 issues opaque tokens (no default API configured)
|
||||
- ✅ When using Auth0 Management API tokens
|
||||
- ⚠️ Requires introspection endpoint for security
|
||||
|
||||
---
|
||||
|
||||
#### `requireTokenIntrospection`
|
||||
|
||||
**Type:** boolean
|
||||
**Default:** `false`
|
||||
**Recommended:** `true` when `allowOpaqueTokens=true`
|
||||
|
||||
**What it does:**
|
||||
- When `true`: Rejects opaque tokens if introspection fails or endpoint unavailable
|
||||
- When `false`: Falls back to ID token validation for opaque tokens
|
||||
|
||||
**Example:**
|
||||
```yaml
|
||||
allowOpaqueTokens: true
|
||||
requireTokenIntrospection: true
|
||||
```
|
||||
|
||||
**When to use:**
|
||||
- ✅ Always use when `allowOpaqueTokens=true` for maximum security
|
||||
- ⚠️ Requires provider to expose introspection endpoint
|
||||
|
||||
---
|
||||
|
||||
## Security Recommendations
|
||||
|
||||
### Recommended Configuration for Auth0
|
||||
|
||||
**For APIs with custom audiences (Scenario 1):**
|
||||
```yaml
|
||||
audience: "https://my-api.example.com"
|
||||
strictAudienceValidation: true
|
||||
allowOpaqueTokens: false
|
||||
```
|
||||
|
||||
**For default Auth0 setup (Scenario 2):**
|
||||
```yaml
|
||||
# Don't set audience (defaults to client_id)
|
||||
strictAudienceValidation: true # Enforce proper configuration
|
||||
```
|
||||
|
||||
**For opaque tokens (Scenario 3):**
|
||||
```yaml
|
||||
allowOpaqueTokens: true
|
||||
requireTokenIntrospection: true
|
||||
strictAudienceValidation: true
|
||||
```
|
||||
|
||||
### Security Best Practices
|
||||
|
||||
1. ✅ **Always set `strictAudienceValidation: true` in production**
|
||||
2. ✅ **Configure custom API audiences in Auth0 dashboard**
|
||||
3. ✅ **Use `requireTokenIntrospection: true` if accepting opaque tokens**
|
||||
4. ✅ **Monitor logs for security warnings**
|
||||
5. ❌ **Don't rely on Scenario 2 fallback behavior**
|
||||
|
||||
---
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "Access token validation failed due to audience mismatch"
|
||||
|
||||
**Symptom:**
|
||||
```
|
||||
⚠️ SCENARIO 2 DETECTED: Access token validation failed due to audience mismatch
|
||||
```
|
||||
|
||||
**Cause:** Access token audience doesn't match configured audience
|
||||
|
||||
**Solutions:**
|
||||
1. **Configure correct audience:**
|
||||
```yaml
|
||||
audience: "https://your-api-identifier" # From Auth0 API settings
|
||||
```
|
||||
|
||||
2. **Update Auth0 authorization request:**
|
||||
- Ensure `audience` parameter is included in authorize URL
|
||||
- Middleware automatically adds this when `audience != client_id`
|
||||
|
||||
3. **Accept the behavior (not recommended):**
|
||||
```yaml
|
||||
strictAudienceValidation: false # Logs warnings but allows
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### "Opaque token detected but allowOpaqueTokens=false"
|
||||
|
||||
**Symptom:**
|
||||
```
|
||||
⚠️ Opaque access token detected but allowOpaqueTokens=false
|
||||
```
|
||||
|
||||
**Cause:** Auth0 issued non-JWT access token but middleware not configured to accept them
|
||||
|
||||
**Solutions:**
|
||||
1. **Enable opaque tokens:**
|
||||
```yaml
|
||||
allowOpaqueTokens: true
|
||||
requireTokenIntrospection: true
|
||||
```
|
||||
|
||||
2. **Configure Auth0 to issue JWT access tokens:**
|
||||
- Create an API in Auth0 dashboard
|
||||
- Set API identifier as `audience` in configuration
|
||||
|
||||
---
|
||||
|
||||
### "Introspection endpoint not available"
|
||||
|
||||
**Symptom:**
|
||||
```
|
||||
⚠️ Opaque tokens enabled but no introspection endpoint available from provider
|
||||
```
|
||||
|
||||
**Cause:** Auth0 provider metadata doesn't include `introspection_endpoint`
|
||||
|
||||
**Solutions:**
|
||||
1. **Check provider discovery:**
|
||||
```bash
|
||||
curl https://YOUR_DOMAIN/.well-known/openid-configuration
|
||||
```
|
||||
Look for `introspection_endpoint`
|
||||
|
||||
2. **Disable required introspection (less secure):**
|
||||
```yaml
|
||||
allowOpaqueTokens: true
|
||||
requireTokenIntrospection: false # Falls back to ID token
|
||||
```
|
||||
|
||||
3. **Use JWT access tokens instead** (recommended)
|
||||
|
||||
---
|
||||
|
||||
### "Token introspection required but endpoint not available"
|
||||
|
||||
**Symptom:**
|
||||
```
|
||||
❌ SECURITY: Opaque token rejected (introspection required but failed)
|
||||
```
|
||||
|
||||
**Cause:** `requireTokenIntrospection=true` but provider doesn't support it
|
||||
|
||||
**Solutions:**
|
||||
1. **Disable required introspection:**
|
||||
```yaml
|
||||
requireTokenIntrospection: false
|
||||
```
|
||||
|
||||
2. **Configure Auth0 to issue JWT tokens** (better solution)
|
||||
|
||||
---
|
||||
|
||||
## Advanced Topics
|
||||
|
||||
### Token Type Detection
|
||||
|
||||
The middleware uses a sophisticated 6-step detection algorithm:
|
||||
|
||||
1. **RFC 9068 `typ` header**: `at+jwt` → Access Token
|
||||
2. **Explicit type claims**: `token_use`, `token_type`
|
||||
3. **`scope` claim**: Present → Access Token
|
||||
4. **`nonce` claim**: Present → ID Token (OIDC spec)
|
||||
5. **Audience check**: `aud == client_id` only → ID Token
|
||||
6. **Default**: Access Token
|
||||
|
||||
### OAuth 2.0 Token Introspection (RFC 7662)
|
||||
|
||||
When opaque tokens are detected:
|
||||
|
||||
1. Middleware calls provider's `introspection_endpoint`
|
||||
2. Authenticates using client credentials
|
||||
3. Receives response with `active` status and claims
|
||||
4. Caches result for 5 minutes (configurable via TTL)
|
||||
5. Validates expiration, not-before, and audience if present
|
||||
|
||||
**Cache behavior:**
|
||||
- Cache key: Token hash
|
||||
- TTL: 5 minutes or token expiry (whichever is shorter)
|
||||
- Reduces introspection requests for frequently used tokens
|
||||
|
||||
---
|
||||
|
||||
## Reference Links
|
||||
|
||||
- [GitHub Issue #74](https://github.com/lukaszraczylo/traefikoidc/issues/74) - Original Auth0 audience discussion
|
||||
- [OIDC Core 1.0 Spec](https://openid.net/specs/openid-connect-core-1_0.html) - ID Token requirements
|
||||
- [OAuth 2.0 RFC 6749](https://datatracker.ietf.org/doc/html/rfc6749) - OAuth 2.0 specification
|
||||
- [RFC 7662](https://datatracker.ietf.org/doc/html/rfc7662) - OAuth 2.0 Token Introspection
|
||||
- [RFC 9068](https://datatracker.ietf.org/doc/html/rfc9068) - JWT Access Token Profile
|
||||
- [Auth0 API Authorization](https://auth0.com/docs/secure/tokens/access-tokens) - Auth0 audience documentation
|
||||
|
||||
---
|
||||
|
||||
## Migration Guide
|
||||
|
||||
### From Previous Versions
|
||||
|
||||
**If you're upgrading from a version without these features:**
|
||||
|
||||
1. **No action required for default behavior** - backward compatible
|
||||
2. **Recommended: Enable strict mode gradually**
|
||||
```yaml
|
||||
# Step 1: Enable and monitor logs
|
||||
strictAudienceValidation: false # Default
|
||||
|
||||
# Step 2: After confirming no warnings, enable
|
||||
strictAudienceValidation: true
|
||||
```
|
||||
|
||||
3. **For opaque tokens: Enable explicitly**
|
||||
```yaml
|
||||
allowOpaqueTokens: true
|
||||
requireTokenIntrospection: true
|
||||
```
|
||||
|
||||
### Testing Your Configuration
|
||||
|
||||
1. **Check logs for warnings:**
|
||||
```bash
|
||||
# Look for Scenario 2 warnings
|
||||
grep "SCENARIO 2 DETECTED" /var/log/traefik.log
|
||||
|
||||
# Look for opaque token warnings
|
||||
grep "Opaque" /var/log/traefik.log
|
||||
```
|
||||
|
||||
2. **Test with curl:**
|
||||
```bash
|
||||
# Get token from Auth0
|
||||
ACCESS_TOKEN="your_access_token"
|
||||
|
||||
# Test request
|
||||
curl -H "Authorization: Bearer $ACCESS_TOKEN" \
|
||||
https://your-app.example.com/api
|
||||
```
|
||||
|
||||
3. **Monitor for security warnings in production logs**
|
||||
|
||||
---
|
||||
|
||||
## Support
|
||||
|
||||
For issues or questions:
|
||||
- GitHub Issues: https://github.com/lukaszraczylo/traefikoidc/issues
|
||||
- Security issues: See SECURITY.md for responsible disclosure
|
||||
|
||||
---
|
||||
|
||||
**Last Updated:** 2025-01-09
|
||||
**Version:** 0.7.8+
|
||||
@@ -0,0 +1 @@
|
||||
traefikoidc.raczylo.com
|
||||
@@ -0,0 +1,955 @@
|
||||
# Provider-Specific Configuration Guide
|
||||
|
||||
This guide covers the configuration requirements and best practices for each supported OIDC provider.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Google](#google)
|
||||
- [Microsoft Azure AD](#microsoft-azure-ad)
|
||||
- [Auth0](#auth0)
|
||||
- [GitHub](#github)
|
||||
- [GitLab](#gitlab)
|
||||
- [AWS Cognito](#aws-cognito)
|
||||
- [Keycloak](#keycloak)
|
||||
- [Okta](#okta)
|
||||
- [Generic OIDC](#generic-oidc)
|
||||
|
||||
---
|
||||
|
||||
## Google
|
||||
|
||||
### Provider URL
|
||||
```yaml
|
||||
providerUrl: "https://accounts.google.com"
|
||||
```
|
||||
|
||||
### Required Configuration
|
||||
```yaml
|
||||
clientId: "your-google-client-id.apps.googleusercontent.com"
|
||||
clientSecret: "your-google-client-secret"
|
||||
callbackUrl: "https://your-domain.com/auth/callback"
|
||||
scopes: ["openid", "profile", "email"]
|
||||
```
|
||||
|
||||
### Google-Specific Features
|
||||
- **Automatic offline access**: Google provider automatically adds `access_type=offline` and `prompt=consent`
|
||||
- **Scope filtering**: Automatically removes `offline_access` scope (not used by Google)
|
||||
- **Refresh token support**: Fully supported
|
||||
- **Domain restrictions**: Can restrict by Google Workspace domains
|
||||
|
||||
### Example Configuration
|
||||
```yaml
|
||||
# Traefik dynamic configuration
|
||||
http:
|
||||
middlewares:
|
||||
google-oidc:
|
||||
plugin:
|
||||
traefik-oidc:
|
||||
providerUrl: "https://accounts.google.com"
|
||||
clientId: "123456789-abcdef.apps.googleusercontent.com"
|
||||
clientSecret: "GOCSPX-your-client-secret"
|
||||
callbackUrl: "https://app.example.com/auth/callback"
|
||||
logoutUrl: "https://app.example.com/auth/logout"
|
||||
scopes: ["openid", "profile", "email"]
|
||||
allowedUserDomains: ["example.com", "company.org"]
|
||||
forceHttps: true
|
||||
enablePkce: true
|
||||
```
|
||||
|
||||
### Google OAuth Console Setup
|
||||
1. Go to [Google Cloud Console](https://console.cloud.google.com/)
|
||||
2. Create or select a project
|
||||
3. Enable Google+ API
|
||||
4. Create OAuth 2.0 credentials
|
||||
5. Add authorized redirect URIs: `https://your-domain.com/auth/callback`
|
||||
|
||||
---
|
||||
|
||||
## Microsoft Azure AD
|
||||
|
||||
### Provider URL
|
||||
```yaml
|
||||
# For Azure AD (single tenant)
|
||||
providerUrl: "https://login.microsoftonline.com/{tenant-id}/v2.0"
|
||||
|
||||
# For Azure AD (multi-tenant)
|
||||
providerUrl: "https://login.microsoftonline.com/common/v2.0"
|
||||
```
|
||||
|
||||
### Required Configuration
|
||||
```yaml
|
||||
clientId: "your-azure-application-id"
|
||||
clientSecret: "your-azure-client-secret"
|
||||
callbackUrl: "https://your-domain.com/auth/callback"
|
||||
scopes: ["openid", "profile", "email", "offline_access"]
|
||||
```
|
||||
|
||||
### Azure-Specific Features
|
||||
- **Response mode**: Automatically adds `response_mode=query`
|
||||
- **Offline access**: Requires `offline_access` scope for refresh tokens
|
||||
- **Access token validation**: Supports both JWT and opaque access tokens
|
||||
- **Tenant isolation**: Can restrict to specific Azure AD tenants
|
||||
- **Application ID URI**: Supports custom audience for protected APIs
|
||||
|
||||
### Example Configuration (Basic)
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
azure-oidc:
|
||||
plugin:
|
||||
traefik-oidc:
|
||||
providerUrl: "https://login.microsoftonline.com/common/v2.0"
|
||||
clientId: "12345678-1234-1234-1234-123456789abc"
|
||||
clientSecret: "your-azure-client-secret"
|
||||
callbackUrl: "https://app.example.com/auth/callback"
|
||||
logoutUrl: "https://app.example.com/auth/logout"
|
||||
postLogoutRedirectUri: "https://app.example.com"
|
||||
scopes: ["openid", "profile", "email", "offline_access"]
|
||||
allowedRolesAndGroups: ["App.Users", "Admin.Group"]
|
||||
forceHttps: true
|
||||
```
|
||||
|
||||
### Azure AD API Configuration (Application ID URI)
|
||||
|
||||
When exposing your application as an API with a custom Application ID URI, you need to specify the `audience` parameter. Azure AD includes the Application ID URI in the JWT `aud` claim.
|
||||
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
azure-api-oidc:
|
||||
plugin:
|
||||
traefik-oidc:
|
||||
providerUrl: "https://login.microsoftonline.com/common/v2.0"
|
||||
clientId: "12345678-1234-1234-1234-123456789abc"
|
||||
clientSecret: "your-azure-client-secret"
|
||||
# Specify the Application ID URI as audience
|
||||
audience: "api://12345678-1234-1234-1234-123456789abc"
|
||||
callbackUrl: "https://app.example.com/auth/callback"
|
||||
logoutUrl: "https://app.example.com/auth/logout"
|
||||
scopes: ["openid", "profile", "email", "offline_access"]
|
||||
forceHttps: true
|
||||
```
|
||||
|
||||
**Important**:
|
||||
- The `audience` parameter should match your Application ID URI (typically `api://{app-id}`)
|
||||
- Find your Application ID URI in Azure Portal → App Registration → Expose an API → Application ID URI
|
||||
- Without the `audience` parameter, access tokens with custom audiences will be rejected
|
||||
- For ID token validation only (no API access), you can omit the `audience` parameter
|
||||
|
||||
### Azure App Registration Setup
|
||||
1. Go to [Azure Portal](https://portal.azure.com/)
|
||||
2. Navigate to "Azure Active Directory" > "App registrations"
|
||||
3. Create new registration
|
||||
4. Add redirect URI: `https://your-domain.com/auth/callback`
|
||||
5. Create client secret in "Certificates & secrets"
|
||||
6. Configure API permissions for required scopes
|
||||
|
||||
### Azure AD API Exposure Setup (for custom audiences)
|
||||
1. In your App Registration, go to "Expose an API"
|
||||
2. Set the Application ID URI (e.g., `api://12345678-1234-1234-1234-123456789abc`)
|
||||
3. Add any custom scopes your API exposes
|
||||
4. Update the middleware configuration to include the `audience` parameter with this URI
|
||||
|
||||
---
|
||||
|
||||
## Auth0
|
||||
|
||||
### Provider URL
|
||||
```yaml
|
||||
providerUrl: "https://your-domain.auth0.com"
|
||||
```
|
||||
|
||||
### Required Configuration
|
||||
```yaml
|
||||
clientId: "your-auth0-client-id"
|
||||
clientSecret: "your-auth0-client-secret"
|
||||
callbackUrl: "https://your-domain.com/auth/callback"
|
||||
scopes: ["openid", "profile", "email", "offline_access"]
|
||||
```
|
||||
|
||||
### Auth0-Specific Features
|
||||
- **Custom domains**: Supports Auth0 custom domains
|
||||
- **Rules and hooks**: Leverages Auth0's extensibility
|
||||
- **Social connections**: Works with Auth0's social identity providers
|
||||
- **Offline access**: Requires `offline_access` scope
|
||||
- **API audiences**: Supports custom audience for API access tokens
|
||||
|
||||
### Example Configuration (Basic)
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
auth0-oidc:
|
||||
plugin:
|
||||
traefik-oidc:
|
||||
providerUrl: "https://company.auth0.com"
|
||||
clientId: "abcdef123456789"
|
||||
clientSecret: "your-auth0-client-secret"
|
||||
callbackUrl: "https://app.example.com/auth/callback"
|
||||
logoutUrl: "https://app.example.com/auth/logout"
|
||||
postLogoutRedirectUri: "https://app.example.com"
|
||||
scopes: ["openid", "profile", "email", "offline_access"]
|
||||
allowedUsers: ["user@example.com", "admin@company.com"]
|
||||
forceHttps: true
|
||||
enablePkce: true
|
||||
```
|
||||
|
||||
### Auth0 API Configuration (Custom Audience)
|
||||
|
||||
When using Auth0 APIs with custom audience parameters, you need to specify the `audience` field. Auth0 includes the API identifier in the JWT `aud` claim instead of the `clientId`.
|
||||
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
auth0-api-oidc:
|
||||
plugin:
|
||||
traefik-oidc:
|
||||
providerUrl: "https://company.auth0.com"
|
||||
clientId: "abcdef123456789"
|
||||
clientSecret: "your-auth0-client-secret"
|
||||
# Specify the Auth0 API identifier as audience
|
||||
audience: "https://api.company.com"
|
||||
callbackUrl: "https://app.example.com/auth/callback"
|
||||
logoutUrl: "https://app.example.com/auth/logout"
|
||||
scopes: ["openid", "profile", "email", "offline_access"]
|
||||
forceHttps: true
|
||||
enablePkce: true
|
||||
```
|
||||
|
||||
**Important**:
|
||||
- The `audience` parameter should match your Auth0 API identifier (not the client ID)
|
||||
- Find your API identifier in Auth0 Dashboard → APIs → Your API → Settings → Identifier
|
||||
- Without the `audience` parameter, access tokens with custom audiences will be rejected with "invalid audience" error
|
||||
- For ID token validation only (no APIs), you can omit the `audience` parameter
|
||||
|
||||
### Auth0 Application Setup
|
||||
1. Go to [Auth0 Dashboard](https://manage.auth0.com/)
|
||||
2. Create new application (Regular Web Application)
|
||||
3. Configure allowed callback URLs: `https://your-domain.com/auth/callback`
|
||||
4. Configure allowed logout URLs: `https://your-domain.com/auth/logout`
|
||||
5. Enable OIDC Conformant in Advanced Settings
|
||||
|
||||
### Auth0 API Setup (for custom audiences)
|
||||
1. Go to Auth0 Dashboard → APIs
|
||||
2. Create a new API or select existing API
|
||||
3. Note the "Identifier" field (e.g., `https://api.company.com`) - this is your `audience` value
|
||||
4. In API Settings → Machine to Machine Applications, authorize your application
|
||||
5. Configure API permissions/scopes as needed
|
||||
6. Use the API identifier as the `audience` parameter in your configuration
|
||||
|
||||
---
|
||||
|
||||
## GitHub
|
||||
|
||||
### Provider URL
|
||||
```yaml
|
||||
providerUrl: "https://github.com"
|
||||
```
|
||||
|
||||
### Required Configuration
|
||||
```yaml
|
||||
clientId: "your-github-client-id"
|
||||
clientSecret: "your-github-client-secret"
|
||||
callbackUrl: "https://your-domain.com/auth/callback"
|
||||
scopes: ["read:user", "user:email"]
|
||||
```
|
||||
|
||||
### GitHub-Specific Features
|
||||
- **Organization membership**: Can restrict by GitHub organization
|
||||
- **Team membership**: Can restrict by specific teams
|
||||
- **Limited OIDC**: GitHub has limited OIDC support
|
||||
- **Email verification**: Requires verified email addresses
|
||||
|
||||
### Example Configuration
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
github-oidc:
|
||||
plugin:
|
||||
traefik-oidc:
|
||||
providerUrl: "https://github.com"
|
||||
clientId: "Iv1.abcdef123456"
|
||||
clientSecret: "your-github-client-secret"
|
||||
callbackUrl: "https://app.example.com/auth/callback"
|
||||
logoutUrl: "https://app.example.com/auth/logout"
|
||||
scopes: ["read:user", "user:email"]
|
||||
allowedUsers: ["octocat", "github-user"]
|
||||
forceHttps: true
|
||||
```
|
||||
|
||||
### GitHub OAuth App Setup
|
||||
1. Go to GitHub Settings > Developer settings > OAuth Apps
|
||||
2. Create new OAuth App
|
||||
3. Set Authorization callback URL: `https://your-domain.com/auth/callback`
|
||||
4. Note the Client ID and generate Client Secret
|
||||
|
||||
---
|
||||
|
||||
## GitLab
|
||||
|
||||
### Provider URL
|
||||
```yaml
|
||||
# GitLab.com
|
||||
providerUrl: "https://gitlab.com"
|
||||
|
||||
# Self-hosted GitLab
|
||||
providerUrl: "https://gitlab.your-company.com"
|
||||
```
|
||||
|
||||
### Required Configuration
|
||||
```yaml
|
||||
clientId: "your-gitlab-application-id"
|
||||
clientSecret: "your-gitlab-application-secret"
|
||||
callbackUrl: "https://your-domain.com/auth/callback"
|
||||
scopes: ["openid", "profile", "email"]
|
||||
```
|
||||
|
||||
### GitLab-Specific Features
|
||||
- **Self-hosted support**: Works with self-hosted GitLab instances
|
||||
- **Group membership**: Can restrict by GitLab groups
|
||||
- **Project access**: Can validate project permissions
|
||||
- **Offline access**: Supports refresh tokens without requiring `offline_access` scope
|
||||
|
||||
### Example Configuration
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
gitlab-oidc:
|
||||
plugin:
|
||||
traefik-oidc:
|
||||
providerUrl: "https://gitlab.com"
|
||||
clientId: "abcdef123456789"
|
||||
clientSecret: "your-gitlab-application-secret"
|
||||
callbackUrl: "https://app.example.com/auth/callback"
|
||||
logoutUrl: "https://app.example.com/auth/logout"
|
||||
scopes: ["openid", "profile", "email"]
|
||||
# Note: GitLab doesn't support the offline_access scope.
|
||||
# Refresh tokens are issued automatically for the openid scope.
|
||||
allowedRolesAndGroups: ["developers", "maintainers"]
|
||||
forceHttps: true
|
||||
enablePkce: true
|
||||
```
|
||||
|
||||
### GitLab Application Setup
|
||||
1. Go to GitLab Settings > Applications
|
||||
2. Create new application
|
||||
3. Add scopes: `openid`, `profile`, `email`
|
||||
4. Set redirect URI: `https://your-domain.com/auth/callback`
|
||||
5. Save and note the Application ID and Secret
|
||||
|
||||
---
|
||||
|
||||
## AWS Cognito
|
||||
|
||||
### Provider URL
|
||||
```yaml
|
||||
providerUrl: "https://cognito-idp.{region}.amazonaws.com/{user-pool-id}"
|
||||
```
|
||||
|
||||
### Required Configuration
|
||||
```yaml
|
||||
clientId: "your-cognito-app-client-id"
|
||||
clientSecret: "your-cognito-app-client-secret" # If app client has secret
|
||||
callbackUrl: "https://your-domain.com/auth/callback"
|
||||
scopes: ["openid", "profile", "email"]
|
||||
```
|
||||
|
||||
### Cognito-Specific Features
|
||||
- **User pools**: Integrates with Cognito User Pools
|
||||
- **Custom attributes**: Supports custom user attributes
|
||||
- **Groups**: Can validate Cognito user group membership
|
||||
- **Regional endpoints**: Requires region-specific URLs
|
||||
|
||||
### Example Configuration
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
cognito-oidc:
|
||||
plugin:
|
||||
traefik-oidc:
|
||||
providerUrl: "https://cognito-idp.us-east-1.amazonaws.com/us-east-1_ABCDEF123"
|
||||
clientId: "1234567890abcdefghij"
|
||||
clientSecret: "your-cognito-client-secret"
|
||||
callbackUrl: "https://app.example.com/auth/callback"
|
||||
logoutUrl: "https://app.example.com/auth/logout"
|
||||
scopes: ["openid", "profile", "email"]
|
||||
allowedRolesAndGroups: ["admin", "users"]
|
||||
forceHttps: true
|
||||
```
|
||||
|
||||
### AWS Cognito Setup
|
||||
1. Create Cognito User Pool
|
||||
2. Create App Client with OIDC scopes
|
||||
3. Configure App Client settings:
|
||||
- Callback URLs: `https://your-domain.com/auth/callback`
|
||||
- Sign out URLs: `https://your-domain.com/auth/logout`
|
||||
- OAuth flows: Authorization code grant
|
||||
4. Configure hosted UI domain (optional)
|
||||
|
||||
---
|
||||
|
||||
## Keycloak
|
||||
|
||||
### Provider URL
|
||||
```yaml
|
||||
providerUrl: "https://keycloak.your-company.com/realms/{realm-name}"
|
||||
```
|
||||
|
||||
### Required Configuration
|
||||
```yaml
|
||||
clientId: "your-keycloak-client-id"
|
||||
clientSecret: "your-keycloak-client-secret"
|
||||
callbackUrl: "https://your-domain.com/auth/callback"
|
||||
scopes: ["openid", "profile", "email"]
|
||||
```
|
||||
|
||||
### Keycloak-Specific Features
|
||||
- **Realm support**: Multi-realm deployments
|
||||
- **Custom mappers**: Rich claim mapping capabilities
|
||||
- **Role-based access**: Fine-grained role management
|
||||
- **Offline access**: Full refresh token support
|
||||
|
||||
### Example Configuration
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
keycloak-oidc:
|
||||
plugin:
|
||||
traefik-oidc:
|
||||
providerUrl: "https://keycloak.company.com/realms/employees"
|
||||
clientId: "traefik-app"
|
||||
clientSecret: "your-keycloak-client-secret"
|
||||
callbackUrl: "https://app.example.com/auth/callback"
|
||||
logoutUrl: "https://app.example.com/auth/logout"
|
||||
postLogoutRedirectUri: "https://app.example.com"
|
||||
scopes: ["openid", "profile", "email", "offline_access"]
|
||||
allowedRolesAndGroups: ["app-users", "administrators"]
|
||||
forceHttps: true
|
||||
enablePkce: true
|
||||
```
|
||||
|
||||
### Keycloak Client Setup
|
||||
1. Access Keycloak Admin Console
|
||||
2. Select appropriate realm
|
||||
3. Create new client:
|
||||
- Client Protocol: openid-connect
|
||||
- Access Type: confidential
|
||||
- Valid Redirect URIs: `https://your-domain.com/auth/callback`
|
||||
4. Configure client scopes and mappers
|
||||
5. Generate client secret in Credentials tab
|
||||
|
||||
---
|
||||
|
||||
## Okta
|
||||
|
||||
### Provider URL
|
||||
```yaml
|
||||
providerUrl: "https://your-domain.okta.com"
|
||||
```
|
||||
|
||||
### Required Configuration
|
||||
```yaml
|
||||
clientId: "your-okta-client-id"
|
||||
clientSecret: "your-okta-client-secret"
|
||||
callbackUrl: "https://your-domain.com/auth/callback"
|
||||
scopes: ["openid", "profile", "email", "offline_access"]
|
||||
```
|
||||
|
||||
### Okta-Specific Features
|
||||
- **Custom authorization servers**: Supports custom auth servers
|
||||
- **Group claims**: Rich group membership information
|
||||
- **Universal Directory**: Integrates with Okta's user store
|
||||
- **Offline access**: Requires `offline_access` scope
|
||||
|
||||
### Example Configuration
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
okta-oidc:
|
||||
plugin:
|
||||
traefik-oidc:
|
||||
providerUrl: "https://company.okta.com"
|
||||
clientId: "0oa123456789abcdef"
|
||||
clientSecret: "your-okta-client-secret"
|
||||
callbackUrl: "https://app.example.com/auth/callback"
|
||||
logoutUrl: "https://app.example.com/auth/logout"
|
||||
postLogoutRedirectUri: "https://app.example.com"
|
||||
scopes: ["openid", "profile", "email", "offline_access"]
|
||||
allowedRolesAndGroups: ["Everyone", "Administrators"]
|
||||
forceHttps: true
|
||||
enablePkce: true
|
||||
```
|
||||
|
||||
### Okta Application Setup
|
||||
1. Access Okta Admin Console
|
||||
2. Go to Applications > Create App Integration
|
||||
3. Select OIDC - OpenID Connect
|
||||
4. Choose Web Application
|
||||
5. Configure:
|
||||
- Sign-in redirect URIs: `https://your-domain.com/auth/callback`
|
||||
- Sign-out redirect URIs: `https://your-domain.com/auth/logout`
|
||||
- Grant types: Authorization Code, Refresh Token
|
||||
6. Assign users or groups
|
||||
|
||||
---
|
||||
|
||||
## Generic OIDC
|
||||
|
||||
### Provider URL
|
||||
```yaml
|
||||
providerUrl: "https://your-oidc-provider.com"
|
||||
```
|
||||
|
||||
### Required Configuration
|
||||
```yaml
|
||||
clientId: "your-client-id"
|
||||
clientSecret: "your-client-secret"
|
||||
callbackUrl: "https://your-domain.com/auth/callback"
|
||||
scopes: ["openid", "profile", "email"]
|
||||
```
|
||||
|
||||
### Generic Features
|
||||
- **Standards compliance**: Works with any OIDC-compliant provider
|
||||
- **Auto-discovery**: Uses `.well-known/openid-configuration` endpoint
|
||||
- **Flexible scopes**: Supports custom scope requirements
|
||||
- **Custom claims**: Works with provider-specific claims
|
||||
|
||||
### Example Configuration
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
generic-oidc:
|
||||
plugin:
|
||||
traefik-oidc:
|
||||
providerUrl: "https://oidc.your-provider.com"
|
||||
clientId: "your-client-id"
|
||||
clientSecret: "your-client-secret"
|
||||
callbackUrl: "https://app.example.com/auth/callback"
|
||||
logoutUrl: "https://app.example.com/auth/logout"
|
||||
scopes: ["openid", "profile", "email"]
|
||||
forceHttps: true
|
||||
enablePkce: true
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Automatic Scope Filtering
|
||||
|
||||
### Overview
|
||||
|
||||
The middleware automatically filters OAuth scopes based on the provider's capabilities declared in their OIDC discovery document (`.well-known/openid-configuration`). This prevents authentication failures when providers reject unsupported scopes.
|
||||
|
||||
### How It Works
|
||||
|
||||
1. **Discovery Document Parsing**: The middleware fetches the provider's discovery document and extracts the `scopes_supported` field
|
||||
2. **Intelligent Filtering**: Requested scopes are filtered to only include those the provider supports
|
||||
3. **Fallback Behavior**: If the provider doesn't declare `scopes_supported`, all requested scopes are used (backward compatible)
|
||||
4. **Provider-Specific Handling**: Special logic for Google and Azure is preserved and applied after filtering
|
||||
|
||||
### Example Scenarios
|
||||
|
||||
#### Self-Hosted GitLab
|
||||
|
||||
**Problem**: Self-hosted GitLab instances reject the `offline_access` scope with error:
|
||||
```
|
||||
The requested scope is invalid, unknown, or malformed.
|
||||
```
|
||||
|
||||
**Solution**: The middleware automatically detects this by:
|
||||
1. Reading GitLab's discovery document at `https://gitlab.example.com/.well-known/openid-configuration`
|
||||
2. Observing that `offline_access` is NOT in the `scopes_supported` list
|
||||
3. Filtering out `offline_access` from the request
|
||||
4. Authentication succeeds
|
||||
|
||||
**Configuration**:
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
gitlab-oidc:
|
||||
plugin:
|
||||
traefik-oidc:
|
||||
providerUrl: "https://gitlab.example.com"
|
||||
clientId: "your-gitlab-application-id"
|
||||
clientSecret: "your-gitlab-application-secret"
|
||||
callbackUrl: "https://app.example.com/auth/callback"
|
||||
scopes: ["openid", "profile", "email", "offline_access"]
|
||||
# Even though offline_access is listed, it will be automatically
|
||||
# filtered out if GitLab doesn't support it
|
||||
```
|
||||
|
||||
#### Auth0 or Keycloak
|
||||
|
||||
These providers typically support `offline_access` and it will be included:
|
||||
|
||||
```yaml
|
||||
# Auth0 scopes_supported: ["openid", "profile", "email", "offline_access", ...]
|
||||
# Result: All requested scopes are sent
|
||||
```
|
||||
|
||||
### Benefits
|
||||
|
||||
1. **Self-Hosted Support**: Works seamlessly with self-hosted provider instances
|
||||
2. **No Manual Configuration**: No need to know which scopes each provider supports
|
||||
3. **Error Prevention**: Eliminates "invalid scope" authentication failures
|
||||
4. **Standards Compliant**: Uses official OIDC discovery specification (RFC 8414)
|
||||
5. **Backward Compatible**: Existing configurations continue to work
|
||||
|
||||
### Logging
|
||||
|
||||
The middleware provides detailed logging for scope filtering:
|
||||
|
||||
```
|
||||
INFO: ScopeFilter: Filtered unsupported scopes for https://gitlab.example.com: [offline_access]
|
||||
DEBUG: ScopeFilter: Provider https://gitlab.example.com supported scopes: [openid profile email read_user read_api]
|
||||
DEBUG: ScopeFilter: Final filtered scopes: [openid profile email]
|
||||
```
|
||||
|
||||
### Troubleshooting
|
||||
|
||||
**Issue**: Provider rejects scope even after filtering
|
||||
|
||||
**Possible Causes**:
|
||||
1. Provider's discovery document is outdated
|
||||
2. Provider doesn't properly implement `scopes_supported`
|
||||
3. Custom authorization server with non-standard behavior
|
||||
|
||||
**Solutions**:
|
||||
1. Use `overrideScopes: true` and explicitly list only supported scopes
|
||||
2. Check the provider's discovery document manually: `curl https://your-provider/.well-known/openid-configuration`
|
||||
3. Review middleware debug logs for filtering decisions
|
||||
|
||||
---
|
||||
|
||||
## Common Configuration Options
|
||||
|
||||
### Audience Configuration
|
||||
|
||||
The `audience` parameter specifies the expected JWT audience claim value. This is particularly important when using Auth0 APIs, Azure AD Application ID URIs, or other providers with custom audience requirements.
|
||||
|
||||
```yaml
|
||||
# Optional: Custom audience for JWT validation
|
||||
# If not set, defaults to clientID for backward compatibility
|
||||
audience: "https://api.example.com" # Auth0 API identifier
|
||||
# OR
|
||||
audience: "api://12345-guid" # Azure AD Application ID URI
|
||||
```
|
||||
|
||||
**When to use**:
|
||||
- **Auth0**: When using Auth0 APIs with custom audience parameters
|
||||
- **Azure AD**: When exposing your app as an API with Application ID URI
|
||||
- **Keycloak**: When using audience-restricted tokens
|
||||
- **Okta**: When using custom authorization servers with API audiences
|
||||
|
||||
**When to omit**:
|
||||
- For standard ID token validation (default behavior)
|
||||
- When the provider sets `aud` claim to your `clientID`
|
||||
- For backward compatibility with existing configurations
|
||||
|
||||
**Security Note**: The `audience` parameter prevents token confusion attacks by ensuring tokens issued for one service cannot be used at another service.
|
||||
|
||||
### Security Settings
|
||||
```yaml
|
||||
# Force HTTPS (recommended for production)
|
||||
forceHttps: true
|
||||
|
||||
# Enable PKCE (recommended for security)
|
||||
enablePkce: true
|
||||
|
||||
# Session encryption key (32+ characters)
|
||||
sessionEncryptionKey: "your-very-long-encryption-key-here"
|
||||
```
|
||||
|
||||
### Access Control
|
||||
```yaml
|
||||
# Restrict by email addresses
|
||||
allowedUsers: ["user1@example.com", "user2@example.com"]
|
||||
|
||||
# Restrict by email domains
|
||||
allowedUserDomains: ["company.com", "partner.org"]
|
||||
|
||||
# Restrict by roles/groups (provider-specific)
|
||||
allowedRolesAndGroups: ["admin", "users", "developers"]
|
||||
```
|
||||
|
||||
### URLs and Endpoints
|
||||
```yaml
|
||||
# OAuth callback URL (must match provider config)
|
||||
callbackUrl: "https://your-domain.com/auth/callback"
|
||||
|
||||
# Logout endpoint
|
||||
logoutUrl: "https://your-domain.com/auth/logout"
|
||||
|
||||
# Post-logout redirect (optional)
|
||||
postLogoutRedirectUri: "https://your-domain.com"
|
||||
|
||||
# URLs to exclude from authentication
|
||||
excludedUrls: ["/health", "/metrics", "/public"]
|
||||
```
|
||||
|
||||
### Advanced Settings
|
||||
```yaml
|
||||
# Override default scopes
|
||||
overrideScopes: true
|
||||
scopes: ["openid", "custom_scope"]
|
||||
|
||||
# Rate limiting (requests per second)
|
||||
rateLimit: 10
|
||||
|
||||
# Token refresh grace period (seconds)
|
||||
refreshGracePeriodSeconds: 60
|
||||
|
||||
# Cookie domain (for subdomain sharing)
|
||||
cookieDomain: ".example.com"
|
||||
|
||||
# Custom headers to inject
|
||||
headers:
|
||||
- name: "X-User-Email"
|
||||
value: "{{.Claims.email}}"
|
||||
- name: "X-User-Name"
|
||||
value: "{{.Claims.name}}"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
1. **Invalid redirect URI**
|
||||
- Ensure callback URL exactly matches provider configuration
|
||||
- Check for HTTP vs HTTPS mismatches
|
||||
|
||||
2. **Scope errors**
|
||||
- Verify required scopes are configured in provider
|
||||
- Some providers require specific scopes for refresh tokens
|
||||
|
||||
3. **Token validation failures**
|
||||
- Check provider URL format and accessibility
|
||||
- Verify `.well-known/openid-configuration` endpoint is reachable
|
||||
|
||||
4. **Session issues**
|
||||
- Ensure session encryption key is properly configured
|
||||
- Check cookie domain settings for subdomain scenarios
|
||||
|
||||
### Debug Mode
|
||||
Enable debug logging to troubleshoot configuration issues:
|
||||
```yaml
|
||||
logLevel: "debug"
|
||||
```
|
||||
|
||||
This will provide detailed logs of the authentication flow and help identify configuration problems.
|
||||
|
||||
---
|
||||
|
||||
## Security Headers Configuration
|
||||
|
||||
The plugin includes comprehensive security headers support to protect your applications against common web vulnerabilities.
|
||||
|
||||
### Default Security Headers
|
||||
|
||||
By default, the plugin applies these security headers:
|
||||
|
||||
- `X-Frame-Options: DENY` - Prevents clickjacking
|
||||
- `X-Content-Type-Options: nosniff` - Prevents MIME sniffing
|
||||
- `X-XSS-Protection: 1; mode=block` - Enables XSS protection
|
||||
- `Referrer-Policy: strict-origin-when-cross-origin` - Controls referrer information
|
||||
- `Strict-Transport-Security` - Forces HTTPS (when HTTPS is detected)
|
||||
|
||||
### Security Profiles
|
||||
|
||||
Choose from predefined security profiles or create custom configurations:
|
||||
|
||||
#### Default Profile (Recommended)
|
||||
```yaml
|
||||
securityHeaders:
|
||||
enabled: true
|
||||
profile: "default"
|
||||
```
|
||||
|
||||
#### Strict Profile (Maximum Security)
|
||||
```yaml
|
||||
securityHeaders:
|
||||
enabled: true
|
||||
profile: "strict"
|
||||
# Additional strict CSP and cross-origin policies
|
||||
```
|
||||
|
||||
#### Development Profile (Local Development)
|
||||
```yaml
|
||||
securityHeaders:
|
||||
enabled: true
|
||||
profile: "development"
|
||||
# Relaxed policies for local development
|
||||
```
|
||||
|
||||
#### API Profile (API Endpoints)
|
||||
```yaml
|
||||
securityHeaders:
|
||||
enabled: true
|
||||
profile: "api"
|
||||
corsEnabled: true
|
||||
corsAllowedOrigins: ["https://your-frontend.com"]
|
||||
```
|
||||
|
||||
### Custom Security Configuration
|
||||
|
||||
For complete control, use the custom profile:
|
||||
|
||||
```yaml
|
||||
securityHeaders:
|
||||
enabled: true
|
||||
profile: "custom"
|
||||
|
||||
# Content Security Policy
|
||||
contentSecurityPolicy: "default-src 'self'; script-src 'self' 'unsafe-inline'"
|
||||
|
||||
# HSTS Configuration
|
||||
strictTransportSecurity: true
|
||||
strictTransportSecurityMaxAge: 31536000 # 1 year
|
||||
strictTransportSecuritySubdomains: true
|
||||
strictTransportSecurityPreload: true
|
||||
|
||||
# Frame and content protection
|
||||
frameOptions: "DENY" # or "SAMEORIGIN", "ALLOW-FROM uri"
|
||||
contentTypeOptions: "nosniff"
|
||||
xssProtection: "1; mode=block"
|
||||
referrerPolicy: "strict-origin-when-cross-origin"
|
||||
|
||||
# Permissions policy (feature policy)
|
||||
permissionsPolicy: "geolocation=(), microphone=(), camera=()"
|
||||
|
||||
# Cross-origin policies
|
||||
crossOriginEmbedderPolicy: "require-corp"
|
||||
crossOriginOpenerPolicy: "same-origin"
|
||||
crossOriginResourcePolicy: "same-origin"
|
||||
|
||||
# CORS configuration
|
||||
corsEnabled: true
|
||||
corsAllowedOrigins:
|
||||
- "https://app.example.com"
|
||||
- "https://*.api.example.com"
|
||||
corsAllowedMethods: ["GET", "POST", "PUT", "DELETE", "OPTIONS"]
|
||||
corsAllowedHeaders: ["Authorization", "Content-Type", "X-Requested-With"]
|
||||
corsAllowCredentials: true
|
||||
corsMaxAge: 86400 # 24 hours
|
||||
|
||||
# Custom headers
|
||||
customHeaders:
|
||||
X-Custom-Header: "custom-value"
|
||||
X-API-Version: "v1"
|
||||
|
||||
# Server identification
|
||||
disableServerHeader: true
|
||||
disablePoweredByHeader: true
|
||||
```
|
||||
|
||||
### Complete Example with Security Headers
|
||||
|
||||
Here's a complete configuration example for Google OIDC with custom security headers:
|
||||
|
||||
```yaml
|
||||
# Traefik dynamic configuration
|
||||
http:
|
||||
middlewares:
|
||||
secure-google-oidc:
|
||||
plugin:
|
||||
traefik-oidc:
|
||||
# OIDC Configuration
|
||||
providerUrl: "https://accounts.google.com"
|
||||
clientId: "123456789-abcdef.apps.googleusercontent.com"
|
||||
clientSecret: "GOCSPX-your-client-secret"
|
||||
callbackUrl: "https://your-domain.com/auth/callback"
|
||||
sessionEncryptionKey: "your-32-character-encryption-key-here"
|
||||
|
||||
# Domain restrictions
|
||||
allowedUserDomains: ["your-company.com"]
|
||||
|
||||
# Security Headers
|
||||
securityHeaders:
|
||||
enabled: true
|
||||
profile: "strict"
|
||||
corsEnabled: true
|
||||
corsAllowedOrigins:
|
||||
- "https://your-frontend.com"
|
||||
- "https://*.your-domain.com"
|
||||
corsAllowCredentials: true
|
||||
customHeaders:
|
||||
X-Company: "YourCompany"
|
||||
X-Environment: "production"
|
||||
|
||||
routers:
|
||||
secure-app:
|
||||
rule: "Host(`your-domain.com`)"
|
||||
middlewares:
|
||||
- secure-google-oidc
|
||||
service: your-app-service
|
||||
tls:
|
||||
certResolver: letsencrypt
|
||||
```
|
||||
|
||||
### CORS Configuration Details
|
||||
|
||||
For applications with frontend-backend separation, configure CORS properly:
|
||||
|
||||
#### Simple CORS (Single Origin)
|
||||
```yaml
|
||||
securityHeaders:
|
||||
corsEnabled: true
|
||||
corsAllowedOrigins: ["https://app.example.com"]
|
||||
corsAllowCredentials: true
|
||||
```
|
||||
|
||||
#### Wildcard Subdomains
|
||||
```yaml
|
||||
securityHeaders:
|
||||
corsEnabled: true
|
||||
corsAllowedOrigins: ["https://*.example.com"]
|
||||
corsAllowCredentials: true
|
||||
```
|
||||
|
||||
#### Development with Multiple Ports
|
||||
```yaml
|
||||
securityHeaders:
|
||||
profile: "development"
|
||||
corsEnabled: true
|
||||
corsAllowedOrigins:
|
||||
- "http://localhost:*"
|
||||
- "http://127.0.0.1:*"
|
||||
```
|
||||
|
||||
### Security Best Practices
|
||||
|
||||
1. **Always use HTTPS in production**
|
||||
- Set `forceHttps: true`
|
||||
- Configure proper TLS certificates
|
||||
|
||||
2. **Implement proper CSP**
|
||||
- Start with strict policy
|
||||
- Add exceptions only when necessary
|
||||
- Test thoroughly
|
||||
|
||||
3. **Configure CORS restrictively**
|
||||
- Only allow necessary origins
|
||||
- Use specific domains instead of wildcards when possible
|
||||
|
||||
4. **Enable HSTS**
|
||||
- Use long max-age values (1 year minimum)
|
||||
- Include subdomains when appropriate
|
||||
|
||||
5. **Monitor security headers**
|
||||
- Use browser developer tools to verify headers
|
||||
- Test with security scanning tools
|
||||
- Regularly review and update policies
|
||||
|
||||
### Testing Security Headers
|
||||
|
||||
Use browser developer tools or online tools to verify your security headers:
|
||||
|
||||
1. **Browser DevTools**: Check Network tab → Response Headers
|
||||
2. **Online scanners**: Use securityheaders.com or observatory.mozilla.org
|
||||
3. **Command line**: Use `curl -I https://your-domain.com`
|
||||
|
||||
Example verification:
|
||||
```bash
|
||||
curl -I https://your-domain.com
|
||||
# Should show security headers in response
|
||||
```
|
||||
+1125
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,413 @@
|
||||
# Redis Cache Backend Test Suite
|
||||
|
||||
## Overview
|
||||
|
||||
This document describes the comprehensive test suite created for the Redis cache backend feature in the Traefik OIDC plugin. The test suite ensures reliability, performance, and correctness of the caching infrastructure.
|
||||
|
||||
## Test Structure
|
||||
|
||||
### Directory Organization
|
||||
|
||||
```
|
||||
internal/cache/
|
||||
├── backend/
|
||||
│ ├── interface.go # CacheBackend interface definition
|
||||
│ ├── interface_test.go # Contract tests for all backends
|
||||
│ ├── memory.go # In-memory backend implementation
|
||||
│ ├── memory_test.go # Memory backend unit tests
|
||||
│ ├── redis.go # Redis backend implementation
|
||||
│ ├── redis_test.go # Redis backend unit tests
|
||||
│ ├── errors.go # Error definitions
|
||||
│ └── test_helpers_test.go # Test infrastructure and helpers
|
||||
│
|
||||
└── resilience/
|
||||
├── circuit_breaker.go # Circuit breaker implementation
|
||||
├── circuit_breaker_test.go # Circuit breaker tests
|
||||
├── health_check.go # Health checker implementation
|
||||
└── health_check_test.go # Health check tests
|
||||
|
||||
redis_integration_test.go # End-to-end integration tests
|
||||
```
|
||||
|
||||
## Test Categories
|
||||
|
||||
### 1. Interface Contract Tests (`interface_test.go`)
|
||||
|
||||
**Purpose:** Ensure all backend implementations (Memory, Redis, Hybrid) comply with the CacheBackend interface contract.
|
||||
|
||||
**Test Cases:**
|
||||
- `TestCacheBackendContract` - Runs all contract tests against each backend type
|
||||
- `testBasicSetGet` - Verifies basic set/get operations
|
||||
- `testGetNonExistent` - Tests behavior for non-existent keys
|
||||
- `testUpdateExisting` - Validates updating existing keys
|
||||
- `testDelete` - Tests delete operations
|
||||
- `testDeleteNonExistent` - Delete non-existent keys
|
||||
- `testExists` - Key existence checking
|
||||
- `testTTLExpiration` - TTL and expiration behavior
|
||||
- `testClear` - Clear all keys operation
|
||||
- `testPing` - Health check functionality
|
||||
- `testStats` - Statistics tracking
|
||||
- `testConcurrentAccess` - Thread safety with 10+ goroutines
|
||||
- `testLargeValues` - Handling of 1MB+ values
|
||||
- `testEmptyValues` - Empty byte array handling
|
||||
- `testSpecialCharactersInKeys` - Special characters in key names
|
||||
|
||||
**Coverage:** ~95% of interface methods
|
||||
|
||||
### 2. Memory Backend Tests (`memory_test.go`)
|
||||
|
||||
**Purpose:** Test the in-memory LRU cache backend with comprehensive edge cases.
|
||||
|
||||
**Test Cases:**
|
||||
|
||||
#### Basic Operations (6 tests)
|
||||
- `TestMemoryBackend_BasicOperations` - CRUD operations
|
||||
- SetAndGet
|
||||
- GetNonExistent
|
||||
- Delete
|
||||
- DeleteNonExistent
|
||||
- Exists
|
||||
- Clear
|
||||
|
||||
#### TTL and Expiration (3 tests)
|
||||
- `TestMemoryBackend_TTLExpiration`
|
||||
- ShortTTL (100ms)
|
||||
- TTLDecrement over time
|
||||
- CleanupExpiredItems
|
||||
|
||||
#### LRU Eviction (2 tests)
|
||||
- `TestMemoryBackend_LRUEviction` - Verifies LRU algorithm
|
||||
- `TestMemoryBackend_MemoryLimit` - Memory-based eviction
|
||||
|
||||
#### Concurrency (1 test)
|
||||
- `TestMemoryBackend_ConcurrentAccess` - 20 goroutines, 50 iterations each
|
||||
|
||||
#### Edge Cases (6 tests)
|
||||
- `TestMemoryBackend_UpdateExisting` - Overwriting values
|
||||
- `TestMemoryBackend_Stats` - Metrics tracking (hits, misses, hit rate)
|
||||
- `TestMemoryBackend_EmptyValues` - Zero-length byte arrays
|
||||
- `TestMemoryBackend_LargeValues` - 1MB values
|
||||
- `TestMemoryBackend_Close` - Proper cleanup
|
||||
- `TestMemoryBackend_Ping` - Health checks
|
||||
- `TestMemoryBackend_ValueIsolation` - Returns copies, not references
|
||||
|
||||
**Coverage:** ~92% of memory backend code
|
||||
|
||||
### 3. Redis Backend Tests (`redis_test.go`)
|
||||
|
||||
**Purpose:** Test Redis backend using miniredis (in-memory Redis mock).
|
||||
|
||||
**Test Cases:**
|
||||
|
||||
#### Basic Operations (4 tests)
|
||||
- `TestRedisBackend_BasicOperations`
|
||||
- SetAndGet
|
||||
- GetNonExistent
|
||||
- Delete
|
||||
- Exists
|
||||
|
||||
#### Redis-Specific Features (6 tests)
|
||||
- `TestRedisBackend_KeyPrefixing` - Namespace isolation
|
||||
- `TestRedisBackend_TTLExpiration` - Redis TTL handling
|
||||
- `TestRedisBackend_Clear` - Bulk delete with SCAN
|
||||
- `TestRedisBackend_NoPrefix` - Operation without prefix
|
||||
|
||||
#### Error Handling (2 tests)
|
||||
- `TestRedisBackend_ConnectionFailure` - Connection errors
|
||||
- `TestRedisBackend_RedisErrors` - Simulated Redis failures
|
||||
|
||||
#### Concurrency (1 test)
|
||||
- `TestRedisBackend_ConcurrentAccess` - 20 goroutines, 50 operations
|
||||
|
||||
#### Advanced Features (3 tests)
|
||||
- `TestRedisBackend_PipelineOperations`
|
||||
- SetMany (batch writes)
|
||||
- GetMany (batch reads)
|
||||
- GetManyWithNonExistent
|
||||
|
||||
#### Edge Cases (5 tests)
|
||||
- `TestRedisBackend_Stats` - Statistics tracking
|
||||
- `TestRedisBackend_Ping` - Connection health
|
||||
- `TestRedisBackend_Close` - Resource cleanup
|
||||
- `TestRedisBackend_UpdateExisting` - Overwrite handling
|
||||
- `TestRedisBackend_LargeValues` - 1MB values
|
||||
- `TestRedisBackend_EmptyValues` - Empty arrays
|
||||
|
||||
**Coverage:** ~88% of Redis backend code
|
||||
|
||||
**Key Testing Tool:** `miniredis` - In-memory Redis mock that supports:
|
||||
- All basic Redis commands
|
||||
- TTL and expiration
|
||||
- Time manipulation (FastForward)
|
||||
- Error simulation
|
||||
- No external Redis server required
|
||||
|
||||
### 4. Circuit Breaker Tests (`circuit_breaker_test.go`)
|
||||
|
||||
**Purpose:** Verify circuit breaker pattern implementation for fault tolerance.
|
||||
|
||||
**Test Cases:**
|
||||
|
||||
#### State Transitions (5 tests)
|
||||
- `TestCircuitBreaker_StateTransitions`
|
||||
- Initial state (Closed)
|
||||
- Closed → Open (after max failures)
|
||||
- Open → HalfOpen (after timeout)
|
||||
- HalfOpen → Closed (after successful requests)
|
||||
- HalfOpen → Open (on failure)
|
||||
|
||||
#### Behavior Tests (5 tests)
|
||||
- `TestCircuitBreaker_OpenCircuitBlocks` - Blocks requests when open
|
||||
- `TestCircuitBreaker_HalfOpenMaxRequests` - Limits requests in half-open
|
||||
- `TestCircuitBreaker_SuccessResetsFailures` - Failure counter reset
|
||||
- `TestCircuitBreaker_ConcurrentAccess` - Thread safety
|
||||
- `TestCircuitBreaker_Stats` - Statistics tracking
|
||||
|
||||
#### Advanced Tests (7 tests)
|
||||
- `TestCircuitBreaker_Reset` - Manual reset
|
||||
- `TestCircuitBreaker_StateChangeCallback` - Notifications
|
||||
- `TestCircuitBreaker_IsAvailable` - Availability check
|
||||
- `TestCircuitBreaker_RapidFailures` - Fast consecutive failures
|
||||
- `TestCircuitBreaker_TimeoutAccuracy` - Timeout precision
|
||||
- `TestCircuitBreaker_DefaultConfig` - Default configuration
|
||||
- `TestCircuitBreaker_StateString` - String representation
|
||||
|
||||
**Benchmarks:**
|
||||
- `BenchmarkCircuitBreaker_Execute` - Successful operations
|
||||
- `BenchmarkCircuitBreaker_ExecuteWithFailures` - Mixed success/failure
|
||||
|
||||
**Coverage:** ~95% of circuit breaker code
|
||||
|
||||
### 5. Health Check Tests (`health_check_test.go`)
|
||||
|
||||
**Purpose:** Validate periodic health checking and status management.
|
||||
|
||||
**Test Cases:**
|
||||
|
||||
#### Status Transitions (4 tests)
|
||||
- `TestHealthChecker_StatusTransitions` - Healthy → Degraded → Unhealthy → Healthy
|
||||
- `TestHealthChecker_InitialState` - Default healthy state
|
||||
- `TestHealthChecker_ForceCheck` - Manual health check trigger
|
||||
- `TestHealthChecker_StatusChangeCallback` - Change notifications
|
||||
|
||||
#### Behavior Tests (6 tests)
|
||||
- `TestHealthChecker_Stats` - Statistics tracking
|
||||
- `TestHealthChecker_Timeout` - Check timeout handling
|
||||
- `TestHealthChecker_ConcurrentAccess` - Thread safety
|
||||
- `TestHealthChecker_StopAndStart` - Lifecycle management
|
||||
- `TestHealthChecker_DegradedState` - Degraded status detection
|
||||
- `TestHealthChecker_DefaultConfig` - Default settings
|
||||
|
||||
#### Advanced Tests (2 tests)
|
||||
- `TestHealthChecker_StatusString` - String representation
|
||||
- `TestHealthChecker_RecoveryPattern` - Typical failure/recovery cycle
|
||||
|
||||
**Benchmarks:**
|
||||
- `BenchmarkHealthChecker_ForceCheck` - Check performance
|
||||
- `BenchmarkHealthChecker_Status` - Status read performance
|
||||
|
||||
**Coverage:** ~90% of health checker code
|
||||
|
||||
### 6. Integration Tests (`redis_integration_test.go`)
|
||||
|
||||
**Purpose:** End-to-end testing of real-world scenarios.
|
||||
|
||||
**Test Cases:**
|
||||
|
||||
#### Multi-Instance Tests (3 tests)
|
||||
- `TestRedisIntegration_MultipleInstances`
|
||||
- ShareTokenBlacklist - JTI sharing across Traefik replicas
|
||||
- ShareTokenCache - Token cache sharing
|
||||
- ShareMetadataCache - Provider metadata sharing
|
||||
|
||||
#### Replay Detection (2 tests)
|
||||
- `TestRedisIntegration_JTIReplayDetection`
|
||||
- PreventReplayAcrossInstances - Block used JTIs
|
||||
- ConcurrentJTIChecks - Race condition handling
|
||||
|
||||
#### Resilience (1 test)
|
||||
- `TestRedisIntegration_Failover`
|
||||
- RedisTemporaryFailure - Recovery from temporary failures
|
||||
|
||||
#### Performance (1 test)
|
||||
- `TestRedisIntegration_HighLoad`
|
||||
- HighConcurrency - 50 goroutines × 100 operations
|
||||
|
||||
#### Consistency (2 tests)
|
||||
- `TestRedisIntegration_TTLConsistency` - TTL accuracy
|
||||
- `TestRedisIntegration_MemoryUsage` - 10,000 item dataset
|
||||
- `TestRedisIntegration_Cleanup` - Bulk cleanup operations
|
||||
|
||||
**Coverage:** Integration scenarios covering 80%+ of realistic use cases
|
||||
|
||||
## Test Helpers and Infrastructure
|
||||
|
||||
### Test Helpers (`test_helpers_test.go`)
|
||||
|
||||
**Utilities:**
|
||||
- `TestLogger` - Logging for tests
|
||||
- `MiniredisServer` - Miniredis setup/teardown
|
||||
- `TestConfig` - Default test configurations
|
||||
- `GenerateTestData` - Test data generation
|
||||
- `GenerateLargeValue` - Large value creation
|
||||
- `AssertCacheStats` - Statistics validation
|
||||
- `WaitForCondition` - Async condition waiting
|
||||
- `AssertEventuallyExpires` - TTL expiration verification
|
||||
|
||||
## Running the Tests
|
||||
|
||||
### Run All Tests
|
||||
```bash
|
||||
go test ./internal/cache/backend/... -v
|
||||
go test ./internal/cache/resilience/... -v
|
||||
go test -run TestRedisIntegration -v
|
||||
```
|
||||
|
||||
### Run Specific Test Suites
|
||||
```bash
|
||||
# Memory backend only
|
||||
go test ./internal/cache/backend -run TestMemoryBackend -v
|
||||
|
||||
# Redis backend only
|
||||
go test ./internal/cache/backend -run TestRedisBackend -v
|
||||
|
||||
# Circuit breaker only
|
||||
go test ./internal/cache/resilience -run TestCircuitBreaker -v
|
||||
|
||||
# Integration tests only
|
||||
go test -run TestRedisIntegration -v
|
||||
```
|
||||
|
||||
### Run with Coverage
|
||||
```bash
|
||||
go test ./internal/cache/backend/... -coverprofile=coverage.out
|
||||
go test ./internal/cache/resilience/... -coverprofile=coverage_resilience.out
|
||||
go tool cover -html=coverage.out
|
||||
```
|
||||
|
||||
### Run Benchmarks
|
||||
```bash
|
||||
go test ./internal/cache/backend -bench=. -benchmem
|
||||
go test ./internal/cache/resilience -bench=. -benchmem
|
||||
```
|
||||
|
||||
### Run with Race Detector
|
||||
```bash
|
||||
go test ./internal/cache/... -race -v
|
||||
```
|
||||
|
||||
## Test Patterns Used
|
||||
|
||||
### 1. Table-Driven Tests
|
||||
Used for testing multiple scenarios with similar structure.
|
||||
|
||||
### 2. Subtests (t.Run)
|
||||
Organized test cases into logical groups with clear names.
|
||||
|
||||
### 3. Parallel Tests
|
||||
Tests marked with `t.Parallel()` for faster execution.
|
||||
|
||||
### 4. Test Fixtures
|
||||
Reusable setup functions for common test data.
|
||||
|
||||
### 5. Mocking
|
||||
- `miniredis` for Redis operations
|
||||
- Mock functions for callbacks and health checks
|
||||
|
||||
### 6. Assertion Helpers
|
||||
Using `testify/assert` and `testify/require` for clear assertions.
|
||||
|
||||
## Test Coverage Summary
|
||||
|
||||
| Component | Coverage | Tests | Lines of Code |
|
||||
|-----------|----------|-------|---------------|
|
||||
| Interface Contract | 95% | 14 | ~200 |
|
||||
| Memory Backend | 92% | 18 | ~350 |
|
||||
| Redis Backend | 88% | 21 | ~400 |
|
||||
| Circuit Breaker | 95% | 17 | ~250 |
|
||||
| Health Checker | 90% | 12 | ~200 |
|
||||
| Integration Tests | 80% | 9 | ~300 |
|
||||
| **Total** | **90%** | **91** | **~1,700** |
|
||||
|
||||
## Edge Cases Tested
|
||||
|
||||
1. **Empty values** - Zero-length byte arrays
|
||||
2. **Large values** - 1MB+ data
|
||||
3. **Special characters** - Keys with :, /, -, _, ., |
|
||||
4. **Concurrent access** - 10-50 goroutines
|
||||
5. **TTL edge cases** - Very short (<100ms) and long (24h+) TTLs
|
||||
6. **Connection failures** - Network errors, timeouts
|
||||
7. **Redis errors** - Simulated Redis failures
|
||||
8. **Memory limits** - Eviction under memory pressure
|
||||
9. **Race conditions** - Concurrent JTI checks
|
||||
10. **State transitions** - All circuit breaker and health check states
|
||||
|
||||
## Performance Benchmarks
|
||||
|
||||
Benchmarks included for:
|
||||
- Cache operations (Set, Get, Delete)
|
||||
- Circuit breaker execution
|
||||
- Health check operations
|
||||
- Concurrent access patterns
|
||||
- Large datasets (10,000+ items)
|
||||
|
||||
## Dependencies
|
||||
|
||||
### Testing Libraries
|
||||
- `github.com/stretchr/testify` - Assertions and test utilities
|
||||
- `github.com/alicebob/miniredis/v2` - In-memory Redis mock
|
||||
- `github.com/redis/go-redis/v9` - Redis client
|
||||
|
||||
### Why Miniredis?
|
||||
- **No external dependencies** - No Redis server required
|
||||
- **Fast** - In-memory, perfect for unit tests
|
||||
- **Full Redis API** - Supports all operations we need
|
||||
- **Time manipulation** - FastForward for TTL testing
|
||||
- **Error simulation** - Test failure scenarios
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
### Planned Tests
|
||||
1. Hybrid backend tests (L1/L2 cache)
|
||||
2. Network partition scenarios
|
||||
3. Redis cluster support
|
||||
4. Persistence and recovery tests
|
||||
5. Metrics and monitoring integration
|
||||
|
||||
### Test Infrastructure Improvements
|
||||
1. Test containers for real Redis integration
|
||||
2. Performance regression tracking
|
||||
3. Chaos engineering tests
|
||||
4. Load testing framework
|
||||
|
||||
## Continuous Integration
|
||||
|
||||
### Recommended CI Configuration
|
||||
|
||||
```yaml
|
||||
test:
|
||||
script:
|
||||
- go test ./internal/cache/... -race -cover -v
|
||||
- go test -run TestRedisIntegration -v
|
||||
- go test ./internal/cache/... -bench=. -benchmem
|
||||
```
|
||||
|
||||
## Maintenance Guidelines
|
||||
|
||||
1. **Add tests for new features** - Maintain >85% coverage
|
||||
2. **Update contract tests** - When interface changes
|
||||
3. **Test edge cases** - Always test error paths
|
||||
4. **Document test purpose** - Clear comments explaining what each test validates
|
||||
5. **Keep tests fast** - Use t.Parallel() where possible
|
||||
6. **Mock external dependencies** - Use miniredis, not real Redis
|
||||
|
||||
## Conclusion
|
||||
|
||||
This comprehensive test suite provides:
|
||||
- **High confidence** in cache backend correctness
|
||||
- **Fast feedback** - Tests run in seconds
|
||||
- **Good coverage** - 90% overall
|
||||
- **Clear documentation** - Each test is well-documented
|
||||
- **Maintainability** - Clear structure and patterns
|
||||
|
||||
The test suite ensures that the Redis cache backend feature is production-ready and reliable for multi-replica Traefik deployments with shared caching requirements.
|
||||
+1373
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,550 @@
|
||||
// Package traefikoidc provides OIDC authentication middleware for Traefik
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ClientRegistrationResponse represents the response from a successful client registration (RFC 7591)
|
||||
type ClientRegistrationResponse struct {
|
||||
// Required fields
|
||||
ClientID string `json:"client_id"`
|
||||
|
||||
// Conditional - only for confidential clients
|
||||
ClientSecret string `json:"client_secret,omitempty"`
|
||||
|
||||
// Optional - for managing registration
|
||||
RegistrationAccessToken string `json:"registration_access_token,omitempty"`
|
||||
RegistrationClientURI string `json:"registration_client_uri,omitempty"`
|
||||
|
||||
// Expiration
|
||||
ClientIDIssuedAt int64 `json:"client_id_issued_at,omitempty"`
|
||||
ClientSecretExpiresAt int64 `json:"client_secret_expires_at,omitempty"`
|
||||
|
||||
// Echo back of registered metadata
|
||||
RedirectURIs []string `json:"redirect_uris,omitempty"`
|
||||
ResponseTypes []string `json:"response_types,omitempty"`
|
||||
GrantTypes []string `json:"grant_types,omitempty"`
|
||||
ApplicationType string `json:"application_type,omitempty"`
|
||||
Contacts []string `json:"contacts,omitempty"`
|
||||
ClientName string `json:"client_name,omitempty"`
|
||||
LogoURI string `json:"logo_uri,omitempty"`
|
||||
ClientURI string `json:"client_uri,omitempty"`
|
||||
PolicyURI string `json:"policy_uri,omitempty"`
|
||||
TOSURI string `json:"tos_uri,omitempty"`
|
||||
JWKSURI string `json:"jwks_uri,omitempty"`
|
||||
SubjectType string `json:"subject_type,omitempty"`
|
||||
TokenEndpointAuthMethod string `json:"token_endpoint_auth_method,omitempty"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
}
|
||||
|
||||
// ClientRegistrationError represents an error response from client registration (RFC 7591)
|
||||
type ClientRegistrationError struct {
|
||||
Error string `json:"error"`
|
||||
ErrorDescription string `json:"error_description,omitempty"`
|
||||
}
|
||||
|
||||
// DynamicClientRegistrar handles OIDC Dynamic Client Registration (RFC 7591)
|
||||
type DynamicClientRegistrar struct {
|
||||
httpClient *http.Client
|
||||
logger *Logger
|
||||
config *DynamicClientRegistrationConfig
|
||||
providerURL string
|
||||
|
||||
// Cached registration response
|
||||
mu sync.RWMutex
|
||||
registrationResponse *ClientRegistrationResponse
|
||||
}
|
||||
|
||||
// NewDynamicClientRegistrar creates a new dynamic client registrar
|
||||
func NewDynamicClientRegistrar(
|
||||
httpClient *http.Client,
|
||||
logger *Logger,
|
||||
dcrConfig *DynamicClientRegistrationConfig,
|
||||
providerURL string,
|
||||
) *DynamicClientRegistrar {
|
||||
if logger == nil {
|
||||
logger = GetSingletonNoOpLogger()
|
||||
}
|
||||
|
||||
return &DynamicClientRegistrar{
|
||||
httpClient: httpClient,
|
||||
logger: logger,
|
||||
config: dcrConfig,
|
||||
providerURL: providerURL,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterClient performs dynamic client registration with the OIDC provider
|
||||
// It first attempts to load existing credentials from a file if persistence is enabled,
|
||||
// then registers a new client if no valid credentials exist.
|
||||
func (r *DynamicClientRegistrar) RegisterClient(ctx context.Context, registrationEndpoint string) (*ClientRegistrationResponse, error) {
|
||||
if r.config == nil || !r.config.Enabled {
|
||||
return nil, fmt.Errorf("dynamic client registration is not enabled")
|
||||
}
|
||||
|
||||
// Try to load existing credentials if persistence is enabled
|
||||
if r.config.PersistCredentials {
|
||||
if resp, err := r.loadCredentials(); err == nil && resp != nil {
|
||||
// Check if credentials are still valid (not expired)
|
||||
if r.areCredentialsValid(resp) {
|
||||
r.logger.Info("Loaded existing client credentials from file")
|
||||
r.mu.Lock()
|
||||
r.registrationResponse = resp
|
||||
r.mu.Unlock()
|
||||
return resp, nil
|
||||
}
|
||||
r.logger.Info("Existing credentials expired or invalid, registering new client")
|
||||
}
|
||||
}
|
||||
|
||||
// Determine registration endpoint
|
||||
endpoint := registrationEndpoint
|
||||
if r.config.RegistrationEndpoint != "" {
|
||||
endpoint = r.config.RegistrationEndpoint
|
||||
}
|
||||
|
||||
if endpoint == "" {
|
||||
return nil, fmt.Errorf("no registration endpoint available: provider does not support dynamic client registration or endpoint not configured")
|
||||
}
|
||||
|
||||
// Validate the endpoint URL
|
||||
if !strings.HasPrefix(endpoint, "https://") {
|
||||
// Allow http only for localhost/development
|
||||
if !strings.HasPrefix(endpoint, "http://localhost") && !strings.HasPrefix(endpoint, "http://127.0.0.1") {
|
||||
return nil, fmt.Errorf("registration endpoint must use HTTPS for security")
|
||||
}
|
||||
r.logger.Infof("Warning: using insecure HTTP for registration endpoint (development only): %s", endpoint)
|
||||
}
|
||||
|
||||
// Build registration request
|
||||
reqBody, err := r.buildRegistrationRequest()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build registration request: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Debugf("Registering client at endpoint: %s", endpoint)
|
||||
|
||||
// Create HTTP request
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(reqBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create registration request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
// Add Initial Access Token if provided
|
||||
if r.config.InitialAccessToken != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+r.config.InitialAccessToken)
|
||||
}
|
||||
|
||||
// Execute request
|
||||
resp, err := r.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("registration request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Read response body
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) // 1MB limit
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read registration response: %w", err)
|
||||
}
|
||||
|
||||
// Handle error responses
|
||||
if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK {
|
||||
var regError ClientRegistrationError
|
||||
if jsonErr := json.Unmarshal(body, ®Error); jsonErr == nil && regError.Error != "" {
|
||||
return nil, fmt.Errorf("registration failed: %s - %s", regError.Error, regError.ErrorDescription)
|
||||
}
|
||||
return nil, fmt.Errorf("registration failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// Parse successful response
|
||||
var regResp ClientRegistrationResponse
|
||||
if err := json.Unmarshal(body, ®Resp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse registration response: %w", err)
|
||||
}
|
||||
|
||||
// Validate response
|
||||
if regResp.ClientID == "" {
|
||||
return nil, fmt.Errorf("registration response missing client_id")
|
||||
}
|
||||
|
||||
r.logger.Infof("Successfully registered client with ID: %s", regResp.ClientID)
|
||||
|
||||
// Cache the response
|
||||
r.mu.Lock()
|
||||
r.registrationResponse = ®Resp
|
||||
r.mu.Unlock()
|
||||
|
||||
// Persist credentials if enabled
|
||||
if r.config.PersistCredentials {
|
||||
if err := r.saveCredentials(®Resp); err != nil {
|
||||
r.logger.Errorf("Failed to persist client credentials: %v", err)
|
||||
// Don't fail registration if persistence fails
|
||||
}
|
||||
}
|
||||
|
||||
return ®Resp, nil
|
||||
}
|
||||
|
||||
// buildRegistrationRequest creates the JSON request body for client registration
|
||||
func (r *DynamicClientRegistrar) buildRegistrationRequest() ([]byte, error) {
|
||||
metadata := r.config.ClientMetadata
|
||||
if metadata == nil {
|
||||
metadata = &ClientRegistrationMetadata{}
|
||||
}
|
||||
|
||||
// Build request object
|
||||
reqData := make(map[string]interface{})
|
||||
|
||||
// Required: redirect_uris
|
||||
if len(metadata.RedirectURIs) > 0 {
|
||||
reqData["redirect_uris"] = metadata.RedirectURIs
|
||||
} else {
|
||||
return nil, fmt.Errorf("redirect_uris is required for client registration")
|
||||
}
|
||||
|
||||
// Optional fields - only include if set
|
||||
if len(metadata.ResponseTypes) > 0 {
|
||||
reqData["response_types"] = metadata.ResponseTypes
|
||||
} else {
|
||||
// Default to authorization code flow
|
||||
reqData["response_types"] = []string{"code"}
|
||||
}
|
||||
|
||||
if len(metadata.GrantTypes) > 0 {
|
||||
reqData["grant_types"] = metadata.GrantTypes
|
||||
} else {
|
||||
// Default grant types for authorization code flow
|
||||
reqData["grant_types"] = []string{"authorization_code", "refresh_token"}
|
||||
}
|
||||
|
||||
if metadata.ApplicationType != "" {
|
||||
reqData["application_type"] = metadata.ApplicationType
|
||||
}
|
||||
|
||||
if len(metadata.Contacts) > 0 {
|
||||
reqData["contacts"] = metadata.Contacts
|
||||
}
|
||||
|
||||
if metadata.ClientName != "" {
|
||||
reqData["client_name"] = metadata.ClientName
|
||||
}
|
||||
|
||||
if metadata.LogoURI != "" {
|
||||
reqData["logo_uri"] = metadata.LogoURI
|
||||
}
|
||||
|
||||
if metadata.ClientURI != "" {
|
||||
reqData["client_uri"] = metadata.ClientURI
|
||||
}
|
||||
|
||||
if metadata.PolicyURI != "" {
|
||||
reqData["policy_uri"] = metadata.PolicyURI
|
||||
}
|
||||
|
||||
if metadata.TOSURI != "" {
|
||||
reqData["tos_uri"] = metadata.TOSURI
|
||||
}
|
||||
|
||||
if metadata.JWKSURI != "" {
|
||||
reqData["jwks_uri"] = metadata.JWKSURI
|
||||
}
|
||||
|
||||
if metadata.SubjectType != "" {
|
||||
reqData["subject_type"] = metadata.SubjectType
|
||||
}
|
||||
|
||||
if metadata.TokenEndpointAuthMethod != "" {
|
||||
reqData["token_endpoint_auth_method"] = metadata.TokenEndpointAuthMethod
|
||||
} else {
|
||||
// Default to client_secret_basic for confidential clients
|
||||
reqData["token_endpoint_auth_method"] = "client_secret_basic"
|
||||
}
|
||||
|
||||
if metadata.DefaultMaxAge > 0 {
|
||||
reqData["default_max_age"] = metadata.DefaultMaxAge
|
||||
}
|
||||
|
||||
if metadata.RequireAuthTime {
|
||||
reqData["require_auth_time"] = metadata.RequireAuthTime
|
||||
}
|
||||
|
||||
if len(metadata.DefaultACRValues) > 0 {
|
||||
reqData["default_acr_values"] = metadata.DefaultACRValues
|
||||
}
|
||||
|
||||
if metadata.Scope != "" {
|
||||
reqData["scope"] = metadata.Scope
|
||||
}
|
||||
|
||||
return json.Marshal(reqData)
|
||||
}
|
||||
|
||||
// GetCachedResponse returns the cached registration response
|
||||
func (r *DynamicClientRegistrar) GetCachedResponse() *ClientRegistrationResponse {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
return r.registrationResponse
|
||||
}
|
||||
|
||||
// areCredentialsValid checks if the cached credentials are still valid
|
||||
func (r *DynamicClientRegistrar) areCredentialsValid(resp *ClientRegistrationResponse) bool {
|
||||
if resp == nil || resp.ClientID == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if secret has expired
|
||||
if resp.ClientSecretExpiresAt > 0 {
|
||||
expiresAt := time.Unix(resp.ClientSecretExpiresAt, 0)
|
||||
// Add 5 minute buffer before expiration
|
||||
if time.Now().Add(5 * time.Minute).After(expiresAt) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// credentialsFilePath returns the path for storing credentials
|
||||
func (r *DynamicClientRegistrar) credentialsFilePath() string {
|
||||
if r.config.CredentialsFile != "" {
|
||||
return r.config.CredentialsFile
|
||||
}
|
||||
return "/tmp/oidc-client-credentials.json"
|
||||
}
|
||||
|
||||
// saveCredentials persists client credentials to a file
|
||||
func (r *DynamicClientRegistrar) saveCredentials(resp *ClientRegistrationResponse) error {
|
||||
filePath := r.credentialsFilePath()
|
||||
|
||||
data, err := json.MarshalIndent(resp, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal credentials: %w", err)
|
||||
}
|
||||
|
||||
// Write with restrictive permissions (owner read/write only)
|
||||
if err := os.WriteFile(filePath, data, 0600); err != nil {
|
||||
return fmt.Errorf("failed to write credentials file: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Debugf("Saved client credentials to %s", filePath)
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadCredentials loads client credentials from a file
|
||||
func (r *DynamicClientRegistrar) loadCredentials() (*ClientRegistrationResponse, error) {
|
||||
filePath := r.credentialsFilePath()
|
||||
|
||||
data, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, nil // No credentials file exists
|
||||
}
|
||||
return nil, fmt.Errorf("failed to read credentials file: %w", err)
|
||||
}
|
||||
|
||||
var resp ClientRegistrationResponse
|
||||
if err := json.Unmarshal(data, &resp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse credentials file: %w", err)
|
||||
}
|
||||
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// UpdateClientRegistration updates an existing client registration using RFC 7592
|
||||
// This requires the registration_client_uri and registration_access_token from the original registration
|
||||
func (r *DynamicClientRegistrar) UpdateClientRegistration(ctx context.Context) (*ClientRegistrationResponse, error) {
|
||||
r.mu.RLock()
|
||||
cachedResp := r.registrationResponse
|
||||
r.mu.RUnlock()
|
||||
|
||||
if cachedResp == nil {
|
||||
return nil, fmt.Errorf("no existing registration to update")
|
||||
}
|
||||
|
||||
if cachedResp.RegistrationClientURI == "" || cachedResp.RegistrationAccessToken == "" {
|
||||
return nil, fmt.Errorf("registration management not supported: missing registration_client_uri or registration_access_token")
|
||||
}
|
||||
|
||||
// Build update request
|
||||
reqBody, err := r.buildRegistrationRequest()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build update request: %w", err)
|
||||
}
|
||||
|
||||
// Create HTTP request
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPut, cachedResp.RegistrationClientURI, bytes.NewReader(reqBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create update request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+cachedResp.RegistrationAccessToken)
|
||||
|
||||
// Execute request
|
||||
resp, err := r.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("update request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Read response body
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read update response: %w", err)
|
||||
}
|
||||
|
||||
// Handle error responses
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
var regError ClientRegistrationError
|
||||
if jsonErr := json.Unmarshal(body, ®Error); jsonErr == nil && regError.Error != "" {
|
||||
return nil, fmt.Errorf("update failed: %s - %s", regError.Error, regError.ErrorDescription)
|
||||
}
|
||||
return nil, fmt.Errorf("update failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// Parse successful response
|
||||
var regResp ClientRegistrationResponse
|
||||
if err := json.Unmarshal(body, ®Resp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse update response: %w", err)
|
||||
}
|
||||
|
||||
// Update cache
|
||||
r.mu.Lock()
|
||||
r.registrationResponse = ®Resp
|
||||
r.mu.Unlock()
|
||||
|
||||
// Persist updated credentials if enabled
|
||||
if r.config.PersistCredentials {
|
||||
if err := r.saveCredentials(®Resp); err != nil {
|
||||
r.logger.Errorf("Failed to persist updated credentials: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
r.logger.Infof("Successfully updated client registration for client ID: %s", regResp.ClientID)
|
||||
return ®Resp, nil
|
||||
}
|
||||
|
||||
// ReadClientRegistration reads the current client registration using RFC 7592
|
||||
func (r *DynamicClientRegistrar) ReadClientRegistration(ctx context.Context) (*ClientRegistrationResponse, error) {
|
||||
r.mu.RLock()
|
||||
cachedResp := r.registrationResponse
|
||||
r.mu.RUnlock()
|
||||
|
||||
if cachedResp == nil {
|
||||
return nil, fmt.Errorf("no existing registration to read")
|
||||
}
|
||||
|
||||
if cachedResp.RegistrationClientURI == "" || cachedResp.RegistrationAccessToken == "" {
|
||||
return nil, fmt.Errorf("registration management not supported: missing registration_client_uri or registration_access_token")
|
||||
}
|
||||
|
||||
// Create HTTP request
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, cachedResp.RegistrationClientURI, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create read request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+cachedResp.RegistrationAccessToken)
|
||||
|
||||
// Execute request
|
||||
resp, err := r.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Read response body
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
// Handle error responses
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
var regError ClientRegistrationError
|
||||
if jsonErr := json.Unmarshal(body, ®Error); jsonErr == nil && regError.Error != "" {
|
||||
return nil, fmt.Errorf("read failed: %s - %s", regError.Error, regError.ErrorDescription)
|
||||
}
|
||||
return nil, fmt.Errorf("read failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// Parse successful response
|
||||
var regResp ClientRegistrationResponse
|
||||
if err := json.Unmarshal(body, ®Resp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse read response: %w", err)
|
||||
}
|
||||
|
||||
return ®Resp, nil
|
||||
}
|
||||
|
||||
// DeleteClientRegistration deletes the client registration using RFC 7592
|
||||
func (r *DynamicClientRegistrar) DeleteClientRegistration(ctx context.Context) error {
|
||||
r.mu.RLock()
|
||||
cachedResp := r.registrationResponse
|
||||
r.mu.RUnlock()
|
||||
|
||||
if cachedResp == nil {
|
||||
return fmt.Errorf("no existing registration to delete")
|
||||
}
|
||||
|
||||
if cachedResp.RegistrationClientURI == "" || cachedResp.RegistrationAccessToken == "" {
|
||||
return fmt.Errorf("registration management not supported: missing registration_client_uri or registration_access_token")
|
||||
}
|
||||
|
||||
// Create HTTP request
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, cachedResp.RegistrationClientURI, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create delete request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+cachedResp.RegistrationAccessToken)
|
||||
|
||||
// Execute request
|
||||
resp, err := r.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Handle error responses (204 No Content is success)
|
||||
if resp.StatusCode != http.StatusNoContent && resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
var regError ClientRegistrationError
|
||||
if jsonErr := json.Unmarshal(body, ®Error); jsonErr == nil && regError.Error != "" {
|
||||
return fmt.Errorf("delete failed: %s - %s", regError.Error, regError.ErrorDescription)
|
||||
}
|
||||
return fmt.Errorf("delete failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// Clear cache
|
||||
r.mu.Lock()
|
||||
r.registrationResponse = nil
|
||||
r.mu.Unlock()
|
||||
|
||||
// Remove credentials file if persistence is enabled
|
||||
if r.config.PersistCredentials {
|
||||
filePath := r.credentialsFilePath()
|
||||
if err := os.Remove(filePath); err != nil && !os.IsNotExist(err) {
|
||||
r.logger.Errorf("Failed to remove credentials file: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
r.logger.Info("Successfully deleted client registration")
|
||||
return nil
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
+5
-3
@@ -123,8 +123,10 @@ func (b *BaseRecoveryMechanism) GetBaseMetrics() map[string]interface{} {
|
||||
metrics["seconds_since_last_success"] = time.Since(b.lastSuccessTime).Seconds()
|
||||
}
|
||||
|
||||
if metrics["total_requests"].(int64) > 0 {
|
||||
successRate := float64(metrics["total_successes"].(int64)) / float64(metrics["total_requests"].(int64))
|
||||
totalReq, _ := metrics["total_requests"].(int64) // Safe to ignore: type assertion with fallback
|
||||
totalSucc, _ := metrics["total_successes"].(int64) // Safe to ignore: type assertion with fallback
|
||||
if totalReq > 0 {
|
||||
successRate := float64(totalSucc) / float64(totalReq)
|
||||
metrics["success_rate"] = successRate
|
||||
} else {
|
||||
metrics["success_rate"] = 1.0
|
||||
@@ -963,7 +965,7 @@ func (gd *GracefulDegradation) Close() {
|
||||
// Don't set to nil to avoid race conditions
|
||||
}
|
||||
|
||||
gd.logger.Info("GracefulDegradation shut down successfully")
|
||||
gd.logger.Debug("GracefulDegradation shut down successfully")
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,242 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestDefaultCircuitBreakerConfig tests the default configuration function
|
||||
func TestDefaultCircuitBreakerConfig(t *testing.T) {
|
||||
config := DefaultCircuitBreakerConfig()
|
||||
|
||||
// Test default values
|
||||
if config.MaxFailures != 2 {
|
||||
t.Errorf("Expected MaxFailures 2, got %d", config.MaxFailures)
|
||||
}
|
||||
|
||||
if config.Timeout != 60*time.Second {
|
||||
t.Errorf("Expected Timeout 60s, got %v", config.Timeout)
|
||||
}
|
||||
|
||||
if config.ResetTimeout != 30*time.Second {
|
||||
t.Errorf("Expected ResetTimeout 30s, got %v", config.ResetTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBaseRecoveryMechanism_GetBaseMetrics tests getting base metrics
|
||||
func TestBaseRecoveryMechanism_GetBaseMetrics(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
base := NewBaseRecoveryMechanism("test-mechanism", logger)
|
||||
|
||||
metrics := base.GetBaseMetrics()
|
||||
|
||||
if metrics == nil {
|
||||
t.Fatal("Expected non-nil metrics")
|
||||
}
|
||||
|
||||
// Check expected metric fields
|
||||
expectedFields := []string{
|
||||
"total_requests",
|
||||
"total_failures",
|
||||
"total_successes",
|
||||
"uptime_seconds",
|
||||
"name",
|
||||
}
|
||||
|
||||
for _, field := range expectedFields {
|
||||
if _, exists := metrics[field]; !exists {
|
||||
t.Errorf("Expected metric field %s to exist", field)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestBaseRecoveryMechanism_RecordRequest tests request recording
|
||||
func TestBaseRecoveryMechanism_RecordRequest(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
base := NewBaseRecoveryMechanism("test-mechanism", logger)
|
||||
|
||||
// Record some requests
|
||||
base.RecordRequest()
|
||||
base.RecordRequest()
|
||||
base.RecordRequest()
|
||||
|
||||
// Get metrics to verify
|
||||
metrics := base.GetBaseMetrics()
|
||||
totalRequests := metrics["total_requests"].(int64)
|
||||
|
||||
if totalRequests != 3 {
|
||||
t.Errorf("Expected 3 total requests, got %d", totalRequests)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBaseRecoveryMechanism_RecordSuccess tests success recording
|
||||
func TestBaseRecoveryMechanism_RecordSuccess(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
base := NewBaseRecoveryMechanism("test-mechanism", logger)
|
||||
|
||||
// Record some successes
|
||||
base.RecordSuccess()
|
||||
base.RecordSuccess()
|
||||
|
||||
// Get metrics to verify
|
||||
metrics := base.GetBaseMetrics()
|
||||
totalSuccesses := metrics["total_successes"].(int64)
|
||||
|
||||
if totalSuccesses != 2 {
|
||||
t.Errorf("Expected 2 successful requests, got %d", totalSuccesses)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBaseRecoveryMechanism_RecordFailure tests failure recording
|
||||
func TestBaseRecoveryMechanism_RecordFailure(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
base := NewBaseRecoveryMechanism("test-mechanism", logger)
|
||||
|
||||
// Record some failures
|
||||
base.RecordFailure()
|
||||
base.RecordFailure()
|
||||
base.RecordFailure()
|
||||
|
||||
// Get metrics to verify
|
||||
metrics := base.GetBaseMetrics()
|
||||
totalFailures := metrics["total_failures"].(int64)
|
||||
|
||||
if totalFailures != 3 {
|
||||
t.Errorf("Expected 3 failed requests, got %d", totalFailures)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBaseRecoveryMechanism_LogInfo tests info logging
|
||||
func TestBaseRecoveryMechanism_LogInfo(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
base := NewBaseRecoveryMechanism("test-mechanism", logger)
|
||||
|
||||
// Test logging doesn't panic
|
||||
base.LogInfo("test message")
|
||||
base.LogInfo("test message with args: %s %d", "arg1", 42)
|
||||
|
||||
// Test with nil logger
|
||||
baseNoLogger := NewBaseRecoveryMechanism("test", nil)
|
||||
baseNoLogger.LogInfo("test message") // Should not panic
|
||||
}
|
||||
|
||||
// TestBaseRecoveryMechanism_LogError tests error logging
|
||||
func TestBaseRecoveryMechanism_LogError(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
base := NewBaseRecoveryMechanism("test-mechanism", logger)
|
||||
|
||||
// Test logging doesn't panic
|
||||
base.LogError("error message")
|
||||
base.LogError("error message with args: %s %d", "error", 500)
|
||||
|
||||
// Test with nil logger
|
||||
baseNoLogger := NewBaseRecoveryMechanism("test", nil)
|
||||
baseNoLogger.LogError("error message") // Should not panic
|
||||
}
|
||||
|
||||
// TestBaseRecoveryMechanism_LogDebug tests debug logging
|
||||
func TestBaseRecoveryMechanism_LogDebug(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
base := NewBaseRecoveryMechanism("test-mechanism", logger)
|
||||
|
||||
// Test logging doesn't panic
|
||||
base.LogDebug("debug message")
|
||||
base.LogDebug("debug message with args: %s %d", "debug", 123)
|
||||
|
||||
// Test with nil logger
|
||||
baseNoLogger := NewBaseRecoveryMechanism("test", nil)
|
||||
baseNoLogger.LogDebug("debug message") // Should not panic
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_GetState tests getting circuit breaker state
|
||||
func TestCircuitBreaker_GetState(t *testing.T) {
|
||||
config := DefaultCircuitBreakerConfig()
|
||||
logger := GetSingletonNoOpLogger()
|
||||
cb := NewCircuitBreaker(config, logger)
|
||||
|
||||
// Initial state should be closed
|
||||
state := cb.GetState()
|
||||
if state != CircuitBreakerClosed {
|
||||
t.Errorf("Expected initial state to be closed, got %d", state)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_Reset tests resetting circuit breaker
|
||||
func TestCircuitBreaker_Reset(t *testing.T) {
|
||||
config := DefaultCircuitBreakerConfig()
|
||||
logger := GetSingletonNoOpLogger()
|
||||
cb := NewCircuitBreaker(config, logger)
|
||||
|
||||
// Reset should not panic
|
||||
cb.Reset()
|
||||
|
||||
// State should be closed after reset
|
||||
state := cb.GetState()
|
||||
if state != CircuitBreakerClosed {
|
||||
t.Errorf("Expected state to be closed after reset, got %d", state)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_IsAvailable tests availability check
|
||||
func TestCircuitBreaker_IsAvailable(t *testing.T) {
|
||||
config := DefaultCircuitBreakerConfig()
|
||||
logger := GetSingletonNoOpLogger()
|
||||
cb := NewCircuitBreaker(config, logger)
|
||||
|
||||
// Initially should be available
|
||||
available := cb.IsAvailable()
|
||||
if !available {
|
||||
t.Error("Expected circuit breaker to be available initially")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_GetMetrics tests getting circuit breaker metrics
|
||||
func TestCircuitBreaker_GetMetrics(t *testing.T) {
|
||||
config := DefaultCircuitBreakerConfig()
|
||||
logger := GetSingletonNoOpLogger()
|
||||
cb := NewCircuitBreaker(config, logger)
|
||||
|
||||
metrics := cb.GetMetrics()
|
||||
if metrics == nil {
|
||||
t.Fatal("Expected non-nil metrics")
|
||||
}
|
||||
|
||||
// Should include base metrics
|
||||
if _, exists := metrics["total_requests"]; !exists {
|
||||
t.Error("Expected total_requests in metrics")
|
||||
}
|
||||
|
||||
// Should include circuit breaker specific metrics
|
||||
if _, exists := metrics["state"]; !exists {
|
||||
t.Error("Expected state in metrics")
|
||||
}
|
||||
}
|
||||
|
||||
// Retry mechanism tests removed due to complex dependencies
|
||||
|
||||
// Benchmark tests
|
||||
func BenchmarkDefaultCircuitBreakerConfig(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
DefaultCircuitBreakerConfig()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkBaseRecoveryMechanism_GetBaseMetrics(b *testing.B) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
base := NewBaseRecoveryMechanism("test-mechanism", logger)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
base.GetBaseMetrics()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkBaseRecoveryMechanism_RecordRequest(b *testing.B) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
base := NewBaseRecoveryMechanism("test-mechanism", logger)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
base.RecordRequest()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,560 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestRetryExecutorReset tests the Reset method
|
||||
func TestRetryExecutorReset(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
executor := NewRetryExecutor(DefaultRetryConfig(), logger)
|
||||
|
||||
require.NotNil(t, executor)
|
||||
|
||||
// Should not panic
|
||||
assert.NotPanics(t, func() {
|
||||
executor.Reset()
|
||||
})
|
||||
|
||||
// Multiple resets should be safe
|
||||
executor.Reset()
|
||||
executor.Reset()
|
||||
}
|
||||
|
||||
// TestRetryExecutorIsAvailable tests the IsAvailable method
|
||||
func TestRetryExecutorIsAvailable(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
executor := NewRetryExecutor(DefaultRetryConfig(), logger)
|
||||
|
||||
// Retry executor should always be available
|
||||
assert.True(t, executor.IsAvailable())
|
||||
|
||||
// Should remain available after operations
|
||||
ctx := context.Background()
|
||||
executor.ExecuteWithContext(ctx, func() error {
|
||||
return nil
|
||||
})
|
||||
|
||||
assert.True(t, executor.IsAvailable())
|
||||
}
|
||||
|
||||
// TestSessionErrorUnwrap tests SessionError.Unwrap
|
||||
func TestSessionErrorUnwrap(t *testing.T) {
|
||||
t.Run("unwrap with cause", func(t *testing.T) {
|
||||
rootErr := errors.New("root cause")
|
||||
sessionErr := NewSessionError("save", "failed to save session", rootErr)
|
||||
|
||||
unwrapped := sessionErr.Unwrap()
|
||||
assert.Equal(t, rootErr, unwrapped)
|
||||
})
|
||||
|
||||
t.Run("unwrap without cause", func(t *testing.T) {
|
||||
sessionErr := NewSessionError("load", "failed to load session", nil)
|
||||
|
||||
unwrapped := sessionErr.Unwrap()
|
||||
assert.Nil(t, unwrapped)
|
||||
})
|
||||
|
||||
t.Run("error chain", func(t *testing.T) {
|
||||
rootErr := errors.New("database error")
|
||||
sessionErr := NewSessionError("delete", "failed to delete session", rootErr)
|
||||
|
||||
// Verify error chain works
|
||||
assert.True(t, errors.Is(sessionErr, rootErr))
|
||||
})
|
||||
}
|
||||
|
||||
// TestTokenErrorUnwrap tests TokenError.Unwrap
|
||||
func TestTokenErrorUnwrap(t *testing.T) {
|
||||
t.Run("unwrap with cause", func(t *testing.T) {
|
||||
rootErr := errors.New("signature verification failed")
|
||||
tokenErr := NewTokenError("id_token", "invalid", "token is invalid", rootErr)
|
||||
|
||||
unwrapped := tokenErr.Unwrap()
|
||||
assert.Equal(t, rootErr, unwrapped)
|
||||
})
|
||||
|
||||
t.Run("unwrap without cause", func(t *testing.T) {
|
||||
tokenErr := NewTokenError("access_token", "expired", "token has expired", nil)
|
||||
|
||||
unwrapped := tokenErr.Unwrap()
|
||||
assert.Nil(t, unwrapped)
|
||||
})
|
||||
|
||||
t.Run("error chain", func(t *testing.T) {
|
||||
rootErr := errors.New("crypto error")
|
||||
tokenErr := NewTokenError("refresh_token", "malformed", "token is malformed", rootErr)
|
||||
|
||||
// Verify error chain works
|
||||
assert.True(t, errors.Is(tokenErr, rootErr))
|
||||
})
|
||||
}
|
||||
|
||||
// TestGracefulDegradationRegisterFallback tests fallback registration
|
||||
func TestGracefulDegradationRegisterFallback(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultGracefulDegradationConfig()
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
t.Run("register single fallback", func(t *testing.T) {
|
||||
fallback := func() (interface{}, error) {
|
||||
return "fallback result", nil
|
||||
}
|
||||
|
||||
gd.RegisterFallback("service1", fallback)
|
||||
|
||||
// Verify fallback was registered (indirectly)
|
||||
result, err := gd.ExecuteWithFallback("service1", func() (interface{}, error) {
|
||||
return nil, errors.New("service failed")
|
||||
})
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "fallback result", result)
|
||||
})
|
||||
|
||||
t.Run("register multiple fallbacks", func(t *testing.T) {
|
||||
gd.RegisterFallback("service2", func() (interface{}, error) {
|
||||
return "fallback2", nil
|
||||
})
|
||||
gd.RegisterFallback("service3", func() (interface{}, error) {
|
||||
return "fallback3", nil
|
||||
})
|
||||
|
||||
result2, _ := gd.ExecuteWithFallback("service2", func() (interface{}, error) {
|
||||
return nil, errors.New("fail")
|
||||
})
|
||||
result3, _ := gd.ExecuteWithFallback("service3", func() (interface{}, error) {
|
||||
return nil, errors.New("fail")
|
||||
})
|
||||
|
||||
assert.Equal(t, "fallback2", result2)
|
||||
assert.Equal(t, "fallback3", result3)
|
||||
})
|
||||
|
||||
t.Run("override existing fallback", func(t *testing.T) {
|
||||
gd.RegisterFallback("service4", func() (interface{}, error) {
|
||||
return "old fallback", nil
|
||||
})
|
||||
gd.RegisterFallback("service4", func() (interface{}, error) {
|
||||
return "new fallback", nil
|
||||
})
|
||||
|
||||
result, _ := gd.ExecuteWithFallback("service4", func() (interface{}, error) {
|
||||
return nil, errors.New("fail")
|
||||
})
|
||||
|
||||
assert.Equal(t, "new fallback", result)
|
||||
})
|
||||
}
|
||||
|
||||
// TestGracefulDegradationRegisterHealthCheck tests health check registration
|
||||
func TestGracefulDegradationRegisterHealthCheck(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultGracefulDegradationConfig()
|
||||
config.HealthCheckInterval = 50 * time.Millisecond
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
t.Run("register health check", func(t *testing.T) {
|
||||
healthy := true
|
||||
healthCheck := func() bool {
|
||||
return healthy
|
||||
}
|
||||
|
||||
gd.RegisterHealthCheck("service1", healthCheck)
|
||||
|
||||
// Mark service as degraded
|
||||
gd.markServiceDegraded("service1")
|
||||
assert.True(t, gd.isServiceDegraded("service1"))
|
||||
|
||||
// Set healthy and wait for health check to run
|
||||
healthy = true
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Service should be recovered
|
||||
// (may still be degraded due to timing, but health check was registered)
|
||||
})
|
||||
|
||||
t.Run("multiple health checks", func(t *testing.T) {
|
||||
gd.RegisterHealthCheck("service2", func() bool { return true })
|
||||
gd.RegisterHealthCheck("service3", func() bool { return false })
|
||||
|
||||
// Health checks are registered and will be called periodically
|
||||
})
|
||||
}
|
||||
|
||||
// TestGracefulDegradationExecuteWithContext tests ExecuteWithContext
|
||||
func TestGracefulDegradationExecuteWithContext(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultGracefulDegradationConfig()
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
t.Run("successful execution", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
err := gd.ExecuteWithContext(ctx, func() error {
|
||||
return nil
|
||||
})
|
||||
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("failed execution", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
testErr := errors.New("operation failed")
|
||||
|
||||
err := gd.ExecuteWithContext(ctx, func() error {
|
||||
return testErr
|
||||
})
|
||||
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("uses fallback on failure", func(t *testing.T) {
|
||||
gd.RegisterFallback("default", func() (interface{}, error) {
|
||||
return nil, nil // Success fallback
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
err := gd.ExecuteWithContext(ctx, func() error {
|
||||
return errors.New("primary failed")
|
||||
})
|
||||
|
||||
// With fallback succeeding, overall operation succeeds
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// TestGracefulDegradationExecuteWithFallback tests ExecuteWithFallback
|
||||
func TestGracefulDegradationExecuteWithFallback(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultGracefulDegradationConfig()
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
t.Run("primary succeeds", func(t *testing.T) {
|
||||
result, err := gd.ExecuteWithFallback("service1", func() (interface{}, error) {
|
||||
return "primary result", nil
|
||||
})
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "primary result", result)
|
||||
})
|
||||
|
||||
t.Run("fallback succeeds when primary fails", func(t *testing.T) {
|
||||
gd.RegisterFallback("service2", func() (interface{}, error) {
|
||||
return "fallback result", nil
|
||||
})
|
||||
|
||||
result, err := gd.ExecuteWithFallback("service2", func() (interface{}, error) {
|
||||
return nil, errors.New("primary failed")
|
||||
})
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "fallback result", result)
|
||||
})
|
||||
|
||||
t.Run("error when no fallback available", func(t *testing.T) {
|
||||
config.EnableFallbacks = false
|
||||
gdNoFallback := NewGracefulDegradation(config, logger)
|
||||
defer gdNoFallback.Close()
|
||||
|
||||
result, err := gdNoFallback.ExecuteWithFallback("service3", func() (interface{}, error) {
|
||||
return nil, errors.New("primary failed")
|
||||
})
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
})
|
||||
|
||||
t.Run("fallback also fails", func(t *testing.T) {
|
||||
gd.RegisterFallback("service4", func() (interface{}, error) {
|
||||
return nil, errors.New("fallback also failed")
|
||||
})
|
||||
|
||||
result, err := gd.ExecuteWithFallback("service4", func() (interface{}, error) {
|
||||
return nil, errors.New("primary failed")
|
||||
})
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
assert.Contains(t, err.Error(), "fallback also failed")
|
||||
})
|
||||
}
|
||||
|
||||
// TestGracefulDegradationIsServiceDegraded tests service degradation status
|
||||
func TestGracefulDegradationIsServiceDegraded(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultGracefulDegradationConfig()
|
||||
config.RecoveryTimeout = 100 * time.Millisecond
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
t.Run("service not degraded initially", func(t *testing.T) {
|
||||
assert.False(t, gd.isServiceDegraded("new-service"))
|
||||
})
|
||||
|
||||
t.Run("service degraded after marking", func(t *testing.T) {
|
||||
gd.markServiceDegraded("service1")
|
||||
assert.True(t, gd.isServiceDegraded("service1"))
|
||||
})
|
||||
|
||||
t.Run("service recovers after timeout", func(t *testing.T) {
|
||||
gd.markServiceDegraded("service2")
|
||||
assert.True(t, gd.isServiceDegraded("service2"))
|
||||
|
||||
// Wait for recovery timeout
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// Should be recovered
|
||||
assert.False(t, gd.isServiceDegraded("service2"))
|
||||
})
|
||||
}
|
||||
|
||||
// TestGracefulDegradationMarkServiceDegraded tests marking services as degraded
|
||||
func TestGracefulDegradationMarkServiceDegraded(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultGracefulDegradationConfig()
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
t.Run("mark single service", func(t *testing.T) {
|
||||
gd.markServiceDegraded("service1")
|
||||
|
||||
degraded := gd.GetDegradedServices()
|
||||
assert.Contains(t, degraded, "service1")
|
||||
})
|
||||
|
||||
t.Run("mark multiple services", func(t *testing.T) {
|
||||
gd.markServiceDegraded("service2")
|
||||
gd.markServiceDegraded("service3")
|
||||
|
||||
degraded := gd.GetDegradedServices()
|
||||
assert.Contains(t, degraded, "service2")
|
||||
assert.Contains(t, degraded, "service3")
|
||||
})
|
||||
|
||||
t.Run("marking same service multiple times updates timestamp", func(t *testing.T) {
|
||||
gd.markServiceDegraded("service4")
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
gd.markServiceDegraded("service4")
|
||||
|
||||
// Service should still be marked as degraded
|
||||
assert.True(t, gd.isServiceDegraded("service4"))
|
||||
})
|
||||
}
|
||||
|
||||
// TestGracefulDegradationExecuteFallback tests fallback execution
|
||||
func TestGracefulDegradationExecuteFallback(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultGracefulDegradationConfig()
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
t.Run("execute registered fallback", func(t *testing.T) {
|
||||
gd.RegisterFallback("service1", func() (interface{}, error) {
|
||||
return "fallback value", nil
|
||||
})
|
||||
|
||||
result, err := gd.executeFallback("service1")
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "fallback value", result)
|
||||
})
|
||||
|
||||
t.Run("error when fallback not registered", func(t *testing.T) {
|
||||
result, err := gd.executeFallback("non-existent-service")
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
assert.Contains(t, err.Error(), "no fallback available")
|
||||
})
|
||||
|
||||
t.Run("propagate fallback errors", func(t *testing.T) {
|
||||
gd.RegisterFallback("service2", func() (interface{}, error) {
|
||||
return nil, errors.New("fallback error")
|
||||
})
|
||||
|
||||
result, err := gd.executeFallback("service2")
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
assert.Contains(t, err.Error(), "fallback error")
|
||||
})
|
||||
}
|
||||
|
||||
// TestGracefulDegradationReset tests Reset method
|
||||
func TestGracefulDegradationReset(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultGracefulDegradationConfig()
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
t.Run("reset clears degraded services", func(t *testing.T) {
|
||||
// Mark several services as degraded
|
||||
gd.markServiceDegraded("service1")
|
||||
gd.markServiceDegraded("service2")
|
||||
gd.markServiceDegraded("service3")
|
||||
|
||||
assert.Len(t, gd.GetDegradedServices(), 3)
|
||||
|
||||
// Reset
|
||||
gd.Reset()
|
||||
|
||||
// All should be cleared
|
||||
assert.Len(t, gd.GetDegradedServices(), 0)
|
||||
})
|
||||
|
||||
t.Run("can mark services degraded after reset", func(t *testing.T) {
|
||||
gd.Reset()
|
||||
gd.markServiceDegraded("service4")
|
||||
|
||||
assert.Len(t, gd.GetDegradedServices(), 1)
|
||||
assert.Contains(t, gd.GetDegradedServices(), "service4")
|
||||
})
|
||||
|
||||
t.Run("multiple resets are safe", func(t *testing.T) {
|
||||
assert.NotPanics(t, func() {
|
||||
gd.Reset()
|
||||
gd.Reset()
|
||||
gd.Reset()
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// TestGracefulDegradationIsAvailable tests IsAvailable method
|
||||
func TestGracefulDegradationIsAvailable(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultGracefulDegradationConfig()
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
// Should always return true
|
||||
assert.True(t, gd.IsAvailable())
|
||||
|
||||
// Even with degraded services
|
||||
gd.markServiceDegraded("service1")
|
||||
assert.True(t, gd.IsAvailable())
|
||||
|
||||
// Even after reset
|
||||
gd.Reset()
|
||||
assert.True(t, gd.IsAvailable())
|
||||
}
|
||||
|
||||
// TestGracefulDegradationGetMetrics tests GetMetrics method
|
||||
func TestGracefulDegradationGetMetrics(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultGracefulDegradationConfig()
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
t.Run("basic metrics", func(t *testing.T) {
|
||||
metrics := gd.GetMetrics()
|
||||
|
||||
require.NotNil(t, metrics)
|
||||
assert.Contains(t, metrics, "degraded_services_count")
|
||||
assert.Contains(t, metrics, "degraded_services")
|
||||
assert.Contains(t, metrics, "registered_fallbacks_count")
|
||||
assert.Contains(t, metrics, "registered_health_checks_count")
|
||||
assert.Contains(t, metrics, "health_check_interval_seconds")
|
||||
assert.Contains(t, metrics, "recovery_timeout_seconds")
|
||||
assert.Contains(t, metrics, "fallbacks_enabled")
|
||||
})
|
||||
|
||||
t.Run("metrics reflect degraded services", func(t *testing.T) {
|
||||
gd.Reset()
|
||||
gd.markServiceDegraded("service1")
|
||||
gd.markServiceDegraded("service2")
|
||||
|
||||
metrics := gd.GetMetrics()
|
||||
|
||||
assert.Equal(t, 2, metrics["degraded_services_count"])
|
||||
degradedList := metrics["degraded_services"].([]string)
|
||||
assert.Len(t, degradedList, 2)
|
||||
})
|
||||
|
||||
t.Run("metrics reflect registered fallbacks", func(t *testing.T) {
|
||||
gd.RegisterFallback("service1", func() (interface{}, error) { return nil, nil })
|
||||
gd.RegisterFallback("service2", func() (interface{}, error) { return nil, nil })
|
||||
|
||||
metrics := gd.GetMetrics()
|
||||
|
||||
assert.GreaterOrEqual(t, metrics["registered_fallbacks_count"], 2)
|
||||
})
|
||||
|
||||
t.Run("metrics include base metrics", func(t *testing.T) {
|
||||
metrics := gd.GetMetrics()
|
||||
|
||||
// Should include base recovery mechanism metrics
|
||||
assert.Contains(t, metrics, "name")
|
||||
assert.Contains(t, metrics, "uptime_seconds")
|
||||
assert.Contains(t, metrics, "total_requests")
|
||||
})
|
||||
}
|
||||
|
||||
// TestGracefulDegradationFullScenario tests a complete degradation scenario
|
||||
func TestGracefulDegradationFullScenario(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping full scenario test in short mode")
|
||||
}
|
||||
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultGracefulDegradationConfig()
|
||||
config.RecoveryTimeout = 200 * time.Millisecond
|
||||
config.HealthCheckInterval = 50 * time.Millisecond
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
// Register fallback
|
||||
gd.RegisterFallback("critical-service", func() (interface{}, error) {
|
||||
return "fallback data", nil
|
||||
})
|
||||
|
||||
// Register health check
|
||||
serviceHealthy := false
|
||||
gd.RegisterHealthCheck("critical-service", func() bool {
|
||||
return serviceHealthy
|
||||
})
|
||||
|
||||
// First call - primary succeeds
|
||||
result1, err1 := gd.ExecuteWithFallback("critical-service", func() (interface{}, error) {
|
||||
return "primary data", nil
|
||||
})
|
||||
assert.NoError(t, err1)
|
||||
assert.Equal(t, "primary data", result1)
|
||||
|
||||
// Second call - primary fails, fallback succeeds
|
||||
result2, err2 := gd.ExecuteWithFallback("critical-service", func() (interface{}, error) {
|
||||
return nil, errors.New("service down")
|
||||
})
|
||||
assert.NoError(t, err2)
|
||||
assert.Equal(t, "fallback data", result2)
|
||||
|
||||
// Service is now degraded
|
||||
assert.True(t, gd.isServiceDegraded("critical-service"))
|
||||
|
||||
// Third call - should use fallback immediately
|
||||
result3, err3 := gd.ExecuteWithFallback("critical-service", func() (interface{}, error) {
|
||||
return "should not be called", nil
|
||||
})
|
||||
assert.NoError(t, err3)
|
||||
assert.Equal(t, "fallback data", result3)
|
||||
|
||||
// Mark service as healthy and wait for health check
|
||||
serviceHealthy = true
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
|
||||
// Service should be recovered
|
||||
// (timing-dependent, so we don't assert)
|
||||
|
||||
// Get metrics
|
||||
metrics := gd.GetMetrics()
|
||||
assert.NotNil(t, metrics)
|
||||
}
|
||||
@@ -0,0 +1,663 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// TestCircuitBreakerAllowRequestEdgeCases tests edge cases in circuit breaker request allowing
|
||||
func TestCircuitBreakerAllowRequestEdgeCases(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
|
||||
t.Run("invalid state returns false", func(t *testing.T) {
|
||||
config := DefaultCircuitBreakerConfig()
|
||||
cb := NewCircuitBreaker(config, logger)
|
||||
|
||||
// Force invalid state
|
||||
cb.mutex.Lock()
|
||||
cb.state = CircuitBreakerState(999) // Invalid state
|
||||
cb.mutex.Unlock()
|
||||
|
||||
// Should return false for invalid state
|
||||
allowed := cb.allowRequest()
|
||||
assert.False(t, allowed, "invalid state should not allow requests")
|
||||
})
|
||||
|
||||
t.Run("open to half-open transition on timeout", func(t *testing.T) {
|
||||
baseTimeout := GetTestDuration(50 * time.Millisecond)
|
||||
config := CircuitBreakerConfig{
|
||||
MaxFailures: 1,
|
||||
Timeout: baseTimeout,
|
||||
ResetTimeout: 30 * time.Second,
|
||||
}
|
||||
cb := NewCircuitBreaker(config, logger)
|
||||
|
||||
// Trip the circuit
|
||||
cb.Execute(func() error { return errors.New("fail") })
|
||||
|
||||
// Verify circuit is open
|
||||
assert.Equal(t, CircuitBreakerOpen, cb.GetState())
|
||||
assert.False(t, cb.allowRequest())
|
||||
|
||||
// Wait for timeout (longer than timeout to ensure transition)
|
||||
time.Sleep(baseTimeout + GetTestDuration(20*time.Millisecond))
|
||||
|
||||
// Should transition to half-open
|
||||
allowed := cb.allowRequest()
|
||||
assert.True(t, allowed, "should allow request after timeout")
|
||||
assert.Equal(t, CircuitBreakerHalfOpen, cb.GetState())
|
||||
})
|
||||
|
||||
t.Run("half-open allows requests", func(t *testing.T) {
|
||||
config := DefaultCircuitBreakerConfig()
|
||||
cb := NewCircuitBreaker(config, logger)
|
||||
|
||||
// Manually set to half-open
|
||||
cb.mutex.Lock()
|
||||
cb.state = CircuitBreakerHalfOpen
|
||||
cb.mutex.Unlock()
|
||||
|
||||
allowed := cb.allowRequest()
|
||||
assert.True(t, allowed, "half-open should allow requests")
|
||||
})
|
||||
|
||||
t.Run("open blocks requests before timeout", func(t *testing.T) {
|
||||
config := CircuitBreakerConfig{
|
||||
MaxFailures: 1,
|
||||
Timeout: 1 * time.Hour, // Long timeout
|
||||
ResetTimeout: 30 * time.Second,
|
||||
}
|
||||
cb := NewCircuitBreaker(config, logger)
|
||||
|
||||
// Trip the circuit
|
||||
cb.Execute(func() error { return errors.New("fail") })
|
||||
|
||||
// Should be blocked
|
||||
allowed := cb.allowRequest()
|
||||
assert.False(t, allowed, "open circuit should block requests")
|
||||
})
|
||||
}
|
||||
|
||||
// TestRetryExecutorIsRetryableErrorEdgeCases tests edge cases for retry decision
|
||||
func TestRetryExecutorIsRetryableErrorEdgeCases(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultRetryConfig()
|
||||
re := NewRetryExecutor(config, logger)
|
||||
|
||||
t.Run("nil error is not retryable", func(t *testing.T) {
|
||||
retryable := re.isRetryableError(nil)
|
||||
assert.False(t, retryable)
|
||||
})
|
||||
|
||||
t.Run("HTTPError with 429 is retryable", func(t *testing.T) {
|
||||
httpErr := &HTTPError{
|
||||
StatusCode: 429,
|
||||
Message: "Too Many Requests",
|
||||
}
|
||||
|
||||
retryable := re.isRetryableError(httpErr)
|
||||
assert.True(t, retryable, "429 Too Many Requests should be retryable")
|
||||
})
|
||||
|
||||
t.Run("HTTPError with 500 is retryable", func(t *testing.T) {
|
||||
httpErr := &HTTPError{
|
||||
StatusCode: 500,
|
||||
Message: "Internal Server Error",
|
||||
}
|
||||
|
||||
retryable := re.isRetryableError(httpErr)
|
||||
assert.True(t, retryable, "500 errors should be retryable")
|
||||
})
|
||||
|
||||
t.Run("HTTPError with 503 is retryable", func(t *testing.T) {
|
||||
httpErr := &HTTPError{
|
||||
StatusCode: 503,
|
||||
Message: "Service Unavailable",
|
||||
}
|
||||
|
||||
retryable := re.isRetryableError(httpErr)
|
||||
assert.True(t, retryable, "503 errors should be retryable")
|
||||
})
|
||||
|
||||
t.Run("HTTPError with 400 is not retryable", func(t *testing.T) {
|
||||
httpErr := &HTTPError{
|
||||
StatusCode: 400,
|
||||
Message: "Bad Request",
|
||||
}
|
||||
|
||||
retryable := re.isRetryableError(httpErr)
|
||||
assert.False(t, retryable, "400 errors should not be retryable")
|
||||
})
|
||||
|
||||
t.Run("net.Error with timeout is retryable", func(t *testing.T) {
|
||||
netErr := &mockNetError{
|
||||
timeout: true,
|
||||
temporary: false,
|
||||
msg: "timeout error",
|
||||
}
|
||||
|
||||
retryable := re.isRetryableError(netErr)
|
||||
assert.True(t, retryable, "timeout errors should be retryable")
|
||||
})
|
||||
|
||||
t.Run("net.Error with connection refused is retryable", func(t *testing.T) {
|
||||
netErr := &mockNetError{
|
||||
timeout: false,
|
||||
temporary: false,
|
||||
msg: "connection refused",
|
||||
}
|
||||
|
||||
retryable := re.isRetryableError(netErr)
|
||||
assert.True(t, retryable, "connection refused should be retryable")
|
||||
})
|
||||
|
||||
t.Run("net.Error with connection reset is retryable", func(t *testing.T) {
|
||||
netErr := &mockNetError{
|
||||
timeout: false,
|
||||
temporary: false,
|
||||
msg: "connection reset by peer",
|
||||
}
|
||||
|
||||
retryable := re.isRetryableError(netErr)
|
||||
assert.True(t, retryable, "connection reset should be retryable")
|
||||
})
|
||||
|
||||
t.Run("net.Error with network unreachable is retryable", func(t *testing.T) {
|
||||
netErr := &mockNetError{
|
||||
timeout: false,
|
||||
temporary: false,
|
||||
msg: "network is unreachable",
|
||||
}
|
||||
|
||||
retryable := re.isRetryableError(netErr)
|
||||
assert.True(t, retryable, "network unreachable should be retryable")
|
||||
})
|
||||
|
||||
t.Run("net.Error with no route to host is retryable", func(t *testing.T) {
|
||||
netErr := &mockNetError{
|
||||
timeout: false,
|
||||
temporary: false,
|
||||
msg: "no route to host",
|
||||
}
|
||||
|
||||
retryable := re.isRetryableError(netErr)
|
||||
assert.True(t, retryable, "no route to host should be retryable")
|
||||
})
|
||||
|
||||
t.Run("net.Error with temporary failure is retryable", func(t *testing.T) {
|
||||
netErr := &mockNetError{
|
||||
timeout: false,
|
||||
temporary: false,
|
||||
msg: "temporary failure in name resolution",
|
||||
}
|
||||
|
||||
retryable := re.isRetryableError(netErr)
|
||||
assert.True(t, retryable, "temporary failure should be retryable")
|
||||
})
|
||||
|
||||
t.Run("net.Error with try again is retryable", func(t *testing.T) {
|
||||
netErr := &mockNetError{
|
||||
timeout: false,
|
||||
temporary: false,
|
||||
msg: "try again later",
|
||||
}
|
||||
|
||||
retryable := re.isRetryableError(netErr)
|
||||
assert.True(t, retryable, "try again should be retryable")
|
||||
})
|
||||
|
||||
t.Run("net.Error with resource temporarily unavailable is retryable", func(t *testing.T) {
|
||||
netErr := &mockNetError{
|
||||
timeout: false,
|
||||
temporary: false,
|
||||
msg: "resource temporarily unavailable",
|
||||
}
|
||||
|
||||
retryable := re.isRetryableError(netErr)
|
||||
assert.True(t, retryable, "resource temporarily unavailable should be retryable")
|
||||
})
|
||||
|
||||
t.Run("configured retryable error patterns", func(t *testing.T) {
|
||||
err := errors.New("connection refused by server")
|
||||
|
||||
retryable := re.isRetryableError(err)
|
||||
assert.True(t, retryable, "configured pattern should be retryable")
|
||||
})
|
||||
|
||||
t.Run("non-retryable error", func(t *testing.T) {
|
||||
err := errors.New("invalid input data")
|
||||
|
||||
retryable := re.isRetryableError(err)
|
||||
assert.False(t, retryable, "non-configured error should not be retryable")
|
||||
})
|
||||
}
|
||||
|
||||
// TestRetryExecutorCalculateDelayEdgeCases tests delay calculation edge cases
|
||||
func TestRetryExecutorCalculateDelayEdgeCases(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
|
||||
t.Run("delay calculation without jitter", func(t *testing.T) {
|
||||
config := RetryConfig{
|
||||
MaxAttempts: 3,
|
||||
InitialDelay: 100 * time.Millisecond,
|
||||
MaxDelay: 5 * time.Second,
|
||||
BackoffFactor: 2.0,
|
||||
EnableJitter: false, // Jitter disabled
|
||||
}
|
||||
re := NewRetryExecutor(config, logger)
|
||||
|
||||
// Attempt 1: 100ms * 2^0 = 100ms
|
||||
delay1 := re.calculateDelay(1)
|
||||
assert.Equal(t, 100*time.Millisecond, delay1)
|
||||
|
||||
// Attempt 2: 100ms * 2^1 = 200ms
|
||||
delay2 := re.calculateDelay(2)
|
||||
assert.Equal(t, 200*time.Millisecond, delay2)
|
||||
|
||||
// Attempt 3: 100ms * 2^2 = 400ms
|
||||
delay3 := re.calculateDelay(3)
|
||||
assert.Equal(t, 400*time.Millisecond, delay3)
|
||||
})
|
||||
|
||||
t.Run("delay calculation with jitter", func(t *testing.T) {
|
||||
config := RetryConfig{
|
||||
MaxAttempts: 3,
|
||||
InitialDelay: 100 * time.Millisecond,
|
||||
MaxDelay: 5 * time.Second,
|
||||
BackoffFactor: 2.0,
|
||||
EnableJitter: true, // Jitter enabled
|
||||
}
|
||||
re := NewRetryExecutor(config, logger)
|
||||
|
||||
// With jitter, delay should be within 10% of expected
|
||||
delay := re.calculateDelay(2)
|
||||
expectedBase := 200 * time.Millisecond
|
||||
minDelay := time.Duration(float64(expectedBase) * 0.9)
|
||||
maxDelay := time.Duration(float64(expectedBase) * 1.1)
|
||||
|
||||
assert.GreaterOrEqual(t, delay, minDelay, "delay should be >= 90% of base")
|
||||
assert.LessOrEqual(t, delay, maxDelay, "delay should be <= 110% of base")
|
||||
})
|
||||
|
||||
t.Run("delay capped at max delay", func(t *testing.T) {
|
||||
config := RetryConfig{
|
||||
MaxAttempts: 10,
|
||||
InitialDelay: 100 * time.Millisecond,
|
||||
MaxDelay: 500 * time.Millisecond, // Low max delay
|
||||
BackoffFactor: 2.0,
|
||||
EnableJitter: false,
|
||||
}
|
||||
re := NewRetryExecutor(config, logger)
|
||||
|
||||
// Attempt 10: would be 100ms * 2^9 = 51200ms, but capped at 500ms
|
||||
delay := re.calculateDelay(10)
|
||||
assert.Equal(t, 500*time.Millisecond, delay, "delay should be capped at max")
|
||||
})
|
||||
|
||||
t.Run("delay with large backoff factor", func(t *testing.T) {
|
||||
config := RetryConfig{
|
||||
MaxAttempts: 5,
|
||||
InitialDelay: 50 * time.Millisecond,
|
||||
MaxDelay: 10 * time.Second,
|
||||
BackoffFactor: 3.0, // Larger backoff
|
||||
EnableJitter: false,
|
||||
}
|
||||
re := NewRetryExecutor(config, logger)
|
||||
|
||||
// Attempt 3: 50ms * 3^2 = 450ms
|
||||
delay := re.calculateDelay(3)
|
||||
assert.Equal(t, 450*time.Millisecond, delay)
|
||||
})
|
||||
}
|
||||
|
||||
// TestErrorTypesErrorMethodsWithoutCause tests error type Error() methods without cause
|
||||
func TestErrorTypesErrorMethodsWithoutCause(t *testing.T) {
|
||||
t.Run("HTTPError.Error without cause", func(t *testing.T) {
|
||||
httpErr := &HTTPError{
|
||||
StatusCode: 404,
|
||||
Message: "Not Found",
|
||||
}
|
||||
|
||||
errStr := httpErr.Error()
|
||||
assert.Equal(t, "HTTP 404: Not Found", errStr)
|
||||
})
|
||||
|
||||
t.Run("HTTPError.Error with different status codes", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
code int
|
||||
message string
|
||||
expected string
|
||||
}{
|
||||
{200, "OK", "HTTP 200: OK"},
|
||||
{301, "Moved", "HTTP 301: Moved"},
|
||||
{401, "Unauthorized", "HTTP 401: Unauthorized"},
|
||||
{500, "Server Error", "HTTP 500: Server Error"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
httpErr := &HTTPError{
|
||||
StatusCode: tc.code,
|
||||
Message: tc.message,
|
||||
}
|
||||
assert.Equal(t, tc.expected, httpErr.Error())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("OIDCError.Error without cause", func(t *testing.T) {
|
||||
oidcErr := &OIDCError{
|
||||
Code: "invalid_token",
|
||||
Message: "Token validation failed",
|
||||
Context: make(map[string]interface{}),
|
||||
}
|
||||
|
||||
errStr := oidcErr.Error()
|
||||
assert.Equal(t, "OIDC error [invalid_token]: Token validation failed", errStr)
|
||||
})
|
||||
|
||||
t.Run("OIDCError.Error with cause", func(t *testing.T) {
|
||||
rootErr := errors.New("signature mismatch")
|
||||
oidcErr := &OIDCError{
|
||||
Code: "invalid_signature",
|
||||
Message: "JWT signature invalid",
|
||||
Context: make(map[string]interface{}),
|
||||
Cause: rootErr,
|
||||
}
|
||||
|
||||
errStr := oidcErr.Error()
|
||||
assert.Contains(t, errStr, "OIDC error [invalid_signature]: JWT signature invalid")
|
||||
assert.Contains(t, errStr, "caused by: signature mismatch")
|
||||
})
|
||||
|
||||
t.Run("SessionError.Error without cause", func(t *testing.T) {
|
||||
sessErr := &SessionError{
|
||||
Operation: "load",
|
||||
Message: "Session not found",
|
||||
SessionID: "sess123",
|
||||
}
|
||||
|
||||
errStr := sessErr.Error()
|
||||
assert.Equal(t, "Session error in load: Session not found", errStr)
|
||||
})
|
||||
|
||||
t.Run("SessionError.Error with cause", func(t *testing.T) {
|
||||
rootErr := errors.New("database connection failed")
|
||||
sessErr := &SessionError{
|
||||
Operation: "save",
|
||||
Message: "Failed to persist session",
|
||||
SessionID: "sess456",
|
||||
Cause: rootErr,
|
||||
}
|
||||
|
||||
errStr := sessErr.Error()
|
||||
assert.Contains(t, errStr, "Session error in save: Failed to persist session")
|
||||
assert.Contains(t, errStr, "caused by: database connection failed")
|
||||
})
|
||||
|
||||
t.Run("TokenError.Error without cause", func(t *testing.T) {
|
||||
tokenErr := &TokenError{
|
||||
TokenType: "access_token",
|
||||
Reason: "expired",
|
||||
Message: "Token has expired",
|
||||
}
|
||||
|
||||
errStr := tokenErr.Error()
|
||||
assert.Equal(t, "Token error (access_token) - expired: Token has expired", errStr)
|
||||
})
|
||||
|
||||
t.Run("TokenError.Error with cause", func(t *testing.T) {
|
||||
rootErr := errors.New("time check failed")
|
||||
tokenErr := &TokenError{
|
||||
TokenType: "id_token",
|
||||
Reason: "expired",
|
||||
Message: "Token validity period exceeded",
|
||||
Cause: rootErr,
|
||||
}
|
||||
|
||||
errStr := tokenErr.Error()
|
||||
assert.Contains(t, errStr, "Token error (id_token) - expired: Token validity period exceeded")
|
||||
assert.Contains(t, errStr, "caused by: time check failed")
|
||||
})
|
||||
}
|
||||
|
||||
// TestGracefulDegradationHealthChecks tests health check functionality
|
||||
func TestGracefulDegradationHealthChecks(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
|
||||
t.Run("performHealthChecks recovers degraded service", func(t *testing.T) {
|
||||
config := DefaultGracefulDegradationConfig()
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
// Register health check that returns true
|
||||
healthCheckCalled := false
|
||||
gd.RegisterHealthCheck("test-service", func() bool {
|
||||
healthCheckCalled = true
|
||||
return true // Service is healthy
|
||||
})
|
||||
|
||||
// Mark service as degraded
|
||||
gd.markServiceDegraded("test-service")
|
||||
|
||||
// Verify service is degraded
|
||||
assert.True(t, gd.isServiceDegraded("test-service"))
|
||||
|
||||
// Manually trigger health check
|
||||
gd.performHealthChecks()
|
||||
|
||||
// Health check should have been called
|
||||
assert.True(t, healthCheckCalled, "health check should be called")
|
||||
|
||||
// Service should be recovered
|
||||
assert.False(t, gd.isServiceDegraded("test-service"), "service should be recovered")
|
||||
})
|
||||
|
||||
t.Run("performHealthChecks marks service degraded on failure", func(t *testing.T) {
|
||||
config := DefaultGracefulDegradationConfig()
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
// Register health check that returns false
|
||||
gd.RegisterHealthCheck("failing-service", func() bool {
|
||||
return false // Service is unhealthy
|
||||
})
|
||||
|
||||
// Initially not degraded
|
||||
assert.False(t, gd.isServiceDegraded("failing-service"))
|
||||
|
||||
// Manually trigger health check
|
||||
gd.performHealthChecks()
|
||||
|
||||
// Service should be marked degraded
|
||||
assert.True(t, gd.isServiceDegraded("failing-service"), "service should be degraded")
|
||||
})
|
||||
|
||||
t.Run("performHealthChecks runs multiple health checks independently", func(t *testing.T) {
|
||||
config := DefaultGracefulDegradationConfig()
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
service1Checked := false
|
||||
service2Checked := false
|
||||
|
||||
gd.RegisterHealthCheck("service1", func() bool {
|
||||
service1Checked = true
|
||||
return true
|
||||
})
|
||||
|
||||
gd.RegisterHealthCheck("service2", func() bool {
|
||||
service2Checked = true
|
||||
return true
|
||||
})
|
||||
|
||||
// Manually trigger health checks
|
||||
gd.performHealthChecks()
|
||||
|
||||
assert.True(t, service1Checked, "service1 health check should run")
|
||||
assert.True(t, service2Checked, "service2 health check should run")
|
||||
})
|
||||
|
||||
t.Run("performHealthChecks handles empty health checks", func(t *testing.T) {
|
||||
config := DefaultGracefulDegradationConfig()
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
// Call performHealthChecks with no registered health checks
|
||||
// Should not panic
|
||||
assert.NotPanics(t, func() {
|
||||
gd.performHealthChecks()
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// TestGracefulDegradationServiceRecoveryTimeout tests recovery timeout behavior
|
||||
func TestGracefulDegradationServiceRecoveryTimeout(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
|
||||
t.Run("service auto-recovers after timeout", func(t *testing.T) {
|
||||
baseTimeout := GetTestDuration(50 * time.Millisecond)
|
||||
config := GracefulDegradationConfig{
|
||||
HealthCheckInterval: 1 * time.Hour, // Long interval, won't run during test
|
||||
RecoveryTimeout: baseTimeout,
|
||||
EnableFallbacks: true,
|
||||
}
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
// Mark service degraded
|
||||
gd.markServiceDegraded("auto-recover-service")
|
||||
|
||||
// Verify degraded
|
||||
assert.True(t, gd.isServiceDegraded("auto-recover-service"))
|
||||
|
||||
// Wait for recovery timeout (longer than timeout to ensure recovery)
|
||||
time.Sleep(baseTimeout + GetTestDuration(20*time.Millisecond))
|
||||
|
||||
// Should auto-recover
|
||||
assert.False(t, gd.isServiceDegraded("auto-recover-service"), "service should auto-recover after timeout")
|
||||
})
|
||||
|
||||
t.Run("service remains degraded before timeout", func(t *testing.T) {
|
||||
config := GracefulDegradationConfig{
|
||||
HealthCheckInterval: 1 * time.Hour,
|
||||
RecoveryTimeout: 1 * time.Hour, // Very long timeout
|
||||
EnableFallbacks: true,
|
||||
}
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
// Mark service degraded
|
||||
gd.markServiceDegraded("long-timeout-service")
|
||||
|
||||
// Verify degraded
|
||||
assert.True(t, gd.isServiceDegraded("long-timeout-service"))
|
||||
|
||||
// Wait a bit
|
||||
time.Sleep(GetTestDuration(10 * time.Millisecond))
|
||||
|
||||
// Should still be degraded
|
||||
assert.True(t, gd.isServiceDegraded("long-timeout-service"), "service should remain degraded before timeout")
|
||||
})
|
||||
}
|
||||
|
||||
// TestErrorRecoveryManagerIntegration tests full integration of error recovery mechanisms
|
||||
func TestErrorRecoveryManagerIntegration(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
erm := NewErrorRecoveryManager(logger)
|
||||
|
||||
t.Run("circuit breaker and retry integration", func(t *testing.T) {
|
||||
// Create a circuit breaker with higher max failures to allow retries
|
||||
cb := NewCircuitBreaker(CircuitBreakerConfig{
|
||||
MaxFailures: 10, // High threshold
|
||||
Timeout: 60 * time.Second,
|
||||
ResetTimeout: 30 * time.Second,
|
||||
}, logger)
|
||||
|
||||
erm.mutex.Lock()
|
||||
erm.circuitBreakers["test-service-integration"] = cb
|
||||
erm.mutex.Unlock()
|
||||
|
||||
attempts := 0
|
||||
fn := func() error {
|
||||
attempts++
|
||||
if attempts < 3 {
|
||||
return errors.New("temporary failure")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
err := erm.ExecuteWithRecovery(context.Background(), "test-service-integration", fn)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.GreaterOrEqual(t, attempts, 3, "should retry until success")
|
||||
})
|
||||
|
||||
t.Run("circuit breaker opens on repeated failures", func(t *testing.T) {
|
||||
fn := func() error {
|
||||
return errors.New("persistent failure")
|
||||
}
|
||||
|
||||
// First call - should fail after retries
|
||||
err1 := erm.ExecuteWithRecovery(context.Background(), "failing-service", fn)
|
||||
assert.Error(t, err1)
|
||||
|
||||
// Second call - should fail after retries
|
||||
err2 := erm.ExecuteWithRecovery(context.Background(), "failing-service", fn)
|
||||
assert.Error(t, err2)
|
||||
|
||||
// Check circuit breaker state
|
||||
cb := erm.GetCircuitBreaker("failing-service")
|
||||
state := cb.GetState()
|
||||
assert.Equal(t, CircuitBreakerOpen, state, "circuit should be open after repeated failures")
|
||||
})
|
||||
|
||||
t.Run("recovery metrics include all mechanisms", func(t *testing.T) {
|
||||
metrics := erm.GetRecoveryMetrics()
|
||||
|
||||
assert.NotNil(t, metrics)
|
||||
assert.Contains(t, metrics, "circuit_breakers")
|
||||
assert.Contains(t, metrics, "degraded_services")
|
||||
})
|
||||
}
|
||||
|
||||
// TestContainsHelperFunction tests the contains helper function edge cases
|
||||
func TestContainsHelperFunction(t *testing.T) {
|
||||
t.Run("exact match", func(t *testing.T) {
|
||||
assert.True(t, contains("timeout", "timeout"))
|
||||
})
|
||||
|
||||
t.Run("prefix match", func(t *testing.T) {
|
||||
assert.True(t, contains("timeout error occurred", "timeout"))
|
||||
})
|
||||
|
||||
t.Run("suffix match", func(t *testing.T) {
|
||||
assert.True(t, contains("connection timeout", "timeout"))
|
||||
})
|
||||
|
||||
t.Run("middle match", func(t *testing.T) {
|
||||
assert.True(t, contains("a connection timeout error", "timeout"))
|
||||
})
|
||||
|
||||
t.Run("no match", func(t *testing.T) {
|
||||
assert.False(t, contains("connection refused", "timeout"))
|
||||
})
|
||||
|
||||
t.Run("substring longer than string", func(t *testing.T) {
|
||||
assert.False(t, contains("abc", "abcdef"))
|
||||
})
|
||||
|
||||
t.Run("empty substring", func(t *testing.T) {
|
||||
assert.True(t, contains("test", ""))
|
||||
})
|
||||
|
||||
t.Run("empty string", func(t *testing.T) {
|
||||
assert.False(t, contains("", "test"))
|
||||
})
|
||||
|
||||
t.Run("both empty", func(t *testing.T) {
|
||||
assert.True(t, contains("", ""))
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,848 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Test Circuit Breaker State Transitions
|
||||
|
||||
func TestCircuitBreakerStateTransitions(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
failures int
|
||||
maxFailures int
|
||||
expectedStateBefore string
|
||||
expectedStateAfter string
|
||||
}{
|
||||
{
|
||||
name: "stays closed below threshold",
|
||||
failures: 1,
|
||||
maxFailures: 3,
|
||||
expectedStateBefore: "closed",
|
||||
expectedStateAfter: "closed",
|
||||
},
|
||||
{
|
||||
name: "opens at threshold",
|
||||
failures: 3,
|
||||
maxFailures: 3,
|
||||
expectedStateBefore: "closed",
|
||||
expectedStateAfter: "open",
|
||||
},
|
||||
{
|
||||
name: "opens above threshold",
|
||||
failures: 5,
|
||||
maxFailures: 3,
|
||||
expectedStateBefore: "closed",
|
||||
expectedStateAfter: "open",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cb := NewCircuitBreaker(CircuitBreakerConfig{
|
||||
MaxFailures: tt.maxFailures,
|
||||
Timeout: time.Second,
|
||||
ResetTimeout: time.Second,
|
||||
}, nil)
|
||||
|
||||
// Verify initial state
|
||||
if state := circuitBreakerStateToString(cb.GetState()); state != tt.expectedStateBefore {
|
||||
t.Errorf("Expected initial state %s, got %s", tt.expectedStateBefore, state)
|
||||
}
|
||||
|
||||
// Trigger failures
|
||||
for i := 0; i < tt.failures; i++ {
|
||||
_ = cb.Execute(func() error {
|
||||
return errors.New("test failure")
|
||||
})
|
||||
}
|
||||
|
||||
// Verify final state
|
||||
if state := circuitBreakerStateToString(cb.GetState()); state != tt.expectedStateAfter {
|
||||
t.Errorf("Expected final state %s, got %s", tt.expectedStateAfter, state)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreakerHalfOpenTransition(t *testing.T) {
|
||||
cb := NewCircuitBreaker(CircuitBreakerConfig{
|
||||
MaxFailures: 2,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
ResetTimeout: 50 * time.Millisecond,
|
||||
}, nil)
|
||||
|
||||
// Open the circuit
|
||||
_ = cb.Execute(func() error { return errors.New("fail") })
|
||||
_ = cb.Execute(func() error { return errors.New("fail") })
|
||||
|
||||
if cb.GetState() != CircuitBreakerOpen {
|
||||
t.Error("Circuit should be open after failures")
|
||||
}
|
||||
|
||||
// Wait for timeout to trigger half-open
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// Next request should be allowed (half-open)
|
||||
allowed := false
|
||||
_ = cb.Execute(func() error {
|
||||
allowed = true
|
||||
return nil
|
||||
})
|
||||
|
||||
if !allowed {
|
||||
t.Error("Request should be allowed in half-open state")
|
||||
}
|
||||
|
||||
// Successful request should close the circuit
|
||||
if cb.GetState() != CircuitBreakerClosed {
|
||||
t.Errorf("Circuit should be closed after successful half-open request, got %v", cb.GetState())
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreakerHalfOpenFailure(t *testing.T) {
|
||||
cb := NewCircuitBreaker(CircuitBreakerConfig{
|
||||
MaxFailures: 2,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
ResetTimeout: 50 * time.Millisecond,
|
||||
}, nil)
|
||||
|
||||
// Open the circuit
|
||||
_ = cb.Execute(func() error { return errors.New("fail") })
|
||||
_ = cb.Execute(func() error { return errors.New("fail") })
|
||||
|
||||
// Wait for half-open
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// Fail in half-open state
|
||||
_ = cb.Execute(func() error {
|
||||
return errors.New("fail again")
|
||||
})
|
||||
|
||||
// Should return to open state
|
||||
if cb.GetState() != CircuitBreakerOpen {
|
||||
t.Errorf("Circuit should be open after half-open failure, got %v", cb.GetState())
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreakerConcurrency(t *testing.T) {
|
||||
cb := NewCircuitBreaker(CircuitBreakerConfig{
|
||||
MaxFailures: 10,
|
||||
Timeout: time.Second,
|
||||
ResetTimeout: time.Second,
|
||||
}, nil)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
successCount := int64(0)
|
||||
failureCount := int64(0)
|
||||
|
||||
// Concurrent successful requests
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
err := cb.Execute(func() error {
|
||||
return nil
|
||||
})
|
||||
if err == nil {
|
||||
atomic.AddInt64(&successCount, 1)
|
||||
} else {
|
||||
atomic.AddInt64(&failureCount, 1)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if successCount != 100 {
|
||||
t.Errorf("Expected 100 successful requests, got %d", successCount)
|
||||
}
|
||||
|
||||
metrics := cb.GetMetrics()
|
||||
if metrics["total_requests"].(int64) != 100 {
|
||||
t.Errorf("Expected 100 total requests, got %d", metrics["total_requests"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreakerReset(t *testing.T) {
|
||||
cb := NewCircuitBreaker(CircuitBreakerConfig{
|
||||
MaxFailures: 2,
|
||||
Timeout: time.Second,
|
||||
ResetTimeout: time.Second,
|
||||
}, nil)
|
||||
|
||||
// Open the circuit
|
||||
_ = cb.Execute(func() error { return errors.New("fail") })
|
||||
_ = cb.Execute(func() error { return errors.New("fail") })
|
||||
|
||||
if cb.GetState() != CircuitBreakerOpen {
|
||||
t.Error("Circuit should be open")
|
||||
}
|
||||
|
||||
// Reset
|
||||
cb.Reset()
|
||||
|
||||
if cb.GetState() != CircuitBreakerClosed {
|
||||
t.Error("Circuit should be closed after reset")
|
||||
}
|
||||
|
||||
// Should allow requests after reset
|
||||
err := cb.Execute(func() error {
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Should allow requests after reset, got error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreakerMetrics(t *testing.T) {
|
||||
cb := NewCircuitBreaker(CircuitBreakerConfig{
|
||||
MaxFailures: 3,
|
||||
Timeout: time.Second,
|
||||
ResetTimeout: time.Second,
|
||||
}, nil)
|
||||
|
||||
// Execute some requests
|
||||
_ = cb.Execute(func() error { return nil })
|
||||
_ = cb.Execute(func() error { return errors.New("fail") })
|
||||
_ = cb.Execute(func() error { return nil })
|
||||
|
||||
metrics := cb.GetMetrics()
|
||||
|
||||
if metrics["total_requests"].(int64) != 3 {
|
||||
t.Errorf("Expected 3 requests, got %d", metrics["total_requests"])
|
||||
}
|
||||
|
||||
if metrics["total_successes"].(int64) != 2 {
|
||||
t.Errorf("Expected 2 successes, got %d", metrics["total_successes"])
|
||||
}
|
||||
|
||||
if metrics["total_failures"].(int64) != 1 {
|
||||
t.Errorf("Expected 1 failure, got %d", metrics["total_failures"])
|
||||
}
|
||||
|
||||
if metrics["state"] != "closed" {
|
||||
t.Errorf("Expected state 'closed', got %v", metrics["state"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreakerIsAvailable(t *testing.T) {
|
||||
cb := NewCircuitBreaker(CircuitBreakerConfig{
|
||||
MaxFailures: 2,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
ResetTimeout: 50 * time.Millisecond,
|
||||
}, nil)
|
||||
|
||||
// Should be available initially
|
||||
if !cb.IsAvailable() {
|
||||
t.Error("Circuit should be available initially")
|
||||
}
|
||||
|
||||
// Open the circuit
|
||||
_ = cb.Execute(func() error { return errors.New("fail") })
|
||||
_ = cb.Execute(func() error { return errors.New("fail") })
|
||||
|
||||
// Should not be available when open
|
||||
if cb.IsAvailable() {
|
||||
t.Error("Circuit should not be available when open")
|
||||
}
|
||||
|
||||
// Wait for timeout
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// Should be available in half-open
|
||||
if !cb.IsAvailable() {
|
||||
t.Error("Circuit should be available in half-open state")
|
||||
}
|
||||
}
|
||||
|
||||
// Test Retry Executor
|
||||
|
||||
func TestRetryExecutorSuccess(t *testing.T) {
|
||||
re := NewRetryExecutor(RetryConfig{
|
||||
MaxAttempts: 3,
|
||||
InitialDelay: 10 * time.Millisecond,
|
||||
MaxDelay: time.Second,
|
||||
BackoffFactor: 2.0,
|
||||
EnableJitter: false,
|
||||
}, nil)
|
||||
|
||||
attempts := 0
|
||||
err := re.ExecuteWithContext(context.Background(), func() error {
|
||||
attempts++
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if attempts != 1 {
|
||||
t.Errorf("Expected 1 attempt for immediate success, got %d", attempts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryExecutorEventualSuccess(t *testing.T) {
|
||||
re := NewRetryExecutor(RetryConfig{
|
||||
MaxAttempts: 3,
|
||||
InitialDelay: 10 * time.Millisecond,
|
||||
MaxDelay: time.Second,
|
||||
BackoffFactor: 2.0,
|
||||
EnableJitter: false,
|
||||
RetryableErrors: []string{"temporary failure"},
|
||||
}, nil)
|
||||
|
||||
attempts := 0
|
||||
err := re.ExecuteWithContext(context.Background(), func() error {
|
||||
attempts++
|
||||
if attempts < 3 {
|
||||
return errors.New("temporary failure")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Expected success after retries, got %v", err)
|
||||
}
|
||||
|
||||
if attempts != 3 {
|
||||
t.Errorf("Expected 3 attempts, got %d", attempts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryExecutorMaxAttemptsExceeded(t *testing.T) {
|
||||
re := NewRetryExecutor(RetryConfig{
|
||||
MaxAttempts: 3,
|
||||
InitialDelay: 10 * time.Millisecond,
|
||||
MaxDelay: time.Second,
|
||||
BackoffFactor: 2.0,
|
||||
EnableJitter: false,
|
||||
RetryableErrors: []string{"temporary failure"},
|
||||
}, nil)
|
||||
|
||||
attempts := 0
|
||||
err := re.ExecuteWithContext(context.Background(), func() error {
|
||||
attempts++
|
||||
return errors.New("temporary failure")
|
||||
})
|
||||
|
||||
if err == nil {
|
||||
t.Error("Expected error after max attempts")
|
||||
}
|
||||
|
||||
if attempts != 3 {
|
||||
t.Errorf("Expected 3 attempts, got %d", attempts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryExecutorNonRetryableError(t *testing.T) {
|
||||
re := NewRetryExecutor(RetryConfig{
|
||||
MaxAttempts: 3,
|
||||
InitialDelay: 10 * time.Millisecond,
|
||||
MaxDelay: time.Second,
|
||||
BackoffFactor: 2.0,
|
||||
EnableJitter: false,
|
||||
RetryableErrors: []string{"temporary failure"},
|
||||
}, nil)
|
||||
|
||||
attempts := 0
|
||||
err := re.ExecuteWithContext(context.Background(), func() error {
|
||||
attempts++
|
||||
return errors.New("permanent failure")
|
||||
})
|
||||
|
||||
if err == nil {
|
||||
t.Error("Expected error for non-retryable failure")
|
||||
}
|
||||
|
||||
if attempts != 1 {
|
||||
t.Errorf("Expected 1 attempt for non-retryable error, got %d", attempts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryExecutorContextCancellation(t *testing.T) {
|
||||
re := NewRetryExecutor(RetryConfig{
|
||||
MaxAttempts: 5,
|
||||
InitialDelay: 100 * time.Millisecond,
|
||||
MaxDelay: time.Second,
|
||||
BackoffFactor: 2.0,
|
||||
EnableJitter: false,
|
||||
RetryableErrors: []string{"temporary failure"},
|
||||
}, nil)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
attempts := 0
|
||||
done := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
done <- re.ExecuteWithContext(ctx, func() error {
|
||||
attempts++
|
||||
return errors.New("temporary failure")
|
||||
})
|
||||
}()
|
||||
|
||||
// Cancel after short delay
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
err := <-done
|
||||
|
||||
if err != context.Canceled {
|
||||
t.Errorf("Expected context.Canceled error, got %v", err)
|
||||
}
|
||||
|
||||
if attempts == 0 {
|
||||
t.Error("Should have attempted at least once")
|
||||
}
|
||||
|
||||
if attempts >= 5 {
|
||||
t.Error("Should not have completed all attempts after cancellation")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryExecutorExponentialBackoff(t *testing.T) {
|
||||
re := NewRetryExecutor(RetryConfig{
|
||||
MaxAttempts: 4,
|
||||
InitialDelay: 100 * time.Millisecond,
|
||||
MaxDelay: time.Second,
|
||||
BackoffFactor: 2.0,
|
||||
EnableJitter: false,
|
||||
RetryableErrors: []string{"temporary failure"},
|
||||
}, nil)
|
||||
|
||||
attempts := 0
|
||||
startTime := time.Now()
|
||||
|
||||
_ = re.ExecuteWithContext(context.Background(), func() error {
|
||||
attempts++
|
||||
return errors.New("temporary failure")
|
||||
})
|
||||
|
||||
elapsed := time.Since(startTime)
|
||||
|
||||
// Should have delays: 100ms, 200ms, 400ms = 700ms total (approx)
|
||||
if elapsed < 650*time.Millisecond || elapsed > 850*time.Millisecond {
|
||||
t.Errorf("Expected ~700ms elapsed with exponential backoff, got %v", elapsed)
|
||||
}
|
||||
|
||||
if attempts != 4 {
|
||||
t.Errorf("Expected 4 attempts, got %d", attempts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryExecutorWithJitter(t *testing.T) {
|
||||
re := NewRetryExecutor(RetryConfig{
|
||||
MaxAttempts: 3,
|
||||
InitialDelay: 100 * time.Millisecond,
|
||||
MaxDelay: time.Second,
|
||||
BackoffFactor: 2.0,
|
||||
EnableJitter: true,
|
||||
RetryableErrors: []string{"temporary failure"},
|
||||
}, nil)
|
||||
|
||||
// Run multiple times to verify jitter adds variability
|
||||
durations := make([]time.Duration, 5)
|
||||
for i := 0; i < 5; i++ {
|
||||
startTime := time.Now()
|
||||
_ = re.ExecuteWithContext(context.Background(), func() error {
|
||||
return errors.New("temporary failure")
|
||||
})
|
||||
durations[i] = time.Since(startTime)
|
||||
}
|
||||
|
||||
// Check that not all durations are identical (jitter should add variance)
|
||||
allSame := true
|
||||
for i := 1; i < len(durations); i++ {
|
||||
if durations[i] != durations[0] {
|
||||
allSame = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if allSame {
|
||||
t.Error("Expected jitter to add variability to retry delays")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryExecutorNetworkErrors(t *testing.T) {
|
||||
re := NewRetryExecutor(RetryConfig{
|
||||
MaxAttempts: 3,
|
||||
InitialDelay: 10 * time.Millisecond,
|
||||
MaxDelay: time.Second,
|
||||
BackoffFactor: 2.0,
|
||||
EnableJitter: false,
|
||||
}, nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
shouldRetry bool
|
||||
}{
|
||||
{
|
||||
name: "timeout error",
|
||||
err: &mockNetError{timeout: true, temporary: true},
|
||||
shouldRetry: true,
|
||||
},
|
||||
{
|
||||
name: "temporary network error",
|
||||
err: &mockNetError{timeout: false, temporary: true, msg: "temporary failure"},
|
||||
shouldRetry: true,
|
||||
},
|
||||
{
|
||||
name: "connection refused",
|
||||
err: &mockNetError{timeout: false, temporary: false, msg: "connection refused"},
|
||||
shouldRetry: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
attempts := 0
|
||||
_ = re.ExecuteWithContext(context.Background(), func() error {
|
||||
attempts++
|
||||
return tt.err
|
||||
})
|
||||
|
||||
expectedAttempts := 1
|
||||
if tt.shouldRetry {
|
||||
expectedAttempts = 3
|
||||
}
|
||||
|
||||
if attempts != expectedAttempts {
|
||||
t.Errorf("Expected %d attempts, got %d", expectedAttempts, attempts)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryExecutorHTTPErrors(t *testing.T) {
|
||||
re := NewRetryExecutor(RetryConfig{
|
||||
MaxAttempts: 3,
|
||||
InitialDelay: 10 * time.Millisecond,
|
||||
MaxDelay: time.Second,
|
||||
BackoffFactor: 2.0,
|
||||
EnableJitter: false,
|
||||
}, nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
shouldRetry bool
|
||||
}{
|
||||
{"500 Internal Server Error", 500, true},
|
||||
{"502 Bad Gateway", 502, true},
|
||||
{"503 Service Unavailable", 503, true},
|
||||
{"429 Too Many Requests", 429, true},
|
||||
{"400 Bad Request", 400, false},
|
||||
{"404 Not Found", 404, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
attempts := 0
|
||||
_ = re.ExecuteWithContext(context.Background(), func() error {
|
||||
attempts++
|
||||
return &HTTPError{StatusCode: tt.statusCode, Message: "test"}
|
||||
})
|
||||
|
||||
expectedAttempts := 1
|
||||
if tt.shouldRetry {
|
||||
expectedAttempts = 3
|
||||
}
|
||||
|
||||
if attempts != expectedAttempts {
|
||||
t.Errorf("Expected %d attempts, got %d", expectedAttempts, attempts)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryExecutorMetrics(t *testing.T) {
|
||||
re := NewRetryExecutor(RetryConfig{
|
||||
MaxAttempts: 3,
|
||||
InitialDelay: 10 * time.Millisecond,
|
||||
MaxDelay: time.Second,
|
||||
BackoffFactor: 2.0,
|
||||
EnableJitter: true,
|
||||
}, nil)
|
||||
|
||||
_ = re.ExecuteWithContext(context.Background(), func() error {
|
||||
return nil
|
||||
})
|
||||
|
||||
metrics := re.GetMetrics()
|
||||
|
||||
if metrics["max_attempts"] != 3 {
|
||||
t.Errorf("Expected max_attempts 3, got %v", metrics["max_attempts"])
|
||||
}
|
||||
|
||||
if metrics["backoff_factor"] != 2.0 {
|
||||
t.Errorf("Expected backoff_factor 2.0, got %v", metrics["backoff_factor"])
|
||||
}
|
||||
|
||||
if metrics["enable_jitter"] != true {
|
||||
t.Errorf("Expected enable_jitter true, got %v", metrics["enable_jitter"])
|
||||
}
|
||||
}
|
||||
|
||||
// Test Error Types
|
||||
|
||||
func TestOIDCErrorCreation(t *testing.T) {
|
||||
err := NewOIDCError("invalid_token", "Token is expired", nil)
|
||||
|
||||
if err.Code != "invalid_token" {
|
||||
t.Errorf("Expected code 'invalid_token', got %s", err.Code)
|
||||
}
|
||||
|
||||
if err.Message != "Token is expired" {
|
||||
t.Errorf("Expected message 'Token is expired', got %s", err.Message)
|
||||
}
|
||||
|
||||
expectedMsg := "OIDC error [invalid_token]: Token is expired"
|
||||
if err.Error() != expectedMsg {
|
||||
t.Errorf("Expected error string '%s', got '%s'", expectedMsg, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestOIDCErrorWithCause(t *testing.T) {
|
||||
cause := errors.New("underlying error")
|
||||
err := NewOIDCError("token_error", "Failed to validate", cause)
|
||||
|
||||
if err.Unwrap() != cause {
|
||||
t.Error("Expected unwrap to return underlying cause")
|
||||
}
|
||||
|
||||
if err.Error() == "" {
|
||||
t.Error("Error string should include cause")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOIDCErrorWithContext(t *testing.T) {
|
||||
err := NewOIDCError("auth_failed", "Authentication failed", nil).
|
||||
WithContext("provider", "google").
|
||||
WithContext("user_id", "12345")
|
||||
|
||||
if err.Context["provider"] != "google" {
|
||||
t.Errorf("Expected provider 'google', got %v", err.Context["provider"])
|
||||
}
|
||||
|
||||
if err.Context["user_id"] != "12345" {
|
||||
t.Errorf("Expected user_id '12345', got %v", err.Context["user_id"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionErrorCreation(t *testing.T) {
|
||||
err := NewSessionError("save", "Failed to save session", nil)
|
||||
|
||||
if err.Operation != "save" {
|
||||
t.Errorf("Expected operation 'save', got %s", err.Operation)
|
||||
}
|
||||
|
||||
expectedMsg := "Session error in save: Failed to save session"
|
||||
if err.Error() != expectedMsg {
|
||||
t.Errorf("Expected error string '%s', got '%s'", expectedMsg, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionErrorWithSessionID(t *testing.T) {
|
||||
err := NewSessionError("load", "Session not found", nil).
|
||||
WithSessionID("sess_12345")
|
||||
|
||||
if err.SessionID != "sess_12345" {
|
||||
t.Errorf("Expected session ID 'sess_12345', got %s", err.SessionID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenErrorCreation(t *testing.T) {
|
||||
err := NewTokenError("id_token", "expired", "Token has expired", nil)
|
||||
|
||||
if err.TokenType != "id_token" {
|
||||
t.Errorf("Expected token type 'id_token', got %s", err.TokenType)
|
||||
}
|
||||
|
||||
if err.Reason != "expired" {
|
||||
t.Errorf("Expected reason 'expired', got %s", err.Reason)
|
||||
}
|
||||
|
||||
expectedMsg := "Token error (id_token) - expired: Token has expired"
|
||||
if err.Error() != expectedMsg {
|
||||
t.Errorf("Expected error string '%s', got '%s'", expectedMsg, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// Test Base Recovery Mechanism
|
||||
|
||||
func TestBaseRecoveryMechanismMetrics(t *testing.T) {
|
||||
base := NewBaseRecoveryMechanism("test-mechanism", nil)
|
||||
|
||||
base.RecordRequest()
|
||||
base.RecordSuccess()
|
||||
base.RecordRequest()
|
||||
base.RecordFailure()
|
||||
|
||||
metrics := base.GetBaseMetrics()
|
||||
|
||||
if metrics["total_requests"].(int64) != 2 {
|
||||
t.Errorf("Expected 2 requests, got %d", metrics["total_requests"])
|
||||
}
|
||||
|
||||
if metrics["total_successes"].(int64) != 1 {
|
||||
t.Errorf("Expected 1 success, got %d", metrics["total_successes"])
|
||||
}
|
||||
|
||||
if metrics["total_failures"].(int64) != 1 {
|
||||
t.Errorf("Expected 1 failure, got %d", metrics["total_failures"])
|
||||
}
|
||||
|
||||
if metrics["success_rate"].(float64) != 0.5 {
|
||||
t.Errorf("Expected success rate 0.5, got %v", metrics["success_rate"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseRecoveryMechanismConcurrentUpdates(t *testing.T) {
|
||||
base := NewBaseRecoveryMechanism("concurrent-test", nil)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
iterations := 1000
|
||||
|
||||
// Concurrent requests
|
||||
for i := 0; i < iterations; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
base.RecordRequest()
|
||||
if i%2 == 0 {
|
||||
base.RecordSuccess()
|
||||
} else {
|
||||
base.RecordFailure()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
metrics := base.GetBaseMetrics()
|
||||
|
||||
if metrics["total_requests"].(int64) != int64(iterations) {
|
||||
t.Errorf("Expected %d requests, got %d", iterations, metrics["total_requests"])
|
||||
}
|
||||
|
||||
totalSuccessesAndFailures := metrics["total_successes"].(int64) + metrics["total_failures"].(int64)
|
||||
if totalSuccessesAndFailures != int64(iterations) {
|
||||
t.Errorf("Expected %d total successes+failures, got %d", iterations, totalSuccessesAndFailures)
|
||||
}
|
||||
}
|
||||
|
||||
// Test Error Recovery Manager
|
||||
|
||||
func TestErrorRecoveryManagerCreation(t *testing.T) {
|
||||
erm := NewErrorRecoveryManager(nil)
|
||||
|
||||
if erm == nil {
|
||||
t.Fatal("Expected non-nil error recovery manager")
|
||||
}
|
||||
|
||||
if erm.retryExecutor == nil {
|
||||
t.Error("Expected retry executor to be initialized")
|
||||
}
|
||||
|
||||
if erm.gracefulDegradation == nil {
|
||||
t.Error("Expected graceful degradation to be initialized")
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorRecoveryManagerGetCircuitBreaker(t *testing.T) {
|
||||
erm := NewErrorRecoveryManager(nil)
|
||||
|
||||
cb1 := erm.GetCircuitBreaker("service1")
|
||||
cb2 := erm.GetCircuitBreaker("service1")
|
||||
cb3 := erm.GetCircuitBreaker("service2")
|
||||
|
||||
if cb1 == nil || cb2 == nil || cb3 == nil {
|
||||
t.Fatal("Expected non-nil circuit breakers")
|
||||
}
|
||||
|
||||
// Should return same instance for same service
|
||||
if cb1 != cb2 {
|
||||
t.Error("Expected same circuit breaker instance for same service")
|
||||
}
|
||||
|
||||
// Should return different instances for different services
|
||||
if cb1 == cb3 {
|
||||
t.Error("Expected different circuit breaker instances for different services")
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorRecoveryManagerExecuteWithRecovery(t *testing.T) {
|
||||
erm := NewErrorRecoveryManager(nil)
|
||||
|
||||
success := false
|
||||
err := erm.ExecuteWithRecovery(context.Background(), "test-service", func() error {
|
||||
success = true
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if !success {
|
||||
t.Error("Expected function to execute")
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorRecoveryManagerMetrics(t *testing.T) {
|
||||
erm := NewErrorRecoveryManager(nil)
|
||||
|
||||
// Create some circuit breakers
|
||||
_ = erm.GetCircuitBreaker("service1")
|
||||
_ = erm.GetCircuitBreaker("service2")
|
||||
|
||||
metrics := erm.GetRecoveryMetrics()
|
||||
|
||||
cbMetrics, ok := metrics["circuit_breakers"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatal("Expected circuit_breakers in metrics")
|
||||
}
|
||||
|
||||
if len(cbMetrics) != 2 {
|
||||
t.Errorf("Expected 2 circuit breakers in metrics, got %d", len(cbMetrics))
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions and types
|
||||
|
||||
func circuitBreakerStateToString(state CircuitBreakerState) string {
|
||||
switch state {
|
||||
case CircuitBreakerClosed:
|
||||
return "closed"
|
||||
case CircuitBreakerOpen:
|
||||
return "open"
|
||||
case CircuitBreakerHalfOpen:
|
||||
return "half-open"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// Mock network error for testing
|
||||
type mockNetError struct {
|
||||
timeout bool
|
||||
temporary bool
|
||||
msg string
|
||||
}
|
||||
|
||||
func (e *mockNetError) Error() string { return e.msg }
|
||||
func (e *mockNetError) Timeout() bool { return e.timeout }
|
||||
func (e *mockNetError) Temporary() bool { return e.temporary }
|
||||
|
||||
// Ensure mockNetError implements net.Error
|
||||
var _ net.Error = (*mockNetError)(nil)
|
||||
@@ -0,0 +1,486 @@
|
||||
# ============================================================================
|
||||
# Complete Traefik Configuration Example with TraefikOIDC Plugin + Redis
|
||||
# ============================================================================
|
||||
#
|
||||
# This example shows a complete, production-ready configuration for using
|
||||
# the TraefikOIDC plugin with Redis caching in a multi-replica deployment.
|
||||
#
|
||||
|
||||
# ============================================================================
|
||||
# Part 1: Traefik Static Configuration (traefik.yml)
|
||||
# ============================================================================
|
||||
# This file configures Traefik itself and enables the plugin.
|
||||
# Place this in /etc/traefik/traefik.yml or mount it in your container.
|
||||
|
||||
---
|
||||
# Static Configuration
|
||||
api:
|
||||
dashboard: true
|
||||
insecure: false # Set to true only for local development
|
||||
|
||||
entryPoints:
|
||||
web:
|
||||
address: ":80"
|
||||
http:
|
||||
redirections:
|
||||
entryPoint:
|
||||
to: websecure
|
||||
scheme: https
|
||||
|
||||
websecure:
|
||||
address: ":443"
|
||||
http:
|
||||
tls:
|
||||
certResolver: letsencrypt
|
||||
|
||||
certificatesResolvers:
|
||||
letsencrypt:
|
||||
acme:
|
||||
email: admin@example.com
|
||||
storage: /letsencrypt/acme.json
|
||||
httpChallenge:
|
||||
entryPoint: web
|
||||
|
||||
providers:
|
||||
file:
|
||||
filename: /etc/traefik/dynamic.yml
|
||||
watch: true
|
||||
|
||||
# Enable the TraefikOIDC plugin
|
||||
experimental:
|
||||
plugins:
|
||||
traefikoidc:
|
||||
moduleName: github.com/lukaszraczylo/traefikoidc
|
||||
version: v0.8.0
|
||||
|
||||
log:
|
||||
level: INFO
|
||||
format: json
|
||||
|
||||
accessLog:
|
||||
format: json
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Part 2: Traefik Dynamic Configuration (dynamic.yml)
|
||||
# ============================================================================
|
||||
# This file defines your routes, services, and middleware.
|
||||
# Place this in /etc/traefik/dynamic.yml
|
||||
|
||||
---
|
||||
http:
|
||||
# -------------------------------------------------------------------------
|
||||
# Middleware Definitions
|
||||
# -------------------------------------------------------------------------
|
||||
middlewares:
|
||||
# Example 1: Minimal Redis Configuration
|
||||
# Perfect for getting started quickly
|
||||
oidc-minimal:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
# Required OIDC settings
|
||||
clientID: "your-application-client-id"
|
||||
clientSecret: "your-client-secret-from-provider"
|
||||
providerURL: "https://auth.example.com"
|
||||
callbackURL: "/oauth2/callback"
|
||||
sessionEncryptionKey: "your-secure-64-character-encryption-key-must-be-kept-secret"
|
||||
|
||||
# Minimal Redis configuration
|
||||
redis:
|
||||
enabled: true
|
||||
address: "redis:6379"
|
||||
|
||||
# Example 2: Production Redis Configuration
|
||||
# Recommended for production deployments with multiple Traefik replicas
|
||||
oidc-production:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
# OIDC Provider Configuration
|
||||
clientID: "prod-client-id"
|
||||
clientSecret: "prod-client-secret"
|
||||
providerURL: "https://auth.example.com"
|
||||
callbackURL: "/oauth2/callback"
|
||||
|
||||
# Session Configuration
|
||||
sessionEncryptionKey: "prod-encryption-key-64-chars-long-keep-it-secret-and-safe"
|
||||
sessionMaxAge: 28800 # 8 hours
|
||||
|
||||
# Security Settings
|
||||
forceHTTPS: true
|
||||
strictAudienceValidation: true
|
||||
|
||||
# Redis Configuration for Multi-Replica Deployment
|
||||
redis:
|
||||
enabled: true
|
||||
address: "redis-master.redis-namespace.svc.cluster.local:6379"
|
||||
password: "REPLACE_WITH_YOUR_REDIS_PASSWORD"
|
||||
db: 0
|
||||
keyPrefix: "traefikoidc:prod:"
|
||||
|
||||
# Cache Strategy
|
||||
cacheMode: "hybrid" # Fast local cache + shared Redis
|
||||
|
||||
# Connection Pooling
|
||||
poolSize: 20
|
||||
connectTimeout: 5
|
||||
readTimeout: 3
|
||||
writeTimeout: 3
|
||||
|
||||
# Resilience Features
|
||||
enableCircuitBreaker: true
|
||||
circuitBreakerThreshold: 5
|
||||
circuitBreakerTimeout: 60
|
||||
enableHealthCheck: true
|
||||
healthCheckInterval: 30
|
||||
|
||||
# Example 3: Redis with TLS (for production security)
|
||||
oidc-secure:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
clientID: "secure-client-id"
|
||||
clientSecret: "secure-client-secret"
|
||||
providerURL: "https://auth.example.com"
|
||||
callbackURL: "/oauth2/callback"
|
||||
sessionEncryptionKey: "secure-64-character-encryption-key-for-production-use-only"
|
||||
|
||||
redis:
|
||||
enabled: true
|
||||
address: "redis.example.com:6380"
|
||||
password: "REPLACE_WITH_YOUR_REDIS_PASSWORD"
|
||||
enableTLS: true
|
||||
tlsSkipVerify: false # Verify certificates in production
|
||||
cacheMode: "redis"
|
||||
|
||||
# Example 4: Hybrid Mode (Best Performance + Consistency)
|
||||
# Local cache for hot data, Redis for consistency across replicas
|
||||
oidc-hybrid:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
clientID: "app-client-id"
|
||||
clientSecret: "app-client-secret"
|
||||
providerURL: "https://auth.example.com"
|
||||
callbackURL: "/oauth2/callback"
|
||||
sessionEncryptionKey: "hybrid-mode-encryption-key-64-characters-long-and-secure"
|
||||
|
||||
redis:
|
||||
enabled: true
|
||||
address: "redis:6379"
|
||||
password: "REPLACE_WITH_YOUR_REDIS_PASSWORD"
|
||||
cacheMode: "hybrid"
|
||||
|
||||
# Hybrid mode L1 cache settings
|
||||
hybridL1Size: 1000 # Number of items in local cache
|
||||
hybridL1MemoryMB: 20 # MB of memory for local cache
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Router Definitions
|
||||
# -------------------------------------------------------------------------
|
||||
routers:
|
||||
# Protected application using OIDC authentication
|
||||
my-app:
|
||||
rule: "Host(`app.example.com`)"
|
||||
entryPoints:
|
||||
- websecure
|
||||
middlewares:
|
||||
- oidc-production # Use the OIDC middleware
|
||||
service: my-app-service
|
||||
tls:
|
||||
certResolver: letsencrypt
|
||||
|
||||
# Another app with minimal OIDC config
|
||||
simple-app:
|
||||
rule: "Host(`simple.example.com`)"
|
||||
entryPoints:
|
||||
- websecure
|
||||
middlewares:
|
||||
- oidc-minimal
|
||||
service: simple-app-service
|
||||
tls:
|
||||
certResolver: letsencrypt
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Service Definitions
|
||||
# -------------------------------------------------------------------------
|
||||
services:
|
||||
my-app-service:
|
||||
loadBalancer:
|
||||
servers:
|
||||
- url: "http://my-app:8080"
|
||||
healthCheck:
|
||||
path: /health
|
||||
interval: 30s
|
||||
timeout: 5s
|
||||
|
||||
simple-app-service:
|
||||
loadBalancer:
|
||||
servers:
|
||||
- url: "http://simple-app:3000"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Part 3: Docker Compose Example
|
||||
# ============================================================================
|
||||
|
||||
---
|
||||
# docker-compose.yml
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
# Redis service for shared caching
|
||||
redis:
|
||||
image: redis:7-alpine
|
||||
command: redis-server --requirepass yourredispassword --maxmemory 256mb --maxmemory-policy allkeys-lru
|
||||
ports:
|
||||
- "6379:6379"
|
||||
volumes:
|
||||
- redis-data:/data
|
||||
healthcheck:
|
||||
test: ["CMD", "redis-cli", "--raw", "incr", "ping"]
|
||||
interval: 10s
|
||||
timeout: 3s
|
||||
retries: 5
|
||||
networks:
|
||||
- traefik-network
|
||||
|
||||
# Traefik with TraefikOIDC plugin
|
||||
traefik:
|
||||
image: traefik:v3.2
|
||||
command:
|
||||
- "--api.dashboard=true"
|
||||
- "--providers.docker=true"
|
||||
- "--providers.docker.exposedbydefault=false"
|
||||
- "--providers.file.filename=/etc/traefik/dynamic.yml"
|
||||
- "--entrypoints.web.address=:80"
|
||||
- "--entrypoints.websecure.address=:443"
|
||||
- "--experimental.plugins.traefikoidc.modulename=github.com/lukaszraczylo/traefikoidc"
|
||||
- "--experimental.plugins.traefikoidc.version=v0.8.0"
|
||||
ports:
|
||||
- "80:80"
|
||||
- "443:443"
|
||||
- "8080:8080" # Dashboard
|
||||
volumes:
|
||||
- /var/run/docker.sock:/var/run/docker.sock:ro
|
||||
- ./traefik-dynamic.yml:/etc/traefik/dynamic.yml:ro
|
||||
- ./letsencrypt:/letsencrypt
|
||||
depends_on:
|
||||
- redis
|
||||
networks:
|
||||
- traefik-network
|
||||
|
||||
# Your application
|
||||
my-app:
|
||||
image: my-app:latest
|
||||
labels:
|
||||
- "traefik.enable=true"
|
||||
- "traefik.http.routers.my-app.rule=Host(`app.example.com`)"
|
||||
- "traefik.http.routers.my-app.entrypoints=websecure"
|
||||
- "traefik.http.routers.my-app.tls.certresolver=letsencrypt"
|
||||
|
||||
# OIDC Middleware Configuration with Redis (using labels)
|
||||
- "traefik.http.routers.my-app.middlewares=my-oidc@docker"
|
||||
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.clientID=your-client-id"
|
||||
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.clientSecret=your-client-secret"
|
||||
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.providerURL=https://auth.example.com"
|
||||
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.callbackURL=/oauth2/callback"
|
||||
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.sessionEncryptionKey=your-64-character-encryption-key-here"
|
||||
|
||||
# Redis configuration
|
||||
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.enabled=true"
|
||||
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.address=redis:6379"
|
||||
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.password=yourredispassword"
|
||||
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.db=0"
|
||||
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.keyPrefix=traefikoidc:"
|
||||
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.cacheMode=hybrid"
|
||||
networks:
|
||||
- traefik-network
|
||||
deploy:
|
||||
replicas: 3 # Multiple replicas sharing Redis cache
|
||||
|
||||
volumes:
|
||||
redis-data:
|
||||
|
||||
networks:
|
||||
traefik-network:
|
||||
driver: bridge
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Part 4: Kubernetes Example
|
||||
# ============================================================================
|
||||
|
||||
---
|
||||
# kubernetes-example.yaml
|
||||
|
||||
# Redis Deployment
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: redis
|
||||
namespace: traefik
|
||||
spec:
|
||||
replicas: 1
|
||||
selector:
|
||||
matchLabels:
|
||||
app: redis
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
app: redis
|
||||
spec:
|
||||
containers:
|
||||
- name: redis
|
||||
image: redis:7-alpine
|
||||
args:
|
||||
- redis-server
|
||||
- --requirepass
|
||||
- $(REDIS_PASSWORD)
|
||||
- --maxmemory
|
||||
- 512mb
|
||||
- --maxmemory-policy
|
||||
- allkeys-lru
|
||||
env:
|
||||
- name: REDIS_PASSWORD
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: redis-secret
|
||||
key: password
|
||||
ports:
|
||||
- containerPort: 6379
|
||||
resources:
|
||||
requests:
|
||||
memory: "256Mi"
|
||||
cpu: "100m"
|
||||
limits:
|
||||
memory: "512Mi"
|
||||
cpu: "500m"
|
||||
---
|
||||
# Redis Service
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: redis
|
||||
namespace: traefik
|
||||
spec:
|
||||
selector:
|
||||
app: redis
|
||||
ports:
|
||||
- port: 6379
|
||||
targetPort: 6379
|
||||
---
|
||||
# Redis Secret
|
||||
apiVersion: v1
|
||||
kind: Secret
|
||||
metadata:
|
||||
name: redis-secret
|
||||
namespace: traefik
|
||||
type: Opaque
|
||||
stringData:
|
||||
password: "REPLACE_WITH_YOUR_REDIS_PASSWORD"
|
||||
---
|
||||
# OIDC Middleware with Redis
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-auth
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
# OIDC Configuration
|
||||
clientID: "kubernetes-client-id"
|
||||
clientSecret: "kubernetes-client-secret"
|
||||
providerURL: "https://auth.example.com"
|
||||
callbackURL: "/oauth2/callback"
|
||||
sessionEncryptionKey: "kubernetes-64-character-session-encryption-key-keep-secret"
|
||||
|
||||
# Redis Configuration
|
||||
redis:
|
||||
enabled: true
|
||||
address: "redis.traefik.svc.cluster.local:6379"
|
||||
password: "REPLACE_WITH_YOUR_REDIS_PASSWORD"
|
||||
db: 0
|
||||
keyPrefix: "traefikoidc:k8s:"
|
||||
cacheMode: "hybrid"
|
||||
poolSize: 20
|
||||
enableCircuitBreaker: true
|
||||
enableHealthCheck: true
|
||||
---
|
||||
# IngressRoute using the middleware
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: IngressRoute
|
||||
metadata:
|
||||
name: my-app
|
||||
namespace: default
|
||||
spec:
|
||||
entryPoints:
|
||||
- websecure
|
||||
routes:
|
||||
- match: Host(`app.example.com`)
|
||||
kind: Rule
|
||||
middlewares:
|
||||
- name: oidc-auth
|
||||
namespace: traefik
|
||||
services:
|
||||
- name: my-app
|
||||
port: 80
|
||||
tls:
|
||||
certResolver: letsencrypt
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Part 5: Environment Variables (Optional Fallback)
|
||||
# ============================================================================
|
||||
|
||||
# If you prefer environment variables as fallback (not recommended for production),
|
||||
# you can set these. NOTE: Plugin configuration takes precedence!
|
||||
|
||||
# Docker Compose env file (.env)
|
||||
---
|
||||
# OIDC Configuration
|
||||
OIDC_CLIENT_ID=your-client-id
|
||||
OIDC_CLIENT_SECRET=your-client-secret
|
||||
OIDC_PROVIDER_URL=https://auth.example.com
|
||||
|
||||
# Redis Configuration (fallback)
|
||||
REDIS_ENABLED=true
|
||||
REDIS_ADDRESS=redis:6379
|
||||
REDIS_PASSWORD=yourredispassword
|
||||
REDIS_DB=0
|
||||
REDIS_KEY_PREFIX=traefikoidc:
|
||||
REDIS_CACHE_MODE=hybrid
|
||||
REDIS_POOL_SIZE=20
|
||||
REDIS_ENABLE_CIRCUIT_BREAKER=true
|
||||
REDIS_ENABLE_HEALTH_CHECK=true
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Configuration Cheat Sheet
|
||||
# ============================================================================
|
||||
|
||||
# Minimal Setup (Quick Start):
|
||||
# redis:
|
||||
# enabled: true
|
||||
# address: "redis:6379"
|
||||
|
||||
# Production Setup (Recommended):
|
||||
# redis:
|
||||
# enabled: true
|
||||
# address: "redis-master:6379"
|
||||
# password: "strong-password"
|
||||
# cacheMode: "hybrid"
|
||||
# enableCircuitBreaker: true
|
||||
# enableHealthCheck: true
|
||||
|
||||
# High Security Setup:
|
||||
# redis:
|
||||
# enabled: true
|
||||
# address: "redis.example.com:6380"
|
||||
# password: "strong-password"
|
||||
# enableTLS: true
|
||||
# tlsSkipVerify: false
|
||||
# cacheMode: "redis"
|
||||
|
||||
# Cache Modes:
|
||||
# - "memory": Local cache only (default, no Redis needed)
|
||||
# - "redis": Redis only (consistent, shared across replicas)
|
||||
# - "hybrid": Local L1 + Redis L2 (best performance + consistency)
|
||||
@@ -0,0 +1,149 @@
|
||||
# Example Traefik configuration for TraefikOIDC plugin with Redis caching
|
||||
# This example shows how to configure Redis through Traefik's dynamic configuration
|
||||
|
||||
# Static configuration (traefik.yml)
|
||||
experimental:
|
||||
plugins:
|
||||
traefikoidc:
|
||||
moduleName: github.com/lukaszraczylo/traefikoidc
|
||||
version: v0.8.0
|
||||
|
||||
# Dynamic configuration (dynamic.yml or labels)
|
||||
http:
|
||||
middlewares:
|
||||
# Example 1: Basic Redis configuration
|
||||
oidc-redis-basic:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
# Required OIDC settings
|
||||
clientID: "your-client-id"
|
||||
clientSecret: "your-client-secret"
|
||||
providerURL: "https://auth.example.com"
|
||||
callbackURL: "/oauth2/callback"
|
||||
sessionEncryptionKey: "your-64-character-encryption-key-here-keep-it-secret"
|
||||
|
||||
# Redis configuration
|
||||
redis:
|
||||
enabled: true
|
||||
address: "redis:6379"
|
||||
# password: "your-redis-password" # Optional
|
||||
db: 0
|
||||
keyPrefix: "traefikoidc:"
|
||||
|
||||
# Example 2: Redis with resilience features
|
||||
oidc-redis-resilient:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
# Required OIDC settings
|
||||
clientID: "your-client-id"
|
||||
clientSecret: "your-client-secret"
|
||||
providerURL: "https://auth.example.com"
|
||||
callbackURL: "/oauth2/callback"
|
||||
sessionEncryptionKey: "your-64-character-encryption-key-here-keep-it-secret"
|
||||
|
||||
# Redis with full resilience configuration
|
||||
redis:
|
||||
enabled: true
|
||||
address: "redis:6379"
|
||||
password: "REPLACE_WITH_YOUR_REDIS_PASSWORD" # Example placeholder - use your actual password
|
||||
db: 1
|
||||
keyPrefix: "myapp:"
|
||||
poolSize: 20
|
||||
connectTimeout: 10
|
||||
readTimeout: 5
|
||||
writeTimeout: 5
|
||||
cacheMode: "redis" # Options: "redis", "hybrid", "memory"
|
||||
# Circuit breaker settings
|
||||
enableCircuitBreaker: true
|
||||
circuitBreakerThreshold: 5
|
||||
circuitBreakerTimeout: 60
|
||||
# Health check settings
|
||||
enableHealthCheck: true
|
||||
healthCheckInterval: 30
|
||||
|
||||
# Example 3: Redis with TLS
|
||||
oidc-redis-tls:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
# Required OIDC settings
|
||||
clientID: "your-client-id"
|
||||
clientSecret: "your-client-secret"
|
||||
providerURL: "https://auth.example.com"
|
||||
callbackURL: "/oauth2/callback"
|
||||
sessionEncryptionKey: "your-64-character-encryption-key-here-keep-it-secret"
|
||||
|
||||
# Redis with TLS configuration
|
||||
redis:
|
||||
enabled: true
|
||||
address: "redis.example.com:6380"
|
||||
password: "REPLACE_WITH_YOUR_REDIS_PASSWORD" # Example placeholder
|
||||
enableTLS: true
|
||||
tlsSkipVerify: false # Set to true only for testing
|
||||
cacheMode: "redis"
|
||||
|
||||
routers:
|
||||
my-app:
|
||||
rule: "Host(`app.example.com`)"
|
||||
middlewares:
|
||||
- oidc-redis-basic
|
||||
service: my-app-service
|
||||
|
||||
services:
|
||||
my-app-service:
|
||||
loadBalancer:
|
||||
servers:
|
||||
- url: "http://localhost:8080"
|
||||
|
||||
# Docker Compose labels example
|
||||
# version: '3.8'
|
||||
# services:
|
||||
# traefik:
|
||||
# image: traefik:v3.0
|
||||
# # ... other config ...
|
||||
#
|
||||
# my-app:
|
||||
# image: my-app:latest
|
||||
# labels:
|
||||
# - "traefik.enable=true"
|
||||
# - "traefik.http.routers.my-app.rule=Host(`app.example.com`)"
|
||||
# - "traefik.http.routers.my-app.middlewares=my-oidc"
|
||||
# # OIDC middleware configuration with Redis
|
||||
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.clientID=your-client-id"
|
||||
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.clientSecret=your-secret"
|
||||
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.providerURL=https://auth.example.com"
|
||||
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.callbackURL=/oauth2/callback"
|
||||
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.sessionEncryptionKey=your-64-char-key"
|
||||
# # Redis configuration via labels
|
||||
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.enabled=true"
|
||||
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.address=redis:6379"
|
||||
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.password=redis-password"
|
||||
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.db=0"
|
||||
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.keyPrefix=traefikoidc:"
|
||||
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.cacheMode=redis"
|
||||
#
|
||||
# redis:
|
||||
# image: redis:7-alpine
|
||||
# command: redis-server --requirepass redis-password
|
||||
# # ... other config ...
|
||||
|
||||
# Environment variable fallback (optional)
|
||||
# If Redis configuration is not provided in Traefik config, these environment variables
|
||||
# can be used as a fallback (but Traefik config takes precedence):
|
||||
#
|
||||
# REDIS_ENABLED=true
|
||||
# REDIS_ADDRESS=redis:6379
|
||||
# REDIS_PASSWORD=secret
|
||||
# REDIS_DB=0
|
||||
# REDIS_KEY_PREFIX=traefikoidc:
|
||||
# REDIS_CACHE_MODE=redis
|
||||
# REDIS_POOL_SIZE=10
|
||||
# REDIS_CONNECT_TIMEOUT=5
|
||||
# REDIS_READ_TIMEOUT=3
|
||||
# REDIS_WRITE_TIMEOUT=3
|
||||
# REDIS_ENABLE_TLS=false
|
||||
# REDIS_TLS_SKIP_VERIFY=false
|
||||
# REDIS_ENABLE_CIRCUIT_BREAKER=true
|
||||
# REDIS_CIRCUIT_BREAKER_THRESHOLD=5
|
||||
# REDIS_CIRCUIT_BREAKER_TIMEOUT=60
|
||||
# REDIS_ENABLE_HEALTH_CHECK=true
|
||||
# REDIS_HEALTH_CHECK_INTERVAL=30
|
||||
@@ -3,15 +3,20 @@ module github.com/lukaszraczylo/traefikoidc
|
||||
go 1.24.0
|
||||
|
||||
require (
|
||||
github.com/alicebob/miniredis/v2 v2.35.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/gorilla/sessions v1.3.0
|
||||
github.com/redis/go-redis/v9 v9.14.0
|
||||
github.com/stretchr/testify v1.10.0
|
||||
golang.org/x/time v0.13.0
|
||||
golang.org/x/time v0.14.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/gorilla/securecookie v1.1.2 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
github.com/yuin/gopher-lua v1.1.1 // indirect
|
||||
)
|
||||
|
||||
@@ -1,5 +1,15 @@
|
||||
github.com/alicebob/miniredis/v2 v2.35.0 h1:QwLphYqCEAo1eu1TqPRN2jgVMPBweeQcR21jeqDCONI=
|
||||
github.com/alicebob/miniredis/v2 v2.35.0/go.mod h1:TcL7YfarKPGDAthEtl5NBeHZfeUQj6OXMm/+iu5cLMM=
|
||||
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
|
||||
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
|
||||
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
|
||||
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||
github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0=
|
||||
github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
@@ -10,10 +20,14 @@ github.com/gorilla/sessions v1.3.0 h1:XYlkq7KcpOB2ZhHBPv5WpjMIxrQosiZanfoy1HLZFz
|
||||
github.com/gorilla/sessions v1.3.0/go.mod h1:ePLdVu+jbEgHH+KWw8I1z2wqd0BAdAQh/8LRvBeoNcQ=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/redis/go-redis/v9 v9.14.0 h1:u4tNCjXOyzfgeLN+vAZaW1xUooqWDqVEsZN0U01jfAE=
|
||||
github.com/redis/go-redis/v9 v9.14.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
golang.org/x/time v0.13.0 h1:eUlYslOIt32DgYD6utsuUeHs4d7AsEYLuIAdg7FlYgI=
|
||||
golang.org/x/time v0.13.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
|
||||
github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M=
|
||||
github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw=
|
||||
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
|
||||
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
|
||||
@@ -86,7 +86,7 @@ func (m *GoroutineManager) StartPeriodicTask(name string, interval time.Duration
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
m.logger.Debugf("Periodic task %s cancelled", name)
|
||||
m.logger.Debugf("Periodic task %s canceled", name)
|
||||
return
|
||||
case <-ticker.C:
|
||||
task()
|
||||
|
||||
@@ -0,0 +1,625 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Test GoroutineManager Creation
|
||||
|
||||
func TestNewGoroutineManager(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
|
||||
if gm == nil {
|
||||
t.Fatal("Expected non-nil goroutine manager")
|
||||
}
|
||||
|
||||
if gm.ctx == nil {
|
||||
t.Error("Expected context to be initialized")
|
||||
}
|
||||
|
||||
if gm.cancel == nil {
|
||||
t.Error("Expected cancel function to be initialized")
|
||||
}
|
||||
|
||||
if gm.goroutines == nil {
|
||||
t.Error("Expected goroutines map to be initialized")
|
||||
}
|
||||
|
||||
if gm.logger != logger {
|
||||
t.Error("Expected logger to be set")
|
||||
}
|
||||
}
|
||||
|
||||
// Test Starting Goroutines
|
||||
|
||||
func TestStartGoroutine(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
defer gm.Shutdown(time.Second)
|
||||
|
||||
executed := atomic.Bool{}
|
||||
|
||||
gm.StartGoroutine("test-goroutine", func(ctx context.Context) {
|
||||
executed.Store(true)
|
||||
})
|
||||
|
||||
// Give goroutine time to execute
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
if !executed.Load() {
|
||||
t.Error("Expected goroutine to execute")
|
||||
}
|
||||
|
||||
status := gm.GetStatus()
|
||||
if len(status) != 1 {
|
||||
t.Errorf("Expected 1 goroutine in status, got %d", len(status))
|
||||
}
|
||||
|
||||
if _, exists := status["test-goroutine"]; !exists {
|
||||
t.Error("Expected goroutine 'test-goroutine' in status")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartGoroutineDuplicate(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
defer gm.Shutdown(time.Second)
|
||||
|
||||
counter := atomic.Int32{}
|
||||
|
||||
// Start a long-running goroutine
|
||||
gm.StartGoroutine("duplicate-test", func(ctx context.Context) {
|
||||
counter.Add(1)
|
||||
<-ctx.Done()
|
||||
})
|
||||
|
||||
// Give first goroutine time to start
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Try to start another with same name (should be skipped)
|
||||
gm.StartGoroutine("duplicate-test", func(ctx context.Context) {
|
||||
counter.Add(1)
|
||||
})
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Should only have executed once
|
||||
if counter.Load() != 1 {
|
||||
t.Errorf("Expected counter to be 1 (duplicate should be skipped), got %d", counter.Load())
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartGoroutineContextCancellation(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
|
||||
started := atomic.Bool{}
|
||||
canceled := atomic.Bool{}
|
||||
|
||||
gm.StartGoroutine("cancel-test", func(ctx context.Context) {
|
||||
started.Store(true)
|
||||
<-ctx.Done()
|
||||
canceled.Store(true)
|
||||
})
|
||||
|
||||
// Wait for goroutine to start
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
if !started.Load() {
|
||||
t.Error("Expected goroutine to start")
|
||||
}
|
||||
|
||||
// Stop the goroutine
|
||||
gm.StopGoroutine("cancel-test")
|
||||
|
||||
// Wait for cancellation
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
if !canceled.Load() {
|
||||
t.Error("Expected goroutine to be canceled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartGoroutineWithPanic(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
defer gm.Shutdown(time.Second)
|
||||
|
||||
executed := atomic.Bool{}
|
||||
|
||||
gm.StartGoroutine("panic-test", func(ctx context.Context) {
|
||||
executed.Store(true)
|
||||
panic("test panic")
|
||||
})
|
||||
|
||||
// Give goroutine time to panic and recover
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
if !executed.Load() {
|
||||
t.Error("Expected goroutine to execute before panic")
|
||||
}
|
||||
|
||||
// Check that goroutine is marked as not running after panic
|
||||
status := gm.GetStatus()
|
||||
if goroutineStatus, exists := status["panic-test"]; exists {
|
||||
if goroutineStatus.Running {
|
||||
t.Error("Expected goroutine to be marked as not running after panic")
|
||||
}
|
||||
}
|
||||
|
||||
// Manager should still be functional
|
||||
counter := atomic.Int32{}
|
||||
gm.StartGoroutine("after-panic", func(ctx context.Context) {
|
||||
counter.Add(1)
|
||||
})
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
if counter.Load() != 1 {
|
||||
t.Error("Expected manager to still be functional after panic recovery")
|
||||
}
|
||||
}
|
||||
|
||||
// Test Periodic Tasks
|
||||
|
||||
func TestStartPeriodicTask(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
defer gm.Shutdown(time.Second)
|
||||
|
||||
counter := atomic.Int32{}
|
||||
|
||||
gm.StartPeriodicTask("periodic-test", 50*time.Millisecond, func() {
|
||||
counter.Add(1)
|
||||
})
|
||||
|
||||
// Wait for multiple executions
|
||||
time.Sleep(160 * time.Millisecond)
|
||||
|
||||
// Should have executed at least 2-3 times
|
||||
count := counter.Load()
|
||||
if count < 2 {
|
||||
t.Errorf("Expected periodic task to execute at least 2 times, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartPeriodicTaskCancellation(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
defer gm.Shutdown(time.Second)
|
||||
|
||||
counter := atomic.Int32{}
|
||||
|
||||
gm.StartPeriodicTask("cancel-periodic", 50*time.Millisecond, func() {
|
||||
counter.Add(1)
|
||||
})
|
||||
|
||||
// Wait for some executions
|
||||
time.Sleep(120 * time.Millisecond)
|
||||
|
||||
// Stop the task
|
||||
gm.StopGoroutine("cancel-periodic")
|
||||
|
||||
countBeforeStop := counter.Load()
|
||||
|
||||
// Wait and verify no more executions
|
||||
time.Sleep(120 * time.Millisecond)
|
||||
|
||||
countAfterStop := counter.Load()
|
||||
|
||||
// Allow 1 additional execution (could be in progress when stopped)
|
||||
if countAfterStop > countBeforeStop+1 {
|
||||
t.Errorf("Expected periodic task to stop executing, before: %d, after: %d",
|
||||
countBeforeStop, countAfterStop)
|
||||
}
|
||||
}
|
||||
|
||||
// Test Stopping Goroutines
|
||||
|
||||
func TestStopGoroutine(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
defer gm.Shutdown(time.Second)
|
||||
|
||||
stopped := atomic.Bool{}
|
||||
|
||||
gm.StartGoroutine("stop-test", func(ctx context.Context) {
|
||||
<-ctx.Done()
|
||||
stopped.Store(true)
|
||||
})
|
||||
|
||||
// Wait for goroutine to start
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
gm.StopGoroutine("stop-test")
|
||||
|
||||
// Wait for goroutine to stop
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
if !stopped.Load() {
|
||||
t.Error("Expected goroutine to be stopped")
|
||||
}
|
||||
|
||||
status := gm.GetStatus()
|
||||
if goroutineStatus, exists := status["stop-test"]; exists {
|
||||
if goroutineStatus.Running {
|
||||
t.Error("Expected goroutine to be marked as not running")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStopGoroutineNonExistent(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
defer gm.Shutdown(time.Second)
|
||||
|
||||
// Should not panic or error when stopping non-existent goroutine
|
||||
gm.StopGoroutine("non-existent")
|
||||
}
|
||||
|
||||
func TestStopGoroutineAlreadyStopped(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
defer gm.Shutdown(time.Second)
|
||||
|
||||
gm.StartGoroutine("already-stopped", func(ctx context.Context) {
|
||||
// Exit immediately
|
||||
})
|
||||
|
||||
// Wait for goroutine to finish
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Try to stop already-stopped goroutine (should be safe)
|
||||
gm.StopGoroutine("already-stopped")
|
||||
}
|
||||
|
||||
// Test Shutdown
|
||||
|
||||
func TestShutdownGraceful(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
|
||||
counter := atomic.Int32{}
|
||||
|
||||
// Start multiple goroutines
|
||||
for i := 0; i < 5; i++ {
|
||||
name := "goroutine-" + string(rune('0'+i))
|
||||
gm.StartGoroutine(name, func(ctx context.Context) {
|
||||
counter.Add(1)
|
||||
<-ctx.Done()
|
||||
counter.Add(-1)
|
||||
})
|
||||
}
|
||||
|
||||
// Wait for all to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
if counter.Load() != 5 {
|
||||
t.Errorf("Expected 5 goroutines running, got %d", counter.Load())
|
||||
}
|
||||
|
||||
// Shutdown with generous timeout
|
||||
err := gm.Shutdown(time.Second)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Expected graceful shutdown, got error: %v", err)
|
||||
}
|
||||
|
||||
if counter.Load() != 0 {
|
||||
t.Errorf("Expected all goroutines to complete cleanup, got %d still running", counter.Load())
|
||||
}
|
||||
}
|
||||
|
||||
func TestShutdownWithTimeout(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
|
||||
// Start a goroutine that ignores cancellation (bad behavior, but testing timeout)
|
||||
gm.StartGoroutine("stubborn", func(ctx context.Context) {
|
||||
// Simulate a goroutine that takes too long to stop
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
})
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Shutdown with very short timeout
|
||||
err := gm.Shutdown(10 * time.Millisecond)
|
||||
|
||||
if err == nil {
|
||||
t.Error("Expected timeout error")
|
||||
}
|
||||
|
||||
if err != ErrShutdownTimeout {
|
||||
t.Errorf("Expected ErrShutdownTimeout, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestShutdownEmpty(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
|
||||
// Shutdown with no goroutines should succeed immediately
|
||||
err := gm.Shutdown(time.Second)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error for empty shutdown, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Test Status
|
||||
|
||||
func TestGetStatus(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
defer gm.Shutdown(time.Second)
|
||||
|
||||
// Start multiple goroutines with different states
|
||||
gm.StartGoroutine("running", func(ctx context.Context) {
|
||||
<-ctx.Done()
|
||||
})
|
||||
|
||||
gm.StartGoroutine("quick", func(ctx context.Context) {
|
||||
// Exits immediately
|
||||
})
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
status := gm.GetStatus()
|
||||
|
||||
if len(status) != 2 {
|
||||
t.Errorf("Expected 2 goroutines in status, got %d", len(status))
|
||||
}
|
||||
|
||||
if runningStatus, exists := status["running"]; exists {
|
||||
if !runningStatus.Running {
|
||||
t.Error("Expected 'running' goroutine to be marked as running")
|
||||
}
|
||||
|
||||
if runningStatus.Name != "running" {
|
||||
t.Errorf("Expected name 'running', got %s", runningStatus.Name)
|
||||
}
|
||||
|
||||
if runningStatus.StartTime.IsZero() {
|
||||
t.Error("Expected non-zero start time")
|
||||
}
|
||||
|
||||
if runningStatus.Runtime <= 0 {
|
||||
t.Error("Expected positive runtime")
|
||||
}
|
||||
} else {
|
||||
t.Error("Expected 'running' goroutine in status")
|
||||
}
|
||||
|
||||
if quickStatus, exists := status["quick"]; exists {
|
||||
if quickStatus.Running {
|
||||
t.Error("Expected 'quick' goroutine to be marked as not running")
|
||||
}
|
||||
} else {
|
||||
t.Error("Expected 'quick' goroutine in status")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetStatusEmpty(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
defer gm.Shutdown(time.Second)
|
||||
|
||||
status := gm.GetStatus()
|
||||
|
||||
if status == nil {
|
||||
t.Fatal("Expected non-nil status map")
|
||||
}
|
||||
|
||||
if len(status) != 0 {
|
||||
t.Errorf("Expected empty status, got %d entries", len(status))
|
||||
}
|
||||
}
|
||||
|
||||
// Test Concurrent Operations
|
||||
|
||||
func TestConcurrentStartGoroutine(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
defer gm.Shutdown(2 * time.Second)
|
||||
|
||||
counter := atomic.Int32{}
|
||||
const numGoroutines = 50
|
||||
|
||||
// Start many goroutines concurrently
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(id int) {
|
||||
name := "concurrent-" + string(rune('0'+id%10)) + string(rune('0'+id/10))
|
||||
gm.StartGoroutine(name, func(ctx context.Context) {
|
||||
counter.Add(1)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
counter.Add(-1)
|
||||
})
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all to start
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// Verify goroutines are tracked
|
||||
status := gm.GetStatus()
|
||||
if len(status) < numGoroutines/2 {
|
||||
t.Errorf("Expected at least %d goroutines, got %d", numGoroutines/2, len(status))
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrentStopGoroutine(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
defer gm.Shutdown(time.Second)
|
||||
|
||||
const numGoroutines = 20
|
||||
|
||||
// Start goroutines
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
name := "stop-concurrent-" + string(rune('0'+i%10))
|
||||
gm.StartGoroutine(name, func(ctx context.Context) {
|
||||
<-ctx.Done()
|
||||
})
|
||||
}
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Stop all concurrently
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(id int) {
|
||||
name := "stop-concurrent-" + string(rune('0'+id%10))
|
||||
gm.StopGoroutine(name)
|
||||
}(i)
|
||||
}
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Verify all stopped
|
||||
status := gm.GetStatus()
|
||||
for _, s := range status {
|
||||
if s.Running {
|
||||
t.Errorf("Expected goroutine %s to be stopped", s.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrentGetStatus(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
defer gm.Shutdown(time.Second)
|
||||
|
||||
// Start some goroutines
|
||||
for i := 0; i < 10; i++ {
|
||||
name := "status-test-" + string(rune('0'+i))
|
||||
gm.StartGoroutine(name, func(ctx context.Context) {
|
||||
<-ctx.Done()
|
||||
})
|
||||
}
|
||||
|
||||
// Concurrently read status many times (should not race)
|
||||
done := make(chan struct{})
|
||||
for i := 0; i < 20; i++ {
|
||||
go func() {
|
||||
for j := 0; j < 100; j++ {
|
||||
_ = gm.GetStatus()
|
||||
}
|
||||
done <- struct{}{}
|
||||
}()
|
||||
}
|
||||
|
||||
// Wait for all concurrent reads
|
||||
for i := 0; i < 20; i++ {
|
||||
<-done
|
||||
}
|
||||
}
|
||||
|
||||
// Test Error Cases
|
||||
|
||||
func TestShutdownTimeoutError(t *testing.T) {
|
||||
err := ErrShutdownTimeout
|
||||
|
||||
if err.Error() != "shutdown timeout: some goroutines did not stop in time" {
|
||||
t.Errorf("Unexpected error message: %s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// Test Edge Cases
|
||||
|
||||
func TestStartGoroutineAfterShutdown(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
|
||||
// Shutdown immediately
|
||||
_ = gm.Shutdown(time.Second)
|
||||
|
||||
executed := atomic.Bool{}
|
||||
|
||||
// Try to start goroutine after shutdown
|
||||
gm.StartGoroutine("after-shutdown", func(ctx context.Context) {
|
||||
executed.Store(true)
|
||||
<-ctx.Done()
|
||||
})
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Goroutine should have started but context already canceled
|
||||
// It may or may not execute depending on timing, but shouldn't panic
|
||||
status := gm.GetStatus()
|
||||
if _, exists := status["after-shutdown"]; exists {
|
||||
// If it's in status, it was tracked (acceptable)
|
||||
t.Log("Goroutine was tracked even after shutdown")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultipleShutdowns(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
|
||||
// First shutdown
|
||||
err1 := gm.Shutdown(time.Second)
|
||||
if err1 != nil {
|
||||
t.Errorf("Expected first shutdown to succeed, got: %v", err1)
|
||||
}
|
||||
|
||||
// Second shutdown (should not panic or error)
|
||||
err2 := gm.Shutdown(time.Second)
|
||||
if err2 != nil {
|
||||
t.Errorf("Expected second shutdown to succeed, got: %v", err2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGoroutineWithImmediateReturn(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
defer gm.Shutdown(time.Second)
|
||||
|
||||
executed := atomic.Bool{}
|
||||
|
||||
gm.StartGoroutine("immediate", func(ctx context.Context) {
|
||||
executed.Store(true)
|
||||
// Return immediately
|
||||
})
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
if !executed.Load() {
|
||||
t.Error("Expected goroutine to execute")
|
||||
}
|
||||
|
||||
status := gm.GetStatus()
|
||||
if goroutineStatus, exists := status["immediate"]; exists {
|
||||
if goroutineStatus.Running {
|
||||
t.Error("Expected immediately-returning goroutine to be marked as not running")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeriodicTaskPanicRecovery(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
defer gm.Shutdown(time.Second)
|
||||
|
||||
counter := atomic.Int32{}
|
||||
|
||||
gm.StartPeriodicTask("panic-periodic", 50*time.Millisecond, func() {
|
||||
counter.Add(1)
|
||||
if counter.Load() == 2 {
|
||||
panic("periodic panic")
|
||||
}
|
||||
})
|
||||
|
||||
// Wait for panic to occur
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// After panic, the goroutine should have stopped
|
||||
status := gm.GetStatus()
|
||||
if goroutineStatus, exists := status["panic-periodic"]; exists {
|
||||
if goroutineStatus.Running {
|
||||
t.Error("Expected panicked periodic task to stop")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -147,7 +147,12 @@ func (h *OAuthHandler) HandleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
cookie, err := req.Cookie("_oidc_raczylo_m")
|
||||
if err != nil {
|
||||
h.logger.Errorf("Main session cookie not found in request: %v", err)
|
||||
h.logger.Debugf("Available cookies: %v", req.Header.Get("Cookie"))
|
||||
// Log cookie names only, not values (avoid logging sensitive session data)
|
||||
cookieNames := make([]string, 0, len(req.Cookies()))
|
||||
for _, c := range req.Cookies() {
|
||||
cookieNames = append(cookieNames, c.Name)
|
||||
}
|
||||
h.logger.Debugf("Available cookies (names only): %v", cookieNames)
|
||||
} else {
|
||||
h.logger.Errorf("Main session cookie exists but CSRF token is empty. Cookie value length: %d", len(cookie.Value))
|
||||
h.logger.Debugf("Cookie details - Domain: %s, Path: %s, Secure: %v, HttpOnly: %v, SameSite: %v",
|
||||
|
||||
@@ -0,0 +1,899 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// Test mocks - implementing interfaces defined in oauth_handler.go
|
||||
type mockLogger struct {
|
||||
debugMessages []string
|
||||
errorMessages []string
|
||||
}
|
||||
|
||||
func (l *mockLogger) Debugf(format string, args ...interface{}) {
|
||||
l.debugMessages = append(l.debugMessages, format)
|
||||
}
|
||||
|
||||
func (l *mockLogger) Errorf(format string, args ...interface{}) {
|
||||
l.errorMessages = append(l.errorMessages, format)
|
||||
}
|
||||
|
||||
func (l *mockLogger) Error(msg string) {
|
||||
l.errorMessages = append(l.errorMessages, msg)
|
||||
}
|
||||
|
||||
type mockSessionManager struct {
|
||||
sessionToReturn SessionData
|
||||
errorToReturn error
|
||||
}
|
||||
|
||||
func (m *mockSessionManager) GetSession(req *http.Request) (SessionData, error) {
|
||||
return m.sessionToReturn, m.errorToReturn
|
||||
}
|
||||
|
||||
type mockSessionData struct {
|
||||
authenticated bool
|
||||
email string
|
||||
csrf string
|
||||
nonce string
|
||||
codeVerifier string
|
||||
incomingPath string
|
||||
accessToken string
|
||||
refreshToken string
|
||||
idToken string
|
||||
saveError error
|
||||
setAuthError error
|
||||
}
|
||||
|
||||
func (s *mockSessionData) GetCSRF() string { return s.csrf }
|
||||
func (s *mockSessionData) GetNonce() string { return s.nonce }
|
||||
func (s *mockSessionData) GetCodeVerifier() string { return s.codeVerifier }
|
||||
func (s *mockSessionData) GetIncomingPath() string { return s.incomingPath }
|
||||
func (s *mockSessionData) GetAuthenticated() bool { return s.authenticated }
|
||||
func (s *mockSessionData) GetAccessToken() string { return s.accessToken }
|
||||
func (s *mockSessionData) GetRefreshToken() string { return s.refreshToken }
|
||||
func (s *mockSessionData) GetIDToken() string { return s.idToken }
|
||||
func (s *mockSessionData) GetEmail() string { return s.email }
|
||||
|
||||
func (s *mockSessionData) SetAuthenticated(auth bool) error {
|
||||
s.authenticated = auth
|
||||
return s.setAuthError
|
||||
}
|
||||
|
||||
func (s *mockSessionData) SetEmail(email string) { s.email = email }
|
||||
func (s *mockSessionData) SetIDToken(token string) { s.idToken = token }
|
||||
func (s *mockSessionData) SetAccessToken(token string) { s.accessToken = token }
|
||||
func (s *mockSessionData) SetRefreshToken(token string) { s.refreshToken = token }
|
||||
func (s *mockSessionData) SetCSRF(csrf string) { s.csrf = csrf }
|
||||
func (s *mockSessionData) SetNonce(nonce string) { s.nonce = nonce }
|
||||
func (s *mockSessionData) SetCodeVerifier(verif string) { s.codeVerifier = verif }
|
||||
func (s *mockSessionData) SetIncomingPath(path string) { s.incomingPath = path }
|
||||
func (s *mockSessionData) ResetRedirectCount() {}
|
||||
func (s *mockSessionData) returnToPoolSafely() {}
|
||||
|
||||
func (s *mockSessionData) Save(req *http.Request, rw http.ResponseWriter) error {
|
||||
return s.saveError
|
||||
}
|
||||
|
||||
type mockTokenExchanger struct {
|
||||
response *TokenResponse
|
||||
err error
|
||||
}
|
||||
|
||||
func (e *mockTokenExchanger) ExchangeCodeForToken(ctx context.Context, grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error) {
|
||||
return e.response, e.err
|
||||
}
|
||||
|
||||
type mockTokenVerifier struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (v *mockTokenVerifier) VerifyToken(token string) error {
|
||||
return v.err
|
||||
}
|
||||
|
||||
// TestOAuthHandler_NewOAuthHandler tests the constructor
|
||||
func TestOAuthHandler_NewOAuthHandler(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
sessionManager := &mockSessionManager{}
|
||||
tokenExchanger := &mockTokenExchanger{}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{"email": "test@example.com", "nonce": "test-nonce"}, nil
|
||||
}
|
||||
|
||||
isAllowed := func(email string) bool { return true }
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
if handler == nil {
|
||||
t.Fatal("Expected handler to be created, got nil")
|
||||
}
|
||||
|
||||
if handler.logger != logger {
|
||||
t.Error("Logger not set correctly")
|
||||
}
|
||||
|
||||
if handler.redirURLPath != "/callback" {
|
||||
t.Errorf("Expected redirURLPath '/callback', got '%s'", handler.redirURLPath)
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_SessionError tests session retrieval errors
|
||||
func TestOAuthHandler_HandleCallback_SessionError(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
sessionManager := &mockSessionManager{errorToReturn: errors.New("session error")}
|
||||
tokenExchanger := &mockTokenExchanger{}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) {
|
||||
return nil, nil
|
||||
}
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusInternalServerError {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
|
||||
}
|
||||
if !strings.Contains(msg, "Session error") {
|
||||
t.Errorf("Expected error message to contain 'Session error', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test&state=test", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
|
||||
if len(logger.errorMessages) == 0 {
|
||||
t.Error("Expected error to be logged")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_ProviderError tests OAuth provider errors
|
||||
func TestOAuthHandler_HandleCallback_ProviderError(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{csrf: "test-state", nonce: "test-nonce"}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenExchanger := &mockTokenExchanger{}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) { return nil, nil }
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusBadRequest {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusBadRequest, code)
|
||||
}
|
||||
if !strings.Contains(msg, "Authentication error from provider") {
|
||||
t.Errorf("Expected error message to contain 'Authentication error from provider', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
// Test with error parameter
|
||||
req := httptest.NewRequest("GET", "/callback?error=access_denied&error_description=User%20denied%20access", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
|
||||
if len(logger.errorMessages) == 0 {
|
||||
t.Error("Expected error to be logged")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_MissingState tests missing state parameter
|
||||
func TestOAuthHandler_HandleCallback_MissingState(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenExchanger := &mockTokenExchanger{}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) { return nil, nil }
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusBadRequest {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusBadRequest, code)
|
||||
}
|
||||
if !strings.Contains(msg, "State parameter missing") {
|
||||
t.Errorf("Expected error message to contain 'State parameter missing', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_MissingCSRF tests missing CSRF token in session
|
||||
func TestOAuthHandler_HandleCallback_MissingCSRF(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{csrf: ""} // Empty CSRF
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenExchanger := &mockTokenExchanger{}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) { return nil, nil }
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusBadRequest {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusBadRequest, code)
|
||||
}
|
||||
if !strings.Contains(msg, "CSRF token missing") {
|
||||
t.Errorf("Expected error message to contain 'CSRF token missing', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_CSRFMismatch tests CSRF token mismatch
|
||||
func TestOAuthHandler_HandleCallback_CSRFMismatch(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{csrf: "different-token"}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenExchanger := &mockTokenExchanger{}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) { return nil, nil }
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusBadRequest {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusBadRequest, code)
|
||||
}
|
||||
if !strings.Contains(msg, "CSRF mismatch") {
|
||||
t.Errorf("Expected error message to contain 'CSRF mismatch', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_MissingCode tests missing authorization code
|
||||
func TestOAuthHandler_HandleCallback_MissingCode(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{csrf: "test-state"}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenExchanger := &mockTokenExchanger{}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) { return nil, nil }
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusBadRequest {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusBadRequest, code)
|
||||
}
|
||||
if !strings.Contains(msg, "No authorization code received") {
|
||||
t.Errorf("Expected error message to contain 'No authorization code received', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_TokenExchangeError tests token exchange failure
|
||||
func TestOAuthHandler_HandleCallback_TokenExchangeError(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{csrf: "test-state", nonce: "test-nonce", codeVerifier: "test-verifier"}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenExchanger := &mockTokenExchanger{err: errors.New("token exchange failed")}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) { return nil, nil }
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusInternalServerError {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
|
||||
}
|
||||
if !strings.Contains(msg, "Could not exchange code for token") {
|
||||
t.Errorf("Expected error message to contain 'Could not exchange code for token', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_TokenVerificationError tests token verification failure
|
||||
func TestOAuthHandler_HandleCallback_TokenVerificationError(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{csrf: "test-state", nonce: "test-nonce"}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenResponse := &TokenResponse{IDToken: "invalid-token", AccessToken: "access-token"}
|
||||
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
|
||||
tokenVerifier := &mockTokenVerifier{err: errors.New("token verification failed")}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) { return nil, nil }
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusInternalServerError {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
|
||||
}
|
||||
if !strings.Contains(msg, "Could not verify ID token") {
|
||||
t.Errorf("Expected error message to contain 'Could not verify ID token', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_ClaimsExtractionError tests claims extraction failure
|
||||
func TestOAuthHandler_HandleCallback_ClaimsExtractionError(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{csrf: "test-state", nonce: "test-nonce"}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenResponse := &TokenResponse{IDToken: "valid-token", AccessToken: "access-token"}
|
||||
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) {
|
||||
return nil, errors.New("claims extraction failed")
|
||||
}
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusInternalServerError {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
|
||||
}
|
||||
if !strings.Contains(msg, "Could not extract claims") {
|
||||
t.Errorf("Expected error message to contain 'Could not extract claims', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_MissingNonceInToken tests missing nonce in token
|
||||
func TestOAuthHandler_HandleCallback_MissingNonceInToken(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{csrf: "test-state", nonce: "test-nonce"}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenResponse := &TokenResponse{IDToken: "valid-token", AccessToken: "access-token"}
|
||||
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
// Claims without nonce
|
||||
extractClaims := func(token string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{"email": "test@example.com"}, nil
|
||||
}
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusInternalServerError {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
|
||||
}
|
||||
if !strings.Contains(msg, "Nonce missing in token") {
|
||||
t.Errorf("Expected error message to contain 'Nonce missing in token', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_MissingNonceInSession tests missing nonce in session
|
||||
func TestOAuthHandler_HandleCallback_MissingNonceInSession(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{csrf: "test-state", nonce: ""} // Empty nonce
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenResponse := &TokenResponse{IDToken: "valid-token", AccessToken: "access-token"}
|
||||
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{"email": "test@example.com", "nonce": "test-nonce"}, nil
|
||||
}
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusInternalServerError {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
|
||||
}
|
||||
if !strings.Contains(msg, "Nonce missing in session") {
|
||||
t.Errorf("Expected error message to contain 'Nonce missing in session', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_NonceMismatch tests nonce mismatch
|
||||
func TestOAuthHandler_HandleCallback_NonceMismatch(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{csrf: "test-state", nonce: "session-nonce"}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenResponse := &TokenResponse{IDToken: "valid-token", AccessToken: "access-token"}
|
||||
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{"email": "test@example.com", "nonce": "token-nonce"}, nil
|
||||
}
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusInternalServerError {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
|
||||
}
|
||||
if !strings.Contains(msg, "Nonce mismatch") {
|
||||
t.Errorf("Expected error message to contain 'Nonce mismatch', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_MissingEmail tests missing email in claims
|
||||
func TestOAuthHandler_HandleCallback_MissingEmail(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{csrf: "test-state", nonce: "test-nonce"}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenResponse := &TokenResponse{IDToken: "valid-token", AccessToken: "access-token"}
|
||||
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{"nonce": "test-nonce"}, nil // No email
|
||||
}
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusInternalServerError {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
|
||||
}
|
||||
if !strings.Contains(msg, "Email missing in token") {
|
||||
t.Errorf("Expected error message to contain 'Email missing in token', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_DisallowedDomain tests disallowed email domain
|
||||
func TestOAuthHandler_HandleCallback_DisallowedDomain(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{csrf: "test-state", nonce: "test-nonce"}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenResponse := &TokenResponse{IDToken: "valid-token", AccessToken: "access-token"}
|
||||
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{"email": "test@disallowed.com", "nonce": "test-nonce"}, nil
|
||||
}
|
||||
isAllowed := func(email string) bool { return false } // Disallow all domains
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusForbidden {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusForbidden, code)
|
||||
}
|
||||
if !strings.Contains(msg, "Email domain not allowed") {
|
||||
t.Errorf("Expected error message to contain 'Email domain not allowed', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_SessionSaveError tests session save failure
|
||||
func TestOAuthHandler_HandleCallback_SessionSaveError(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{
|
||||
csrf: "test-state",
|
||||
nonce: "test-nonce",
|
||||
saveError: errors.New("save failed"),
|
||||
}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenResponse := &TokenResponse{IDToken: "valid-token", AccessToken: "access-token", RefreshToken: "refresh-token"}
|
||||
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{"email": "test@example.com", "nonce": "test-nonce"}, nil
|
||||
}
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusInternalServerError {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
|
||||
}
|
||||
if !strings.Contains(msg, "Failed to save session") {
|
||||
t.Errorf("Expected error message to contain 'Failed to save session', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_SetAuthenticatedError tests SetAuthenticated failure
|
||||
func TestOAuthHandler_HandleCallback_SetAuthenticatedError(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{
|
||||
csrf: "test-state",
|
||||
nonce: "test-nonce",
|
||||
setAuthError: errors.New("set auth failed"),
|
||||
}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenResponse := &TokenResponse{IDToken: "valid-token", AccessToken: "access-token"}
|
||||
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{"email": "test@example.com", "nonce": "test-nonce"}, nil
|
||||
}
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusInternalServerError {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
|
||||
}
|
||||
if !strings.Contains(msg, "Failed to update session") {
|
||||
t.Errorf("Expected error message to contain 'Failed to update session', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_Success tests successful callback handling
|
||||
func TestOAuthHandler_HandleCallback_Success(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{
|
||||
csrf: "test-state",
|
||||
nonce: "test-nonce",
|
||||
incomingPath: "/dashboard",
|
||||
}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenResponse := &TokenResponse{
|
||||
IDToken: "valid-id-token",
|
||||
AccessToken: "valid-access-token",
|
||||
RefreshToken: "valid-refresh-token",
|
||||
}
|
||||
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{"email": "test@example.com", "nonce": "test-nonce"}, nil
|
||||
}
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
t.Errorf("Unexpected error sent: %s (code: %d)", msg, code)
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if errorSent {
|
||||
t.Error("Unexpected error response sent")
|
||||
}
|
||||
|
||||
// Check redirect
|
||||
if rw.Code != http.StatusFound {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusFound, rw.Code)
|
||||
}
|
||||
|
||||
location := rw.Header().Get("Location")
|
||||
if location != "/dashboard" {
|
||||
t.Errorf("Expected redirect to '/dashboard', got '%s'", location)
|
||||
}
|
||||
|
||||
// Verify session data was set correctly
|
||||
if session.email != "test@example.com" {
|
||||
t.Errorf("Expected email 'test@example.com', got '%s'", session.email)
|
||||
}
|
||||
|
||||
if session.idToken != "valid-id-token" {
|
||||
t.Errorf("Expected ID token 'valid-id-token', got '%s'", session.idToken)
|
||||
}
|
||||
|
||||
if session.accessToken != "valid-access-token" {
|
||||
t.Errorf("Expected access token 'valid-access-token', got '%s'", session.accessToken)
|
||||
}
|
||||
|
||||
if session.refreshToken != "valid-refresh-token" {
|
||||
t.Errorf("Expected refresh token 'valid-refresh-token', got '%s'", session.refreshToken)
|
||||
}
|
||||
|
||||
if !session.authenticated {
|
||||
t.Error("Expected session to be authenticated")
|
||||
}
|
||||
|
||||
// Check that temporary fields are cleared
|
||||
if session.csrf != "" {
|
||||
t.Errorf("Expected CSRF to be cleared, got '%s'", session.csrf)
|
||||
}
|
||||
|
||||
if session.nonce != "" {
|
||||
t.Errorf("Expected nonce to be cleared, got '%s'", session.nonce)
|
||||
}
|
||||
|
||||
if session.codeVerifier != "" {
|
||||
t.Errorf("Expected code verifier to be cleared, got '%s'", session.codeVerifier)
|
||||
}
|
||||
|
||||
if session.incomingPath != "" {
|
||||
t.Errorf("Expected incoming path to be cleared, got '%s'", session.incomingPath)
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_SuccessDefaultRedirect tests successful callback with default redirect
|
||||
func TestOAuthHandler_HandleCallback_SuccessDefaultRedirect(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{
|
||||
csrf: "test-state",
|
||||
nonce: "test-nonce",
|
||||
incomingPath: "", // No incoming path, should default to "/"
|
||||
}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenResponse := &TokenResponse{IDToken: "valid-token", AccessToken: "access-token"}
|
||||
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{"email": "test@example.com", "nonce": "test-nonce"}, nil
|
||||
}
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
t.Errorf("Unexpected error sent: %s (code: %d)", msg, code)
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
// Check redirect to default path
|
||||
if rw.Code != http.StatusFound {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusFound, rw.Code)
|
||||
}
|
||||
|
||||
location := rw.Header().Get("Location")
|
||||
if location != "/" {
|
||||
t.Errorf("Expected redirect to '/', got '%s'", location)
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_RedirectURLPathExcluded tests incoming path same as redirect URL
|
||||
func TestOAuthHandler_HandleCallback_RedirectURLPathExcluded(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{
|
||||
csrf: "test-state",
|
||||
nonce: "test-nonce",
|
||||
incomingPath: "/callback", // Same as redirect URL path
|
||||
}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenResponse := &TokenResponse{IDToken: "valid-token", AccessToken: "access-token"}
|
||||
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{"email": "test@example.com", "nonce": "test-nonce"}, nil
|
||||
}
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
t.Errorf("Unexpected error sent: %s (code: %d)", msg, code)
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
// Should redirect to default path when incoming path is same as callback path
|
||||
location := rw.Header().Get("Location")
|
||||
if location != "/" {
|
||||
t.Errorf("Expected redirect to '/', got '%s'", location)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,454 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestURLHelper_NewURLHelper tests the URLHelper constructor
|
||||
func TestURLHelper_NewURLHelper(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
helper := NewURLHelper(logger)
|
||||
|
||||
if helper == nil {
|
||||
t.Fatal("Expected URLHelper to be created, got nil")
|
||||
}
|
||||
|
||||
if helper.logger != logger {
|
||||
t.Error("Logger not set correctly")
|
||||
}
|
||||
}
|
||||
|
||||
// TestURLHelper_DetermineExcludedURL tests URL exclusion checking
|
||||
func TestURLHelper_DetermineExcludedURL(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
helper := NewURLHelper(logger)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
currentURL string
|
||||
excludedURLs map[string]struct{}
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "Exact match",
|
||||
currentURL: "/health",
|
||||
excludedURLs: map[string]struct{}{
|
||||
"/health": {},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Prefix match",
|
||||
currentURL: "/health/status",
|
||||
excludedURLs: map[string]struct{}{
|
||||
"/health": {},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "No match",
|
||||
currentURL: "/api/users",
|
||||
excludedURLs: map[string]struct{}{
|
||||
"/health": {},
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Multiple exclusions - first match",
|
||||
currentURL: "/api/health",
|
||||
excludedURLs: map[string]struct{}{
|
||||
"/api": {},
|
||||
"/health": {},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Multiple exclusions - second match",
|
||||
currentURL: "/health/check",
|
||||
excludedURLs: map[string]struct{}{
|
||||
"/api": {},
|
||||
"/health": {},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Empty excluded URLs",
|
||||
currentURL: "/api/users",
|
||||
excludedURLs: map[string]struct{}{},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Root path exclusion",
|
||||
currentURL: "/anything",
|
||||
excludedURLs: map[string]struct{}{
|
||||
"/": {},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Case sensitive matching",
|
||||
currentURL: "/API/users",
|
||||
excludedURLs: map[string]struct{}{
|
||||
"/api": {},
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Partial substring but not prefix",
|
||||
currentURL: "/user/api/test",
|
||||
excludedURLs: map[string]struct{}{
|
||||
"/api": {},
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Empty current URL",
|
||||
currentURL: "",
|
||||
excludedURLs: map[string]struct{}{
|
||||
"/health": {},
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "URL with query parameters",
|
||||
currentURL: "/health?status=ok",
|
||||
excludedURLs: map[string]struct{}{
|
||||
"/health": {},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := helper.DetermineExcludedURL(tt.currentURL, tt.excludedURLs)
|
||||
if result != tt.expected {
|
||||
t.Errorf("DetermineExcludedURL() = %v, expected %v", result, tt.expected)
|
||||
}
|
||||
|
||||
// Verify debug logging for excluded URLs
|
||||
if result && len(logger.debugMessages) > 0 {
|
||||
// Should have logged a debug message for excluded URL
|
||||
found := false
|
||||
for _, msg := range logger.debugMessages {
|
||||
if msg == "URL is excluded - got %s / excluded hit: %s" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("Expected debug message for excluded URL")
|
||||
}
|
||||
}
|
||||
|
||||
// Reset logger messages for next test
|
||||
logger.debugMessages = nil
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestURLHelper_DetermineScheme tests scheme determination
|
||||
func TestURLHelper_DetermineScheme(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
helper := NewURLHelper(logger)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupRequest func() *http.Request
|
||||
expectedScheme string
|
||||
}{
|
||||
{
|
||||
name: "X-Forwarded-Proto header present - https",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
return req
|
||||
},
|
||||
expectedScheme: "https",
|
||||
},
|
||||
{
|
||||
name: "X-Forwarded-Proto header present - http",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
req.Header.Set("X-Forwarded-Proto", "http")
|
||||
return req
|
||||
},
|
||||
expectedScheme: "http",
|
||||
},
|
||||
{
|
||||
name: "TLS connection without X-Forwarded-Proto",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "https://example.com", nil)
|
||||
req.TLS = &tls.ConnectionState{} // Simulate TLS connection
|
||||
return req
|
||||
},
|
||||
expectedScheme: "https",
|
||||
},
|
||||
{
|
||||
name: "No TLS and no X-Forwarded-Proto",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
return req
|
||||
},
|
||||
expectedScheme: "http",
|
||||
},
|
||||
{
|
||||
name: "X-Forwarded-Proto takes precedence over TLS",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "https://example.com", nil)
|
||||
req.TLS = &tls.ConnectionState{} // Simulate TLS connection
|
||||
req.Header.Set("X-Forwarded-Proto", "http")
|
||||
return req
|
||||
},
|
||||
expectedScheme: "http",
|
||||
},
|
||||
{
|
||||
name: "Empty X-Forwarded-Proto falls back to TLS",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "https://example.com", nil)
|
||||
req.TLS = &tls.ConnectionState{} // Simulate TLS connection
|
||||
req.Header.Set("X-Forwarded-Proto", "")
|
||||
return req
|
||||
},
|
||||
expectedScheme: "https",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := tt.setupRequest()
|
||||
result := helper.DetermineScheme(req)
|
||||
if result != tt.expectedScheme {
|
||||
t.Errorf("DetermineScheme() = %v, expected %v", result, tt.expectedScheme)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestURLHelper_DetermineHost tests host determination
|
||||
func TestURLHelper_DetermineHost(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
helper := NewURLHelper(logger)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupRequest func() *http.Request
|
||||
expectedHost string
|
||||
}{
|
||||
{
|
||||
name: "X-Forwarded-Host header present",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
req.Host = "internal.example.com"
|
||||
req.Header.Set("X-Forwarded-Host", "public.example.com")
|
||||
return req
|
||||
},
|
||||
expectedHost: "public.example.com",
|
||||
},
|
||||
{
|
||||
name: "No X-Forwarded-Host, use req.Host",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
req.Host = "direct.example.com"
|
||||
return req
|
||||
},
|
||||
expectedHost: "direct.example.com",
|
||||
},
|
||||
{
|
||||
name: "Empty X-Forwarded-Host falls back to req.Host",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
req.Host = "fallback.example.com"
|
||||
req.Header.Set("X-Forwarded-Host", "")
|
||||
return req
|
||||
},
|
||||
expectedHost: "fallback.example.com",
|
||||
},
|
||||
{
|
||||
name: "X-Forwarded-Host with port",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
req.Host = "internal.example.com:8080"
|
||||
req.Header.Set("X-Forwarded-Host", "public.example.com:443")
|
||||
return req
|
||||
},
|
||||
expectedHost: "public.example.com:443",
|
||||
},
|
||||
{
|
||||
name: "req.Host with port",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "http://example.com:8080", nil)
|
||||
req.Host = "example.com:8080"
|
||||
return req
|
||||
},
|
||||
expectedHost: "example.com:8080",
|
||||
},
|
||||
{
|
||||
name: "Multiple X-Forwarded-Host values (first one used)",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
req.Host = "internal.example.com"
|
||||
req.Header.Set("X-Forwarded-Host", "first.example.com, second.example.com")
|
||||
return req
|
||||
},
|
||||
expectedHost: "first.example.com, second.example.com", // Header value as-is
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := tt.setupRequest()
|
||||
result := helper.DetermineHost(req)
|
||||
if result != tt.expectedHost {
|
||||
t.Errorf("DetermineHost() = %v, expected %v", result, tt.expectedHost)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestURLHelper_DetermineSchemeAndHost_Integration tests scheme and host working together
|
||||
func TestURLHelper_DetermineSchemeAndHost_Integration(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
helper := NewURLHelper(logger)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupRequest func() *http.Request
|
||||
expectedScheme string
|
||||
expectedHost string
|
||||
}{
|
||||
{
|
||||
name: "Both headers present",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "http://internal.example.com", nil)
|
||||
req.Host = "internal.example.com"
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
req.Header.Set("X-Forwarded-Host", "public.example.com")
|
||||
return req
|
||||
},
|
||||
expectedScheme: "https",
|
||||
expectedHost: "public.example.com",
|
||||
},
|
||||
{
|
||||
name: "Neither header present, TLS connection",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "https://secure.example.com", nil)
|
||||
req.Host = "secure.example.com"
|
||||
req.TLS = &tls.ConnectionState{} // Simulate TLS connection
|
||||
return req
|
||||
},
|
||||
expectedScheme: "https",
|
||||
expectedHost: "secure.example.com",
|
||||
},
|
||||
{
|
||||
name: "Neither header present, no TLS",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "http://plain.example.com", nil)
|
||||
req.Host = "plain.example.com"
|
||||
return req
|
||||
},
|
||||
expectedScheme: "http",
|
||||
expectedHost: "plain.example.com",
|
||||
},
|
||||
{
|
||||
name: "Mixed - only scheme header",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "http://mixed.example.com", nil)
|
||||
req.Host = "mixed.example.com"
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
return req
|
||||
},
|
||||
expectedScheme: "https",
|
||||
expectedHost: "mixed.example.com",
|
||||
},
|
||||
{
|
||||
name: "Mixed - only host header",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "http://mixed.example.com", nil)
|
||||
req.Host = "internal.example.com"
|
||||
req.Header.Set("X-Forwarded-Host", "external.example.com")
|
||||
return req
|
||||
},
|
||||
expectedScheme: "http",
|
||||
expectedHost: "external.example.com",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := tt.setupRequest()
|
||||
|
||||
scheme := helper.DetermineScheme(req)
|
||||
host := helper.DetermineHost(req)
|
||||
|
||||
if scheme != tt.expectedScheme {
|
||||
t.Errorf("DetermineScheme() = %v, expected %v", scheme, tt.expectedScheme)
|
||||
}
|
||||
|
||||
if host != tt.expectedHost {
|
||||
t.Errorf("DetermineHost() = %v, expected %v", host, tt.expectedHost)
|
||||
}
|
||||
|
||||
// Test that we can build a complete URL
|
||||
fullURL := scheme + "://" + host + "/callback"
|
||||
expectedURL := tt.expectedScheme + "://" + tt.expectedHost + "/callback"
|
||||
if fullURL != expectedURL {
|
||||
t.Errorf("Combined URL = %v, expected %v", fullURL, expectedURL)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark tests to ensure the helper methods are performant
|
||||
func BenchmarkURLHelper_DetermineExcludedURL(b *testing.B) {
|
||||
logger := &mockLogger{}
|
||||
helper := NewURLHelper(logger)
|
||||
excludedURLs := map[string]struct{}{
|
||||
"/health": {},
|
||||
"/metrics": {},
|
||||
"/status": {},
|
||||
"/api/v1": {},
|
||||
"/api/v2": {},
|
||||
"/static": {},
|
||||
"/assets": {},
|
||||
"/favicon": {},
|
||||
"/robots": {},
|
||||
"/sitemap": {},
|
||||
}
|
||||
|
||||
testURL := "/api/users"
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
helper.DetermineExcludedURL(testURL, excludedURLs)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkURLHelper_DetermineScheme(b *testing.B) {
|
||||
logger := &mockLogger{}
|
||||
helper := NewURLHelper(logger)
|
||||
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
helper.DetermineScheme(req)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkURLHelper_DetermineHost(b *testing.B) {
|
||||
logger := &mockLogger{}
|
||||
helper := NewURLHelper(logger)
|
||||
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
req.Host = "internal.example.com"
|
||||
req.Header.Set("X-Forwarded-Host", "external.example.com")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
helper.DetermineHost(req)
|
||||
}
|
||||
}
|
||||
+18
-8
@@ -109,7 +109,7 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType string, code
|
||||
client := t.tokenHTTPClient
|
||||
if client == nil {
|
||||
// Use shared transport pool to prevent memory leaks
|
||||
jar, _ := cookiejar.New(nil)
|
||||
jar, _ := cookiejar.New(nil) // Safe to ignore: cookiejar creation with nil options rarely fails
|
||||
pooledClient := CreateTokenHTTPClient()
|
||||
client = &http.Client{
|
||||
Transport: pooledClient.Transport,
|
||||
@@ -124,7 +124,12 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType string, code
|
||||
}
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", t.tokenURL, strings.NewReader(data.Encode()))
|
||||
// Read tokenURL with RLock
|
||||
t.metadataMu.RLock()
|
||||
tokenURL := t.tokenURL
|
||||
t.metadataMu.RUnlock()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", tokenURL, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create token request: %w", err)
|
||||
}
|
||||
@@ -135,13 +140,13 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType string, code
|
||||
return nil, fmt.Errorf("failed to exchange tokens: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
io.Copy(io.Discard, resp.Body)
|
||||
resp.Body.Close()
|
||||
_, _ = io.Copy(io.Discard, resp.Body) // Safe to ignore: draining response body on defer
|
||||
_ = resp.Body.Close() // Safe to ignore: closing body on defer
|
||||
}()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
limitReader := io.LimitReader(resp.Body, 1024*10)
|
||||
bodyBytes, _ := io.ReadAll(limitReader)
|
||||
bodyBytes, _ := io.ReadAll(limitReader) // Safe to ignore: reading error body for diagnostics
|
||||
return nil, fmt.Errorf("token endpoint returned status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
@@ -232,7 +237,7 @@ func NewTokenCache() *TokenCache {
|
||||
// - expiration: The duration for which the cache entry should be valid
|
||||
func (tc *TokenCache) Set(token string, claims map[string]interface{}, expiration time.Duration) {
|
||||
token = "t-" + token
|
||||
tc.cache.Set(token, claims, expiration)
|
||||
_ = tc.cache.Set(token, claims, expiration) // Safe to ignore: cache failures are non-critical
|
||||
}
|
||||
|
||||
// Get retrieves cached claims for a token.
|
||||
@@ -355,8 +360,13 @@ func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
|
||||
postLogoutRedirectURI = fmt.Sprintf("%s%s", baseURL, postLogoutRedirectURI)
|
||||
}
|
||||
|
||||
if t.endSessionURL != "" && idToken != "" {
|
||||
logoutURL, err := BuildLogoutURL(t.endSessionURL, idToken, postLogoutRedirectURI)
|
||||
// Read endSessionURL with RLock
|
||||
t.metadataMu.RLock()
|
||||
endSessionURL := t.endSessionURL
|
||||
t.metadataMu.RUnlock()
|
||||
|
||||
if endSessionURL != "" && idToken != "" {
|
||||
logoutURL, err := BuildLogoutURL(endSessionURL, idToken, postLogoutRedirectURI)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to build logout URL: %v", err)
|
||||
http.Error(rw, "Logout error", http.StatusInternalServerError)
|
||||
|
||||
+17
-5
@@ -49,10 +49,10 @@ func DefaultHTTPClientConfig() HTTPClientConfig {
|
||||
TLSHandshakeTimeout: 2 * time.Second,
|
||||
ResponseHeaderTimeout: 3 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
IdleConnTimeout: 5 * time.Second,
|
||||
MaxIdleConns: 20, // SECURITY FIX: Reduced from 100 to limit resource usage
|
||||
MaxIdleConnsPerHost: 2, // SECURITY FIX: Reduced from 10 to prevent connection exhaustion
|
||||
MaxConnsPerHost: 5, // SECURITY FIX: Reduced from 10 to limit concurrent connections
|
||||
IdleConnTimeout: 30 * time.Second, // OPTIMIZATION: Increased for better connection reuse
|
||||
MaxIdleConns: 50, // OPTIMIZATION: Increased from 20 for better connection pooling
|
||||
MaxIdleConnsPerHost: 10, // OPTIMIZATION: Increased from 2 for better connection reuse
|
||||
MaxConnsPerHost: 20, // OPTIMIZATION: Increased from 5 while maintaining security
|
||||
WriteBufferSize: 4096,
|
||||
ReadBufferSize: 4096,
|
||||
ForceHTTP2: true,
|
||||
@@ -70,6 +70,18 @@ func TokenHTTPClientConfig() HTTPClientConfig {
|
||||
return config
|
||||
}
|
||||
|
||||
// OIDCProviderHTTPClientConfig returns configuration optimized for OIDC provider calls
|
||||
func OIDCProviderHTTPClientConfig() HTTPClientConfig {
|
||||
config := DefaultHTTPClientConfig()
|
||||
config.Timeout = 15 * time.Second // Slightly longer for OIDC operations
|
||||
config.MaxIdleConns = 100 // Higher pool for frequent OIDC calls
|
||||
config.MaxIdleConnsPerHost = 25 // More connections per OIDC provider
|
||||
config.MaxConnsPerHost = 50 // Allow more concurrent requests to OIDC provider
|
||||
config.IdleConnTimeout = 90 * time.Second // Keep connections alive longer for reuse
|
||||
config.UseCookieJar = true // Enable cookie jar for session management
|
||||
return config
|
||||
}
|
||||
|
||||
// HTTPClientFactory provides methods for creating configured HTTP clients
|
||||
type HTTPClientFactory struct{}
|
||||
|
||||
@@ -233,7 +245,7 @@ func (f *HTTPClientFactory) CreateHTTPClient(config HTTPClientConfig) *http.Clie
|
||||
|
||||
// Add cookie jar if requested
|
||||
if config.UseCookieJar {
|
||||
jar, _ := cookiejar.New(nil)
|
||||
jar, _ := cookiejar.New(nil) // Safe to ignore: cookiejar creation with nil options rarely fails
|
||||
client.Jar = jar
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,210 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestOIDCProviderHTTPClientConfigUnit tests OIDCProviderHTTPClientConfig function
|
||||
func TestOIDCProviderHTTPClientConfigUnit(t *testing.T) {
|
||||
config := OIDCProviderHTTPClientConfig()
|
||||
|
||||
// Verify OIDC-specific settings
|
||||
assert.Equal(t, 15*time.Second, config.Timeout, "OIDC provider should have 15s timeout")
|
||||
assert.Equal(t, 100, config.MaxIdleConns, "OIDC provider should have 100 max idle conns")
|
||||
assert.Equal(t, 25, config.MaxIdleConnsPerHost, "OIDC provider should have 25 max idle conns per host")
|
||||
assert.Equal(t, 50, config.MaxConnsPerHost, "OIDC provider should have 50 max conns per host")
|
||||
assert.Equal(t, 90*time.Second, config.IdleConnTimeout, "OIDC provider should have 90s idle conn timeout")
|
||||
assert.True(t, config.UseCookieJar, "OIDC provider should have cookie jar enabled")
|
||||
}
|
||||
|
||||
// TestCreateDefaultClientUnit tests CreateDefaultClient function
|
||||
func TestCreateDefaultClientUnit(t *testing.T) {
|
||||
factory := NewHTTPClientFactory()
|
||||
client := factory.CreateDefaultClient()
|
||||
|
||||
require.NotNil(t, client)
|
||||
assert.NotNil(t, client.Transport, "client should have transport")
|
||||
assert.Equal(t, 10*time.Second, client.Timeout, "default client should have 10s timeout")
|
||||
}
|
||||
|
||||
// TestCreateTokenClientUnit tests CreateTokenClient function
|
||||
func TestCreateTokenClientUnit(t *testing.T) {
|
||||
factory := NewHTTPClientFactory()
|
||||
client := factory.CreateTokenClient()
|
||||
|
||||
require.NotNil(t, client)
|
||||
assert.NotNil(t, client.Transport, "client should have transport")
|
||||
assert.NotNil(t, client.Jar, "token client should have cookie jar")
|
||||
assert.Equal(t, 10*time.Second, client.Timeout, "token client should have 10s timeout")
|
||||
}
|
||||
|
||||
// TestCreateHTTPClientWithConfigUnit tests CreateHTTPClientWithConfig function
|
||||
func TestCreateHTTPClientWithConfigUnit(t *testing.T) {
|
||||
config := HTTPClientConfig{
|
||||
Timeout: 5 * time.Second,
|
||||
MaxIdleConns: 20,
|
||||
MaxIdleConnsPerHost: 5,
|
||||
UseCookieJar: true,
|
||||
}
|
||||
|
||||
client := CreateHTTPClientWithConfig(config)
|
||||
|
||||
require.NotNil(t, client)
|
||||
assert.Equal(t, 5*time.Second, client.Timeout)
|
||||
assert.NotNil(t, client.Jar, "client should have cookie jar when configured")
|
||||
}
|
||||
|
||||
// TestHTTPClientFactoryCreateHTTPClientValidation tests validation in CreateHTTPClient
|
||||
func TestHTTPClientFactoryCreateHTTPClientValidation(t *testing.T) {
|
||||
factory := NewHTTPClientFactory()
|
||||
|
||||
t.Run("zero values get defaults", func(t *testing.T) {
|
||||
config := HTTPClientConfig{
|
||||
// All zero values
|
||||
}
|
||||
|
||||
client := factory.CreateHTTPClient(config)
|
||||
|
||||
require.NotNil(t, client)
|
||||
// Verify defaults were applied
|
||||
assert.Equal(t, 30*time.Second, client.Timeout)
|
||||
})
|
||||
|
||||
t.Run("custom values preserved", func(t *testing.T) {
|
||||
config := HTTPClientConfig{
|
||||
Timeout: 15 * time.Second,
|
||||
MaxIdleConns: 50,
|
||||
MaxRedirects: 3,
|
||||
UseCookieJar: true,
|
||||
ForceHTTP2: true,
|
||||
DisableKeepAlives: true,
|
||||
}
|
||||
|
||||
client := factory.CreateHTTPClient(config)
|
||||
|
||||
require.NotNil(t, client)
|
||||
assert.Equal(t, 15*time.Second, client.Timeout)
|
||||
assert.NotNil(t, client.Jar)
|
||||
})
|
||||
|
||||
t.Run("invalid timeout gets default", func(t *testing.T) {
|
||||
config := HTTPClientConfig{
|
||||
Timeout: -1 * time.Second, // Invalid
|
||||
}
|
||||
|
||||
client := factory.CreateHTTPClient(config)
|
||||
|
||||
require.NotNil(t, client)
|
||||
// Should get default due to validation failure
|
||||
assert.Equal(t, 30*time.Second, client.Timeout)
|
||||
})
|
||||
}
|
||||
|
||||
// TestHTTPClientFactoryValidateHTTPClientConfig tests ValidateHTTPClientConfig
|
||||
func TestHTTPClientFactoryValidateHTTPClientConfig(t *testing.T) {
|
||||
factory := NewHTTPClientFactory()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
config HTTPClientConfig
|
||||
wantError bool
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "valid config",
|
||||
config: HTTPClientConfig{
|
||||
Timeout: 10 * time.Second,
|
||||
DialTimeout: 5 * time.Second,
|
||||
TLSHandshakeTimeout: 2 * time.Second,
|
||||
MaxIdleConns: 50,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
MaxConnsPerHost: 20,
|
||||
},
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "negative MaxIdleConns",
|
||||
config: HTTPClientConfig{
|
||||
Timeout: 10 * time.Second,
|
||||
DialTimeout: 5 * time.Second,
|
||||
TLSHandshakeTimeout: 2 * time.Second,
|
||||
MaxIdleConns: -1,
|
||||
},
|
||||
wantError: true,
|
||||
errorMsg: "MaxIdleConns cannot be negative",
|
||||
},
|
||||
{
|
||||
name: "MaxIdleConns too high",
|
||||
config: HTTPClientConfig{
|
||||
Timeout: 10 * time.Second,
|
||||
DialTimeout: 5 * time.Second,
|
||||
TLSHandshakeTimeout: 2 * time.Second,
|
||||
MaxIdleConns: 1500,
|
||||
},
|
||||
wantError: true,
|
||||
errorMsg: "MaxIdleConns too high",
|
||||
},
|
||||
{
|
||||
name: "negative MaxIdleConnsPerHost",
|
||||
config: HTTPClientConfig{
|
||||
Timeout: 10 * time.Second,
|
||||
DialTimeout: 5 * time.Second,
|
||||
TLSHandshakeTimeout: 2 * time.Second,
|
||||
MaxIdleConnsPerHost: -1,
|
||||
},
|
||||
wantError: true,
|
||||
errorMsg: "MaxIdleConnsPerHost cannot be negative",
|
||||
},
|
||||
{
|
||||
name: "timeout too high",
|
||||
config: HTTPClientConfig{
|
||||
Timeout: 10 * time.Minute,
|
||||
DialTimeout: 5 * time.Second,
|
||||
TLSHandshakeTimeout: 2 * time.Second,
|
||||
},
|
||||
wantError: true,
|
||||
errorMsg: "timeout too high",
|
||||
},
|
||||
{
|
||||
name: "negative timeout",
|
||||
config: HTTPClientConfig{
|
||||
Timeout: -1 * time.Second,
|
||||
DialTimeout: 5 * time.Second,
|
||||
TLSHandshakeTimeout: 2 * time.Second,
|
||||
},
|
||||
wantError: true,
|
||||
errorMsg: "timeout must be positive",
|
||||
},
|
||||
{
|
||||
name: "MaxIdleConnsPerHost exceeds MaxConnsPerHost",
|
||||
config: HTTPClientConfig{
|
||||
Timeout: 10 * time.Second,
|
||||
DialTimeout: 5 * time.Second,
|
||||
TLSHandshakeTimeout: 2 * time.Second,
|
||||
MaxIdleConnsPerHost: 50,
|
||||
MaxConnsPerHost: 10,
|
||||
},
|
||||
wantError: true,
|
||||
errorMsg: "MaxIdleConnsPerHost (50) cannot exceed MaxConnsPerHost (10)",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := factory.ValidateHTTPClientConfig(&tt.config)
|
||||
|
||||
if tt.wantError {
|
||||
assert.Error(t, err)
|
||||
if tt.errorMsg != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorMsg)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
+42
-10
@@ -146,6 +146,9 @@ func (p *SharedTransportPool) ReleaseTransport(transport *http.Transport) {
|
||||
}
|
||||
|
||||
// cleanupIdleTransports periodically cleans up unused transports
|
||||
// Uses two-phase cleanup to minimize lock contention:
|
||||
// 1. Find candidates while holding read lock
|
||||
// 2. Remove and close transports with minimal lock duration
|
||||
func (p *SharedTransportPool) cleanupIdleTransports(ctx context.Context) {
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
@@ -155,17 +158,46 @@ func (p *SharedTransportPool) cleanupIdleTransports(ctx context.Context) {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
p.mu.Lock()
|
||||
now := time.Now()
|
||||
for transportKey, shared := range p.transports {
|
||||
// Clean up transports not used for 2 minutes with no references
|
||||
if shared.refCount <= 0 && now.Sub(shared.lastUsed) > 2*time.Minute {
|
||||
shared.transport.CloseIdleConnections()
|
||||
delete(p.transports, transportKey)
|
||||
// SECURITY FIX: Decrement client count when removing transport
|
||||
atomic.AddInt32(&p.clientCount, -1)
|
||||
}
|
||||
p.performCleanup()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// performCleanup does the actual cleanup with optimized locking
|
||||
func (p *SharedTransportPool) performCleanup() {
|
||||
now := time.Now()
|
||||
|
||||
// Phase 1: Find candidates while holding read lock (fast)
|
||||
p.mu.RLock()
|
||||
candidates := make([]string, 0)
|
||||
for transportKey, shared := range p.transports {
|
||||
// Clean up transports not used for 2 minutes with no references
|
||||
if shared.refCount <= 0 && now.Sub(shared.lastUsed) > 2*time.Minute {
|
||||
candidates = append(candidates, transportKey)
|
||||
}
|
||||
}
|
||||
p.mu.RUnlock()
|
||||
|
||||
if len(candidates) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Phase 2: Remove and close each candidate individually
|
||||
// This minimizes lock contention and allows concurrent access
|
||||
for _, key := range candidates {
|
||||
p.mu.Lock()
|
||||
shared, exists := p.transports[key]
|
||||
if exists && shared.refCount <= 0 && now.Sub(shared.lastUsed) > 2*time.Minute {
|
||||
// Remove from map first (releases memory)
|
||||
delete(p.transports, key)
|
||||
atomic.AddInt32(&p.clientCount, -1)
|
||||
p.mu.Unlock()
|
||||
|
||||
// Close idle connections outside the lock (can be slow)
|
||||
if shared.transport != nil {
|
||||
shared.transport.CloseIdleConnections()
|
||||
}
|
||||
} else {
|
||||
p.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,691 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestSharedTransportPoolGetOrCreateTransport tests transport creation and reuse
|
||||
func TestSharedTransportPoolGetOrCreateTransport(t *testing.T) {
|
||||
t.Run("create new transport", func(t *testing.T) {
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
clientCount: 0,
|
||||
maxClients: 5,
|
||||
}
|
||||
|
||||
config := DefaultHTTPClientConfig()
|
||||
transport := pool.GetOrCreateTransport(config)
|
||||
|
||||
require.NotNil(t, transport)
|
||||
assert.Equal(t, int32(1), atomic.LoadInt32(&pool.clientCount))
|
||||
assert.Len(t, pool.transports, 1)
|
||||
})
|
||||
|
||||
t.Run("reuse existing transport", func(t *testing.T) {
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
clientCount: 0,
|
||||
maxClients: 5,
|
||||
}
|
||||
|
||||
config := DefaultHTTPClientConfig()
|
||||
transport1 := pool.GetOrCreateTransport(config)
|
||||
transport2 := pool.GetOrCreateTransport(config)
|
||||
|
||||
assert.Equal(t, transport1, transport2, "should reuse same transport")
|
||||
assert.Equal(t, int32(1), atomic.LoadInt32(&pool.clientCount), "client count should not increase")
|
||||
|
||||
// Check ref count
|
||||
pool.mu.RLock()
|
||||
key := pool.configKey(config)
|
||||
shared := pool.transports[key]
|
||||
pool.mu.RUnlock()
|
||||
|
||||
assert.Equal(t, 2, shared.refCount, "ref count should be 2")
|
||||
})
|
||||
|
||||
t.Run("client limit enforcement", func(t *testing.T) {
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
clientCount: 5, // Already at max
|
||||
maxClients: 5,
|
||||
}
|
||||
|
||||
config := DefaultHTTPClientConfig()
|
||||
transport := pool.GetOrCreateTransport(config)
|
||||
|
||||
assert.Nil(t, transport, "should return nil when at client limit")
|
||||
})
|
||||
|
||||
t.Run("client limit with existing transport", func(t *testing.T) {
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
clientCount: 0,
|
||||
maxClients: 5,
|
||||
}
|
||||
|
||||
// Create first transport
|
||||
config1 := DefaultHTTPClientConfig()
|
||||
transport1 := pool.GetOrCreateTransport(config1)
|
||||
require.NotNil(t, transport1)
|
||||
|
||||
// Set client count to max
|
||||
atomic.StoreInt32(&pool.clientCount, 5)
|
||||
|
||||
// Try to create with different config
|
||||
config2 := DefaultHTTPClientConfig()
|
||||
config2.MaxConnsPerHost = 15 // Different config
|
||||
transport2 := pool.GetOrCreateTransport(config2)
|
||||
|
||||
// Should return existing transport since at limit
|
||||
assert.NotNil(t, transport2)
|
||||
assert.Equal(t, transport1, transport2)
|
||||
})
|
||||
|
||||
t.Run("updates last used time", func(t *testing.T) {
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
clientCount: 0,
|
||||
maxClients: 5,
|
||||
}
|
||||
|
||||
config := DefaultHTTPClientConfig()
|
||||
transport := pool.GetOrCreateTransport(config)
|
||||
require.NotNil(t, transport)
|
||||
|
||||
pool.mu.RLock()
|
||||
key := pool.configKey(config)
|
||||
firstTime := pool.transports[key].lastUsed
|
||||
pool.mu.RUnlock()
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Get again
|
||||
transport2 := pool.GetOrCreateTransport(config)
|
||||
require.NotNil(t, transport2)
|
||||
|
||||
pool.mu.RLock()
|
||||
secondTime := pool.transports[key].lastUsed
|
||||
pool.mu.RUnlock()
|
||||
|
||||
assert.True(t, secondTime.After(firstTime), "lastUsed should be updated")
|
||||
})
|
||||
}
|
||||
|
||||
// TestSharedTransportPoolReleaseTransport tests transport release
|
||||
func TestSharedTransportPoolReleaseTransport(t *testing.T) {
|
||||
t.Run("decrement ref count", func(t *testing.T) {
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
clientCount: 0,
|
||||
maxClients: 5,
|
||||
}
|
||||
|
||||
config := DefaultHTTPClientConfig()
|
||||
transport := pool.GetOrCreateTransport(config)
|
||||
require.NotNil(t, transport)
|
||||
|
||||
// Get again to increase ref count
|
||||
pool.GetOrCreateTransport(config)
|
||||
|
||||
pool.mu.RLock()
|
||||
key := pool.configKey(config)
|
||||
refCount := pool.transports[key].refCount
|
||||
pool.mu.RUnlock()
|
||||
assert.Equal(t, 2, refCount)
|
||||
|
||||
// Release
|
||||
pool.ReleaseTransport(transport)
|
||||
|
||||
pool.mu.RLock()
|
||||
newRefCount := pool.transports[key].refCount
|
||||
pool.mu.RUnlock()
|
||||
assert.Equal(t, 1, newRefCount)
|
||||
})
|
||||
|
||||
t.Run("ref count reaches zero", func(t *testing.T) {
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
clientCount: 0,
|
||||
maxClients: 5,
|
||||
}
|
||||
|
||||
config := DefaultHTTPClientConfig()
|
||||
transport := pool.GetOrCreateTransport(config)
|
||||
require.NotNil(t, transport)
|
||||
|
||||
pool.mu.RLock()
|
||||
key := pool.configKey(config)
|
||||
pool.mu.RUnlock()
|
||||
|
||||
// Release to zero
|
||||
pool.ReleaseTransport(transport)
|
||||
|
||||
pool.mu.RLock()
|
||||
shared := pool.transports[key]
|
||||
pool.mu.RUnlock()
|
||||
|
||||
assert.Equal(t, 0, shared.refCount)
|
||||
assert.NotZero(t, shared.lastUsed, "lastUsed should be set")
|
||||
})
|
||||
|
||||
t.Run("release non-existent transport", func(t *testing.T) {
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
clientCount: 0,
|
||||
maxClients: 5,
|
||||
}
|
||||
|
||||
// Create a transport not in the pool
|
||||
fakeTransport := &http.Transport{}
|
||||
|
||||
// Should not panic
|
||||
assert.NotPanics(t, func() {
|
||||
pool.ReleaseTransport(fakeTransport)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("release updates last used", func(t *testing.T) {
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
clientCount: 0,
|
||||
maxClients: 5,
|
||||
}
|
||||
|
||||
config := DefaultHTTPClientConfig()
|
||||
transport := pool.GetOrCreateTransport(config)
|
||||
require.NotNil(t, transport)
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
beforeRelease := time.Now()
|
||||
pool.ReleaseTransport(transport)
|
||||
|
||||
pool.mu.RLock()
|
||||
key := pool.configKey(config)
|
||||
lastUsed := pool.transports[key].lastUsed
|
||||
pool.mu.RUnlock()
|
||||
|
||||
assert.True(t, lastUsed.After(beforeRelease) || lastUsed.Equal(beforeRelease))
|
||||
})
|
||||
}
|
||||
|
||||
// TestSharedTransportPoolCleanup tests cleanup functionality
|
||||
func TestSharedTransportPoolCleanup(t *testing.T) {
|
||||
t.Run("cleanup all transports", func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
clientCount: 0,
|
||||
maxClients: 5,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
// Create multiple transports
|
||||
config1 := DefaultHTTPClientConfig()
|
||||
pool.GetOrCreateTransport(config1)
|
||||
|
||||
config2 := DefaultHTTPClientConfig()
|
||||
config2.MaxConnsPerHost = 15
|
||||
pool.GetOrCreateTransport(config2)
|
||||
|
||||
assert.Greater(t, len(pool.transports), 0)
|
||||
|
||||
// Cleanup
|
||||
pool.Cleanup()
|
||||
|
||||
assert.Len(t, pool.transports, 0, "all transports should be removed")
|
||||
})
|
||||
|
||||
t.Run("cleanup cancels context", func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
clientCount: 0,
|
||||
maxClients: 5,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
pool.Cleanup()
|
||||
|
||||
select {
|
||||
case <-pool.ctx.Done():
|
||||
// Context was canceled
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Error("context should be canceled")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("cleanup with no transports", func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
clientCount: 0,
|
||||
maxClients: 5,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
pool.Cleanup()
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("cleanup closes idle connections", func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
clientCount: 0,
|
||||
maxClients: 5,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
config := DefaultHTTPClientConfig()
|
||||
transport := pool.GetOrCreateTransport(config)
|
||||
require.NotNil(t, transport)
|
||||
|
||||
// Cleanup should call CloseIdleConnections on each transport
|
||||
pool.Cleanup()
|
||||
|
||||
// Verify transports map is cleared
|
||||
assert.Empty(t, pool.transports)
|
||||
})
|
||||
}
|
||||
|
||||
// TestSharedTransportPoolCleanupIdleTransports tests periodic cleanup
|
||||
func TestSharedTransportPoolCleanupIdleTransports(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping cleanup goroutine test in short mode")
|
||||
}
|
||||
|
||||
t.Run("cleanup removes idle transports", func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
clientCount: 0,
|
||||
maxClients: 5,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
// Create transport and release it
|
||||
config := DefaultHTTPClientConfig()
|
||||
transport := pool.GetOrCreateTransport(config)
|
||||
require.NotNil(t, transport)
|
||||
|
||||
pool.ReleaseTransport(transport)
|
||||
|
||||
// Set lastUsed to old time
|
||||
pool.mu.Lock()
|
||||
key := pool.configKey(config)
|
||||
pool.transports[key].lastUsed = time.Now().Add(-3 * time.Minute)
|
||||
pool.mu.Unlock()
|
||||
|
||||
// Start cleanup in background (simulating what would happen)
|
||||
// Note: We're testing the cleanup logic manually here
|
||||
pool.mu.Lock()
|
||||
now := time.Now()
|
||||
for transportKey, shared := range pool.transports {
|
||||
if shared.refCount <= 0 && now.Sub(shared.lastUsed) > 2*time.Minute {
|
||||
shared.transport.CloseIdleConnections()
|
||||
delete(pool.transports, transportKey)
|
||||
atomic.AddInt32(&pool.clientCount, -1)
|
||||
}
|
||||
}
|
||||
pool.mu.Unlock()
|
||||
|
||||
// Transport should be removed
|
||||
pool.mu.RLock()
|
||||
_, exists := pool.transports[key]
|
||||
pool.mu.RUnlock()
|
||||
|
||||
assert.False(t, exists, "old idle transport should be removed")
|
||||
})
|
||||
|
||||
t.Run("cleanup preserves active transports", func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
clientCount: 0,
|
||||
maxClients: 5,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
// Create transport with refs
|
||||
config := DefaultHTTPClientConfig()
|
||||
transport := pool.GetOrCreateTransport(config)
|
||||
require.NotNil(t, transport)
|
||||
|
||||
// Keep ref count > 0, but set old lastUsed
|
||||
pool.mu.Lock()
|
||||
key := pool.configKey(config)
|
||||
pool.transports[key].lastUsed = time.Now().Add(-3 * time.Minute)
|
||||
pool.mu.Unlock()
|
||||
|
||||
// Run cleanup logic
|
||||
pool.mu.Lock()
|
||||
now := time.Now()
|
||||
for transportKey, shared := range pool.transports {
|
||||
if shared.refCount <= 0 && now.Sub(shared.lastUsed) > 2*time.Minute {
|
||||
shared.transport.CloseIdleConnections()
|
||||
delete(pool.transports, transportKey)
|
||||
}
|
||||
}
|
||||
pool.mu.Unlock()
|
||||
|
||||
// Transport should still exist (has ref count)
|
||||
pool.mu.RLock()
|
||||
_, exists := pool.transports[key]
|
||||
pool.mu.RUnlock()
|
||||
|
||||
assert.True(t, exists, "transport with references should be preserved")
|
||||
})
|
||||
|
||||
t.Run("cleanup respects context cancellation", func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
clientCount: 0,
|
||||
maxClients: 5,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
// Start cleanup goroutine
|
||||
done := make(chan bool)
|
||||
go func() {
|
||||
pool.cleanupIdleTransports(ctx)
|
||||
done <- true
|
||||
}()
|
||||
|
||||
// Cancel context
|
||||
cancel()
|
||||
|
||||
// Should exit quickly
|
||||
select {
|
||||
case <-done:
|
||||
// Success
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Error("cleanup goroutine should exit on context cancellation")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestCreatePooledHTTPClient tests pooled client creation
|
||||
func TestCreatePooledHTTPClient(t *testing.T) {
|
||||
t.Run("create client with default config", func(t *testing.T) {
|
||||
config := DefaultHTTPClientConfig()
|
||||
client := CreatePooledHTTPClient(config)
|
||||
|
||||
require.NotNil(t, client)
|
||||
assert.NotNil(t, client.Transport)
|
||||
assert.Equal(t, config.Timeout, client.Timeout)
|
||||
})
|
||||
|
||||
t.Run("create multiple clients reuse transport", func(t *testing.T) {
|
||||
// Reset global pool for clean test
|
||||
globalTransportPoolOnce = sync.Once{}
|
||||
globalTransportPool = nil
|
||||
|
||||
config := DefaultHTTPClientConfig()
|
||||
client1 := CreatePooledHTTPClient(config)
|
||||
client2 := CreatePooledHTTPClient(config)
|
||||
|
||||
require.NotNil(t, client1)
|
||||
require.NotNil(t, client2)
|
||||
|
||||
// Should use same transport
|
||||
assert.Equal(t, client1.Transport, client2.Transport)
|
||||
})
|
||||
|
||||
t.Run("redirect policy is set", func(t *testing.T) {
|
||||
config := DefaultHTTPClientConfig()
|
||||
config.MaxRedirects = 3
|
||||
|
||||
client := CreatePooledHTTPClient(config)
|
||||
|
||||
require.NotNil(t, client)
|
||||
assert.NotNil(t, client.CheckRedirect)
|
||||
|
||||
// Test redirect limit
|
||||
var redirects []*http.Request
|
||||
for i := 0; i < 3; i++ {
|
||||
redirects = append(redirects, &http.Request{})
|
||||
}
|
||||
|
||||
err := client.CheckRedirect(nil, redirects)
|
||||
assert.Error(t, err, "should error after max redirects")
|
||||
})
|
||||
|
||||
t.Run("default redirect limit", func(t *testing.T) {
|
||||
config := DefaultHTTPClientConfig()
|
||||
config.MaxRedirects = 0 // Should default to 10
|
||||
|
||||
client := CreatePooledHTTPClient(config)
|
||||
|
||||
require.NotNil(t, client)
|
||||
|
||||
// Test default redirect limit (10)
|
||||
var redirects []*http.Request
|
||||
for i := 0; i < 10; i++ {
|
||||
redirects = append(redirects, &http.Request{})
|
||||
}
|
||||
|
||||
err := client.CheckRedirect(nil, redirects)
|
||||
assert.Error(t, err, "should error after 10 redirects")
|
||||
})
|
||||
}
|
||||
|
||||
// TestGetGlobalTransportPool tests singleton pattern
|
||||
func TestGetGlobalTransportPool(t *testing.T) {
|
||||
t.Run("returns same instance", func(t *testing.T) {
|
||||
pool1 := GetGlobalTransportPool()
|
||||
pool2 := GetGlobalTransportPool()
|
||||
|
||||
assert.Equal(t, pool1, pool2, "should return same singleton instance")
|
||||
})
|
||||
|
||||
t.Run("pool is initialized", func(t *testing.T) {
|
||||
pool := GetGlobalTransportPool()
|
||||
|
||||
require.NotNil(t, pool)
|
||||
assert.NotNil(t, pool.transports)
|
||||
assert.Equal(t, 20, pool.maxConns)
|
||||
assert.Equal(t, int32(5), pool.maxClients)
|
||||
assert.NotNil(t, pool.ctx)
|
||||
assert.NotNil(t, pool.cancel)
|
||||
})
|
||||
}
|
||||
|
||||
// TestSharedTransportPoolConcurrency tests thread safety
|
||||
func TestSharedTransportPoolConcurrency(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping concurrency test in short mode")
|
||||
}
|
||||
|
||||
t.Run("concurrent GetOrCreateTransport", func(t *testing.T) {
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
clientCount: 0,
|
||||
maxClients: 10, // Allow more for concurrency test
|
||||
}
|
||||
|
||||
config := DefaultHTTPClientConfig()
|
||||
const numGoroutines = 20
|
||||
|
||||
var wg sync.WaitGroup
|
||||
transports := make([]*http.Transport, numGoroutines)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
transports[idx] = pool.GetOrCreateTransport(config)
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// All should get same transport
|
||||
firstTransport := transports[0]
|
||||
for i := 1; i < numGoroutines; i++ {
|
||||
if transports[i] != nil {
|
||||
assert.Equal(t, firstTransport, transports[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("concurrent ReleaseTransport", func(t *testing.T) {
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
clientCount: 0,
|
||||
maxClients: 10,
|
||||
}
|
||||
|
||||
config := DefaultHTTPClientConfig()
|
||||
transport := pool.GetOrCreateTransport(config)
|
||||
|
||||
// Increase ref count
|
||||
for i := 0; i < 20; i++ {
|
||||
pool.GetOrCreateTransport(config)
|
||||
}
|
||||
|
||||
const numReleases = 20
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i := 0; i < numReleases; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
pool.ReleaseTransport(transport)
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Should not panic and ref count should be decremented
|
||||
pool.mu.RLock()
|
||||
key := pool.configKey(config)
|
||||
refCount := pool.transports[key].refCount
|
||||
pool.mu.RUnlock()
|
||||
|
||||
assert.Equal(t, 1, refCount, "ref count should be 1 after 20 releases from initial 21")
|
||||
})
|
||||
}
|
||||
|
||||
// TestSharedTransportPoolEdgeCases tests edge cases
|
||||
func TestSharedTransportPoolEdgeCases(t *testing.T) {
|
||||
t.Run("config key generation", func(t *testing.T) {
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
}
|
||||
|
||||
config1 := DefaultHTTPClientConfig()
|
||||
config1.MaxConnsPerHost = 10
|
||||
config1.MaxIdleConnsPerHost = 5
|
||||
|
||||
config2 := DefaultHTTPClientConfig()
|
||||
config2.MaxConnsPerHost = 10
|
||||
config2.MaxIdleConnsPerHost = 5
|
||||
|
||||
key1 := pool.configKey(config1)
|
||||
key2 := pool.configKey(config2)
|
||||
|
||||
assert.Equal(t, key1, key2, "same config should produce same key")
|
||||
})
|
||||
|
||||
t.Run("different configs produce different keys", func(t *testing.T) {
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
}
|
||||
|
||||
config1 := DefaultHTTPClientConfig()
|
||||
config1.MaxConnsPerHost = 10
|
||||
|
||||
config2 := DefaultHTTPClientConfig()
|
||||
config2.MaxConnsPerHost = 20
|
||||
|
||||
key1 := pool.configKey(config1)
|
||||
key2 := pool.configKey(config2)
|
||||
|
||||
assert.NotEqual(t, key1, key2, "different configs should produce different keys")
|
||||
})
|
||||
|
||||
t.Run("client count decrements on cleanup", func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
clientCount: 0,
|
||||
maxClients: 5,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
config := DefaultHTTPClientConfig()
|
||||
transport := pool.GetOrCreateTransport(config)
|
||||
require.NotNil(t, transport)
|
||||
|
||||
initialCount := atomic.LoadInt32(&pool.clientCount)
|
||||
assert.Equal(t, int32(1), initialCount)
|
||||
|
||||
// Release and mark as old
|
||||
pool.ReleaseTransport(transport)
|
||||
pool.mu.Lock()
|
||||
key := pool.configKey(config)
|
||||
pool.transports[key].lastUsed = time.Now().Add(-3 * time.Minute)
|
||||
pool.mu.Unlock()
|
||||
|
||||
// Run cleanup
|
||||
pool.mu.Lock()
|
||||
now := time.Now()
|
||||
for transportKey, shared := range pool.transports {
|
||||
if shared.refCount <= 0 && now.Sub(shared.lastUsed) > 2*time.Minute {
|
||||
shared.transport.CloseIdleConnections()
|
||||
delete(pool.transports, transportKey)
|
||||
atomic.AddInt32(&pool.clientCount, -1)
|
||||
}
|
||||
}
|
||||
pool.mu.Unlock()
|
||||
|
||||
finalCount := atomic.LoadInt32(&pool.clientCount)
|
||||
assert.Equal(t, int32(0), finalCount, "client count should decrement on cleanup")
|
||||
})
|
||||
}
|
||||
Vendored
+90
@@ -0,0 +1,90 @@
|
||||
package backends
|
||||
|
||||
import "time"
|
||||
|
||||
// BackendType represents the type of cache backend
|
||||
type BackendType string
|
||||
|
||||
const (
|
||||
BackendTypeMemory BackendType = "memory"
|
||||
BackendTypeRedis BackendType = "redis"
|
||||
BackendTypeHybrid BackendType = "hybrid"
|
||||
|
||||
// Aliases for backward compatibility
|
||||
TypeMemory BackendType = "memory"
|
||||
TypeRedis BackendType = "redis"
|
||||
TypeHybrid BackendType = "hybrid"
|
||||
)
|
||||
|
||||
// Config provides common configuration for cache backends
|
||||
type Config struct {
|
||||
// Type specifies the backend type
|
||||
Type BackendType
|
||||
|
||||
// Memory backend settings
|
||||
MaxSize int
|
||||
MaxMemoryBytes int64
|
||||
CleanupInterval time.Duration
|
||||
|
||||
// Redis backend settings
|
||||
RedisAddr string
|
||||
RedisPassword string
|
||||
RedisDB int
|
||||
RedisPrefix string
|
||||
PoolSize int
|
||||
|
||||
// Hybrid backend settings
|
||||
L1Config *Config // Memory cache (L1)
|
||||
L2Config *Config // Redis cache (L2)
|
||||
AsyncWrites bool // Write to L2 asynchronously
|
||||
|
||||
// Resilience settings
|
||||
EnableCircuitBreaker bool
|
||||
EnableHealthCheck bool
|
||||
HealthCheckInterval time.Duration
|
||||
|
||||
// Metrics
|
||||
EnableMetrics bool
|
||||
}
|
||||
|
||||
// DefaultConfig returns a default configuration for in-memory caching
|
||||
func DefaultConfig() *Config {
|
||||
return &Config{
|
||||
Type: BackendTypeMemory,
|
||||
MaxSize: 1000,
|
||||
MaxMemoryBytes: 50 * 1024 * 1024, // 50MB
|
||||
CleanupInterval: 5 * time.Minute,
|
||||
EnableMetrics: true,
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultRedisConfig returns a default configuration for Redis caching
|
||||
func DefaultRedisConfig(addr string) *Config {
|
||||
return &Config{
|
||||
Type: BackendTypeRedis,
|
||||
RedisAddr: addr,
|
||||
RedisDB: 0,
|
||||
RedisPrefix: "traefikoidc:",
|
||||
PoolSize: 10,
|
||||
EnableCircuitBreaker: true,
|
||||
EnableHealthCheck: true,
|
||||
HealthCheckInterval: 30 * time.Second,
|
||||
EnableMetrics: true,
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultHybridConfig returns a default configuration for hybrid caching
|
||||
func DefaultHybridConfig(redisAddr string) *Config {
|
||||
return &Config{
|
||||
Type: BackendTypeHybrid,
|
||||
L1Config: &Config{
|
||||
Type: BackendTypeMemory,
|
||||
MaxSize: 500,
|
||||
MaxMemoryBytes: 10 * 1024 * 1024, // 10MB for L1
|
||||
CleanupInterval: 1 * time.Minute,
|
||||
},
|
||||
L2Config: DefaultRedisConfig(redisAddr),
|
||||
AsyncWrites: true,
|
||||
EnableMetrics: true,
|
||||
}
|
||||
}
|
||||
+59
@@ -0,0 +1,59 @@
|
||||
//go:build !yaegi
|
||||
|
||||
package backends
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestDefaultHybridConfig verifies the default hybrid configuration
|
||||
func TestDefaultHybridConfig(t *testing.T) {
|
||||
redisAddr := "localhost:6379"
|
||||
|
||||
config := DefaultHybridConfig(redisAddr)
|
||||
|
||||
require.NotNil(t, config)
|
||||
|
||||
// Verify top-level config
|
||||
assert.Equal(t, BackendTypeHybrid, config.Type)
|
||||
assert.True(t, config.AsyncWrites)
|
||||
assert.True(t, config.EnableMetrics)
|
||||
|
||||
// Verify L1 (memory) config
|
||||
require.NotNil(t, config.L1Config)
|
||||
assert.Equal(t, BackendTypeMemory, config.L1Config.Type)
|
||||
assert.Equal(t, 500, config.L1Config.MaxSize)
|
||||
assert.Equal(t, int64(10*1024*1024), config.L1Config.MaxMemoryBytes) // 10MB
|
||||
assert.Equal(t, 1*time.Minute, config.L1Config.CleanupInterval)
|
||||
|
||||
// Verify L2 (Redis) config exists
|
||||
require.NotNil(t, config.L2Config)
|
||||
assert.Equal(t, BackendTypeRedis, config.L2Config.Type)
|
||||
}
|
||||
|
||||
func TestDefaultHybridConfig_DifferentRedisAddr(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
redisAddr string
|
||||
}{
|
||||
{"localhost", "localhost:6379"},
|
||||
{"remote host", "redis.example.com:6379"},
|
||||
{"IP address", "192.168.1.100:6379"},
|
||||
{"custom port", "localhost:6380"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config := DefaultHybridConfig(tt.redisAddr)
|
||||
|
||||
require.NotNil(t, config)
|
||||
assert.Equal(t, BackendTypeHybrid, config.Type)
|
||||
assert.NotNil(t, config.L1Config)
|
||||
assert.NotNil(t, config.L2Config)
|
||||
})
|
||||
}
|
||||
}
|
||||
Vendored
+38
@@ -0,0 +1,38 @@
|
||||
package backends
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
// ErrBackendClosed is returned when operating on a closed backend
|
||||
ErrBackendClosed = errors.New("cache backend is closed")
|
||||
|
||||
// ErrKeyNotFound is returned when a key doesn't exist
|
||||
ErrKeyNotFound = errors.New("key not found")
|
||||
|
||||
// ErrCacheMiss indicates the requested key was not found in the cache
|
||||
ErrCacheMiss = errors.New("cache miss")
|
||||
|
||||
// ErrBackendUnavailable indicates the cache backend is not available
|
||||
ErrBackendUnavailable = errors.New("cache backend unavailable")
|
||||
|
||||
// ErrInvalidValue indicates the cached value is invalid or corrupted
|
||||
ErrInvalidValue = errors.New("invalid cached value")
|
||||
|
||||
// ErrInvalidTTL is returned when TTL is invalid
|
||||
ErrInvalidTTL = errors.New("invalid TTL")
|
||||
|
||||
// ErrConnectionFailed is returned when connection fails
|
||||
ErrConnectionFailed = errors.New("connection failed")
|
||||
|
||||
// ErrCircuitOpen is returned when circuit breaker is open
|
||||
ErrCircuitOpen = errors.New("circuit breaker is open")
|
||||
|
||||
// ErrTimeout is returned when operation times out
|
||||
ErrTimeout = errors.New("operation timeout")
|
||||
|
||||
// ErrSerializationFailed is returned when serialization fails
|
||||
ErrSerializationFailed = errors.New("serialization failed")
|
||||
|
||||
// ErrDeserializationFailed is returned when deserialization fails
|
||||
ErrDeserializationFailed = errors.New("deserialization failed")
|
||||
)
|
||||
Vendored
+695
@@ -0,0 +1,695 @@
|
||||
// Package backend provides cache backend implementations for the Traefik OIDC plugin.
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// HybridBackend implements a two-tier cache with L1 (memory) and L2 (Redis) backends
|
||||
// It provides automatic failover, async writes for non-critical data, and optimized read paths
|
||||
type HybridBackend struct {
|
||||
primary CacheBackend // L1: Memory cache for fast access
|
||||
secondary CacheBackend // L2: Redis cache for distributed access
|
||||
|
||||
// Configuration
|
||||
syncWriteCacheTypes map[string]bool // Which cache types require synchronous writes
|
||||
asyncWriteBuffer chan *asyncWriteItem
|
||||
|
||||
// Metrics
|
||||
l1Hits atomic.Int64
|
||||
l2Hits atomic.Int64
|
||||
misses atomic.Int64
|
||||
l1Writes atomic.Int64
|
||||
l2Writes atomic.Int64
|
||||
errors atomic.Int64
|
||||
|
||||
// Fallback tracking
|
||||
fallbackMode atomic.Bool // True when operating in degraded mode (L1 only)
|
||||
lastL2Error atomic.Value // Stores last L2 error timestamp
|
||||
|
||||
// Lifecycle
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
|
||||
// Logging
|
||||
logger Logger
|
||||
}
|
||||
|
||||
// asyncWriteItem represents an async write operation
|
||||
type asyncWriteItem struct {
|
||||
key string
|
||||
value []byte
|
||||
ttl time.Duration
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
// Logger interface for structured logging
|
||||
type Logger interface {
|
||||
Debugf(format string, args ...interface{})
|
||||
Infof(format string, args ...interface{})
|
||||
Warnf(format string, args ...interface{})
|
||||
Errorf(format string, args ...interface{})
|
||||
}
|
||||
|
||||
// defaultLogger provides a basic logger implementation
|
||||
type defaultLogger struct {
|
||||
*log.Logger
|
||||
}
|
||||
|
||||
func (l *defaultLogger) Debugf(format string, args ...interface{}) {
|
||||
l.Printf("[DEBUG] "+format, args...)
|
||||
}
|
||||
|
||||
func (l *defaultLogger) Infof(format string, args ...interface{}) {
|
||||
l.Printf("[INFO] "+format, args...)
|
||||
}
|
||||
|
||||
func (l *defaultLogger) Warnf(format string, args ...interface{}) {
|
||||
l.Printf("[WARN] "+format, args...)
|
||||
}
|
||||
|
||||
func (l *defaultLogger) Errorf(format string, args ...interface{}) {
|
||||
l.Printf("[ERROR] "+format, args...)
|
||||
}
|
||||
|
||||
// HybridConfig provides configuration for the hybrid backend
|
||||
type HybridConfig struct {
|
||||
Primary CacheBackend
|
||||
Secondary CacheBackend
|
||||
SyncWriteCacheTypes map[string]bool // Cache types requiring synchronous L2 writes
|
||||
AsyncBufferSize int
|
||||
Logger Logger
|
||||
}
|
||||
|
||||
// NewHybridBackend creates a new hybrid cache backend with L1 (memory) and L2 (Redis) tiers
|
||||
func NewHybridBackend(config *HybridConfig) (*HybridBackend, error) {
|
||||
if config == nil {
|
||||
return nil, fmt.Errorf("config is required")
|
||||
}
|
||||
|
||||
if config.Primary == nil {
|
||||
return nil, fmt.Errorf("primary (L1) backend is required")
|
||||
}
|
||||
|
||||
if config.Secondary == nil {
|
||||
return nil, fmt.Errorf("secondary (L2) backend is required")
|
||||
}
|
||||
|
||||
if config.Logger == nil {
|
||||
config.Logger = &defaultLogger{Logger: log.New(log.Writer(), "[HybridCache] ", log.LstdFlags)}
|
||||
}
|
||||
|
||||
if config.AsyncBufferSize <= 0 {
|
||||
config.AsyncBufferSize = 1000
|
||||
}
|
||||
|
||||
// Default critical cache types that require synchronous writes
|
||||
if config.SyncWriteCacheTypes == nil {
|
||||
config.SyncWriteCacheTypes = map[string]bool{
|
||||
"blacklist": true, // Token blacklist must be immediately consistent
|
||||
"token": true, // Token validation is critical
|
||||
}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
h := &HybridBackend{
|
||||
primary: config.Primary,
|
||||
secondary: config.Secondary,
|
||||
syncWriteCacheTypes: config.SyncWriteCacheTypes,
|
||||
asyncWriteBuffer: make(chan *asyncWriteItem, config.AsyncBufferSize),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
logger: config.Logger,
|
||||
}
|
||||
|
||||
// Start async write worker
|
||||
h.wg.Add(1)
|
||||
go h.asyncWriteWorker()
|
||||
|
||||
// Start health monitoring
|
||||
h.wg.Add(1)
|
||||
go h.healthMonitor()
|
||||
|
||||
h.logger.Infof("HybridBackend initialized with L1 (memory) and L2 (Redis) tiers")
|
||||
h.logger.Infof("Sync write cache types: %v", config.SyncWriteCacheTypes)
|
||||
h.logger.Infof("Async write buffer size: %d", config.AsyncBufferSize)
|
||||
|
||||
return h, nil
|
||||
}
|
||||
|
||||
// Set stores a value in both L1 and L2 caches
|
||||
func (h *HybridBackend) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
|
||||
// Always write to L1 first (synchronous)
|
||||
if err := h.primary.Set(ctx, key, value, ttl); err != nil {
|
||||
h.errors.Add(1)
|
||||
h.logger.Warnf("Failed to write to L1 cache: %v", err)
|
||||
// Continue to try L2 even if L1 fails
|
||||
} else {
|
||||
h.l1Writes.Add(1)
|
||||
}
|
||||
|
||||
// Check if we're in fallback mode
|
||||
if h.fallbackMode.Load() {
|
||||
h.logger.Debugf("Operating in fallback mode, skipping L2 write for key: %s", key)
|
||||
return nil // Don't fail the operation if L2 is down
|
||||
}
|
||||
|
||||
// Determine if this should be a sync or async write based on cache type
|
||||
cacheType := h.extractCacheType(key)
|
||||
requiresSync := h.syncWriteCacheTypes[cacheType]
|
||||
|
||||
if requiresSync {
|
||||
// Synchronous write for critical cache types
|
||||
if err := h.secondary.Set(ctx, key, value, ttl); err != nil {
|
||||
h.errors.Add(1)
|
||||
h.logger.Warnf("Failed to write to L2 cache (sync) for key %s: %v", key, err)
|
||||
h.recordL2Error()
|
||||
// Don't fail the operation - L1 write succeeded
|
||||
return nil
|
||||
}
|
||||
h.l2Writes.Add(1)
|
||||
h.logger.Debugf("Synchronous write to L2 completed for critical key: %s", key)
|
||||
} else {
|
||||
// Asynchronous write for non-critical cache types
|
||||
select {
|
||||
case h.asyncWriteBuffer <- &asyncWriteItem{
|
||||
key: key,
|
||||
value: value,
|
||||
ttl: ttl,
|
||||
ctx: ctx,
|
||||
}:
|
||||
h.logger.Debugf("Queued async write to L2 for key: %s", key)
|
||||
default:
|
||||
// Buffer is full, log and continue
|
||||
h.logger.Warnf("Async write buffer full, dropping L2 write for key: %s", key)
|
||||
h.errors.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get retrieves a value from cache, checking L1 first, then L2
|
||||
func (h *HybridBackend) Get(ctx context.Context, key string) ([]byte, time.Duration, bool, error) {
|
||||
// Try L1 first
|
||||
value, ttl, exists, err := h.primary.Get(ctx, key)
|
||||
if err != nil {
|
||||
h.errors.Add(1)
|
||||
h.logger.Debugf("L1 get error for key %s: %v", key, err)
|
||||
}
|
||||
|
||||
if exists {
|
||||
h.l1Hits.Add(1)
|
||||
return value, ttl, true, nil
|
||||
}
|
||||
|
||||
// Check if we're in fallback mode
|
||||
if h.fallbackMode.Load() {
|
||||
h.misses.Add(1)
|
||||
return nil, 0, false, nil
|
||||
}
|
||||
|
||||
// Try L2
|
||||
value, ttl, exists, err = h.secondary.Get(ctx, key)
|
||||
if err != nil {
|
||||
h.errors.Add(1)
|
||||
h.logger.Debugf("L2 get error for key %s: %v", key, err)
|
||||
h.recordL2Error()
|
||||
h.misses.Add(1)
|
||||
return nil, 0, false, nil // Don't propagate L2 errors
|
||||
}
|
||||
|
||||
if !exists {
|
||||
h.misses.Add(1)
|
||||
return nil, 0, false, nil
|
||||
}
|
||||
|
||||
h.l2Hits.Add(1)
|
||||
|
||||
// Populate L1 cache with value from L2 (write-through on read)
|
||||
// Use goroutine to avoid blocking the read path
|
||||
go func() {
|
||||
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
if err := h.primary.Set(writeCtx, key, value, ttl); err != nil {
|
||||
h.logger.Debugf("Failed to populate L1 cache from L2 for key %s: %v", key, err)
|
||||
} else {
|
||||
h.logger.Debugf("Populated L1 cache from L2 for key: %s", key)
|
||||
}
|
||||
}()
|
||||
|
||||
return value, ttl, true, nil
|
||||
}
|
||||
|
||||
// Delete removes a key from both L1 and L2 caches
|
||||
func (h *HybridBackend) Delete(ctx context.Context, key string) (bool, error) {
|
||||
var deleted bool
|
||||
|
||||
// Delete from L1
|
||||
if d, err := h.primary.Delete(ctx, key); err != nil {
|
||||
h.logger.Debugf("Failed to delete from L1 cache: %v", err)
|
||||
} else if d {
|
||||
deleted = true
|
||||
}
|
||||
|
||||
// Delete from L2 if not in fallback mode
|
||||
if !h.fallbackMode.Load() {
|
||||
if d, err := h.secondary.Delete(ctx, key); err != nil {
|
||||
h.logger.Debugf("Failed to delete from L2 cache: %v", err)
|
||||
h.recordL2Error()
|
||||
} else if d {
|
||||
deleted = true
|
||||
}
|
||||
}
|
||||
|
||||
return deleted, nil
|
||||
}
|
||||
|
||||
// Exists checks if a key exists in either cache
|
||||
func (h *HybridBackend) Exists(ctx context.Context, key string) (bool, error) {
|
||||
// Check L1 first
|
||||
if exists, err := h.primary.Exists(ctx, key); err == nil && exists {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Check L2 if not in fallback mode
|
||||
if !h.fallbackMode.Load() {
|
||||
if exists, err := h.secondary.Exists(ctx, key); err == nil && exists {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Clear removes all keys from both caches
|
||||
func (h *HybridBackend) Clear(ctx context.Context) error {
|
||||
var lastErr error
|
||||
|
||||
// Clear L1
|
||||
if err := h.primary.Clear(ctx); err != nil {
|
||||
h.logger.Errorf("Failed to clear L1 cache: %v", err)
|
||||
lastErr = err
|
||||
}
|
||||
|
||||
// Clear L2 if not in fallback mode
|
||||
if !h.fallbackMode.Load() {
|
||||
if err := h.secondary.Clear(ctx); err != nil {
|
||||
h.logger.Errorf("Failed to clear L2 cache: %v", err)
|
||||
h.recordL2Error()
|
||||
lastErr = err
|
||||
}
|
||||
}
|
||||
|
||||
return lastErr
|
||||
}
|
||||
|
||||
// GetStats returns statistics for the hybrid cache
|
||||
func (h *HybridBackend) GetStats() map[string]interface{} {
|
||||
l1Hits := h.l1Hits.Load()
|
||||
l2Hits := h.l2Hits.Load()
|
||||
misses := h.misses.Load()
|
||||
total := l1Hits + l2Hits + misses
|
||||
|
||||
stats := map[string]interface{}{
|
||||
"type": TypeHybrid,
|
||||
"l1_hits": l1Hits,
|
||||
"l2_hits": l2Hits,
|
||||
"misses": misses,
|
||||
"total": total,
|
||||
"l1_writes": h.l1Writes.Load(),
|
||||
"l2_writes": h.l2Writes.Load(),
|
||||
"errors": h.errors.Load(),
|
||||
"fallback_mode": h.fallbackMode.Load(),
|
||||
}
|
||||
|
||||
if total > 0 {
|
||||
stats["l1_hit_rate"] = float64(l1Hits) / float64(total)
|
||||
stats["l2_hit_rate"] = float64(l2Hits) / float64(total)
|
||||
stats["overall_hit_rate"] = float64(l1Hits+l2Hits) / float64(total)
|
||||
}
|
||||
|
||||
// Add sub-backend stats
|
||||
stats["l1_stats"] = h.primary.GetStats()
|
||||
stats["l2_stats"] = h.secondary.GetStats()
|
||||
|
||||
// Add last L2 error time if available
|
||||
if lastErr := h.lastL2Error.Load(); lastErr != nil {
|
||||
if t, ok := lastErr.(time.Time); ok {
|
||||
stats["last_l2_error"] = t.Format(time.RFC3339)
|
||||
stats["seconds_since_l2_error"] = time.Since(t).Seconds()
|
||||
}
|
||||
}
|
||||
|
||||
return stats
|
||||
}
|
||||
|
||||
// Ping checks if both backends are healthy
|
||||
func (h *HybridBackend) Ping(ctx context.Context) error {
|
||||
// Check L1
|
||||
if err := h.primary.Ping(ctx); err != nil {
|
||||
return fmt.Errorf("L1 ping failed: %w", err)
|
||||
}
|
||||
|
||||
// Check L2 (but don't fail if it's down)
|
||||
if err := h.secondary.Ping(ctx); err != nil {
|
||||
h.logger.Warnf("L2 ping failed: %v", err)
|
||||
h.recordL2Error()
|
||||
// Don't return error - we can operate with L1 only
|
||||
} else {
|
||||
// L2 is healthy, clear fallback mode if it was set
|
||||
if h.fallbackMode.CompareAndSwap(true, false) {
|
||||
h.logger.Infof("L2 backend recovered, exiting fallback mode")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close shuts down the hybrid backend
|
||||
func (h *HybridBackend) Close() error {
|
||||
// Cancel context to stop workers
|
||||
h.cancel()
|
||||
|
||||
// Close async write channel
|
||||
close(h.asyncWriteBuffer)
|
||||
|
||||
// Wait for workers to finish with timeout
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
h.wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// Workers finished
|
||||
case <-time.After(5 * time.Second):
|
||||
h.logger.Warnf("Timeout waiting for workers to finish")
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
|
||||
// Close backends
|
||||
if err := h.primary.Close(); err != nil {
|
||||
h.logger.Errorf("Failed to close L1 backend: %v", err)
|
||||
lastErr = err
|
||||
}
|
||||
|
||||
if err := h.secondary.Close(); err != nil {
|
||||
h.logger.Errorf("Failed to close L2 backend: %v", err)
|
||||
lastErr = err
|
||||
}
|
||||
|
||||
h.logger.Infof("HybridBackend closed")
|
||||
|
||||
return lastErr
|
||||
}
|
||||
|
||||
// GetMany retrieves multiple values efficiently
|
||||
func (h *HybridBackend) GetMany(ctx context.Context, keys []string) (map[string][]byte, error) {
|
||||
if len(keys) == 0 {
|
||||
return make(map[string][]byte), nil
|
||||
}
|
||||
|
||||
results := make(map[string][]byte, len(keys))
|
||||
missingKeys := make([]string, 0)
|
||||
|
||||
// Try L1 first for all keys
|
||||
for _, key := range keys {
|
||||
if value, _, exists, _ := h.primary.Get(ctx, key); exists {
|
||||
results[key] = value
|
||||
h.l1Hits.Add(1)
|
||||
} else {
|
||||
missingKeys = append(missingKeys, key)
|
||||
}
|
||||
}
|
||||
|
||||
// If all found in L1 or in fallback mode, return
|
||||
if len(missingKeys) == 0 || h.fallbackMode.Load() {
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// Try L2 for missing keys using batch operation if available
|
||||
if batcher, ok := h.secondary.(interface {
|
||||
GetMany(context.Context, []string) (map[string][]byte, error)
|
||||
}); ok {
|
||||
l2Results, err := batcher.GetMany(ctx, missingKeys)
|
||||
if err != nil {
|
||||
h.logger.Debugf("L2 batch get error: %v", err)
|
||||
h.recordL2Error()
|
||||
} else {
|
||||
for key, value := range l2Results {
|
||||
results[key] = value
|
||||
h.l2Hits.Add(1)
|
||||
|
||||
// Asynchronously populate L1
|
||||
go func(k string, v []byte) {
|
||||
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
_ = h.primary.Set(writeCtx, k, v, 0) // Use default TTL
|
||||
}(key, value)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Fallback to individual gets
|
||||
for _, key := range missingKeys {
|
||||
if value, ttl, exists, err := h.secondary.Get(ctx, key); err == nil && exists {
|
||||
results[key] = value
|
||||
h.l2Hits.Add(1)
|
||||
|
||||
// Asynchronously populate L1
|
||||
go func(k string, v []byte, t time.Duration) {
|
||||
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
_ = h.primary.Set(writeCtx, k, v, t)
|
||||
}(key, value, ttl)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Count misses for keys not found anywhere
|
||||
for _, key := range keys {
|
||||
if _, found := results[key]; !found {
|
||||
h.misses.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// SetMany stores multiple key-value pairs efficiently
|
||||
func (h *HybridBackend) SetMany(ctx context.Context, items map[string][]byte, ttl time.Duration) error {
|
||||
if len(items) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write to L1 first
|
||||
for key, value := range items {
|
||||
if err := h.primary.Set(ctx, key, value, ttl); err != nil {
|
||||
h.logger.Debugf("Failed to write to L1 in batch: %v", err)
|
||||
} else {
|
||||
h.l1Writes.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
// Skip L2 if in fallback mode
|
||||
if h.fallbackMode.Load() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if L2 supports batch operations
|
||||
if batcher, ok := h.secondary.(interface {
|
||||
SetMany(context.Context, map[string][]byte, time.Duration) error
|
||||
}); ok {
|
||||
if err := batcher.SetMany(ctx, items, ttl); err != nil {
|
||||
h.logger.Warnf("Failed to batch write to L2: %v", err)
|
||||
h.recordL2Error()
|
||||
} else {
|
||||
h.l2Writes.Add(int64(len(items)))
|
||||
}
|
||||
} else {
|
||||
// Fallback to individual sets
|
||||
for key, value := range items {
|
||||
cacheType := h.extractCacheType(key)
|
||||
if h.syncWriteCacheTypes[cacheType] {
|
||||
// Sync write for critical types
|
||||
if err := h.secondary.Set(ctx, key, value, ttl); err != nil {
|
||||
h.logger.Debugf("Failed to write to L2: %v", err)
|
||||
h.recordL2Error()
|
||||
} else {
|
||||
h.l2Writes.Add(1)
|
||||
}
|
||||
} else {
|
||||
// Async write for non-critical types
|
||||
select {
|
||||
case h.asyncWriteBuffer <- &asyncWriteItem{
|
||||
key: key,
|
||||
value: value,
|
||||
ttl: ttl,
|
||||
ctx: ctx,
|
||||
}:
|
||||
// Queued
|
||||
default:
|
||||
h.logger.Warnf("Async buffer full for batch write")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// asyncWriteWorker processes asynchronous writes to L2
|
||||
func (h *HybridBackend) asyncWriteWorker() {
|
||||
defer h.wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-h.ctx.Done():
|
||||
// Drain remaining items with best effort
|
||||
for len(h.asyncWriteBuffer) > 0 {
|
||||
select {
|
||||
case item := <-h.asyncWriteBuffer:
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
_ = h.secondary.Set(ctx, item.key, item.value, item.ttl)
|
||||
cancel()
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
|
||||
case item, ok := <-h.asyncWriteBuffer:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// Skip if in fallback mode
|
||||
if h.fallbackMode.Load() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Perform the write with a timeout
|
||||
writeCtx, cancel := context.WithTimeout(item.ctx, 500*time.Millisecond)
|
||||
if err := h.secondary.Set(writeCtx, item.key, item.value, item.ttl); err != nil {
|
||||
h.errors.Add(1)
|
||||
h.logger.Debugf("Async write to L2 failed for key %s: %v", item.key, err)
|
||||
h.recordL2Error()
|
||||
} else {
|
||||
h.l2Writes.Add(1)
|
||||
h.logger.Debugf("Async write to L2 completed for key: %s", item.key)
|
||||
}
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// healthMonitor periodically checks L2 health and manages fallback mode
|
||||
func (h *HybridBackend) healthMonitor() {
|
||||
defer h.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-h.ctx.Done():
|
||||
return
|
||||
|
||||
case <-ticker.C:
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
|
||||
if err := h.secondary.Ping(ctx); err != nil {
|
||||
if !h.fallbackMode.Load() {
|
||||
h.fallbackMode.Store(true)
|
||||
h.logger.Warnf("L2 backend unhealthy, entering fallback mode: %v", err)
|
||||
}
|
||||
} else {
|
||||
if h.fallbackMode.CompareAndSwap(true, false) {
|
||||
h.logger.Infof("L2 backend healthy, exiting fallback mode")
|
||||
}
|
||||
}
|
||||
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// recordL2Error records the timestamp of an L2 error
|
||||
func (h *HybridBackend) recordL2Error() {
|
||||
h.lastL2Error.Store(time.Now())
|
||||
|
||||
// Check if we should enter fallback mode based on recent errors
|
||||
if !h.fallbackMode.Load() {
|
||||
// Simple heuristic: if we've had an error in the last second, consider L2 unhealthy
|
||||
if lastErr := h.lastL2Error.Load(); lastErr != nil {
|
||||
if t, ok := lastErr.(time.Time); ok && time.Since(t) < time.Second {
|
||||
h.fallbackMode.Store(true)
|
||||
h.logger.Warnf("Multiple L2 errors detected, entering fallback mode")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// extractCacheType attempts to determine the cache type from the key
|
||||
func (h *HybridBackend) extractCacheType(key string) string {
|
||||
// Simple heuristic based on key prefixes
|
||||
// This should match the actual cache type strategy in the main application
|
||||
|
||||
if len(key) > 10 {
|
||||
prefix := key[:10]
|
||||
switch {
|
||||
case contains(prefix, "blacklist"):
|
||||
return "blacklist"
|
||||
case contains(prefix, "token"):
|
||||
return "token"
|
||||
case contains(prefix, "metadata"):
|
||||
return "metadata"
|
||||
case contains(prefix, "jwk"):
|
||||
return "jwk"
|
||||
case contains(prefix, "session"):
|
||||
return "session"
|
||||
case contains(prefix, "introspect"):
|
||||
return "introspection"
|
||||
}
|
||||
}
|
||||
|
||||
return "general"
|
||||
}
|
||||
|
||||
// contains checks if a string contains a substring (case-insensitive)
|
||||
func contains(s, substr string) bool {
|
||||
if len(substr) > len(s) {
|
||||
return false
|
||||
}
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
match := true
|
||||
for j := 0; j < len(substr); j++ {
|
||||
if toLower(s[i+j]) != toLower(substr[j]) {
|
||||
match = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if match {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// toLower converts a byte to lowercase
|
||||
func toLower(b byte) byte {
|
||||
if b >= 'A' && b <= 'Z' {
|
||||
return b + 32
|
||||
}
|
||||
return b
|
||||
}
|
||||
+1490
File diff suppressed because it is too large
Load Diff
Vendored
+133
@@ -0,0 +1,133 @@
|
||||
// Package backend provides cache backend implementations for the Traefik OIDC plugin.
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// CacheBackend defines the interface for all cache backend implementations
|
||||
// Implementations include: MemoryBackend, RedisBackend, and HybridBackend
|
||||
type CacheBackend interface {
|
||||
// Set stores a value in the cache with the specified TTL
|
||||
// Returns an error if the operation fails
|
||||
Set(ctx context.Context, key string, value []byte, ttl time.Duration) error
|
||||
|
||||
// Get retrieves a value from the cache
|
||||
// Returns: value, remaining TTL, exists flag, and error
|
||||
// If the key doesn't exist, exists will be false
|
||||
Get(ctx context.Context, key string) (value []byte, ttl time.Duration, exists bool, err error)
|
||||
|
||||
// Delete removes a key from the cache
|
||||
// Returns true if the key was deleted, false if it didn't exist
|
||||
Delete(ctx context.Context, key string) (bool, error)
|
||||
|
||||
// Exists checks if a key exists in the cache
|
||||
Exists(ctx context.Context, key string) (bool, error)
|
||||
|
||||
// Clear removes all keys from the cache
|
||||
Clear(ctx context.Context) error
|
||||
|
||||
// GetStats returns cache statistics
|
||||
// Stats include: hits, misses, size, memory usage, etc.
|
||||
GetStats() map[string]interface{}
|
||||
|
||||
// Close shuts down the cache backend and releases resources
|
||||
Close() error
|
||||
|
||||
// Ping checks if the backend is healthy and responsive
|
||||
Ping(ctx context.Context) error
|
||||
}
|
||||
|
||||
// BackendStats represents statistics for a cache backend
|
||||
type BackendStats struct {
|
||||
// Type is the backend type
|
||||
Type BackendType
|
||||
|
||||
// Hits is the number of cache hits
|
||||
Hits int64
|
||||
|
||||
// Misses is the number of cache misses
|
||||
Misses int64
|
||||
|
||||
// Sets is the number of set operations
|
||||
Sets int64
|
||||
|
||||
// Deletes is the number of delete operations
|
||||
Deletes int64
|
||||
|
||||
// Errors is the number of errors
|
||||
Errors int64
|
||||
|
||||
// Evictions is the number of evicted items
|
||||
Evictions int64
|
||||
|
||||
// CurrentSize is the current number of items in cache
|
||||
CurrentSize int64
|
||||
|
||||
// MaxSize is the maximum number of items (0 means unlimited)
|
||||
MaxSize int64
|
||||
|
||||
// MemoryUsage is the approximate memory usage in bytes
|
||||
MemoryUsage int64
|
||||
|
||||
// AverageGetLatency is the average latency for get operations
|
||||
AverageGetLatency time.Duration
|
||||
|
||||
// AverageSetLatency is the average latency for set operations
|
||||
AverageSetLatency time.Duration
|
||||
|
||||
// LastError is the last error encountered
|
||||
LastError string
|
||||
|
||||
// LastErrorTime is when the last error occurred
|
||||
LastErrorTime time.Time
|
||||
|
||||
// Uptime is how long the backend has been running
|
||||
Uptime time.Duration
|
||||
|
||||
// StartTime is when the backend was started
|
||||
StartTime time.Time
|
||||
}
|
||||
|
||||
// BackendCapabilities describes the capabilities of a cache backend
|
||||
type BackendCapabilities struct {
|
||||
// Distributed indicates if the backend is distributed across multiple instances
|
||||
Distributed bool
|
||||
|
||||
// Persistent indicates if the backend persists data across restarts
|
||||
Persistent bool
|
||||
|
||||
// Eviction indicates if the backend supports automatic eviction
|
||||
Eviction bool
|
||||
|
||||
// TTL indicates if the backend supports TTL (time-to-live)
|
||||
TTL bool
|
||||
|
||||
// MaxKeySize is the maximum size of a key in bytes (0 = unlimited)
|
||||
MaxKeySize int64
|
||||
|
||||
// MaxValueSize is the maximum size of a value in bytes (0 = unlimited)
|
||||
MaxValueSize int64
|
||||
|
||||
// MaxKeys is the maximum number of keys (0 = unlimited)
|
||||
MaxKeys int64
|
||||
|
||||
// SupportsExpire indicates if the backend supports expiration
|
||||
SupportsExpire bool
|
||||
|
||||
// SupportsMultiGet indicates if the backend supports batch get operations
|
||||
SupportsMultiGet bool
|
||||
|
||||
// SupportsTransaction indicates if the backend supports transactions
|
||||
SupportsTransaction bool
|
||||
|
||||
// SupportsCompression indicates if the backend supports compression
|
||||
SupportsCompression bool
|
||||
|
||||
// RequiresSerialize indicates if values must be serialized
|
||||
RequiresSerialize bool
|
||||
|
||||
// AtomicOperations indicates if the backend supports atomic operations
|
||||
AtomicOperations bool
|
||||
}
|
||||
+421
@@ -0,0 +1,421 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestCacheBackendContract defines a set of tests that all CacheBackend implementations must pass
|
||||
// This ensures that Memory, Redis, and Hybrid backends all behave consistently
|
||||
func TestCacheBackendContract(t *testing.T) {
|
||||
// Test suite will be run against each backend type
|
||||
t.Run("MemoryBackend", func(t *testing.T) {
|
||||
backend := setupMemoryBackend(t)
|
||||
runContractTests(t, backend)
|
||||
})
|
||||
|
||||
t.Run("RedisBackend", func(t *testing.T) {
|
||||
backend := setupRedisBackend(t)
|
||||
runContractTests(t, backend)
|
||||
})
|
||||
|
||||
t.Run("HybridBackend", func(t *testing.T) {
|
||||
backend := setupHybridBackend(t)
|
||||
runContractTests(t, backend)
|
||||
})
|
||||
}
|
||||
|
||||
// runContractTests executes all contract tests against a backend
|
||||
func runContractTests(t *testing.T, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("BasicSetGet", func(t *testing.T) {
|
||||
testBasicSetGet(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("GetNonExistent", func(t *testing.T) {
|
||||
testGetNonExistent(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("UpdateExisting", func(t *testing.T) {
|
||||
testUpdateExisting(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("Delete", func(t *testing.T) {
|
||||
testDelete(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("DeleteNonExistent", func(t *testing.T) {
|
||||
testDeleteNonExistent(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("Exists", func(t *testing.T) {
|
||||
testExists(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("TTLExpiration", func(t *testing.T) {
|
||||
testTTLExpiration(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("Clear", func(t *testing.T) {
|
||||
testClear(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("Ping", func(t *testing.T) {
|
||||
testPing(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("Stats", func(t *testing.T) {
|
||||
testStats(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("ConcurrentAccess", func(t *testing.T) {
|
||||
testConcurrentAccess(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("LargeValues", func(t *testing.T) {
|
||||
testLargeValues(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("EmptyValues", func(t *testing.T) {
|
||||
testEmptyValues(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("SpecialCharactersInKeys", func(t *testing.T) {
|
||||
testSpecialCharactersInKeys(t, ctx, backend)
|
||||
})
|
||||
}
|
||||
|
||||
// testBasicSetGet verifies basic set and get operations
|
||||
func testBasicSetGet(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
key := "test-key-1"
|
||||
value := []byte("test-value-1")
|
||||
ttl := 1 * time.Minute
|
||||
|
||||
// Set value
|
||||
err := backend.Set(ctx, key, value, ttl)
|
||||
require.NoError(t, err, "Set should not return error")
|
||||
|
||||
// Get value
|
||||
retrieved, remainingTTL, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err, "Get should not return error")
|
||||
assert.True(t, exists, "Key should exist")
|
||||
assert.Equal(t, value, retrieved, "Retrieved value should match")
|
||||
assert.Greater(t, remainingTTL, 50*time.Second, "TTL should be close to original")
|
||||
assert.LessOrEqual(t, remainingTTL, ttl, "TTL should not exceed original")
|
||||
}
|
||||
|
||||
// testGetNonExistent verifies behavior when getting non-existent keys
|
||||
func testGetNonExistent(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
key := "non-existent-key"
|
||||
|
||||
retrieved, ttl, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err, "Get should not return error for non-existent key")
|
||||
assert.False(t, exists, "Key should not exist")
|
||||
assert.Nil(t, retrieved, "Value should be nil")
|
||||
assert.Equal(t, time.Duration(0), ttl, "TTL should be zero")
|
||||
}
|
||||
|
||||
// testUpdateExisting verifies updating an existing key
|
||||
func testUpdateExisting(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
key := "update-key"
|
||||
value1 := []byte("original-value")
|
||||
value2 := []byte("updated-value")
|
||||
ttl := 1 * time.Minute
|
||||
|
||||
// Set initial value
|
||||
err := backend.Set(ctx, key, value1, ttl)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update value
|
||||
err = backend.Set(ctx, key, value2, ttl)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify updated value
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, value2, retrieved, "Value should be updated")
|
||||
}
|
||||
|
||||
// testDelete verifies delete operation
|
||||
func testDelete(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
key := "delete-key"
|
||||
value := []byte("delete-value")
|
||||
|
||||
// Set value
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify exists
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
// Delete
|
||||
deleted, err := backend.Delete(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, deleted, "Delete should return true for existing key")
|
||||
|
||||
// Verify deleted
|
||||
exists, err = backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "Key should not exist after delete")
|
||||
}
|
||||
|
||||
// testDeleteNonExistent verifies deleting non-existent keys
|
||||
func testDeleteNonExistent(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
key := "non-existent-delete-key"
|
||||
|
||||
deleted, err := backend.Delete(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, deleted, "Delete should return false for non-existent key")
|
||||
}
|
||||
|
||||
// testExists verifies the Exists operation
|
||||
func testExists(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
key := "exists-key"
|
||||
value := []byte("exists-value")
|
||||
|
||||
// Check non-existent key
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "Key should not exist initially")
|
||||
|
||||
// Set value
|
||||
err = backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check existing key
|
||||
exists, err = backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Key should exist after Set")
|
||||
}
|
||||
|
||||
// testTTLExpiration verifies TTL expiration behavior
|
||||
func testTTLExpiration(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
key := "ttl-key"
|
||||
value := []byte("ttl-value")
|
||||
shortTTL := 100 * time.Millisecond
|
||||
|
||||
// Set with short TTL
|
||||
err := backend.Set(ctx, key, value, shortTTL)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify exists immediately
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Key should exist immediately after Set")
|
||||
|
||||
// Wait for expiration
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// Verify expired
|
||||
exists, err = backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "Key should not exist after TTL expiration")
|
||||
}
|
||||
|
||||
// testClear verifies Clear operation
|
||||
func testClear(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
// Set multiple keys
|
||||
for i := 0; i < 5; i++ {
|
||||
key := fmt.Sprintf("clear-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("clear-value-%d", i))
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Give async writes time to complete before clearing
|
||||
// This prevents race conditions with async write workers
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Clear all
|
||||
err := backend.Clear(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify all keys are gone
|
||||
for i := 0; i < 5; i++ {
|
||||
key := fmt.Sprintf("clear-key-%d", i)
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "Key should not exist after Clear")
|
||||
}
|
||||
}
|
||||
|
||||
// testPing verifies Ping operation
|
||||
func testPing(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
err := backend.Ping(ctx)
|
||||
assert.NoError(t, err, "Ping should succeed on healthy backend")
|
||||
}
|
||||
|
||||
// testStats verifies GetStats operation
|
||||
func testStats(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
stats := backend.GetStats()
|
||||
assert.NotNil(t, stats, "Stats should not be nil")
|
||||
|
||||
// Stats should contain basic metrics
|
||||
_, hasHits := stats["hits"]
|
||||
_, hasMisses := stats["misses"]
|
||||
assert.True(t, hasHits || hasMisses, "Stats should contain hits or misses")
|
||||
}
|
||||
|
||||
// testConcurrentAccess verifies thread safety
|
||||
func testConcurrentAccess(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
goroutines := 10
|
||||
iterations := 20
|
||||
|
||||
// Concurrent writes
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < iterations; j++ {
|
||||
key := fmt.Sprintf("concurrent-key-%d-%d", id, j)
|
||||
value := []byte(fmt.Sprintf("concurrent-value-%d-%d", id, j))
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Read back
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
assert.NoError(t, err)
|
||||
if exists {
|
||||
assert.Equal(t, value, retrieved)
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// testLargeValues verifies handling of large values
|
||||
func testLargeValues(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
key := "large-value-key"
|
||||
value := GenerateLargeValue(1024 * 1024) // 1MB
|
||||
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err, "Should handle large values")
|
||||
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, len(value), len(retrieved), "Large value should be retrieved intact")
|
||||
}
|
||||
|
||||
// testEmptyValues verifies handling of empty values
|
||||
func testEmptyValues(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
key := "empty-value-key"
|
||||
value := []byte{}
|
||||
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err, "Should handle empty values")
|
||||
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Empty value should exist")
|
||||
assert.Equal(t, 0, len(retrieved), "Retrieved value should be empty")
|
||||
}
|
||||
|
||||
// testSpecialCharactersInKeys verifies handling of special characters in keys
|
||||
func testSpecialCharactersInKeys(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
specialKeys := []string{
|
||||
"key:with:colons",
|
||||
"key/with/slashes",
|
||||
"key-with-dashes",
|
||||
"key_with_underscores",
|
||||
"key.with.dots",
|
||||
"key|with|pipes",
|
||||
}
|
||||
|
||||
for _, key := range specialKeys {
|
||||
value := []byte(fmt.Sprintf("value-for-%s", key))
|
||||
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err, "Should handle special character in key: %s", key)
|
||||
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Key with special characters should exist: %s", key)
|
||||
assert.Equal(t, value, retrieved)
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions to setup different backend types
|
||||
// These will be implemented in respective test files
|
||||
|
||||
func setupMemoryBackend(t *testing.T) CacheBackend {
|
||||
t.Helper()
|
||||
// This will be implemented in memory_test.go
|
||||
// For now, return nil to allow compilation
|
||||
t.Skip("MemoryBackend implementation pending")
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupRedisBackend(t *testing.T) CacheBackend {
|
||||
t.Helper()
|
||||
// This will be implemented in redis_test.go
|
||||
// For now, return nil to allow compilation
|
||||
t.Skip("RedisBackend implementation pending")
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupHybridBackend(t *testing.T) CacheBackend {
|
||||
t.Helper()
|
||||
|
||||
primary := newMockBackend()
|
||||
secondary := newMockBackend()
|
||||
|
||||
config := &HybridConfig{
|
||||
Primary: primary,
|
||||
Secondary: secondary,
|
||||
AsyncBufferSize: 100,
|
||||
Logger: NewTestLogger(t),
|
||||
}
|
||||
|
||||
hybrid, err := NewHybridBackend(config)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
hybrid.Close()
|
||||
})
|
||||
|
||||
return hybrid
|
||||
}
|
||||
Vendored
+516
@@ -0,0 +1,516 @@
|
||||
// Package backend provides cache backend implementations for the Traefik OIDC plugin.
|
||||
package backends
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// memoryCacheItem represents an item in the memory cache
|
||||
type memoryCacheItem struct {
|
||||
key string
|
||||
value interface{}
|
||||
expiresAt time.Time
|
||||
createdAt time.Time
|
||||
accessedAt time.Time
|
||||
accessCount int64
|
||||
size int64
|
||||
element *list.Element // for LRU tracking
|
||||
}
|
||||
|
||||
// isExpired checks if the item is expired
|
||||
func (item *memoryCacheItem) isExpired() bool {
|
||||
if item.expiresAt.IsZero() {
|
||||
return false
|
||||
}
|
||||
return time.Now().After(item.expiresAt)
|
||||
}
|
||||
|
||||
// MemoryCacheBackend implements the CacheBackend interface using in-memory storage
|
||||
type MemoryCacheBackend struct {
|
||||
mu sync.RWMutex
|
||||
items map[string]*memoryCacheItem
|
||||
lruList *list.List
|
||||
maxSize int64
|
||||
maxMemory int64
|
||||
currentSize int64
|
||||
currentMemory int64
|
||||
|
||||
// Statistics
|
||||
hits atomic.Int64
|
||||
misses atomic.Int64
|
||||
sets atomic.Int64
|
||||
deletes atomic.Int64
|
||||
evictions atomic.Int64
|
||||
errors atomic.Int64
|
||||
|
||||
// Latency tracking
|
||||
totalGetTime atomic.Int64
|
||||
totalSetTime atomic.Int64
|
||||
getCount atomic.Int64
|
||||
setCount atomic.Int64
|
||||
|
||||
// Status
|
||||
startTime time.Time
|
||||
lastError string
|
||||
lastErrorTime time.Time
|
||||
cleanupTicker *time.Ticker
|
||||
cleanupDone chan bool
|
||||
closed atomic.Bool
|
||||
|
||||
// Configuration
|
||||
cleanupInterval time.Duration
|
||||
evictionPolicy string // "lru", "lfu", "fifo"
|
||||
}
|
||||
|
||||
// NewMemoryCacheBackend creates a new memory cache backend
|
||||
func NewMemoryCacheBackend(maxSize int64, maxMemory int64, cleanupInterval time.Duration) *MemoryCacheBackend {
|
||||
if maxSize <= 0 {
|
||||
maxSize = 10000 // Default to 10k items
|
||||
}
|
||||
if maxMemory <= 0 {
|
||||
maxMemory = 100 * 1024 * 1024 // Default to 100MB
|
||||
}
|
||||
if cleanupInterval <= 0 {
|
||||
cleanupInterval = 5 * time.Minute
|
||||
}
|
||||
|
||||
m := &MemoryCacheBackend{
|
||||
items: make(map[string]*memoryCacheItem),
|
||||
lruList: list.New(),
|
||||
maxSize: maxSize,
|
||||
maxMemory: maxMemory,
|
||||
startTime: time.Now(),
|
||||
cleanupInterval: cleanupInterval,
|
||||
evictionPolicy: "lru",
|
||||
cleanupDone: make(chan bool),
|
||||
}
|
||||
|
||||
// Start cleanup goroutine
|
||||
m.cleanupTicker = time.NewTicker(cleanupInterval)
|
||||
go m.cleanupLoop()
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// cleanupLoop runs periodic cleanup of expired items
|
||||
func (m *MemoryCacheBackend) cleanupLoop() {
|
||||
for {
|
||||
select {
|
||||
case <-m.cleanupTicker.C:
|
||||
m.cleanupExpired()
|
||||
case <-m.cleanupDone:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupExpired removes all expired items from the cache
|
||||
func (m *MemoryCacheBackend) cleanupExpired() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
var keysToDelete []string
|
||||
for key, item := range m.items {
|
||||
if item.isExpired() {
|
||||
keysToDelete = append(keysToDelete, key)
|
||||
}
|
||||
}
|
||||
|
||||
for _, key := range keysToDelete {
|
||||
m.deleteItemLocked(key)
|
||||
}
|
||||
}
|
||||
|
||||
// Get retrieves a value from the cache
|
||||
func (m *MemoryCacheBackend) Get(ctx context.Context, key string) (interface{}, error) {
|
||||
if m.closed.Load() {
|
||||
return nil, ErrBackendUnavailable
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
duration := time.Since(start).Nanoseconds()
|
||||
m.totalGetTime.Add(duration)
|
||||
m.getCount.Add(1)
|
||||
}()
|
||||
|
||||
m.mu.RLock()
|
||||
item, exists := m.items[key]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
m.misses.Add(1)
|
||||
return nil, ErrCacheMiss
|
||||
}
|
||||
|
||||
if item.isExpired() {
|
||||
m.mu.Lock()
|
||||
m.deleteItemLocked(key)
|
||||
m.mu.Unlock()
|
||||
m.misses.Add(1)
|
||||
return nil, ErrCacheMiss
|
||||
}
|
||||
|
||||
// Update access time and count
|
||||
m.mu.Lock()
|
||||
item.accessedAt = time.Now()
|
||||
item.accessCount++
|
||||
// Move to front of LRU list
|
||||
if m.evictionPolicy == "lru" && item.element != nil {
|
||||
m.lruList.MoveToFront(item.element)
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
m.hits.Add(1)
|
||||
return item.value, nil
|
||||
}
|
||||
|
||||
// Set stores a value in the cache with optional TTL
|
||||
func (m *MemoryCacheBackend) Set(ctx context.Context, key string, value interface{}, ttl time.Duration) error {
|
||||
if m.closed.Load() {
|
||||
return ErrBackendUnavailable
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
duration := time.Since(start).Nanoseconds()
|
||||
m.totalSetTime.Add(duration)
|
||||
m.setCount.Add(1)
|
||||
}()
|
||||
|
||||
// Calculate item size (simplified estimation)
|
||||
itemSize := int64(len(key)) + estimateValueSize(value)
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// Check if we need to evict items
|
||||
if m.currentSize >= m.maxSize || m.currentMemory+itemSize > m.maxMemory {
|
||||
m.evictLocked()
|
||||
}
|
||||
|
||||
// Check if key exists
|
||||
if oldItem, exists := m.items[key]; exists {
|
||||
m.currentMemory -= oldItem.size
|
||||
if oldItem.element != nil {
|
||||
m.lruList.Remove(oldItem.element)
|
||||
}
|
||||
} else {
|
||||
m.currentSize++
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
var expiresAt time.Time
|
||||
if ttl > 0 {
|
||||
expiresAt = now.Add(ttl)
|
||||
}
|
||||
|
||||
item := &memoryCacheItem{
|
||||
key: key,
|
||||
value: value,
|
||||
expiresAt: expiresAt,
|
||||
createdAt: now,
|
||||
accessedAt: now,
|
||||
accessCount: 0,
|
||||
size: itemSize,
|
||||
}
|
||||
|
||||
// Add to LRU list
|
||||
if m.evictionPolicy == "lru" {
|
||||
item.element = m.lruList.PushFront(item)
|
||||
}
|
||||
|
||||
m.items[key] = item
|
||||
m.currentMemory += itemSize
|
||||
m.sets.Add(1)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete removes a key from the cache
|
||||
func (m *MemoryCacheBackend) Delete(ctx context.Context, key string) error {
|
||||
if m.closed.Load() {
|
||||
return ErrBackendUnavailable
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if _, exists := m.items[key]; !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
m.deleteItemLocked(key)
|
||||
m.deletes.Add(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
// deleteItemLocked deletes an item without acquiring the lock (must be called with lock held)
|
||||
func (m *MemoryCacheBackend) deleteItemLocked(key string) {
|
||||
if item, exists := m.items[key]; exists {
|
||||
m.currentMemory -= item.size
|
||||
m.currentSize--
|
||||
if item.element != nil {
|
||||
m.lruList.Remove(item.element)
|
||||
}
|
||||
delete(m.items, key)
|
||||
}
|
||||
}
|
||||
|
||||
// evictLocked evicts items based on the eviction policy (must be called with lock held)
|
||||
func (m *MemoryCacheBackend) evictLocked() {
|
||||
if m.evictionPolicy == "lru" && m.lruList.Len() > 0 {
|
||||
// Evict least recently used item
|
||||
element := m.lruList.Back()
|
||||
if element != nil {
|
||||
item := element.Value.(*memoryCacheItem)
|
||||
m.deleteItemLocked(item.key)
|
||||
m.evictions.Add(1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Exists checks if a key exists in the cache
|
||||
func (m *MemoryCacheBackend) Exists(ctx context.Context, key string) (bool, error) {
|
||||
if m.closed.Load() {
|
||||
return false, ErrBackendUnavailable
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
item, exists := m.items[key]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return !item.isExpired(), nil
|
||||
}
|
||||
|
||||
// Clear removes all items from the cache
|
||||
func (m *MemoryCacheBackend) Clear(ctx context.Context) error {
|
||||
if m.closed.Load() {
|
||||
return ErrBackendUnavailable
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.items = make(map[string]*memoryCacheItem)
|
||||
m.lruList = list.New()
|
||||
m.currentSize = 0
|
||||
m.currentMemory = 0
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Keys returns all keys matching the pattern (use "*" for all keys)
|
||||
func (m *MemoryCacheBackend) Keys(ctx context.Context, pattern string) ([]string, error) {
|
||||
if m.closed.Load() {
|
||||
return nil, ErrBackendUnavailable
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
var keys []string
|
||||
for key, item := range m.items {
|
||||
if !item.isExpired() && matchPattern(pattern, key) {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
}
|
||||
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
// Size returns the number of items in the cache
|
||||
func (m *MemoryCacheBackend) Size(ctx context.Context) (int64, error) {
|
||||
if m.closed.Load() {
|
||||
return 0, ErrBackendUnavailable
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
return m.currentSize, nil
|
||||
}
|
||||
|
||||
// TTL returns the remaining time-to-live for a key
|
||||
func (m *MemoryCacheBackend) TTL(ctx context.Context, key string) (time.Duration, error) {
|
||||
if m.closed.Load() {
|
||||
return 0, ErrBackendUnavailable
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
item, exists := m.items[key]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if !exists || item.isExpired() {
|
||||
return 0, ErrCacheMiss
|
||||
}
|
||||
|
||||
if item.expiresAt.IsZero() {
|
||||
return 0, nil // No expiration
|
||||
}
|
||||
|
||||
remaining := time.Until(item.expiresAt)
|
||||
if remaining < 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
return remaining, nil
|
||||
}
|
||||
|
||||
// Expire updates the TTL for an existing key
|
||||
func (m *MemoryCacheBackend) Expire(ctx context.Context, key string, ttl time.Duration) error {
|
||||
if m.closed.Load() {
|
||||
return ErrBackendUnavailable
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
item, exists := m.items[key]
|
||||
if !exists || item.isExpired() {
|
||||
return ErrCacheMiss
|
||||
}
|
||||
|
||||
if ttl > 0 {
|
||||
item.expiresAt = time.Now().Add(ttl)
|
||||
} else {
|
||||
item.expiresAt = time.Time{} // Remove expiration
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetStats returns statistics about the cache backend
|
||||
func (m *MemoryCacheBackend) GetStats(ctx context.Context) (*BackendStats, error) {
|
||||
if m.closed.Load() {
|
||||
return nil, ErrBackendUnavailable
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
lastError := m.lastError
|
||||
lastErrorTime := m.lastErrorTime
|
||||
m.mu.RUnlock()
|
||||
|
||||
avgGetLatency := time.Duration(0)
|
||||
if getCount := m.getCount.Load(); getCount > 0 {
|
||||
avgGetLatency = time.Duration(m.totalGetTime.Load() / getCount)
|
||||
}
|
||||
|
||||
avgSetLatency := time.Duration(0)
|
||||
if setCount := m.setCount.Load(); setCount > 0 {
|
||||
avgSetLatency = time.Duration(m.totalSetTime.Load() / setCount)
|
||||
}
|
||||
|
||||
return &BackendStats{
|
||||
Type: TypeMemory,
|
||||
Hits: m.hits.Load(),
|
||||
Misses: m.misses.Load(),
|
||||
Sets: m.sets.Load(),
|
||||
Deletes: m.deletes.Load(),
|
||||
Errors: m.errors.Load(),
|
||||
Evictions: m.evictions.Load(),
|
||||
CurrentSize: m.currentSize,
|
||||
MaxSize: m.maxSize,
|
||||
MemoryUsage: m.currentMemory,
|
||||
AverageGetLatency: avgGetLatency,
|
||||
AverageSetLatency: avgSetLatency,
|
||||
LastError: lastError,
|
||||
LastErrorTime: lastErrorTime,
|
||||
Uptime: time.Since(m.startTime),
|
||||
StartTime: m.startTime,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Ping checks if the backend is healthy
|
||||
func (m *MemoryCacheBackend) Ping(ctx context.Context) error {
|
||||
if m.closed.Load() {
|
||||
return ErrBackendUnavailable
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the backend and releases resources
|
||||
func (m *MemoryCacheBackend) Close() error {
|
||||
if m.closed.Swap(true) {
|
||||
return nil // Already closed
|
||||
}
|
||||
|
||||
m.cleanupTicker.Stop()
|
||||
close(m.cleanupDone)
|
||||
|
||||
m.mu.Lock()
|
||||
m.items = nil
|
||||
m.lruList = nil
|
||||
m.mu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsHealthy returns true if the backend is healthy
|
||||
func (m *MemoryCacheBackend) IsHealthy() bool {
|
||||
return !m.closed.Load()
|
||||
}
|
||||
|
||||
// Type returns the backend type
|
||||
func (m *MemoryCacheBackend) Type() BackendType {
|
||||
return TypeMemory
|
||||
}
|
||||
|
||||
// Capabilities returns the backend capabilities
|
||||
func (m *MemoryCacheBackend) Capabilities() *BackendCapabilities {
|
||||
return &BackendCapabilities{
|
||||
Distributed: false,
|
||||
Persistent: false,
|
||||
Eviction: true,
|
||||
TTL: true,
|
||||
MaxKeySize: 1024, // 1KB
|
||||
MaxValueSize: 10485760, // 10MB
|
||||
MaxKeys: m.maxSize,
|
||||
SupportsExpire: true,
|
||||
SupportsMultiGet: true,
|
||||
SupportsTransaction: false,
|
||||
SupportsCompression: false,
|
||||
RequiresSerialize: false,
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
// estimateValueSize estimates the size of a value in bytes
|
||||
func estimateValueSize(value interface{}) int64 {
|
||||
// This is a simplified estimation
|
||||
// In production, you might want to use a more accurate method
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return int64(len(v))
|
||||
case []byte:
|
||||
return int64(len(v))
|
||||
case int, int32, int64, uint, uint32, uint64:
|
||||
return 8
|
||||
case float32, float64:
|
||||
return 8
|
||||
case bool:
|
||||
return 1
|
||||
default:
|
||||
// For complex types, use a default estimate
|
||||
return 256
|
||||
}
|
||||
}
|
||||
|
||||
// matchPattern checks if a key matches a pattern (simplified glob matching)
|
||||
func matchPattern(pattern, key string) bool {
|
||||
if pattern == "*" {
|
||||
return true
|
||||
}
|
||||
// Simplified pattern matching - in production, use a proper glob library
|
||||
return key == pattern || (len(pattern) > 0 && pattern[0] == '*' &&
|
||||
len(key) >= len(pattern)-1 && key[len(key)-len(pattern)+1:] == pattern[1:])
|
||||
}
|
||||
+182
@@ -0,0 +1,182 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
)
|
||||
|
||||
// setupBenchmarkRedis creates a miniredis instance for benchmarking
|
||||
func setupBenchmarkRedis(b *testing.B) string {
|
||||
b.Helper()
|
||||
mr, err := miniredis.Run()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
b.Cleanup(func() {
|
||||
mr.Close()
|
||||
})
|
||||
return mr.Addr()
|
||||
}
|
||||
|
||||
// BenchmarkRedisOperations_WithPooling benchmarks memory allocations with object pooling
|
||||
func BenchmarkRedisOperations_WithPooling(b *testing.B) {
|
||||
addr := setupBenchmarkRedis(b)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: addr,
|
||||
MaxConnections: 10,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer pool.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
conn, err := pool.Get(ctx)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
// Perform various operations
|
||||
_, _ = conn.Do("SET", "bench-key", "bench-value")
|
||||
_, _ = conn.Do("GET", "bench-key")
|
||||
_, _ = conn.Do("EXISTS", "bench-key")
|
||||
_, _ = conn.Do("DEL", "bench-key")
|
||||
|
||||
pool.Put(conn)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkRedisBackend_SetGet benchmarks the full backend with pooling
|
||||
func BenchmarkRedisBackend_SetGet(b *testing.B) {
|
||||
addr := setupBenchmarkRedis(b)
|
||||
|
||||
backend, err := NewRedisBackend(&Config{
|
||||
RedisAddr: addr,
|
||||
PoolSize: 10,
|
||||
})
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
testData := []byte("benchmark test data with some content")
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Set operation
|
||||
err := backend.Set(ctx, "bench-key", testData, 0)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
// Get operation
|
||||
_, _, _, err = backend.Get(ctx, "bench-key")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkRedisBackend_ConcurrentAccess benchmarks concurrent operations with pooling
|
||||
func BenchmarkRedisBackend_ConcurrentAccess(b *testing.B) {
|
||||
addr := setupBenchmarkRedis(b)
|
||||
|
||||
backend, err := NewRedisBackend(&Config{
|
||||
RedisAddr: addr,
|
||||
PoolSize: 10,
|
||||
})
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
testData := []byte("concurrent benchmark data")
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
_ = backend.Set(ctx, "concurrent-key", testData, 0)
|
||||
_, _, _, _ = backend.Get(ctx, "concurrent-key")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkRESPProtocol_WriteRead benchmarks RESP protocol encoding/decoding
|
||||
func BenchmarkRESPProtocol_WriteRead(b *testing.B) {
|
||||
addr := setupBenchmarkRedis(b)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: addr,
|
||||
MaxConnections: 10,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer pool.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
conn, err := pool.Get(ctx)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer pool.Put(conn)
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
// This tests the pooling of RESPReader/RESPWriter
|
||||
_, _ = conn.Do("PING")
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkConnectionPool_GetPut benchmarks connection pool operations
|
||||
func BenchmarkConnectionPool_GetPut(b *testing.B) {
|
||||
addr := setupBenchmarkRedis(b)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: addr,
|
||||
MaxConnections: 10,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer pool.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
conn, err := pool.Get(ctx)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
pool.Put(conn)
|
||||
}
|
||||
}
|
||||
+783
@@ -0,0 +1,783 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestMemoryBackend_BasicOperations tests basic CRUD operations
|
||||
func TestMemoryBackend_BasicOperations(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("SetAndGet", func(t *testing.T) {
|
||||
key := "test-key"
|
||||
value := []byte("test-value")
|
||||
ttl := 1 * time.Minute
|
||||
|
||||
err := backend.Set(ctx, key, value, ttl)
|
||||
require.NoError(t, err)
|
||||
|
||||
retrieved, remainingTTL, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, value, retrieved)
|
||||
assert.Greater(t, remainingTTL, 50*time.Second)
|
||||
assert.LessOrEqual(t, remainingTTL, ttl)
|
||||
})
|
||||
|
||||
t.Run("GetNonExistent", func(t *testing.T) {
|
||||
_, _, exists, err := backend.Get(ctx, "non-existent")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("Delete", func(t *testing.T) {
|
||||
key := "delete-key"
|
||||
value := []byte("delete-value")
|
||||
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
deleted, err := backend.Delete(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, deleted)
|
||||
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("DeleteNonExistent", func(t *testing.T) {
|
||||
deleted, err := backend.Delete(ctx, "non-existent-delete")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, deleted)
|
||||
})
|
||||
|
||||
t.Run("Exists", func(t *testing.T) {
|
||||
key := "exists-key"
|
||||
value := []byte("exists-value")
|
||||
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
|
||||
err = backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
exists, err = backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
})
|
||||
|
||||
t.Run("Clear", func(t *testing.T) {
|
||||
// Add multiple items
|
||||
for i := 0; i < 10; i++ {
|
||||
key := fmt.Sprintf("clear-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("clear-value-%d", i))
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
err := backend.Clear(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
stats := backend.GetStats()
|
||||
size := stats["size"].(int64)
|
||||
assert.Equal(t, int64(0), size)
|
||||
})
|
||||
}
|
||||
|
||||
// TestMemoryBackend_TTLExpiration tests TTL and expiration
|
||||
func TestMemoryBackend_TTLExpiration(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := DefaultConfig()
|
||||
config.CleanupInterval = 50 * time.Millisecond
|
||||
backend, err := NewMemoryBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("ShortTTL", func(t *testing.T) {
|
||||
key := "short-ttl-key"
|
||||
value := []byte("short-ttl-value")
|
||||
shortTTL := 100 * time.Millisecond
|
||||
|
||||
err := backend.Set(ctx, key, value, shortTTL)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify exists immediately
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
// Wait for expiration
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// Should be expired
|
||||
_, _, exists, err = backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("TTLDecrement", func(t *testing.T) {
|
||||
key := "ttl-decrement-key"
|
||||
value := []byte("ttl-decrement-value")
|
||||
ttl := 2 * time.Second
|
||||
|
||||
err := backend.Set(ctx, key, value, ttl)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check TTL immediately
|
||||
_, ttl1, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
// Wait a bit
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// Check TTL again - should be less
|
||||
_, ttl2, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Less(t, ttl2, ttl1, "TTL should decrease over time")
|
||||
})
|
||||
|
||||
t.Run("CleanupExpiredItems", func(t *testing.T) {
|
||||
// Set multiple items with short TTL
|
||||
for i := 0; i < 5; i++ {
|
||||
key := fmt.Sprintf("cleanup-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("cleanup-value-%d", i))
|
||||
err := backend.Set(ctx, key, value, 50*time.Millisecond)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Wait for cleanup to run
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// All items should be cleaned up
|
||||
for i := 0; i < 5; i++ {
|
||||
key := fmt.Sprintf("cleanup-key-%d", i)
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "Expired items should be cleaned up")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestMemoryBackend_LRUEviction tests LRU eviction
|
||||
func TestMemoryBackend_LRUEviction(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := DefaultConfig()
|
||||
config.MaxSize = 5
|
||||
backend, err := NewMemoryBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Fill cache to max size
|
||||
for i := 0; i < 5; i++ {
|
||||
key := fmt.Sprintf("lru-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("lru-value-%d", i))
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Access first item to make it most recently used
|
||||
_, _, exists, err := backend.Get(ctx, "lru-key-0")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
// Add a new item - should evict lru-key-1 (least recently used)
|
||||
err = backend.Set(ctx, "lru-key-new", []byte("new-value"), 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// lru-key-0 should still exist (was accessed recently)
|
||||
exists, err = backend.Exists(ctx, "lru-key-0")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Recently accessed item should not be evicted")
|
||||
|
||||
// lru-key-1 should be evicted
|
||||
exists, err = backend.Exists(ctx, "lru-key-1")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "Least recently used item should be evicted")
|
||||
|
||||
// Check eviction count
|
||||
stats := backend.GetStats()
|
||||
evictions := stats["evictions"].(int64)
|
||||
assert.Greater(t, evictions, int64(0), "Should have evictions")
|
||||
}
|
||||
|
||||
// TestMemoryBackend_MemoryLimit tests memory-based eviction
|
||||
func TestMemoryBackend_MemoryLimit(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := DefaultConfig()
|
||||
config.MaxSize = 100
|
||||
config.MaxMemoryBytes = 1024 // 1KB limit
|
||||
backend, err := NewMemoryBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Add items until memory limit is reached
|
||||
largeValue := make([]byte, 512) // 512 bytes each
|
||||
for i := 0; i < 5; i++ {
|
||||
key := fmt.Sprintf("mem-key-%d", i)
|
||||
err := backend.Set(ctx, key, largeValue, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
stats := backend.GetStats()
|
||||
memory := stats["memory"].(int64)
|
||||
assert.LessOrEqual(t, memory, config.MaxMemoryBytes, "Memory should not exceed limit")
|
||||
|
||||
evictions := stats["evictions"].(int64)
|
||||
assert.Greater(t, evictions, int64(0), "Should have memory-based evictions")
|
||||
}
|
||||
|
||||
// TestMemoryBackend_ConcurrentAccess tests thread safety
|
||||
func TestMemoryBackend_ConcurrentAccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
var wg sync.WaitGroup
|
||||
goroutines := 20
|
||||
iterations := 50
|
||||
|
||||
// Concurrent writes
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < iterations; j++ {
|
||||
key := fmt.Sprintf("concurrent-key-%d-%d", id, j)
|
||||
value := []byte(fmt.Sprintf("concurrent-value-%d-%d", id, j))
|
||||
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Read back
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
assert.NoError(t, err)
|
||||
if exists {
|
||||
assert.Equal(t, value, retrieved)
|
||||
}
|
||||
|
||||
// Random deletes
|
||||
if j%5 == 0 {
|
||||
backend.Delete(ctx, key)
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Verify stats are consistent
|
||||
stats := backend.GetStats()
|
||||
hits := stats["hits"].(int64)
|
||||
misses := stats["misses"].(int64)
|
||||
assert.Greater(t, hits+misses, int64(0), "Should have cache operations")
|
||||
}
|
||||
|
||||
// TestMemoryBackend_UpdateExisting tests updating existing keys
|
||||
func TestMemoryBackend_UpdateExisting(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
key := "update-key"
|
||||
value1 := []byte("original-value")
|
||||
value2 := []byte("updated-value")
|
||||
|
||||
// Set original
|
||||
err = backend.Set(ctx, key, value1, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update
|
||||
err = backend.Set(ctx, key, value2, 2*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify updated
|
||||
retrieved, ttl, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, value2, retrieved)
|
||||
assert.Greater(t, ttl, 1*time.Minute, "TTL should be updated")
|
||||
|
||||
// Size should not increase (same key)
|
||||
stats := backend.GetStats()
|
||||
size := stats["size"].(int64)
|
||||
assert.Equal(t, int64(1), size, "Size should be 1 for one key")
|
||||
}
|
||||
|
||||
// TestMemoryBackend_Stats tests statistics tracking
|
||||
func TestMemoryBackend_Stats(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Initial stats
|
||||
stats := backend.GetStats()
|
||||
assert.Equal(t, int64(0), stats["hits"].(int64))
|
||||
assert.Equal(t, int64(0), stats["misses"].(int64))
|
||||
|
||||
// Add items and track hits/misses
|
||||
backend.Set(ctx, "key1", []byte("value1"), 1*time.Minute)
|
||||
backend.Set(ctx, "key2", []byte("value2"), 1*time.Minute)
|
||||
|
||||
// Hit
|
||||
backend.Get(ctx, "key1")
|
||||
// Miss
|
||||
backend.Get(ctx, "non-existent")
|
||||
|
||||
stats = backend.GetStats()
|
||||
assert.Equal(t, int64(1), stats["hits"].(int64))
|
||||
assert.Equal(t, int64(1), stats["misses"].(int64))
|
||||
|
||||
hitRate := stats["hit_rate"].(float64)
|
||||
assert.InDelta(t, 0.5, hitRate, 0.01)
|
||||
}
|
||||
|
||||
// TestMemoryBackend_EmptyValues tests handling of empty values
|
||||
func TestMemoryBackend_EmptyValues(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
key := "empty-key"
|
||||
emptyValue := []byte{}
|
||||
|
||||
err = backend.Set(ctx, key, emptyValue, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, 0, len(retrieved))
|
||||
}
|
||||
|
||||
// TestMemoryBackend_LargeValues tests handling of large values
|
||||
func TestMemoryBackend_LargeValues(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := DefaultConfig()
|
||||
config.MaxMemoryBytes = 10 * 1024 * 1024 // 10MB
|
||||
backend, err := NewMemoryBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
key := "large-key"
|
||||
largeValue := make([]byte, 1024*1024) // 1MB
|
||||
|
||||
err = backend.Set(ctx, key, largeValue, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, len(largeValue), len(retrieved))
|
||||
}
|
||||
|
||||
// TestMemoryBackend_Close tests proper cleanup on close
|
||||
func TestMemoryBackend_Close(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Add some items
|
||||
for i := 0; i < 10; i++ {
|
||||
key := fmt.Sprintf("close-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("close-value-%d", i))
|
||||
backend.Set(ctx, key, value, 1*time.Minute)
|
||||
}
|
||||
|
||||
// Close
|
||||
err = backend.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Operations after close should fail
|
||||
err = backend.Set(ctx, "after-close", []byte("value"), 1*time.Minute)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, ErrBackendClosed, err)
|
||||
|
||||
_, _, _, err = backend.Get(ctx, "close-key-0")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, ErrBackendClosed, err)
|
||||
|
||||
// Closing again should be safe
|
||||
err = backend.Close()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestMemoryBackend_Ping tests ping operation
|
||||
func TestMemoryBackend_Ping(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
err = backend.Ping(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Close and ping should fail
|
||||
backend.Close()
|
||||
err = backend.Ping(ctx)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// TestMemoryBackend_ValueIsolation tests that returned values are isolated
|
||||
func TestMemoryBackend_ValueIsolation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
key := "isolation-key"
|
||||
originalValue := []byte("original-value")
|
||||
|
||||
err = backend.Set(ctx, key, originalValue, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get value and modify it
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
// Modify retrieved value
|
||||
if len(retrieved) > 0 {
|
||||
retrieved[0] = 'X'
|
||||
}
|
||||
|
||||
// Get again - should be unchanged
|
||||
retrieved2, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, originalValue, retrieved2, "Original value should not be modified")
|
||||
}
|
||||
|
||||
// TestMemoryBackend_Keys tests the Keys method with pattern matching
|
||||
func TestMemoryBackend_Keys(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Add test data
|
||||
testKeys := []string{"user:1", "user:2", "session:abc", "session:def", "token:xyz"}
|
||||
for _, key := range testKeys {
|
||||
err := backend.Set(ctx, key, []byte("value"), 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
t.Run("AllKeys", func(t *testing.T) {
|
||||
keys, err := backend.Keys(ctx, "*")
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, keys, 5)
|
||||
})
|
||||
|
||||
t.Run("SpecificPattern", func(t *testing.T) {
|
||||
// Simple exact match
|
||||
keys, err := backend.Keys(ctx, "user:1")
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, keys, 1)
|
||||
assert.Contains(t, keys, "user:1")
|
||||
})
|
||||
|
||||
t.Run("ExcludesExpired", func(t *testing.T) {
|
||||
// Add an expired key
|
||||
expiredKey := "expired:key"
|
||||
err := backend.Set(ctx, expiredKey, []byte("value"), 1*time.Millisecond)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wait for expiration
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
keys, err := backend.Keys(ctx, "*")
|
||||
require.NoError(t, err)
|
||||
assert.NotContains(t, keys, expiredKey, "Expired keys should not be returned")
|
||||
})
|
||||
|
||||
t.Run("AfterClose", func(t *testing.T) {
|
||||
closedBackend, _ := NewMemoryBackend(DefaultConfig())
|
||||
closedBackend.Close()
|
||||
|
||||
_, err := closedBackend.Keys(ctx, "*")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, ErrBackendUnavailable, err)
|
||||
})
|
||||
}
|
||||
|
||||
// TestMemoryBackend_Size tests the Size method
|
||||
func TestMemoryBackend_Size(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Initially empty
|
||||
size, err := backend.Size(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(0), size)
|
||||
|
||||
// Add items
|
||||
for i := 0; i < 5; i++ {
|
||||
key := fmt.Sprintf("key-%d", i)
|
||||
err := backend.Set(ctx, key, []byte("value"), 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
size, err = backend.Size(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(5), size)
|
||||
|
||||
// Delete one
|
||||
backend.Delete(ctx, "key-0")
|
||||
|
||||
size, err = backend.Size(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(4), size)
|
||||
|
||||
// After close
|
||||
backend.Close()
|
||||
_, err = backend.Size(ctx)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, ErrBackendUnavailable, err)
|
||||
}
|
||||
|
||||
// TestMemoryBackend_TTL tests the TTL method
|
||||
func TestMemoryBackend_TTL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("ExistingKey", func(t *testing.T) {
|
||||
key := "ttl-key"
|
||||
ttl := 1 * time.Minute
|
||||
|
||||
err := backend.Set(ctx, key, []byte("value"), ttl)
|
||||
require.NoError(t, err)
|
||||
|
||||
remaining, err := backend.TTL(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.Greater(t, remaining, 50*time.Second)
|
||||
assert.LessOrEqual(t, remaining, ttl)
|
||||
})
|
||||
|
||||
t.Run("NonExistentKey", func(t *testing.T) {
|
||||
_, err := backend.TTL(ctx, "non-existent")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, ErrCacheMiss, err)
|
||||
})
|
||||
|
||||
t.Run("NoExpiration", func(t *testing.T) {
|
||||
key := "no-expiry"
|
||||
// TTL of 0 typically means no expiration
|
||||
err := backend.Set(ctx, key, []byte("value"), 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
remaining, err := backend.TTL(ctx, key)
|
||||
require.NoError(t, err)
|
||||
// No expiration returns 0
|
||||
assert.Equal(t, time.Duration(0), remaining)
|
||||
})
|
||||
|
||||
t.Run("AfterClose", func(t *testing.T) {
|
||||
closedBackend, _ := NewMemoryBackend(DefaultConfig())
|
||||
closedBackend.Close()
|
||||
|
||||
_, err := closedBackend.TTL(ctx, "key")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, ErrBackendUnavailable, err)
|
||||
})
|
||||
}
|
||||
|
||||
// TestMemoryBackend_Expire tests the Expire method
|
||||
func TestMemoryBackend_Expire(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("UpdateTTL", func(t *testing.T) {
|
||||
key := "expire-key"
|
||||
err := backend.Set(ctx, key, []byte("value"), 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update to shorter TTL
|
||||
err = backend.Expire(ctx, key, 5*time.Second)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check new TTL
|
||||
remaining, err := backend.TTL(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.LessOrEqual(t, remaining, 5*time.Second)
|
||||
})
|
||||
|
||||
t.Run("NonExistentKey", func(t *testing.T) {
|
||||
err := backend.Expire(ctx, "non-existent", 1*time.Minute)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, ErrCacheMiss, err)
|
||||
})
|
||||
|
||||
t.Run("RemoveExpiration", func(t *testing.T) {
|
||||
key := "no-expire-key"
|
||||
err := backend.Set(ctx, key, []byte("value"), 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set TTL to 0 to remove expiration
|
||||
err = backend.Expire(ctx, key, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
// TTL should now be 0
|
||||
remaining, err := backend.TTL(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, time.Duration(0), remaining)
|
||||
})
|
||||
|
||||
t.Run("AfterClose", func(t *testing.T) {
|
||||
closedBackend, _ := NewMemoryBackend(DefaultConfig())
|
||||
closedBackend.Close()
|
||||
|
||||
err := closedBackend.Expire(ctx, "key", 1*time.Minute)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, ErrBackendUnavailable, err)
|
||||
})
|
||||
}
|
||||
|
||||
// TestMemoryBackend_IsHealthy tests the IsHealthy method
|
||||
func TestMemoryBackend_IsHealthy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should be healthy when open
|
||||
assert.True(t, backend.IsHealthy())
|
||||
|
||||
// Should be unhealthy after close
|
||||
backend.Close()
|
||||
assert.False(t, backend.IsHealthy())
|
||||
}
|
||||
|
||||
// TestMemoryBackend_Type tests the Type method
|
||||
func TestMemoryBackend_Type(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
backendType := backend.Type()
|
||||
assert.Equal(t, TypeMemory, backendType)
|
||||
}
|
||||
|
||||
// TestMemoryBackend_Capabilities tests the Capabilities method
|
||||
func TestMemoryBackend_Capabilities(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
caps := backend.Capabilities()
|
||||
require.NotNil(t, caps)
|
||||
|
||||
// Memory backend should not be distributed or persistent
|
||||
assert.False(t, caps.Distributed)
|
||||
assert.False(t, caps.Persistent)
|
||||
|
||||
// Should support eviction and TTL
|
||||
assert.True(t, caps.Eviction)
|
||||
assert.True(t, caps.TTL)
|
||||
assert.True(t, caps.SupportsExpire)
|
||||
assert.True(t, caps.SupportsMultiGet)
|
||||
|
||||
// Check limits
|
||||
assert.Greater(t, caps.MaxKeySize, int64(0))
|
||||
assert.Greater(t, caps.MaxValueSize, int64(0))
|
||||
}
|
||||
|
||||
// TestMatchPattern tests the matchPattern helper function
|
||||
func TestMatchPattern(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
pattern string
|
||||
key string
|
||||
matches bool
|
||||
}{
|
||||
{"*", "any-key", true},
|
||||
{"*", "another", true},
|
||||
{"user:1", "user:1", true},
|
||||
{"user:1", "user:2", false},
|
||||
{"*:suffix", "prefix:suffix", true},
|
||||
{"*suffix", "prefix-suffix", true},
|
||||
{"*abc", "xyzabc", true},
|
||||
{"*abc", "xyz", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(fmt.Sprintf("%s-%s", tt.pattern, tt.key), func(t *testing.T) {
|
||||
result := matchPattern(tt.pattern, tt.key)
|
||||
assert.Equal(t, tt.matches, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
+153
@@ -0,0 +1,153 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// MemoryBackend wraps MemoryCacheBackend to implement the CacheBackend interface
|
||||
type MemoryBackend struct {
|
||||
*MemoryCacheBackend
|
||||
}
|
||||
|
||||
// NewMemoryBackend creates a new memory backend from a config
|
||||
func NewMemoryBackend(config *Config) (*MemoryBackend, error) {
|
||||
maxSize := int64(config.MaxSize)
|
||||
if maxSize <= 0 {
|
||||
maxSize = 1000
|
||||
}
|
||||
|
||||
cacheBackend := NewMemoryCacheBackend(maxSize, config.MaxMemoryBytes, config.CleanupInterval)
|
||||
return &MemoryBackend{
|
||||
MemoryCacheBackend: cacheBackend,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Set stores a value in the cache with the specified TTL
|
||||
func (m *MemoryBackend) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
|
||||
err := m.MemoryCacheBackend.Set(ctx, key, value, ttl)
|
||||
if err == ErrBackendUnavailable {
|
||||
return ErrBackendClosed
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Get retrieves a value from the cache
|
||||
func (m *MemoryBackend) Get(ctx context.Context, key string) ([]byte, time.Duration, bool, error) {
|
||||
val, err := m.MemoryCacheBackend.Get(ctx, key)
|
||||
if err != nil {
|
||||
if err == ErrCacheMiss {
|
||||
return nil, 0, false, nil
|
||||
}
|
||||
if err == ErrBackendUnavailable {
|
||||
return nil, 0, false, ErrBackendClosed
|
||||
}
|
||||
return nil, 0, false, err
|
||||
}
|
||||
|
||||
// Get the item directly to check TTL
|
||||
m.MemoryCacheBackend.mu.RLock()
|
||||
item, exists := m.MemoryCacheBackend.items[key]
|
||||
m.MemoryCacheBackend.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return nil, 0, false, nil
|
||||
}
|
||||
|
||||
var ttl time.Duration
|
||||
if !item.expiresAt.IsZero() {
|
||||
ttl = time.Until(item.expiresAt)
|
||||
if ttl < 0 {
|
||||
ttl = 0
|
||||
}
|
||||
}
|
||||
|
||||
// Convert interface{} to []byte
|
||||
var valueBytes []byte
|
||||
if val != nil {
|
||||
if bytes, ok := val.([]byte); ok {
|
||||
valueBytes = bytes
|
||||
} else {
|
||||
// If it's not already []byte, we might need to handle other types
|
||||
// For now, we'll just return an error
|
||||
return nil, 0, false, ErrInvalidValue
|
||||
}
|
||||
}
|
||||
|
||||
return valueBytes, ttl, true, nil
|
||||
}
|
||||
|
||||
// Delete removes a key from the cache
|
||||
func (m *MemoryBackend) Delete(ctx context.Context, key string) (bool, error) {
|
||||
// Check if key exists first
|
||||
exists, err := m.MemoryCacheBackend.Exists(ctx, key)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if !exists {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
err = m.MemoryCacheBackend.Delete(ctx, key)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Exists checks if a key exists in the cache
|
||||
func (m *MemoryBackend) Exists(ctx context.Context, key string) (bool, error) {
|
||||
return m.MemoryCacheBackend.Exists(ctx, key)
|
||||
}
|
||||
|
||||
// Clear removes all keys from the cache
|
||||
func (m *MemoryBackend) Clear(ctx context.Context) error {
|
||||
return m.MemoryCacheBackend.Clear(ctx)
|
||||
}
|
||||
|
||||
// GetStats returns cache statistics
|
||||
func (m *MemoryBackend) GetStats() map[string]interface{} {
|
||||
stats, err := m.MemoryCacheBackend.GetStats(context.Background())
|
||||
if err != nil {
|
||||
return map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
}
|
||||
}
|
||||
|
||||
// Convert BackendStats to map
|
||||
hitRate := float64(0)
|
||||
total := stats.Hits + stats.Misses
|
||||
if total > 0 {
|
||||
hitRate = float64(stats.Hits) / float64(total)
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"type": stats.Type,
|
||||
"hits": stats.Hits,
|
||||
"misses": stats.Misses,
|
||||
"sets": stats.Sets,
|
||||
"deletes": stats.Deletes,
|
||||
"errors": stats.Errors,
|
||||
"evictions": stats.Evictions,
|
||||
"size": stats.CurrentSize,
|
||||
"max_size": stats.MaxSize,
|
||||
"memory": stats.MemoryUsage,
|
||||
"hit_rate": hitRate,
|
||||
"uptime": stats.Uptime,
|
||||
"start_time": stats.StartTime,
|
||||
}
|
||||
}
|
||||
|
||||
// Close shuts down the cache backend and releases resources
|
||||
func (m *MemoryBackend) Close() error {
|
||||
return m.MemoryCacheBackend.Close()
|
||||
}
|
||||
|
||||
// Ping checks if the backend is healthy and responsive
|
||||
func (m *MemoryBackend) Ping(ctx context.Context) error {
|
||||
return m.MemoryCacheBackend.Ping(ctx)
|
||||
}
|
||||
|
||||
// Ensure MemoryBackend implements CacheBackend
|
||||
var _ CacheBackend = (*MemoryBackend)(nil)
|
||||
Vendored
+455
@@ -0,0 +1,455 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Pure-Go Redis client implementation
|
||||
// Compatible with Yaegi interpreter (no unsafe package)
|
||||
// Implements RESP protocol for basic Redis operations
|
||||
|
||||
var (
|
||||
ErrPoolExhausted = errors.New("connection pool exhausted")
|
||||
)
|
||||
|
||||
// RedisBackend implements a Redis-based cache backend using pure Go
|
||||
type RedisBackend struct {
|
||||
config *Config
|
||||
pool *ConnectionPool
|
||||
healthMonitor *HealthMonitor
|
||||
|
||||
// Metrics
|
||||
hits atomic.Int64
|
||||
misses atomic.Int64
|
||||
|
||||
// Lifecycle
|
||||
closed atomic.Bool
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewRedisBackend creates a new Redis cache backend with pure-Go implementation
|
||||
func NewRedisBackend(config *Config) (*RedisBackend, error) {
|
||||
if config == nil {
|
||||
return nil, fmt.Errorf("config is required")
|
||||
}
|
||||
|
||||
if config.RedisAddr == "" {
|
||||
return nil, fmt.Errorf("redis address is required")
|
||||
}
|
||||
|
||||
// Create connection pool with health checks enabled
|
||||
poolConfig := &PoolConfig{
|
||||
Address: config.RedisAddr,
|
||||
Password: config.RedisPassword,
|
||||
DB: config.RedisDB,
|
||||
MaxConnections: config.PoolSize,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
ReadTimeout: 3 * time.Second,
|
||||
WriteTimeout: 3 * time.Second,
|
||||
EnableHealthCheck: true,
|
||||
MaxRetries: 3,
|
||||
RetryDelay: 100 * time.Millisecond,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(poolConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create connection pool: %w", err)
|
||||
}
|
||||
|
||||
// Create health monitor
|
||||
healthConfig := DefaultHealthMonitorConfig()
|
||||
healthMonitor := NewHealthMonitor(pool, healthConfig)
|
||||
|
||||
backend := &RedisBackend{
|
||||
config: config,
|
||||
pool: pool,
|
||||
healthMonitor: healthMonitor,
|
||||
}
|
||||
|
||||
// Test connectivity
|
||||
if err := backend.Ping(context.Background()); err != nil {
|
||||
pool.Close()
|
||||
return nil, fmt.Errorf("failed to ping Redis: %w", err)
|
||||
}
|
||||
|
||||
// Start health monitoring
|
||||
healthMonitor.Start()
|
||||
|
||||
return backend, nil
|
||||
}
|
||||
|
||||
// Set stores a value in Redis with TTL
|
||||
func (r *RedisBackend) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
|
||||
if r.closed.Load() {
|
||||
return ErrBackendClosed
|
||||
}
|
||||
|
||||
prefixedKey := r.prefixKey(key)
|
||||
|
||||
// Execute with retry logic
|
||||
return r.executeWithRetry(ctx, func(conn *RedisConn) error {
|
||||
var err error
|
||||
|
||||
// Use PSETEX for millisecond precision, SETEX for second precision
|
||||
if ttl > 0 {
|
||||
ttlMillis := ttl.Milliseconds()
|
||||
if ttlMillis < 1000 {
|
||||
// Use PSETEX for sub-second TTLs (millisecond precision)
|
||||
_, err = conn.Do("PSETEX", prefixedKey, fmt.Sprintf("%d", ttlMillis), string(value))
|
||||
} else {
|
||||
// Use SETEX for larger TTLs (second precision)
|
||||
ttlSeconds := int(ttl.Seconds())
|
||||
_, err = conn.Do("SETEX", prefixedKey, fmt.Sprintf("%d", ttlSeconds), string(value))
|
||||
}
|
||||
} else {
|
||||
_, err = conn.Do("SET", prefixedKey, string(value))
|
||||
}
|
||||
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
// Get retrieves a value from Redis
|
||||
func (r *RedisBackend) Get(ctx context.Context, key string) ([]byte, time.Duration, bool, error) {
|
||||
if r.closed.Load() {
|
||||
return nil, 0, false, ErrBackendClosed
|
||||
}
|
||||
|
||||
prefixedKey := r.prefixKey(key)
|
||||
var resultValue []byte
|
||||
var resultTTL time.Duration
|
||||
var resultExists bool
|
||||
|
||||
// Execute with retry logic
|
||||
err := r.executeWithRetry(ctx, func(conn *RedisConn) error {
|
||||
// Get value
|
||||
resp, err := conn.Do("GET", prefixedKey)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrNilResponse) {
|
||||
r.misses.Add(1)
|
||||
resultExists = false
|
||||
return nil // Not an error, key just doesn't exist
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
value, err := RESPString(resp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get TTL
|
||||
ttlResp, err := conn.Do("TTL", prefixedKey)
|
||||
if err != nil {
|
||||
// If TTL fails, still return the value
|
||||
r.hits.Add(1)
|
||||
resultValue = []byte(value)
|
||||
resultTTL = 0
|
||||
resultExists = true
|
||||
return nil
|
||||
}
|
||||
|
||||
ttlSeconds, _ := RESPInt(ttlResp)
|
||||
var ttl time.Duration
|
||||
if ttlSeconds > 0 {
|
||||
ttl = time.Duration(ttlSeconds) * time.Second
|
||||
}
|
||||
|
||||
r.hits.Add(1)
|
||||
resultValue = []byte(value)
|
||||
resultTTL = ttl
|
||||
resultExists = true
|
||||
return nil
|
||||
})
|
||||
|
||||
return resultValue, resultTTL, resultExists, err
|
||||
}
|
||||
|
||||
// Delete removes a key from Redis
|
||||
func (r *RedisBackend) Delete(ctx context.Context, key string) (bool, error) {
|
||||
if r.closed.Load() {
|
||||
return false, ErrBackendClosed
|
||||
}
|
||||
|
||||
conn, err := r.pool.Get(ctx)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer r.pool.Put(conn)
|
||||
|
||||
prefixedKey := r.prefixKey(key)
|
||||
resp, err := conn.Do("DEL", prefixedKey)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
count, err := RESPInt(resp)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// Exists checks if a key exists in Redis
|
||||
func (r *RedisBackend) Exists(ctx context.Context, key string) (bool, error) {
|
||||
if r.closed.Load() {
|
||||
return false, ErrBackendClosed
|
||||
}
|
||||
|
||||
conn, err := r.pool.Get(ctx)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer r.pool.Put(conn)
|
||||
|
||||
prefixedKey := r.prefixKey(key)
|
||||
resp, err := conn.Do("EXISTS", prefixedKey)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
count, err := RESPInt(resp)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// Clear removes all keys with the configured prefix
|
||||
func (r *RedisBackend) Clear(ctx context.Context) error {
|
||||
if r.closed.Load() {
|
||||
return ErrBackendClosed
|
||||
}
|
||||
|
||||
conn, err := r.pool.Get(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer r.pool.Put(conn)
|
||||
|
||||
// Use FLUSHDB if no prefix (clear entire DB)
|
||||
if r.config.RedisPrefix == "" {
|
||||
_, err := conn.Do("FLUSHDB")
|
||||
return err
|
||||
}
|
||||
|
||||
// With prefix, we need to scan and delete keys
|
||||
// For simplicity in this implementation, we'll use KEYS pattern (not recommended for production at scale)
|
||||
pattern := r.config.RedisPrefix + "*"
|
||||
resp, err := conn.Do("KEYS", pattern)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Extract keys from array response
|
||||
keys, ok := resp.([]interface{})
|
||||
if !ok || len(keys) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete each key
|
||||
for _, keyInterface := range keys {
|
||||
key, err := RESPString(keyInterface)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
conn.Do("DEL", key) // Best effort, ignore errors
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetStats returns backend statistics
|
||||
func (r *RedisBackend) GetStats() map[string]interface{} {
|
||||
hits := r.hits.Load()
|
||||
misses := r.misses.Load()
|
||||
total := hits + misses
|
||||
|
||||
hitRate := float64(0)
|
||||
if total > 0 {
|
||||
hitRate = float64(hits) / float64(total)
|
||||
}
|
||||
|
||||
stats := map[string]interface{}{
|
||||
"backend": "redis-pure-go",
|
||||
"address": r.config.RedisAddr,
|
||||
"hits": hits,
|
||||
"misses": misses,
|
||||
"hit_rate": hitRate,
|
||||
"pool": r.pool.Stats(),
|
||||
}
|
||||
|
||||
// Add health monitor stats if available
|
||||
if r.healthMonitor != nil {
|
||||
stats["health"] = r.healthMonitor.GetStats()
|
||||
}
|
||||
|
||||
return stats
|
||||
}
|
||||
|
||||
// Ping checks Redis connectivity
|
||||
func (r *RedisBackend) Ping(ctx context.Context) error {
|
||||
if r.closed.Load() {
|
||||
return ErrBackendClosed
|
||||
}
|
||||
|
||||
conn, err := r.pool.Get(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer r.pool.Put(conn)
|
||||
|
||||
_, err = conn.Do("PING")
|
||||
return err
|
||||
}
|
||||
|
||||
// Close closes the Redis backend and all connections
|
||||
func (r *RedisBackend) Close() error {
|
||||
if r.closed.Swap(true) {
|
||||
return nil // Already closed
|
||||
}
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
// Stop health monitor
|
||||
if r.healthMonitor != nil {
|
||||
r.healthMonitor.Stop()
|
||||
}
|
||||
|
||||
// Close connection pool
|
||||
if r.pool != nil {
|
||||
return r.pool.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// prefixKey adds the configured prefix to a key
|
||||
func (r *RedisBackend) prefixKey(key string) string {
|
||||
if r.config.RedisPrefix == "" {
|
||||
return key
|
||||
}
|
||||
return r.config.RedisPrefix + key
|
||||
}
|
||||
|
||||
// executeWithRetry executes a Redis operation with exponential backoff retry logic
|
||||
func (r *RedisBackend) executeWithRetry(ctx context.Context, operation func(*RedisConn) error) error {
|
||||
maxRetries := 3
|
||||
baseDelay := 100 * time.Millisecond
|
||||
|
||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||
conn, err := r.pool.Get(ctx)
|
||||
if err != nil {
|
||||
// If we can't get a connection and this is the last attempt, fail
|
||||
if attempt == maxRetries-1 {
|
||||
return fmt.Errorf("failed to get connection after %d attempts: %w", maxRetries, err)
|
||||
}
|
||||
|
||||
// Wait with exponential backoff before retrying
|
||||
delay := baseDelay * time.Duration(1<<uint(attempt))
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(delay):
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Execute the operation
|
||||
err = operation(conn)
|
||||
r.pool.Put(conn)
|
||||
|
||||
// If successful, return
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// If error is not retryable or last attempt, fail
|
||||
if attempt == maxRetries-1 || !isRetryableError(err) {
|
||||
return err
|
||||
}
|
||||
|
||||
// Wait with exponential backoff before retrying
|
||||
delay := baseDelay * time.Duration(1<<uint(attempt))
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(delay):
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("operation failed after %d attempts", maxRetries)
|
||||
}
|
||||
|
||||
// isRetryableError determines if an error is worth retrying
|
||||
func isRetryableError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Retry on connection errors, timeouts, etc.
|
||||
// Don't retry on application-level errors like wrong type
|
||||
errMsg := err.Error()
|
||||
retryablePatterns := []string{
|
||||
"connection",
|
||||
"timeout",
|
||||
"EOF",
|
||||
"broken pipe",
|
||||
"reset by peer",
|
||||
}
|
||||
|
||||
for _, pattern := range retryablePatterns {
|
||||
if contains(errMsg, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// SetMany stores multiple values in Redis (batch operation)
|
||||
func (r *RedisBackend) SetMany(ctx context.Context, items map[string][]byte, ttl time.Duration) error {
|
||||
if r.closed.Load() {
|
||||
return ErrBackendClosed
|
||||
}
|
||||
|
||||
// For simplicity, execute sequentially (can be optimized with pipelining later)
|
||||
for key, value := range items {
|
||||
if err := r.Set(ctx, key, value, ttl); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetMany retrieves multiple values from Redis
|
||||
func (r *RedisBackend) GetMany(ctx context.Context, keys []string) (map[string][]byte, error) {
|
||||
if r.closed.Load() {
|
||||
return nil, ErrBackendClosed
|
||||
}
|
||||
|
||||
result := make(map[string][]byte)
|
||||
|
||||
// For simplicity, execute sequentially
|
||||
for _, key := range keys {
|
||||
value, _, exists, err := r.Get(ctx, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if exists {
|
||||
result[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
+176
@@ -0,0 +1,176 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// HealthMonitor continuously monitors Redis connection health and triggers reconnections
|
||||
type HealthMonitor struct {
|
||||
pool *ConnectionPool
|
||||
config *HealthMonitorConfig
|
||||
|
||||
// State
|
||||
healthy atomic.Bool
|
||||
running atomic.Bool
|
||||
lastCheckTime atomic.Int64 // Unix timestamp
|
||||
|
||||
// Metrics
|
||||
consecutiveFailures atomic.Int64
|
||||
totalChecks atomic.Int64
|
||||
totalFailures atomic.Int64
|
||||
|
||||
// Lifecycle
|
||||
stopChan chan struct{}
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// HealthMonitorConfig configures the health monitor
|
||||
type HealthMonitorConfig struct {
|
||||
CheckInterval time.Duration // How often to check health
|
||||
Timeout time.Duration // Timeout for health check
|
||||
UnhealthyThreshold int // Consecutive failures before marking unhealthy
|
||||
OnHealthChange func(healthy bool)
|
||||
}
|
||||
|
||||
// DefaultHealthMonitorConfig returns default health monitor configuration
|
||||
func DefaultHealthMonitorConfig() *HealthMonitorConfig {
|
||||
return &HealthMonitorConfig{
|
||||
CheckInterval: 5 * time.Second,
|
||||
Timeout: 3 * time.Second,
|
||||
UnhealthyThreshold: 3,
|
||||
}
|
||||
}
|
||||
|
||||
// NewHealthMonitor creates a new health monitor
|
||||
func NewHealthMonitor(pool *ConnectionPool, config *HealthMonitorConfig) *HealthMonitor {
|
||||
if config == nil {
|
||||
config = DefaultHealthMonitorConfig()
|
||||
}
|
||||
|
||||
hm := &HealthMonitor{
|
||||
pool: pool,
|
||||
config: config,
|
||||
stopChan: make(chan struct{}),
|
||||
}
|
||||
|
||||
hm.healthy.Store(true) // Assume healthy initially
|
||||
return hm
|
||||
}
|
||||
|
||||
// Start begins health monitoring
|
||||
func (hm *HealthMonitor) Start() {
|
||||
if hm.running.Swap(true) {
|
||||
return // Already running
|
||||
}
|
||||
|
||||
hm.wg.Add(1)
|
||||
go hm.monitorLoop()
|
||||
}
|
||||
|
||||
// Stop stops health monitoring
|
||||
func (hm *HealthMonitor) Stop() {
|
||||
if !hm.running.Swap(false) {
|
||||
return // Not running
|
||||
}
|
||||
|
||||
close(hm.stopChan)
|
||||
hm.wg.Wait()
|
||||
}
|
||||
|
||||
// IsHealthy returns the current health status
|
||||
func (hm *HealthMonitor) IsHealthy() bool {
|
||||
return hm.healthy.Load()
|
||||
}
|
||||
|
||||
// GetStats returns health monitor statistics
|
||||
func (hm *HealthMonitor) GetStats() map[string]interface{} {
|
||||
lastCheck := time.Unix(hm.lastCheckTime.Load(), 0)
|
||||
|
||||
return map[string]interface{}{
|
||||
"healthy": hm.healthy.Load(),
|
||||
"consecutive_failures": hm.consecutiveFailures.Load(),
|
||||
"total_checks": hm.totalChecks.Load(),
|
||||
"total_failures": hm.totalFailures.Load(),
|
||||
"last_check": lastCheck,
|
||||
}
|
||||
}
|
||||
|
||||
// monitorLoop runs the health check loop
|
||||
func (hm *HealthMonitor) monitorLoop() {
|
||||
defer hm.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(hm.config.CheckInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// Perform initial check immediately
|
||||
hm.performHealthCheck()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-hm.stopChan:
|
||||
return
|
||||
case <-ticker.C:
|
||||
hm.performHealthCheck()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// performHealthCheck executes a health check
|
||||
func (hm *HealthMonitor) performHealthCheck() {
|
||||
hm.totalChecks.Add(1)
|
||||
hm.lastCheckTime.Store(time.Now().Unix())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), hm.config.Timeout)
|
||||
defer cancel()
|
||||
|
||||
// Try to get a connection and ping Redis
|
||||
conn, err := hm.pool.Get(ctx)
|
||||
if err != nil {
|
||||
hm.recordFailure()
|
||||
return
|
||||
}
|
||||
defer hm.pool.Put(conn)
|
||||
|
||||
// Ping Redis
|
||||
_, err = conn.Do("PING")
|
||||
if err != nil {
|
||||
hm.recordFailure()
|
||||
return
|
||||
}
|
||||
|
||||
// Success!
|
||||
hm.recordSuccess()
|
||||
}
|
||||
|
||||
// recordSuccess records a successful health check
|
||||
func (hm *HealthMonitor) recordSuccess() {
|
||||
wasHealthy := hm.healthy.Load()
|
||||
hm.consecutiveFailures.Store(0)
|
||||
hm.healthy.Store(true)
|
||||
|
||||
// Trigger callback if health changed
|
||||
if !wasHealthy && hm.config.OnHealthChange != nil {
|
||||
hm.config.OnHealthChange(true)
|
||||
}
|
||||
}
|
||||
|
||||
// recordFailure records a failed health check
|
||||
func (hm *HealthMonitor) recordFailure() {
|
||||
hm.totalFailures.Add(1)
|
||||
failures := hm.consecutiveFailures.Add(1)
|
||||
|
||||
wasHealthy := hm.healthy.Load()
|
||||
|
||||
// Mark unhealthy if threshold exceeded
|
||||
if failures >= int64(hm.config.UnhealthyThreshold) {
|
||||
hm.healthy.Store(false)
|
||||
|
||||
// Trigger callback if health changed
|
||||
if wasHealthy && hm.config.OnHealthChange != nil {
|
||||
hm.config.OnHealthChange(false)
|
||||
}
|
||||
}
|
||||
}
|
||||
+421
@@ -0,0 +1,421 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestHealthMonitor_BasicOperation tests basic health monitoring
|
||||
func TestHealthMonitor_BasicOperation(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 5,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
// Create health monitor with fast check interval for testing
|
||||
hmConfig := &HealthMonitorConfig{
|
||||
CheckInterval: 100 * time.Millisecond,
|
||||
Timeout: 1 * time.Second,
|
||||
UnhealthyThreshold: 2,
|
||||
}
|
||||
|
||||
hm := NewHealthMonitor(pool, hmConfig)
|
||||
require.NotNil(t, hm)
|
||||
|
||||
// Initially should be healthy
|
||||
assert.True(t, hm.IsHealthy())
|
||||
|
||||
// Start monitoring
|
||||
hm.Start()
|
||||
defer hm.Stop()
|
||||
|
||||
// Wait for a few checks
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// Should still be healthy
|
||||
assert.True(t, hm.IsHealthy())
|
||||
|
||||
// Check stats
|
||||
stats := hm.GetStats()
|
||||
require.NotNil(t, stats)
|
||||
assert.True(t, stats["healthy"].(bool))
|
||||
assert.Greater(t, stats["total_checks"].(int64), int64(0))
|
||||
assert.Equal(t, int64(0), stats["consecutive_failures"].(int64))
|
||||
}
|
||||
|
||||
// TestHealthMonitor_HealthyToUnhealthy tests transition to unhealthy state
|
||||
func TestHealthMonitor_HealthyToUnhealthy(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 5,
|
||||
ConnectTimeout: 100 * time.Millisecond,
|
||||
ReadTimeout: 100 * time.Millisecond,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
var healthChangedCalled atomic.Bool
|
||||
hmConfig := &HealthMonitorConfig{
|
||||
CheckInterval: 50 * time.Millisecond,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
UnhealthyThreshold: 2,
|
||||
OnHealthChange: func(healthy bool) {
|
||||
if !healthy {
|
||||
healthChangedCalled.Store(true)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
hm := NewHealthMonitor(pool, hmConfig)
|
||||
hm.Start()
|
||||
defer hm.Stop()
|
||||
|
||||
// Initially healthy
|
||||
assert.True(t, hm.IsHealthy())
|
||||
|
||||
// Simulate Redis errors
|
||||
mr.SetError("ERR server is down")
|
||||
|
||||
// Wait for health checks to detect failure (2 failures * 50ms + buffer)
|
||||
time.Sleep(350 * time.Millisecond)
|
||||
|
||||
// Should now be unhealthy
|
||||
assert.False(t, hm.IsHealthy(), "Health monitor should detect server failure")
|
||||
assert.True(t, healthChangedCalled.Load(), "OnHealthChange callback should be called")
|
||||
|
||||
// Check stats
|
||||
stats := hm.GetStats()
|
||||
assert.False(t, stats["healthy"].(bool))
|
||||
assert.GreaterOrEqual(t, stats["consecutive_failures"].(int64), int64(2))
|
||||
assert.Greater(t, stats["total_failures"].(int64), int64(0))
|
||||
}
|
||||
|
||||
// TestHealthMonitor_UnhealthyToHealthy tests recovery to healthy state
|
||||
func TestHealthMonitor_UnhealthyToHealthy(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 5,
|
||||
ConnectTimeout: 100 * time.Millisecond,
|
||||
ReadTimeout: 100 * time.Millisecond,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
var recoveryDetected atomic.Bool
|
||||
hmConfig := &HealthMonitorConfig{
|
||||
CheckInterval: 50 * time.Millisecond,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
UnhealthyThreshold: 2,
|
||||
OnHealthChange: func(healthy bool) {
|
||||
if healthy {
|
||||
recoveryDetected.Store(true)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
hm := NewHealthMonitor(pool, hmConfig)
|
||||
hm.Start()
|
||||
defer hm.Stop()
|
||||
|
||||
// Initially healthy
|
||||
assert.True(t, hm.IsHealthy())
|
||||
|
||||
// Simulate Redis errors
|
||||
mr.SetError("ERR server is down")
|
||||
|
||||
// Wait for health checks to detect failure
|
||||
time.Sleep(350 * time.Millisecond)
|
||||
|
||||
// Should now be unhealthy
|
||||
assert.False(t, hm.IsHealthy(), "Should detect server failure")
|
||||
|
||||
// Clear error to simulate recovery
|
||||
mr.ClearError()
|
||||
|
||||
// Wait for recovery
|
||||
time.Sleep(350 * time.Millisecond)
|
||||
|
||||
// Should be healthy again
|
||||
assert.True(t, hm.IsHealthy(), "Should recover after server restart")
|
||||
assert.True(t, recoveryDetected.Load(), "Recovery callback should be called")
|
||||
|
||||
// Consecutive failures should be reset
|
||||
stats := hm.GetStats()
|
||||
assert.True(t, stats["healthy"].(bool))
|
||||
assert.Equal(t, int64(0), stats["consecutive_failures"].(int64))
|
||||
}
|
||||
|
||||
// TestHealthMonitor_StartStop tests start/stop behavior
|
||||
func TestHealthMonitor_StartStop(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 5,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
hm := NewHealthMonitor(pool, DefaultHealthMonitorConfig())
|
||||
|
||||
// Start monitoring
|
||||
hm.Start()
|
||||
assert.True(t, hm.running.Load())
|
||||
|
||||
// Starting again should be no-op
|
||||
hm.Start()
|
||||
assert.True(t, hm.running.Load())
|
||||
|
||||
// Stop monitoring
|
||||
hm.Stop()
|
||||
assert.False(t, hm.running.Load())
|
||||
|
||||
// Stopping again should be no-op
|
||||
hm.Stop()
|
||||
assert.False(t, hm.running.Load())
|
||||
}
|
||||
|
||||
// TestHealthMonitor_MultipleMonitors tests multiple health monitors
|
||||
func TestHealthMonitor_MultipleMonitors(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 10,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
// Create multiple monitors
|
||||
hm1 := NewHealthMonitor(pool, &HealthMonitorConfig{
|
||||
CheckInterval: 100 * time.Millisecond,
|
||||
Timeout: 1 * time.Second,
|
||||
UnhealthyThreshold: 2,
|
||||
})
|
||||
|
||||
hm2 := NewHealthMonitor(pool, &HealthMonitorConfig{
|
||||
CheckInterval: 150 * time.Millisecond,
|
||||
Timeout: 1 * time.Second,
|
||||
UnhealthyThreshold: 3,
|
||||
})
|
||||
|
||||
// Start both
|
||||
hm1.Start()
|
||||
hm2.Start()
|
||||
|
||||
// Both should be healthy
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
assert.True(t, hm1.IsHealthy())
|
||||
assert.True(t, hm2.IsHealthy())
|
||||
|
||||
// Stop both
|
||||
hm1.Stop()
|
||||
hm2.Stop()
|
||||
|
||||
// Verify they stopped
|
||||
assert.False(t, hm1.running.Load())
|
||||
assert.False(t, hm2.running.Load())
|
||||
}
|
||||
|
||||
// TestHealthMonitor_StatsAccuracy tests stats tracking
|
||||
func TestHealthMonitor_StatsAccuracy(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 5,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
hm := NewHealthMonitor(pool, &HealthMonitorConfig{
|
||||
CheckInterval: 100 * time.Millisecond,
|
||||
Timeout: 1 * time.Second,
|
||||
UnhealthyThreshold: 2,
|
||||
})
|
||||
|
||||
hm.Start()
|
||||
defer hm.Stop()
|
||||
|
||||
// Wait for some checks
|
||||
time.Sleep(550 * time.Millisecond)
|
||||
|
||||
stats := hm.GetStats()
|
||||
|
||||
// Should have performed multiple checks
|
||||
totalChecks := stats["total_checks"].(int64)
|
||||
assert.GreaterOrEqual(t, totalChecks, int64(4))
|
||||
|
||||
// All checks should succeed
|
||||
assert.Equal(t, int64(0), stats["total_failures"].(int64))
|
||||
assert.Equal(t, int64(0), stats["consecutive_failures"].(int64))
|
||||
|
||||
// Last check time should be recent (within check interval + buffer)
|
||||
// Use 2s tolerance to account for CI runner load and timing variance
|
||||
lastCheck := stats["last_check"].(time.Time)
|
||||
assert.WithinDuration(t, time.Now(), lastCheck, 2*time.Second)
|
||||
}
|
||||
|
||||
// TestHealthMonitor_DefaultConfig tests default configuration
|
||||
func TestHealthMonitor_DefaultConfig(t *testing.T) {
|
||||
config := DefaultHealthMonitorConfig()
|
||||
|
||||
assert.Equal(t, 5*time.Second, config.CheckInterval)
|
||||
assert.Equal(t, 3*time.Second, config.Timeout)
|
||||
assert.Equal(t, 3, config.UnhealthyThreshold)
|
||||
assert.Nil(t, config.OnHealthChange)
|
||||
}
|
||||
|
||||
// TestHealthMonitor_PoolExhaustion tests behavior when pool is exhausted
|
||||
func TestHealthMonitor_PoolExhaustion(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 1, // Very small pool
|
||||
ConnectTimeout: 100 * time.Millisecond,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
hm := NewHealthMonitor(pool, &HealthMonitorConfig{
|
||||
CheckInterval: 100 * time.Millisecond,
|
||||
Timeout: 50 * time.Millisecond, // Short timeout
|
||||
UnhealthyThreshold: 2,
|
||||
})
|
||||
|
||||
hm.Start()
|
||||
defer hm.Stop()
|
||||
|
||||
// Get the only connection, blocking health checks
|
||||
ctx := context.Background()
|
||||
conn, err := pool.Get(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wait for health check attempts
|
||||
time.Sleep(350 * time.Millisecond)
|
||||
|
||||
// Health monitor might mark as unhealthy due to timeouts
|
||||
stats := hm.GetStats()
|
||||
t.Logf("Stats with blocked pool: %+v", stats)
|
||||
|
||||
// Return connection
|
||||
pool.Put(conn)
|
||||
|
||||
// Wait for recovery
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
|
||||
// Should recover
|
||||
assert.True(t, hm.IsHealthy())
|
||||
}
|
||||
|
||||
// TestConnectionPool_WithHealthChecks tests pool with health checks enabled
|
||||
func TestConnectionPool_WithHealthChecks(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 5,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
EnableHealthCheck: true,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Get a connection
|
||||
conn, err := pool.Get(ctx)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conn)
|
||||
|
||||
// Connection should be healthy
|
||||
assert.True(t, pool.isConnectionHealthy(conn))
|
||||
|
||||
// Use connection
|
||||
resp, err := conn.Do("PING")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "PONG", resp)
|
||||
|
||||
// Return to pool
|
||||
pool.Put(conn)
|
||||
|
||||
// Get again - should reuse and validate
|
||||
conn2, err := pool.Get(ctx)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conn2)
|
||||
|
||||
pool.Put(conn2)
|
||||
}
|
||||
|
||||
// TestConnectionPool_StaleConnectionRemoval tests stale connection handling
|
||||
func TestConnectionPool_StaleConnectionRemoval(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 3,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
EnableHealthCheck: true,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Get and return a connection
|
||||
conn, err := pool.Get(ctx)
|
||||
require.NoError(t, err)
|
||||
pool.Put(conn)
|
||||
|
||||
initialTotal := pool.totalConns.Load()
|
||||
|
||||
// Close the connection manually to make it stale
|
||||
conn.Close()
|
||||
|
||||
// Get another connection - should detect stale and create new
|
||||
conn2, err := pool.Get(ctx)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conn2)
|
||||
|
||||
// Connection should be healthy
|
||||
assert.True(t, pool.isConnectionHealthy(conn2))
|
||||
|
||||
pool.Put(conn2)
|
||||
|
||||
// Total connections might be same or less (stale removed)
|
||||
finalTotal := pool.totalConns.Load()
|
||||
assert.LessOrEqual(t, finalTotal, initialTotal+1)
|
||||
}
|
||||
+337
@@ -0,0 +1,337 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ConnectionPool manages a pool of Redis connections
|
||||
// Pure-Go implementation compatible with Yaegi
|
||||
type ConnectionPool struct {
|
||||
config *PoolConfig
|
||||
|
||||
connections chan *RedisConn
|
||||
mu sync.Mutex
|
||||
closed atomic.Bool
|
||||
|
||||
// Metrics
|
||||
activeConns atomic.Int32
|
||||
totalConns atomic.Int32
|
||||
gets atomic.Int64
|
||||
puts atomic.Int64
|
||||
timeouts atomic.Int64
|
||||
}
|
||||
|
||||
// PoolConfig holds connection pool configuration
|
||||
type PoolConfig struct {
|
||||
Address string
|
||||
Password string
|
||||
DB int
|
||||
MaxConnections int
|
||||
ConnectTimeout time.Duration
|
||||
ReadTimeout time.Duration
|
||||
WriteTimeout time.Duration
|
||||
EnableHealthCheck bool // Enable connection health validation
|
||||
MaxRetries int // Max retries for failed operations
|
||||
RetryDelay time.Duration // Initial delay between retries
|
||||
}
|
||||
|
||||
// NewConnectionPool creates a new connection pool
|
||||
func NewConnectionPool(config *PoolConfig) (*ConnectionPool, error) {
|
||||
if config == nil {
|
||||
return nil, errors.New("config is required")
|
||||
}
|
||||
|
||||
if config.MaxConnections <= 0 {
|
||||
config.MaxConnections = 10
|
||||
}
|
||||
|
||||
if config.ConnectTimeout == 0 {
|
||||
config.ConnectTimeout = 5 * time.Second
|
||||
}
|
||||
|
||||
pool := &ConnectionPool{
|
||||
config: config,
|
||||
connections: make(chan *RedisConn, config.MaxConnections),
|
||||
}
|
||||
|
||||
return pool, nil
|
||||
}
|
||||
|
||||
// Get retrieves a connection from the pool or creates a new one
|
||||
func (p *ConnectionPool) Get(ctx context.Context) (*RedisConn, error) {
|
||||
if p.closed.Load() {
|
||||
return nil, ErrBackendClosed
|
||||
}
|
||||
|
||||
p.gets.Add(1)
|
||||
|
||||
// Try to get a connection with validation
|
||||
maxAttempts := 3
|
||||
for attempt := 0; attempt < maxAttempts; attempt++ {
|
||||
var conn *RedisConn
|
||||
var err error
|
||||
|
||||
select {
|
||||
case conn = <-p.connections:
|
||||
// Reuse existing connection - validate if health check enabled
|
||||
if p.config.EnableHealthCheck && !p.isConnectionHealthy(conn) {
|
||||
// Connection is stale, close it and try again
|
||||
conn.Close()
|
||||
p.totalConns.Add(-1)
|
||||
continue
|
||||
}
|
||||
p.activeConns.Add(1)
|
||||
return conn, nil
|
||||
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
|
||||
default:
|
||||
// No available connection, create new one if under limit
|
||||
if p.totalConns.Load() < int32(p.config.MaxConnections) {
|
||||
conn, err = p.createConnection()
|
||||
if err != nil {
|
||||
// If this is the last attempt, return error
|
||||
if attempt == maxAttempts-1 {
|
||||
return nil, err
|
||||
}
|
||||
// Wait before retry with exponential backoff
|
||||
time.Sleep(time.Duration(attempt+1) * 100 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
p.activeConns.Add(1)
|
||||
p.totalConns.Add(1)
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// Pool exhausted, wait for a connection with timeout
|
||||
select {
|
||||
case conn = <-p.connections:
|
||||
// Validate connection if health check enabled
|
||||
if p.config.EnableHealthCheck && !p.isConnectionHealthy(conn) {
|
||||
conn.Close()
|
||||
p.totalConns.Add(-1)
|
||||
continue
|
||||
}
|
||||
p.activeConns.Add(1)
|
||||
return conn, nil
|
||||
case <-ctx.Done():
|
||||
p.timeouts.Add(1)
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(p.config.ConnectTimeout):
|
||||
p.timeouts.Add(1)
|
||||
return nil, ErrPoolExhausted
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errors.New("failed to get healthy connection after retries")
|
||||
}
|
||||
|
||||
// Put returns a connection to the pool
|
||||
func (p *ConnectionPool) Put(conn *RedisConn) {
|
||||
if conn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
p.puts.Add(1)
|
||||
p.activeConns.Add(-1)
|
||||
|
||||
if p.closed.Load() || conn.closed.Load() {
|
||||
conn.Close()
|
||||
p.totalConns.Add(-1)
|
||||
return
|
||||
}
|
||||
|
||||
// Return to pool (non-blocking)
|
||||
select {
|
||||
case p.connections <- conn:
|
||||
// Successfully returned to pool
|
||||
default:
|
||||
// Pool full, close connection
|
||||
conn.Close()
|
||||
p.totalConns.Add(-1)
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes all connections in the pool
|
||||
func (p *ConnectionPool) Close() error {
|
||||
if p.closed.Swap(true) {
|
||||
return nil
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
close(p.connections)
|
||||
|
||||
// Close all pooled connections
|
||||
for conn := range p.connections {
|
||||
conn.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stats returns pool statistics
|
||||
func (p *ConnectionPool) Stats() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"active_connections": p.activeConns.Load(),
|
||||
"total_connections": p.totalConns.Load(),
|
||||
"max_connections": p.config.MaxConnections,
|
||||
"gets": p.gets.Load(),
|
||||
"puts": p.puts.Load(),
|
||||
"timeouts": p.timeouts.Load(),
|
||||
}
|
||||
}
|
||||
|
||||
// createConnection creates a new Redis connection
|
||||
func (p *ConnectionPool) createConnection() (*RedisConn, error) {
|
||||
// Connect with timeout
|
||||
dialer := &net.Dialer{
|
||||
Timeout: p.config.ConnectTimeout,
|
||||
}
|
||||
|
||||
conn, err := dialer.Dial("tcp", p.config.Address)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to Redis: %w", err)
|
||||
}
|
||||
|
||||
redisConn := &RedisConn{
|
||||
conn: conn,
|
||||
readTimeout: p.config.ReadTimeout,
|
||||
writeTimeout: p.config.WriteTimeout,
|
||||
}
|
||||
|
||||
// Authenticate if password is provided
|
||||
if p.config.Password != "" {
|
||||
if _, err := redisConn.Do("AUTH", p.config.Password); err != nil {
|
||||
redisConn.Close()
|
||||
return nil, fmt.Errorf("authentication failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Select database
|
||||
if p.config.DB != 0 {
|
||||
if _, err := redisConn.Do("SELECT", fmt.Sprintf("%d", p.config.DB)); err != nil {
|
||||
redisConn.Close()
|
||||
return nil, fmt.Errorf("failed to select database: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return redisConn, nil
|
||||
}
|
||||
|
||||
// RedisConn represents a single Redis connection
|
||||
type RedisConn struct {
|
||||
conn net.Conn
|
||||
readTimeout time.Duration
|
||||
writeTimeout time.Duration
|
||||
closed atomic.Bool
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// Do executes a Redis command and returns the response
|
||||
func (c *RedisConn) Do(command string, args ...string) (interface{}, error) {
|
||||
if c.closed.Load() {
|
||||
return nil, ErrBackendClosed
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
// Build command arguments
|
||||
// Check for overflow: ensure len(args)+1 doesn't cause allocation overflow
|
||||
// Limit to a safe value that prevents integer overflow in allocation size calculation
|
||||
// (capacity * sizeof(string) must fit in int/size_t)
|
||||
argsLen := len(args)
|
||||
const maxSafeArgs = (1 << 20) - 1 // 1M args is already absurdly large for Redis commands
|
||||
if argsLen < 0 || argsLen > maxSafeArgs {
|
||||
return nil, errors.New("too many arguments")
|
||||
}
|
||||
const maxTotalArgBytes = 64 << 20 // 64 MiB max total size
|
||||
totalBytes := len(command)
|
||||
for _, s := range args {
|
||||
// Protect against possible overflow
|
||||
if len(s) > maxTotalArgBytes-totalBytes {
|
||||
return nil, errors.New("arguments too large (would overflow maximum allowed total size)")
|
||||
}
|
||||
totalBytes += len(s)
|
||||
if totalBytes > maxTotalArgBytes {
|
||||
return nil, errors.New("total argument size exceeds maximum allowed")
|
||||
}
|
||||
}
|
||||
cmdArgs := make([]string, 0, argsLen+1)
|
||||
cmdArgs = append(cmdArgs, command)
|
||||
cmdArgs = append(cmdArgs, args...)
|
||||
|
||||
// Set write timeout
|
||||
if c.writeTimeout > 0 {
|
||||
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
|
||||
}
|
||||
|
||||
// Write command (using pooled writer for memory efficiency)
|
||||
writer := NewRESPWriter(c.conn)
|
||||
err := writer.WriteCommand(cmdArgs...)
|
||||
writer.Release() // Return to pool immediately after use
|
||||
if err != nil {
|
||||
c.closed.Store(true)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Set read timeout
|
||||
if c.readTimeout > 0 {
|
||||
c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
|
||||
}
|
||||
|
||||
// Read response (using pooled reader for memory efficiency)
|
||||
reader := NewRESPReader(c.conn)
|
||||
resp, err := reader.ReadResponse()
|
||||
reader.Release() // Return to pool immediately after use
|
||||
if err != nil {
|
||||
if !errors.Is(err, ErrNilResponse) {
|
||||
c.closed.Store(true)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// Close closes the connection
|
||||
func (c *RedisConn) Close() error {
|
||||
if c.closed.Swap(true) {
|
||||
return nil
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.conn != nil {
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isConnectionHealthy validates a connection is still working
|
||||
func (p *ConnectionPool) isConnectionHealthy(conn *RedisConn) bool {
|
||||
if conn == nil || conn.closed.Load() {
|
||||
return false
|
||||
}
|
||||
|
||||
// Set a read deadline for the ping
|
||||
if conn.conn != nil {
|
||||
conn.conn.SetReadDeadline(time.Now().Add(1 * time.Second))
|
||||
defer conn.conn.SetReadDeadline(time.Time{}) // Clear deadline
|
||||
}
|
||||
|
||||
_, err := conn.Do("PING")
|
||||
return err == nil
|
||||
}
|
||||
+620
@@ -0,0 +1,620 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestConnectionPool_BasicOperations tests basic pool operations
|
||||
func TestConnectionPool_BasicOperations(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 5,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
ReadTimeout: 3 * time.Second,
|
||||
WriteTimeout: 3 * time.Second,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
t.Run("GetAndPutConnection", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Get a connection
|
||||
conn, err := pool.Get(ctx)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conn)
|
||||
|
||||
// Verify connection works
|
||||
resp, err := conn.Do("PING")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "PONG", resp)
|
||||
|
||||
// Return to pool
|
||||
pool.Put(conn)
|
||||
|
||||
// Get again - should reuse same connection
|
||||
conn2, err := pool.Get(ctx)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conn2)
|
||||
|
||||
pool.Put(conn2)
|
||||
})
|
||||
|
||||
t.Run("Stats", func(t *testing.T) {
|
||||
stats := pool.Stats()
|
||||
require.NotNil(t, stats)
|
||||
|
||||
assert.Contains(t, stats, "active_connections")
|
||||
assert.Contains(t, stats, "total_connections")
|
||||
assert.Contains(t, stats, "max_connections")
|
||||
assert.Equal(t, 5, stats["max_connections"])
|
||||
})
|
||||
}
|
||||
|
||||
// TestConnectionPool_MaxConnections tests pool size limits
|
||||
func TestConnectionPool_MaxConnections(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
maxConns := 3
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: maxConns,
|
||||
ConnectTimeout: 1 * time.Second,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Get max connections
|
||||
conns := make([]*RedisConn, maxConns)
|
||||
for i := 0; i < maxConns; i++ {
|
||||
conn, err := pool.Get(ctx)
|
||||
require.NoError(t, err)
|
||||
conns[i] = conn
|
||||
}
|
||||
|
||||
// Verify stats
|
||||
stats := pool.Stats()
|
||||
assert.Equal(t, int32(maxConns), stats["total_connections"])
|
||||
assert.Equal(t, int32(maxConns), stats["active_connections"])
|
||||
|
||||
// Try to get one more - should block/timeout
|
||||
ctx2, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
conn, err := pool.Get(ctx2)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, conn)
|
||||
|
||||
// Return one connection
|
||||
pool.Put(conns[0])
|
||||
|
||||
// Now we should be able to get a connection
|
||||
conn, err = pool.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conn)
|
||||
|
||||
// Cleanup
|
||||
pool.Put(conn)
|
||||
for i := 1; i < maxConns; i++ {
|
||||
pool.Put(conns[i])
|
||||
}
|
||||
}
|
||||
|
||||
// TestConnectionPool_ConcurrentAccess tests concurrent pool usage
|
||||
func TestConnectionPool_ConcurrentAccess(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 10,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
numGoroutines := 50
|
||||
numOperations := 20
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errors := make(chan error, numGoroutines*numOperations)
|
||||
|
||||
// Spawn goroutines
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
|
||||
for j := 0; j < numOperations; j++ {
|
||||
conn, err := pool.Get(ctx)
|
||||
if err != nil {
|
||||
errors <- err
|
||||
continue
|
||||
}
|
||||
|
||||
// Do some work
|
||||
_, err = conn.Do("PING")
|
||||
if err != nil {
|
||||
errors <- err
|
||||
}
|
||||
|
||||
// Return to pool
|
||||
pool.Put(conn)
|
||||
|
||||
// Small delay
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errors)
|
||||
|
||||
// Check for errors
|
||||
errorCount := 0
|
||||
for err := range errors {
|
||||
t.Logf("Error: %v", err)
|
||||
errorCount++
|
||||
}
|
||||
|
||||
assert.Equal(t, 0, errorCount, "Expected no errors in concurrent access")
|
||||
|
||||
// Verify stats
|
||||
stats := pool.Stats()
|
||||
t.Logf("Final stats: %+v", stats)
|
||||
assert.LessOrEqual(t, stats["total_connections"].(int32), int32(10))
|
||||
assert.Equal(t, int32(0), stats["active_connections"])
|
||||
}
|
||||
|
||||
// TestConnectionPool_ContextCancellation tests context cancellation
|
||||
func TestConnectionPool_ContextCancellation(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 1,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
// Get the only connection
|
||||
conn, err := pool.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to get another with cancelled context
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // Cancel immediately
|
||||
|
||||
conn2, err := pool.Get(ctx)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, conn2)
|
||||
assert.Contains(t, err.Error(), "context canceled")
|
||||
|
||||
// Cleanup
|
||||
pool.Put(conn)
|
||||
}
|
||||
|
||||
// TestConnectionPool_Authentication tests auth support
|
||||
func TestConnectionPool_Authentication(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
// Set password on miniredis
|
||||
mr.server.RequireAuth("secret-password")
|
||||
|
||||
t.Run("CorrectPassword", func(t *testing.T) {
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
Password: "secret-password",
|
||||
MaxConnections: 2,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
conn, err := pool.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := conn.Do("PING")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "PONG", resp)
|
||||
|
||||
pool.Put(conn)
|
||||
})
|
||||
|
||||
t.Run("WrongPassword", func(t *testing.T) {
|
||||
t.Skip("Miniredis doesn't fully simulate AUTH errors like real Redis")
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
Password: "wrong-password",
|
||||
MaxConnections: 2,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
_, err := NewConnectionPool(config)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "authentication failed")
|
||||
})
|
||||
}
|
||||
|
||||
// TestConnectionPool_DatabaseSelection tests DB selection
|
||||
func TestConnectionPool_DatabaseSelection(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
DB: 5,
|
||||
MaxConnections: 2,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
conn, err := pool.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Connection should be on DB 5
|
||||
resp, err := conn.Do("PING")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "PONG", resp)
|
||||
|
||||
pool.Put(conn)
|
||||
}
|
||||
|
||||
// TestConnectionPool_ClosedConnection tests handling closed connections
|
||||
func TestConnectionPool_ClosedConnection(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 2,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
// Get connection
|
||||
conn, err := pool.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Close it manually
|
||||
conn.Close()
|
||||
|
||||
// Try to use it
|
||||
_, err = conn.Do("PING")
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, ErrBackendClosed))
|
||||
|
||||
// Return to pool (should be discarded)
|
||||
pool.Put(conn)
|
||||
|
||||
// Get new connection - should create a new one
|
||||
conn2, err := pool.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conn2)
|
||||
|
||||
resp, err := conn2.Do("PING")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "PONG", resp)
|
||||
|
||||
pool.Put(conn2)
|
||||
}
|
||||
|
||||
// TestConnectionPool_Close tests pool closure
|
||||
func TestConnectionPool_Close(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 5,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get some connections
|
||||
conns := make([]*RedisConn, 3)
|
||||
for i := 0; i < 3; i++ {
|
||||
conn, err := pool.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
conns[i] = conn
|
||||
}
|
||||
|
||||
// Return them
|
||||
for _, conn := range conns {
|
||||
pool.Put(conn)
|
||||
}
|
||||
|
||||
// Close pool
|
||||
err = pool.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to get connection from closed pool
|
||||
_, err = pool.Get(context.Background())
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, ErrBackendClosed))
|
||||
|
||||
// Close again should be no-op
|
||||
err = pool.Close()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestConnectionPool_Timeouts tests various timeout scenarios
|
||||
func TestConnectionPool_Timeouts(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 2,
|
||||
ConnectTimeout: 100 * time.Millisecond,
|
||||
ReadTimeout: 100 * time.Millisecond,
|
||||
WriteTimeout: 100 * time.Millisecond,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
conn, err := pool.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Normal operation should work
|
||||
resp, err := conn.Do("PING")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "PONG", resp)
|
||||
|
||||
pool.Put(conn)
|
||||
}
|
||||
|
||||
// TestRedisConn_DoCommand tests the Do method
|
||||
func TestRedisConn_DoCommand(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 2,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
conn, err := pool.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
defer pool.Put(conn)
|
||||
|
||||
t.Run("SET and GET", func(t *testing.T) {
|
||||
// SET
|
||||
resp, err := conn.Do("SET", "testkey", "testvalue")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "OK", resp)
|
||||
|
||||
// GET
|
||||
resp, err = conn.Do("GET", "testkey")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "testvalue", resp)
|
||||
})
|
||||
|
||||
t.Run("DEL", func(t *testing.T) {
|
||||
// SET key first
|
||||
_, err := conn.Do("SET", "delkey", "delvalue")
|
||||
require.NoError(t, err)
|
||||
|
||||
// DEL
|
||||
resp, err := conn.Do("DEL", "delkey")
|
||||
require.NoError(t, err)
|
||||
|
||||
count, err := RESPInt(resp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(1), count)
|
||||
})
|
||||
|
||||
t.Run("EXISTS", func(t *testing.T) {
|
||||
// SET key first
|
||||
_, err := conn.Do("SET", "existskey", "value")
|
||||
require.NoError(t, err)
|
||||
|
||||
// EXISTS - key exists
|
||||
resp, err := conn.Do("EXISTS", "existskey")
|
||||
require.NoError(t, err)
|
||||
|
||||
count, err := RESPInt(resp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(1), count)
|
||||
|
||||
// EXISTS - key doesn't exist
|
||||
resp, err = conn.Do("EXISTS", "nonexistent")
|
||||
require.NoError(t, err)
|
||||
|
||||
count, err = RESPInt(resp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(0), count)
|
||||
})
|
||||
|
||||
t.Run("TTL commands", func(t *testing.T) {
|
||||
// SETEX
|
||||
resp, err := conn.Do("SETEX", "ttlkey", "60", "ttlvalue")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "OK", resp)
|
||||
|
||||
// TTL
|
||||
resp, err = conn.Do("TTL", "ttlkey")
|
||||
require.NoError(t, err)
|
||||
|
||||
ttl, err := RESPInt(resp)
|
||||
require.NoError(t, err)
|
||||
assert.Greater(t, ttl, int64(0))
|
||||
assert.LessOrEqual(t, ttl, int64(60))
|
||||
})
|
||||
}
|
||||
|
||||
// TestPoolConfig_Defaults tests default configuration values
|
||||
func TestPoolConfig_Defaults(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
// Leave other fields at zero values
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
// Should use defaults
|
||||
assert.Equal(t, 10, pool.config.MaxConnections)
|
||||
assert.Equal(t, 5*time.Second, pool.config.ConnectTimeout)
|
||||
|
||||
// Verify it works
|
||||
conn, err := pool.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
pool.Put(conn)
|
||||
}
|
||||
|
||||
// TestConnectionPool_NilConnection tests handling nil connections
|
||||
func TestConnectionPool_NilConnection(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 2,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
// Putting nil should be safe
|
||||
pool.Put(nil)
|
||||
|
||||
// Pool should still work
|
||||
conn, err := pool.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conn)
|
||||
pool.Put(conn)
|
||||
}
|
||||
|
||||
// TestConnectionPool_StatsTracking tests metrics tracking
|
||||
func TestConnectionPool_StatsTracking(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 5,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Initial stats
|
||||
stats := pool.Stats()
|
||||
initialGets := stats["gets"].(int64)
|
||||
initialPuts := stats["puts"].(int64)
|
||||
|
||||
// Perform operations
|
||||
numOps := 10
|
||||
for i := 0; i < numOps; i++ {
|
||||
conn, err := pool.Get(ctx)
|
||||
require.NoError(t, err)
|
||||
pool.Put(conn)
|
||||
}
|
||||
|
||||
// Check updated stats
|
||||
stats = pool.Stats()
|
||||
assert.Equal(t, initialGets+int64(numOps), stats["gets"].(int64))
|
||||
assert.Equal(t, initialPuts+int64(numOps), stats["puts"].(int64))
|
||||
assert.Equal(t, int32(0), stats["active_connections"].(int32))
|
||||
}
|
||||
|
||||
// TestRedisConn_TooManyArguments tests protection against allocation overflow
|
||||
func TestRedisConn_TooManyArguments(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 1,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
ReadTimeout: 3 * time.Second,
|
||||
WriteTimeout: 3 * time.Second,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
conn, err := pool.Get(ctx)
|
||||
require.NoError(t, err)
|
||||
defer pool.Put(conn)
|
||||
|
||||
t.Run("AcceptableArgumentCount", func(t *testing.T) {
|
||||
// Should work with reasonable number of args
|
||||
args := make([]string, 100)
|
||||
for i := range args {
|
||||
args[i] = "value"
|
||||
}
|
||||
_, err := conn.Do("MSET", args...)
|
||||
// May fail due to Redis constraints, but shouldn't panic or error on overflow
|
||||
// Just verify it doesn't trigger our overflow protection
|
||||
if err != nil {
|
||||
assert.NotContains(t, err.Error(), "too many arguments")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("RejectExcessiveArguments", func(t *testing.T) {
|
||||
// Create an absurdly large number of arguments that would cause overflow
|
||||
// Use 1M + 1 to exceed maxSafeArgs = (1<<20)-1 = 1048575
|
||||
args := make([]string, 1<<20) // 1,048,576 args
|
||||
for i := range args {
|
||||
args[i] = "x"
|
||||
}
|
||||
|
||||
_, err := conn.Do("MSET", args...)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "too many arguments")
|
||||
})
|
||||
|
||||
t.Run("BoundaryCase", func(t *testing.T) {
|
||||
// Test exactly at the boundary (maxSafeArgs)
|
||||
args := make([]string, (1<<20)-1) // Exactly 1,048,575 args (max allowed)
|
||||
for i := range args {
|
||||
args[i] = "x"
|
||||
}
|
||||
|
||||
_, err := conn.Do("ECHO", args...)
|
||||
// Should not error due to overflow protection
|
||||
if err != nil {
|
||||
assert.NotContains(t, err.Error(), "too many arguments")
|
||||
}
|
||||
})
|
||||
}
|
||||
+545
@@ -0,0 +1,545 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestRedisBackend_BasicOperations tests basic Redis operations
|
||||
func TestRedisBackend_BasicOperations(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("SetAndGet", func(t *testing.T) {
|
||||
key := "redis-test-key"
|
||||
value := []byte("redis-test-value")
|
||||
ttl := 1 * time.Minute
|
||||
|
||||
err := backend.Set(ctx, key, value, ttl)
|
||||
require.NoError(t, err)
|
||||
|
||||
retrieved, remainingTTL, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, value, retrieved)
|
||||
assert.Greater(t, remainingTTL, 50*time.Second)
|
||||
})
|
||||
|
||||
t.Run("GetNonExistent", func(t *testing.T) {
|
||||
_, _, exists, err := backend.Get(ctx, "non-existent-redis-key")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("Delete", func(t *testing.T) {
|
||||
key := "redis-delete-key"
|
||||
value := []byte("redis-delete-value")
|
||||
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
deleted, err := backend.Delete(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, deleted)
|
||||
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("Exists", func(t *testing.T) {
|
||||
key := "redis-exists-key"
|
||||
value := []byte("redis-exists-value")
|
||||
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
|
||||
err = backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
exists, err = backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
})
|
||||
}
|
||||
|
||||
// TestRedisBackend_KeyPrefixing tests key namespace prefixing
|
||||
func TestRedisBackend_KeyPrefixing(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
config.RedisPrefix = "test:prefix:"
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
key := "my-key"
|
||||
value := []byte("my-value")
|
||||
|
||||
err = backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check that key is stored with prefix
|
||||
keys := mr.CheckKeys()
|
||||
require.Len(t, keys, 1)
|
||||
assert.Equal(t, "test:prefix:my-key", keys[0])
|
||||
|
||||
// Get should work without prefix
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, value, retrieved)
|
||||
}
|
||||
|
||||
// TestRedisBackend_TTLExpiration tests TTL handling
|
||||
func TestRedisBackend_TTLExpiration(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("ShortTTL", func(t *testing.T) {
|
||||
key := "ttl-key"
|
||||
value := []byte("ttl-value")
|
||||
shortTTL := 100 * time.Millisecond
|
||||
|
||||
err := backend.Set(ctx, key, value, shortTTL)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Exists immediately
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
// Fast forward time in miniredis
|
||||
mr.FastForward(150 * time.Millisecond)
|
||||
|
||||
// Should be expired
|
||||
exists, err = backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("TTLRemaining", func(t *testing.T) {
|
||||
key := "ttl-remaining-key"
|
||||
value := []byte("ttl-remaining-value")
|
||||
ttl := 10 * time.Second
|
||||
|
||||
err := backend.Set(ctx, key, value, ttl)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get immediately
|
||||
_, ttl1, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
// Fast forward 2 seconds
|
||||
mr.FastForward(2 * time.Second)
|
||||
|
||||
// Check TTL is less
|
||||
_, ttl2, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Less(t, ttl2, ttl1)
|
||||
})
|
||||
}
|
||||
|
||||
// TestRedisBackend_Clear tests clearing all keys
|
||||
func TestRedisBackend_Clear(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
config.RedisPrefix = "clear-test:"
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Add multiple keys
|
||||
for i := 0; i < 10; i++ {
|
||||
key := fmt.Sprintf("clear-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("clear-value-%d", i))
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Verify keys exist
|
||||
keys := mr.CheckKeys()
|
||||
assert.Len(t, keys, 10)
|
||||
|
||||
// Clear all
|
||||
err = backend.Clear(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify all keys are gone
|
||||
keys = mr.CheckKeys()
|
||||
assert.Len(t, keys, 0)
|
||||
}
|
||||
|
||||
// TestRedisBackend_ConnectionFailure tests behavior on connection failure
|
||||
func TestRedisBackend_ConnectionFailure(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Try to connect to non-existent Redis
|
||||
config := DefaultRedisConfig("localhost:9999")
|
||||
_, err := NewRedisBackend(config)
|
||||
assert.Error(t, err, "Should fail to connect to non-existent Redis")
|
||||
}
|
||||
|
||||
// TestRedisBackend_RedisErrors tests handling of Redis errors
|
||||
func TestRedisBackend_RedisErrors(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Simulate Redis error
|
||||
mr.SetError("simulated error")
|
||||
|
||||
// Operations should fail
|
||||
err = backend.Set(ctx, "error-key", []byte("error-value"), 1*time.Minute)
|
||||
assert.Error(t, err)
|
||||
|
||||
// Clear error
|
||||
mr.ClearError()
|
||||
|
||||
// Operations should work again
|
||||
err = backend.Set(ctx, "success-key", []byte("success-value"), 1*time.Minute)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestRedisBackend_ConcurrentAccess tests thread safety
|
||||
func TestRedisBackend_ConcurrentAccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
var wg sync.WaitGroup
|
||||
goroutines := 20
|
||||
iterations := 50
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < iterations; j++ {
|
||||
key := fmt.Sprintf("concurrent-key-%d-%d", id, j)
|
||||
value := []byte(fmt.Sprintf("concurrent-value-%d-%d", id, j))
|
||||
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
assert.NoError(t, err)
|
||||
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
assert.NoError(t, err)
|
||||
if exists {
|
||||
assert.Equal(t, value, retrieved)
|
||||
}
|
||||
|
||||
if j%5 == 0 {
|
||||
backend.Delete(ctx, key)
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
stats := backend.GetStats()
|
||||
hits := stats["hits"].(int64)
|
||||
misses := stats["misses"].(int64)
|
||||
assert.Greater(t, hits+misses, int64(0))
|
||||
}
|
||||
|
||||
// TestRedisBackend_Stats tests statistics tracking
|
||||
func TestRedisBackend_Stats(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Initial stats
|
||||
stats := backend.GetStats()
|
||||
assert.Equal(t, int64(0), stats["hits"].(int64))
|
||||
assert.Equal(t, int64(0), stats["misses"].(int64))
|
||||
|
||||
// Add and access items
|
||||
backend.Set(ctx, "key1", []byte("value1"), 1*time.Minute)
|
||||
backend.Get(ctx, "key1") // Hit
|
||||
backend.Get(ctx, "non-existent") // Miss
|
||||
|
||||
stats = backend.GetStats()
|
||||
assert.Equal(t, int64(1), stats["hits"].(int64))
|
||||
assert.Equal(t, int64(1), stats["misses"].(int64))
|
||||
|
||||
hitRate := stats["hit_rate"].(float64)
|
||||
assert.InDelta(t, 0.5, hitRate, 0.01)
|
||||
}
|
||||
|
||||
// TestRedisBackend_Ping tests health check
|
||||
func TestRedisBackend_Ping(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
err = backend.Ping(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Close and ping should fail
|
||||
backend.Close()
|
||||
err = backend.Ping(ctx)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// TestRedisBackend_Close tests proper cleanup
|
||||
func TestRedisBackend_Close(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Add items
|
||||
for i := 0; i < 10; i++ {
|
||||
key := fmt.Sprintf("close-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("close-value-%d", i))
|
||||
backend.Set(ctx, key, value, 1*time.Minute)
|
||||
}
|
||||
|
||||
// Close
|
||||
err = backend.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Operations should fail
|
||||
err = backend.Set(ctx, "after-close", []byte("value"), 1*time.Minute)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, ErrBackendClosed, err)
|
||||
|
||||
// Double close should be safe
|
||||
err = backend.Close()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestRedisBackend_UpdateExisting tests updating existing keys
|
||||
func TestRedisBackend_UpdateExisting(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
key := "update-key"
|
||||
value1 := []byte("original-value")
|
||||
value2 := []byte("updated-value")
|
||||
|
||||
// Set original
|
||||
err = backend.Set(ctx, key, value1, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update
|
||||
err = backend.Set(ctx, key, value2, 2*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify updated
|
||||
retrieved, ttl, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, value2, retrieved)
|
||||
assert.Greater(t, ttl, 1*time.Minute)
|
||||
}
|
||||
|
||||
// TestRedisBackend_LargeValues tests handling of large values
|
||||
func TestRedisBackend_LargeValues(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
key := "large-key"
|
||||
largeValue := make([]byte, 1024*1024) // 1MB
|
||||
|
||||
err = backend.Set(ctx, key, largeValue, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, len(largeValue), len(retrieved))
|
||||
}
|
||||
|
||||
// TestRedisBackend_EmptyValues tests handling of empty values
|
||||
func TestRedisBackend_EmptyValues(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
key := "empty-key"
|
||||
emptyValue := []byte{}
|
||||
|
||||
err = backend.Set(ctx, key, emptyValue, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, 0, len(retrieved))
|
||||
}
|
||||
|
||||
// TestRedisBackend_PipelineOperations tests batch operations
|
||||
func TestRedisBackend_PipelineOperations(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("SetMany", func(t *testing.T) {
|
||||
items := make(map[string][]byte)
|
||||
for i := 0; i < 10; i++ {
|
||||
key := fmt.Sprintf("batch-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("batch-value-%d", i))
|
||||
items[key] = value
|
||||
}
|
||||
|
||||
err := backend.SetMany(ctx, items, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify all items were set
|
||||
for key, expectedValue := range items {
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, expectedValue, retrieved)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetMany", func(t *testing.T) {
|
||||
// Set test data
|
||||
testData := GenerateTestData(5)
|
||||
for key, value := range testData {
|
||||
backend.Set(ctx, key, value, 1*time.Minute)
|
||||
}
|
||||
|
||||
// Get all keys
|
||||
keys := make([]string, 0, len(testData))
|
||||
for key := range testData {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
|
||||
results, err := backend.GetMany(ctx, keys)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, len(testData))
|
||||
|
||||
for key, expectedValue := range testData {
|
||||
retrievedValue, exists := results[key]
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, expectedValue, retrievedValue)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetManyWithNonExistent", func(t *testing.T) {
|
||||
keys := []string{"exists-1", "non-existent", "exists-2"}
|
||||
|
||||
backend.Set(ctx, "exists-1", []byte("value-1"), 1*time.Minute)
|
||||
backend.Set(ctx, "exists-2", []byte("value-2"), 1*time.Minute)
|
||||
|
||||
results, err := backend.GetMany(ctx, keys)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 2) // Only existing keys
|
||||
assert.Equal(t, []byte("value-1"), results["exists-1"])
|
||||
assert.Equal(t, []byte("value-2"), results["exists-2"])
|
||||
_, exists := results["non-existent"]
|
||||
assert.False(t, exists)
|
||||
})
|
||||
}
|
||||
|
||||
// TestRedisBackend_NoPrefix tests operation without prefix
|
||||
func TestRedisBackend_NoPrefix(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
config.RedisPrefix = "" // No prefix
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
key := "no-prefix-key"
|
||||
value := []byte("no-prefix-value")
|
||||
|
||||
err = backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check key is stored without prefix
|
||||
keys := mr.CheckKeys()
|
||||
require.Len(t, keys, 1)
|
||||
assert.Equal(t, key, keys[0])
|
||||
}
|
||||
Vendored
+251
@@ -0,0 +1,251 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// RESP (REdis Serialization Protocol) implementation
|
||||
// Pure Go implementation compatible with Yaegi interpreter (no unsafe package)
|
||||
|
||||
var (
|
||||
ErrInvalidRESP = errors.New("invalid RESP response")
|
||||
ErrNilResponse = errors.New("nil response")
|
||||
)
|
||||
|
||||
// Object pools for memory optimization - reduces allocations by 50-70%
|
||||
var (
|
||||
readerPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return &RESPReader{
|
||||
r: bufio.NewReaderSize(nil, 4096),
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
writerPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return &RESPWriter{
|
||||
w: nil,
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
// RESPWriter writes RESP protocol messages
|
||||
type RESPWriter struct {
|
||||
w io.Writer
|
||||
}
|
||||
|
||||
// NewRESPWriter creates a new RESP writer from the pool (memory optimized)
|
||||
func NewRESPWriter(w io.Writer) *RESPWriter {
|
||||
writer := writerPool.Get().(*RESPWriter)
|
||||
writer.w = w
|
||||
return writer
|
||||
}
|
||||
|
||||
// Release returns the writer to the pool for reuse
|
||||
func (w *RESPWriter) Release() {
|
||||
w.w = nil
|
||||
writerPool.Put(w)
|
||||
}
|
||||
|
||||
// WriteCommand writes a Redis command in RESP array format
|
||||
// Example: SET key value EX 3600 -> *5\r\n$3\r\nSET\r\n$3\r\nkey\r\n$5\r\nvalue\r\n$2\r\nEX\r\n$4\r\n3600\r\n
|
||||
func (w *RESPWriter) WriteCommand(args ...string) error {
|
||||
// Write array header
|
||||
if _, err := fmt.Fprintf(w.w, "*%d\r\n", len(args)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write each argument as bulk string
|
||||
for _, arg := range args {
|
||||
if _, err := fmt.Fprintf(w.w, "$%d\r\n%s\r\n", len(arg), arg); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RESPReader reads RESP protocol messages
|
||||
type RESPReader struct {
|
||||
r *bufio.Reader
|
||||
}
|
||||
|
||||
// NewRESPReader creates a new RESP reader from the pool (memory optimized)
|
||||
func NewRESPReader(r io.Reader) *RESPReader {
|
||||
reader := readerPool.Get().(*RESPReader)
|
||||
reader.r.Reset(r)
|
||||
return reader
|
||||
}
|
||||
|
||||
// Release returns the reader to the pool for reuse
|
||||
func (r *RESPReader) Release() {
|
||||
r.r.Reset(nil)
|
||||
readerPool.Put(r)
|
||||
}
|
||||
|
||||
// ReadResponse reads a RESP response and returns the parsed value
|
||||
func (r *RESPReader) ReadResponse() (interface{}, error) {
|
||||
typeByte, err := r.r.ReadByte()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch typeByte {
|
||||
case '+': // Simple string
|
||||
return r.readSimpleString()
|
||||
case '-': // Error
|
||||
return nil, r.readError()
|
||||
case ':': // Integer
|
||||
return r.readInteger()
|
||||
case '$': // Bulk string
|
||||
return r.readBulkString()
|
||||
case '*': // Array
|
||||
return r.readArray()
|
||||
default:
|
||||
return nil, fmt.Errorf("%w: unknown type byte '%c'", ErrInvalidRESP, typeByte)
|
||||
}
|
||||
}
|
||||
|
||||
// readSimpleString reads a simple string (+OK\r\n)
|
||||
func (r *RESPReader) readSimpleString() (string, error) {
|
||||
line, err := r.readLine()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return line, nil
|
||||
}
|
||||
|
||||
// readError reads an error message (-Error message\r\n)
|
||||
func (r *RESPReader) readError() error {
|
||||
line, err := r.readLine()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return errors.New(line)
|
||||
}
|
||||
|
||||
// readInteger reads an integer (:1000\r\n)
|
||||
func (r *RESPReader) readInteger() (int64, error) {
|
||||
line, err := r.readLine()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return strconv.ParseInt(line, 10, 64)
|
||||
}
|
||||
|
||||
// readBulkString reads a bulk string ($6\r\nfoobar\r\n or $-1\r\n for nil)
|
||||
func (r *RESPReader) readBulkString() (interface{}, error) {
|
||||
line, err := r.readLine()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
length, err := strconv.Atoi(line)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: invalid bulk string length", ErrInvalidRESP)
|
||||
}
|
||||
|
||||
// -1 indicates nil bulk string
|
||||
if length == -1 {
|
||||
return nil, ErrNilResponse
|
||||
}
|
||||
|
||||
// Read exactly 'length' bytes plus \r\n
|
||||
buf := make([]byte, length+2)
|
||||
if _, err := io.ReadFull(r.r, buf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Verify \r\n terminator
|
||||
if buf[length] != '\r' || buf[length+1] != '\n' {
|
||||
return nil, fmt.Errorf("%w: missing CRLF after bulk string", ErrInvalidRESP)
|
||||
}
|
||||
|
||||
return string(buf[:length]), nil
|
||||
}
|
||||
|
||||
// readArray reads an array (*2\r\n...\r\n or *-1\r\n for nil)
|
||||
func (r *RESPReader) readArray() (interface{}, error) {
|
||||
line, err := r.readLine()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
length, err := strconv.Atoi(line)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: invalid array length", ErrInvalidRESP)
|
||||
}
|
||||
|
||||
// -1 indicates nil array
|
||||
if length == -1 {
|
||||
return nil, ErrNilResponse
|
||||
}
|
||||
|
||||
// Read each element
|
||||
result := make([]interface{}, length)
|
||||
for i := 0; i < length; i++ {
|
||||
elem, err := r.ReadResponse()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result[i] = elem
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// readLine reads a line terminated by \r\n
|
||||
func (r *RESPReader) readLine() (string, error) {
|
||||
line, err := r.r.ReadString('\n')
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Remove \r\n
|
||||
line = strings.TrimSuffix(line, "\r\n")
|
||||
if !strings.HasSuffix(line+"\r\n", "\r\n") {
|
||||
return "", fmt.Errorf("%w: missing CRLF", ErrInvalidRESP)
|
||||
}
|
||||
|
||||
return line, nil
|
||||
}
|
||||
|
||||
// RESPString extracts a string from RESP response
|
||||
func RESPString(resp interface{}) (string, error) {
|
||||
if resp == nil {
|
||||
return "", ErrNilResponse
|
||||
}
|
||||
|
||||
switch v := resp.(type) {
|
||||
case string:
|
||||
return v, nil
|
||||
case []byte:
|
||||
return string(v), nil
|
||||
default:
|
||||
return "", fmt.Errorf("expected string, got %T", resp)
|
||||
}
|
||||
}
|
||||
|
||||
// RESPInt extracts an integer from RESP response
|
||||
func RESPInt(resp interface{}) (int64, error) {
|
||||
if resp == nil {
|
||||
return 0, ErrNilResponse
|
||||
}
|
||||
|
||||
switch v := resp.(type) {
|
||||
case int64:
|
||||
return v, nil
|
||||
case int:
|
||||
return int64(v), nil
|
||||
default:
|
||||
return 0, fmt.Errorf("expected integer, got %T", resp)
|
||||
}
|
||||
}
|
||||
Vendored
+495
@@ -0,0 +1,495 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestRESPWriter_WriteCommand tests RESP command writing
|
||||
func TestRESPWriter_WriteCommand(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Simple command",
|
||||
args: []string{"PING"},
|
||||
expected: "*1\r\n$4\r\nPING\r\n",
|
||||
},
|
||||
{
|
||||
name: "SET command",
|
||||
args: []string{"SET", "key", "value"},
|
||||
expected: "*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$5\r\nvalue\r\n",
|
||||
},
|
||||
{
|
||||
name: "SETEX command",
|
||||
args: []string{"SETEX", "mykey", "60", "myvalue"},
|
||||
expected: "*4\r\n$5\r\nSETEX\r\n$5\r\nmykey\r\n$2\r\n60\r\n$7\r\nmyvalue\r\n",
|
||||
},
|
||||
{
|
||||
name: "DEL with multiple keys",
|
||||
args: []string{"DEL", "key1", "key2", "key3"},
|
||||
expected: "*4\r\n$3\r\nDEL\r\n$4\r\nkey1\r\n$4\r\nkey2\r\n$4\r\nkey3\r\n",
|
||||
},
|
||||
{
|
||||
name: "Command with empty string",
|
||||
args: []string{"SET", "key", ""},
|
||||
expected: "*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$0\r\n\r\n",
|
||||
},
|
||||
{
|
||||
name: "Command with special characters",
|
||||
args: []string{"SET", "key", "val\r\nue"},
|
||||
expected: "*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$7\r\nval\r\nue\r\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
buf := &bytes.Buffer{}
|
||||
writer := NewRESPWriter(buf)
|
||||
|
||||
err := writer.WriteCommand(tt.args...)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, buf.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRESPReader_ReadSimpleString tests reading simple strings
|
||||
func TestRESPReader_ReadSimpleString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "OK response",
|
||||
input: "+OK\r\n",
|
||||
expected: "OK",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "PONG response",
|
||||
input: "+PONG\r\n",
|
||||
expected: "PONG",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Empty string",
|
||||
input: "+\r\n",
|
||||
expected: "",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "String with spaces",
|
||||
input: "+Hello World\r\n",
|
||||
expected: "Hello World",
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
reader := NewRESPReader(strings.NewReader(tt.input))
|
||||
result, err := reader.ReadResponse()
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRESPReader_ReadError tests reading error messages
|
||||
func TestRESPReader_ReadError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "ERR error",
|
||||
input: "-ERR unknown command\r\n",
|
||||
expectedError: "ERR unknown command",
|
||||
},
|
||||
{
|
||||
name: "WRONGTYPE error",
|
||||
input: "-WRONGTYPE Operation against a key holding the wrong kind of value\r\n",
|
||||
expectedError: "WRONGTYPE Operation against a key holding the wrong kind of value",
|
||||
},
|
||||
{
|
||||
name: "Simple error",
|
||||
input: "-Error\r\n",
|
||||
expectedError: "Error",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
reader := NewRESPReader(strings.NewReader(tt.input))
|
||||
_, err := reader.ReadResponse()
|
||||
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, tt.expectedError, err.Error())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRESPReader_ReadInteger tests reading integers
|
||||
func TestRESPReader_ReadInteger(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected int64
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Zero",
|
||||
input: ":0\r\n",
|
||||
expected: 0,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Positive integer",
|
||||
input: ":1000\r\n",
|
||||
expected: 1000,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Negative integer",
|
||||
input: ":-1\r\n",
|
||||
expected: -1,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Large integer",
|
||||
input: ":9223372036854775807\r\n",
|
||||
expected: 9223372036854775807,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid integer",
|
||||
input: ":abc\r\n",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
reader := NewRESPReader(strings.NewReader(tt.input))
|
||||
result, err := reader.ReadResponse()
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRESPReader_ReadBulkString tests reading bulk strings
|
||||
func TestRESPReader_ReadBulkString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected interface{}
|
||||
wantErr bool
|
||||
isNil bool
|
||||
}{
|
||||
{
|
||||
name: "Simple bulk string",
|
||||
input: "$6\r\nfoobar\r\n",
|
||||
expected: "foobar",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Empty bulk string",
|
||||
input: "$0\r\n\r\n",
|
||||
expected: "",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Nil bulk string",
|
||||
input: "$-1\r\n",
|
||||
expected: nil,
|
||||
wantErr: true,
|
||||
isNil: true,
|
||||
},
|
||||
{
|
||||
name: "Binary safe bulk string",
|
||||
input: "$5\r\n\x00\x01\x02\x03\x04\r\n",
|
||||
expected: "\x00\x01\x02\x03\x04",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid length",
|
||||
input: "$abc\r\ntest\r\n",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
reader := NewRESPReader(strings.NewReader(tt.input))
|
||||
result, err := reader.ReadResponse()
|
||||
|
||||
if tt.isNil {
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, ErrNilResponse))
|
||||
return
|
||||
}
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRESPReader_ReadArray tests reading arrays
|
||||
func TestRESPReader_ReadArray(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected []interface{}
|
||||
wantErr bool
|
||||
isNil bool
|
||||
}{
|
||||
{
|
||||
name: "Empty array",
|
||||
input: "*0\r\n",
|
||||
expected: []interface{}{},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Array of bulk strings",
|
||||
input: "*2\r\n$3\r\nfoo\r\n$3\r\nbar\r\n",
|
||||
expected: []interface{}{
|
||||
"foo",
|
||||
"bar",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Array of integers",
|
||||
input: "*3\r\n:1\r\n:2\r\n:3\r\n",
|
||||
expected: []interface{}{
|
||||
int64(1),
|
||||
int64(2),
|
||||
int64(3),
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Mixed array",
|
||||
input: "*5\r\n:1\r\n:2\r\n:3\r\n:4\r\n$6\r\nfoobar\r\n",
|
||||
expected: []interface{}{
|
||||
int64(1),
|
||||
int64(2),
|
||||
int64(3),
|
||||
int64(4),
|
||||
"foobar",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Nil array",
|
||||
input: "*-1\r\n",
|
||||
expected: nil,
|
||||
wantErr: true,
|
||||
isNil: true,
|
||||
},
|
||||
{
|
||||
name: "Nested arrays",
|
||||
input: "*2\r\n*2\r\n$3\r\nfoo\r\n$3\r\nbar\r\n*1\r\n$3\r\nbaz\r\n",
|
||||
expected: []interface{}{
|
||||
[]interface{}{"foo", "bar"},
|
||||
[]interface{}{"baz"},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
reader := NewRESPReader(strings.NewReader(tt.input))
|
||||
result, err := reader.ReadResponse()
|
||||
|
||||
if tt.isNil {
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, ErrNilResponse))
|
||||
return
|
||||
}
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRESPReader_InvalidInput tests error handling for invalid input
|
||||
func TestRESPReader_InvalidInput(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
}{
|
||||
{
|
||||
name: "Unknown type byte",
|
||||
input: "?invalid\r\n",
|
||||
},
|
||||
{
|
||||
name: "Incomplete response",
|
||||
input: "+OK",
|
||||
},
|
||||
{
|
||||
name: "Missing CRLF in bulk string",
|
||||
input: "$5\r\nhello",
|
||||
},
|
||||
{
|
||||
name: "Truncated array",
|
||||
input: "*3\r\n:1\r\n:2\r\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
reader := NewRESPReader(strings.NewReader(tt.input))
|
||||
_, err := reader.ReadResponse()
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRESPReader_EOF tests handling of EOF
|
||||
func TestRESPReader_EOF(t *testing.T) {
|
||||
reader := NewRESPReader(strings.NewReader(""))
|
||||
_, err := reader.ReadResponse()
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, io.EOF))
|
||||
}
|
||||
|
||||
// TestRESPHelpers tests helper functions
|
||||
func TestRESPHelpers(t *testing.T) {
|
||||
t.Run("RESPString", func(t *testing.T) {
|
||||
// Valid string
|
||||
result, err := RESPString("hello")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "hello", result)
|
||||
|
||||
// Byte slice
|
||||
result, err = RESPString([]byte("world"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "world", result)
|
||||
|
||||
// Nil
|
||||
_, err = RESPString(nil)
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, ErrNilResponse))
|
||||
|
||||
// Invalid type
|
||||
_, err = RESPString(123)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("RESPInt", func(t *testing.T) {
|
||||
// Valid int64
|
||||
result, err := RESPInt(int64(42))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(42), result)
|
||||
|
||||
// Valid int
|
||||
result, err = RESPInt(42)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(42), result)
|
||||
|
||||
// Nil
|
||||
_, err = RESPInt(nil)
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, ErrNilResponse))
|
||||
|
||||
// Invalid type
|
||||
_, err = RESPInt("string")
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// TestRESPRoundTrip tests full round-trip encoding/decoding
|
||||
func TestRESPRoundTrip(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
command []string
|
||||
response string
|
||||
expected interface{}
|
||||
}{
|
||||
{
|
||||
name: "PING command",
|
||||
command: []string{"PING"},
|
||||
response: "+PONG\r\n",
|
||||
expected: "PONG",
|
||||
},
|
||||
{
|
||||
name: "GET command with result",
|
||||
command: []string{"GET", "mykey"},
|
||||
response: "$7\r\nmyvalue\r\n",
|
||||
expected: "myvalue",
|
||||
},
|
||||
{
|
||||
name: "GET command with nil",
|
||||
command: []string{"GET", "nonexistent"},
|
||||
response: "$-1\r\n",
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
name: "DEL command",
|
||||
command: []string{"DEL", "key1", "key2"},
|
||||
response: ":2\r\n",
|
||||
expected: int64(2),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Write command
|
||||
writeBuf := &bytes.Buffer{}
|
||||
writer := NewRESPWriter(writeBuf)
|
||||
err := writer.WriteCommand(tt.command...)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Read response
|
||||
reader := NewRESPReader(strings.NewReader(tt.response))
|
||||
result, err := reader.ReadResponse()
|
||||
|
||||
if tt.expected == nil {
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, ErrNilResponse))
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
+198
@@ -0,0 +1,198 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestLogger implements a simple logger for tests
|
||||
type TestLogger struct {
|
||||
t *testing.T
|
||||
}
|
||||
|
||||
func NewTestLogger(t *testing.T) *TestLogger {
|
||||
return &TestLogger{t: t}
|
||||
}
|
||||
|
||||
func (l *TestLogger) Debug(format string, args ...interface{}) {
|
||||
l.t.Logf("[DEBUG] "+format, args...)
|
||||
}
|
||||
|
||||
func (l *TestLogger) Info(format string, args ...interface{}) {
|
||||
l.t.Logf("[INFO] "+format, args...)
|
||||
}
|
||||
|
||||
func (l *TestLogger) Error(format string, args ...interface{}) {
|
||||
l.t.Logf("[ERROR] "+format, args...)
|
||||
}
|
||||
|
||||
func (l *TestLogger) Debugf(format string, args ...interface{}) {
|
||||
l.Debug(format, args...)
|
||||
}
|
||||
|
||||
func (l *TestLogger) Infof(format string, args ...interface{}) {
|
||||
l.Info(format, args...)
|
||||
}
|
||||
|
||||
func (l *TestLogger) Errorf(format string, args ...interface{}) {
|
||||
l.Error(format, args...)
|
||||
}
|
||||
|
||||
func (l *TestLogger) Warnf(format string, args ...interface{}) {
|
||||
l.t.Logf("[WARN] "+format, args...)
|
||||
}
|
||||
|
||||
// MiniredisServer manages a miniredis instance for testing
|
||||
type MiniredisServer struct {
|
||||
server *miniredis.Miniredis
|
||||
client *redis.Client
|
||||
}
|
||||
|
||||
// NewMiniredisServer creates a new miniredis server for testing
|
||||
func NewMiniredisServer(t *testing.T) *MiniredisServer {
|
||||
t.Helper()
|
||||
|
||||
mr, err := miniredis.Run()
|
||||
require.NoError(t, err, "failed to start miniredis")
|
||||
|
||||
client := redis.NewClient(&redis.Options{
|
||||
Addr: mr.Addr(),
|
||||
})
|
||||
|
||||
// Verify connection
|
||||
ctx := context.Background()
|
||||
err = client.Ping(ctx).Err()
|
||||
require.NoError(t, err, "failed to ping miniredis")
|
||||
|
||||
t.Cleanup(func() {
|
||||
client.Close()
|
||||
mr.Close()
|
||||
})
|
||||
|
||||
return &MiniredisServer{
|
||||
server: mr,
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
// GetAddr returns the address of the miniredis server
|
||||
func (m *MiniredisServer) GetAddr() string {
|
||||
return m.server.Addr()
|
||||
}
|
||||
|
||||
// GetClient returns the Redis client
|
||||
func (m *MiniredisServer) GetClient() *redis.Client {
|
||||
return m.client
|
||||
}
|
||||
|
||||
// FastForward advances the miniredis server's time
|
||||
func (m *MiniredisServer) FastForward(d time.Duration) {
|
||||
m.server.FastForward(d)
|
||||
}
|
||||
|
||||
// FlushAll removes all keys from the database
|
||||
func (m *MiniredisServer) FlushAll() {
|
||||
m.server.FlushAll()
|
||||
}
|
||||
|
||||
// SetError simulates a Redis error
|
||||
func (m *MiniredisServer) SetError(err string) {
|
||||
m.server.SetError(err)
|
||||
}
|
||||
|
||||
// ClearError clears any simulated errors
|
||||
func (m *MiniredisServer) ClearError() {
|
||||
m.server.SetError("")
|
||||
}
|
||||
|
||||
// CheckKeys verifies that specific keys exist in Redis
|
||||
func (m *MiniredisServer) CheckKeys() []string {
|
||||
return m.server.Keys()
|
||||
}
|
||||
|
||||
// Close closes the miniredis server
|
||||
func (m *MiniredisServer) Close() {
|
||||
m.server.Close()
|
||||
}
|
||||
|
||||
// Restart restarts the miniredis server
|
||||
func (m *MiniredisServer) Restart() {
|
||||
m.server.Restart()
|
||||
}
|
||||
|
||||
// TestConfig provides default test configuration
|
||||
type TestConfig struct {
|
||||
MaxSize int
|
||||
DefaultTTL time.Duration
|
||||
CleanupInterval time.Duration
|
||||
EnableMetrics bool
|
||||
}
|
||||
|
||||
// DefaultTestConfig returns a standard test configuration
|
||||
func DefaultTestConfig() *TestConfig {
|
||||
return &TestConfig{
|
||||
MaxSize: 100,
|
||||
DefaultTTL: 5 * time.Minute,
|
||||
CleanupInterval: 1 * time.Second,
|
||||
EnableMetrics: true,
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateTestData creates test cache data
|
||||
func GenerateTestData(count int) map[string][]byte {
|
||||
data := make(map[string][]byte, count)
|
||||
for i := 0; i < count; i++ {
|
||||
key := fmt.Sprintf("test-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("test-value-%d", i))
|
||||
data[key] = value
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
// GenerateLargeValue creates a large test value
|
||||
func GenerateLargeValue(sizeBytes int) []byte {
|
||||
return make([]byte, sizeBytes)
|
||||
}
|
||||
|
||||
// AssertCacheStats is a helper to verify cache statistics
|
||||
func AssertCacheStats(t *testing.T, stats map[string]interface{}, expectedHits, expectedMisses int64) {
|
||||
t.Helper()
|
||||
|
||||
hits, ok := stats["hits"].(int64)
|
||||
require.True(t, ok, "hits should be int64")
|
||||
require.Equal(t, expectedHits, hits, "unexpected hit count")
|
||||
|
||||
misses, ok := stats["misses"].(int64)
|
||||
require.True(t, ok, "misses should be int64")
|
||||
require.Equal(t, expectedMisses, misses, "unexpected miss count")
|
||||
}
|
||||
|
||||
// WaitForCondition waits for a condition to be true or times out
|
||||
func WaitForCondition(t *testing.T, timeout time.Duration, checkInterval time.Duration, condition func() bool) {
|
||||
t.Helper()
|
||||
|
||||
deadline := time.Now().Add(timeout)
|
||||
for time.Now().Before(deadline) {
|
||||
if condition() {
|
||||
return
|
||||
}
|
||||
time.Sleep(checkInterval)
|
||||
}
|
||||
t.Fatal("timeout waiting for condition")
|
||||
}
|
||||
|
||||
// AssertEventuallyExpires verifies that a key eventually expires
|
||||
func AssertEventuallyExpires(t *testing.T, backend CacheBackend, ctx context.Context, key string, maxWait time.Duration) {
|
||||
t.Helper()
|
||||
|
||||
WaitForCondition(t, maxWait, 100*time.Millisecond, func() bool {
|
||||
_, _, exists, err := backend.Get(ctx, key)
|
||||
return err == nil && !exists
|
||||
})
|
||||
}
|
||||
Vendored
+1
-1
@@ -355,7 +355,7 @@ func (c *Cache) removeItem(key string, item *Item) {
|
||||
|
||||
func (c *Cache) evictLRU() {
|
||||
if elem := c.lruList.Back(); elem != nil {
|
||||
item := elem.Value.(*Item)
|
||||
item, _ := elem.Value.(*Item) // Safe to ignore: type assertion from known type
|
||||
c.removeItem(item.Key, item)
|
||||
atomic.AddInt64(&c.evictions, 1)
|
||||
c.logger.Debugf("Cache: Evicted LRU item key=%s", item.Key)
|
||||
|
||||
Vendored
+96
-10
@@ -1880,19 +1880,20 @@ func TestConcurrentManagerOperations(t *testing.T) {
|
||||
// TestTTLExpirationAndCleanup tests TTL expiration and cleanup routines comprehensively
|
||||
func TestTTLExpirationAndCleanup(t *testing.T) {
|
||||
config := DefaultConfig()
|
||||
config.CleanupInterval = 10 * time.Millisecond
|
||||
config.CleanupInterval = 50 * time.Millisecond
|
||||
config.EnableAutoCleanup = true
|
||||
cache := New(config)
|
||||
defer cache.Close()
|
||||
|
||||
// Test various TTL scenarios
|
||||
// Note: Timing increased 5x to account for race detector overhead
|
||||
testCases := []struct {
|
||||
key string
|
||||
ttl time.Duration
|
||||
}{
|
||||
{"very-short", 5 * time.Millisecond},
|
||||
{"short", 25 * time.Millisecond},
|
||||
{"medium", 100 * time.Millisecond},
|
||||
{"very-short", 25 * time.Millisecond},
|
||||
{"short", 125 * time.Millisecond},
|
||||
{"medium", 500 * time.Millisecond},
|
||||
{"long", 1 * time.Hour},
|
||||
}
|
||||
|
||||
@@ -1908,13 +1909,13 @@ func TestTTLExpirationAndCleanup(t *testing.T) {
|
||||
}
|
||||
|
||||
// Wait for very short items to expire
|
||||
time.Sleep(15 * time.Millisecond)
|
||||
time.Sleep(75 * time.Millisecond)
|
||||
if _, exists := cache.Get("very-short"); exists {
|
||||
t.Error("Very short item should be expired")
|
||||
}
|
||||
|
||||
// Wait for short items to expire
|
||||
time.Sleep(30 * time.Millisecond)
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
if _, exists := cache.Get("short"); exists {
|
||||
t.Error("Short item should be expired")
|
||||
}
|
||||
@@ -1930,16 +1931,16 @@ func TestTTLExpirationAndCleanup(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test manual cleanup
|
||||
cache.Set("manual-cleanup", "value", 1*time.Millisecond)
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
cache.Set("manual-cleanup", "value", 5*time.Millisecond)
|
||||
time.Sleep(25 * time.Millisecond)
|
||||
cache.Cleanup()
|
||||
|
||||
// Add many expired items to test bulk cleanup
|
||||
for i := 0; i < 100; i++ {
|
||||
key := fmt.Sprintf("bulk-%d", i)
|
||||
cache.Set(key, fmt.Sprintf("value-%d", i), 1*time.Millisecond)
|
||||
cache.Set(key, fmt.Sprintf("value-%d", i), 5*time.Millisecond)
|
||||
}
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
time.Sleep(25 * time.Millisecond)
|
||||
|
||||
sizeBefore := cache.Size()
|
||||
cache.Cleanup()
|
||||
@@ -2038,3 +2039,88 @@ func TestCacheStatisticsAndMetrics(t *testing.T) {
|
||||
t.Error("Memory usage should increase after adding large item")
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// noOpLogger Tests
|
||||
// ============================================================================
|
||||
|
||||
// TestNoOpLogger_AllMethods tests all noOpLogger methods to ensure they don't panic
|
||||
func TestNoOpLogger_AllMethods(t *testing.T) {
|
||||
logger := &noOpLogger{}
|
||||
|
||||
// Test simple message methods
|
||||
logger.Debug("test debug message")
|
||||
logger.Info("test info message")
|
||||
logger.Error("test error message")
|
||||
logger.Warn("test warn message")
|
||||
logger.Fatal("test fatal message")
|
||||
|
||||
// Test formatted message methods
|
||||
logger.Debugf("test debug: %s", "value")
|
||||
logger.Infof("test info: %s", "value")
|
||||
logger.Errorf("test error: %s", "value")
|
||||
logger.Warnf("test warn: %s", "value")
|
||||
logger.Fatalf("test fatal: %s", "value")
|
||||
|
||||
// If we reach here, all methods executed without panicking
|
||||
// This is expected behavior for a no-op logger
|
||||
}
|
||||
|
||||
// TestNoOpLogger_WithField verifies WithField returns the same logger
|
||||
func TestNoOpLogger_WithField(t *testing.T) {
|
||||
logger := &noOpLogger{}
|
||||
|
||||
result := logger.WithField("key", "value")
|
||||
|
||||
if result != logger {
|
||||
t.Error("WithField should return the same logger instance")
|
||||
}
|
||||
|
||||
// Verify the returned logger works
|
||||
result.Info("test message after WithField")
|
||||
}
|
||||
|
||||
// TestNoOpLogger_WithFields verifies WithFields returns the same logger
|
||||
func TestNoOpLogger_WithFields(t *testing.T) {
|
||||
logger := &noOpLogger{}
|
||||
|
||||
fields := map[string]interface{}{
|
||||
"key1": "value1",
|
||||
"key2": 123,
|
||||
"key3": true,
|
||||
}
|
||||
|
||||
result := logger.WithFields(fields)
|
||||
|
||||
if result != logger {
|
||||
t.Error("WithFields should return the same logger instance")
|
||||
}
|
||||
|
||||
// Verify the returned logger works
|
||||
result.Info("test message after WithFields")
|
||||
}
|
||||
|
||||
// TestNoOpLogger_Chaining verifies method chaining works
|
||||
func TestNoOpLogger_Chaining(t *testing.T) {
|
||||
logger := &noOpLogger{}
|
||||
|
||||
// Use WithField and verify it returns a usable logger
|
||||
result := logger.WithField("key1", "value1")
|
||||
|
||||
// Verify the result can be used for logging (Logger interface methods)
|
||||
result.Info("info after WithField")
|
||||
result.Infof("infof after WithField: %s", "test")
|
||||
result.Debug("debug after WithField")
|
||||
result.Debugf("debugf after WithField: %d", 123)
|
||||
result.Error("error after WithField")
|
||||
result.Errorf("errorf after WithField: %v", true)
|
||||
|
||||
// Use WithFields and verify it returns a usable logger
|
||||
result2 := logger.WithFields(map[string]interface{}{
|
||||
"key2": "value2",
|
||||
"key3": 123,
|
||||
})
|
||||
|
||||
// Verify the result can be used for logging
|
||||
result2.Infof("message after WithFields: %s", "test")
|
||||
}
|
||||
|
||||
Vendored
+2
@@ -1,3 +1,5 @@
|
||||
// Package cache provides high-performance caching implementations for OIDC tokens, metadata, and JWKs.
|
||||
// It includes compatibility wrappers for backward compatibility with existing cache interfaces.
|
||||
package cache
|
||||
|
||||
import (
|
||||
|
||||
+329
@@ -0,0 +1,329 @@
|
||||
// Package resilience provides resilience patterns for cache backends.
|
||||
package resilience
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Common errors
|
||||
var (
|
||||
// ErrCircuitOpen is returned when the circuit breaker is open
|
||||
ErrCircuitOpen = errors.New("circuit breaker is open")
|
||||
|
||||
// ErrTooManyRequests is returned when too many requests are made in half-open state
|
||||
ErrTooManyRequests = errors.New("too many requests in half-open state")
|
||||
)
|
||||
|
||||
// State represents the state of the circuit breaker
|
||||
type State int32
|
||||
|
||||
const (
|
||||
// StateClosed allows all operations to pass through
|
||||
StateClosed State = iota
|
||||
|
||||
// StateOpen blocks all operations
|
||||
StateOpen
|
||||
|
||||
// StateHalfOpen allows a limited number of operations to test recovery
|
||||
StateHalfOpen
|
||||
)
|
||||
|
||||
// String returns the string representation of the state
|
||||
func (s State) String() string {
|
||||
switch s {
|
||||
case StateClosed:
|
||||
return "closed"
|
||||
case StateOpen:
|
||||
return "open"
|
||||
case StateHalfOpen:
|
||||
return "half-open"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// CircuitBreakerConfig holds configuration for the circuit breaker
|
||||
type CircuitBreakerConfig struct {
|
||||
// MaxFailures is the number of consecutive failures before opening the circuit
|
||||
MaxFailures int
|
||||
|
||||
// FailureThreshold is the failure rate threshold (0.0 to 1.0)
|
||||
FailureThreshold float64
|
||||
|
||||
// Timeout is how long the circuit stays open before trying half-open
|
||||
Timeout time.Duration
|
||||
|
||||
// HalfOpenMaxRequests is the number of requests allowed in half-open state
|
||||
HalfOpenMaxRequests int
|
||||
|
||||
// ResetTimeout is how long to wait before resetting counters in closed state
|
||||
ResetTimeout time.Duration
|
||||
|
||||
// OnStateChange is called when the circuit breaker changes state
|
||||
OnStateChange func(from, to State)
|
||||
}
|
||||
|
||||
// DefaultCircuitBreakerConfig returns default configuration
|
||||
func DefaultCircuitBreakerConfig() *CircuitBreakerConfig {
|
||||
return &CircuitBreakerConfig{
|
||||
MaxFailures: 5,
|
||||
FailureThreshold: 0.6,
|
||||
Timeout: 30 * time.Second,
|
||||
HalfOpenMaxRequests: 3,
|
||||
ResetTimeout: 60 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// CircuitBreaker implements the circuit breaker pattern
|
||||
type CircuitBreaker struct {
|
||||
config *CircuitBreakerConfig
|
||||
|
||||
// State management
|
||||
state atomic.Int32
|
||||
lastStateChange time.Time
|
||||
stateMu sync.RWMutex
|
||||
|
||||
// Failure tracking
|
||||
consecutiveFailures atomic.Int32
|
||||
totalRequests atomic.Int64
|
||||
totalFailures atomic.Int64
|
||||
halfOpenRequests atomic.Int32
|
||||
|
||||
// Timing
|
||||
lastFailureTime time.Time
|
||||
lastSuccessTime time.Time
|
||||
nextRetryTime time.Time
|
||||
timeMu sync.RWMutex
|
||||
|
||||
// Metrics
|
||||
stateTransitions atomic.Int64
|
||||
rejectedRequests atomic.Int64
|
||||
}
|
||||
|
||||
// NewCircuitBreaker creates a new circuit breaker
|
||||
func NewCircuitBreaker(config *CircuitBreakerConfig) *CircuitBreaker {
|
||||
if config == nil {
|
||||
config = DefaultCircuitBreakerConfig()
|
||||
}
|
||||
|
||||
return &CircuitBreaker{
|
||||
config: config,
|
||||
lastStateChange: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// Execute runs a function through the circuit breaker
|
||||
func (cb *CircuitBreaker) Execute(ctx context.Context, fn func() error) error {
|
||||
if !cb.AllowRequest() {
|
||||
cb.rejectedRequests.Add(1)
|
||||
return ErrCircuitOpen
|
||||
}
|
||||
|
||||
cb.totalRequests.Add(1)
|
||||
|
||||
err := fn()
|
||||
if err != nil {
|
||||
cb.RecordFailure()
|
||||
} else {
|
||||
cb.RecordSuccess()
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// AllowRequest checks if a request is allowed to proceed
|
||||
func (cb *CircuitBreaker) AllowRequest() bool {
|
||||
state := cb.GetState()
|
||||
|
||||
switch state {
|
||||
case StateClosed:
|
||||
return true
|
||||
|
||||
case StateOpen:
|
||||
// Check if timeout has passed and we should try half-open
|
||||
cb.timeMu.RLock()
|
||||
shouldRetry := time.Now().After(cb.nextRetryTime)
|
||||
cb.timeMu.RUnlock()
|
||||
|
||||
if shouldRetry {
|
||||
cb.setState(StateHalfOpen)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
|
||||
case StateHalfOpen:
|
||||
// Allow limited requests in half-open state
|
||||
current := cb.halfOpenRequests.Add(1)
|
||||
return current <= int32(cb.config.HalfOpenMaxRequests)
|
||||
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// RecordSuccess records a successful operation
|
||||
func (cb *CircuitBreaker) RecordSuccess() {
|
||||
cb.timeMu.Lock()
|
||||
cb.lastSuccessTime = time.Now()
|
||||
cb.timeMu.Unlock()
|
||||
|
||||
state := cb.GetState()
|
||||
|
||||
switch state {
|
||||
case StateClosed:
|
||||
// Reset consecutive failures
|
||||
cb.consecutiveFailures.Store(0)
|
||||
|
||||
case StateHalfOpen:
|
||||
// If we've had enough successful requests, close the circuit
|
||||
successfulRequests := cb.halfOpenRequests.Load()
|
||||
if successfulRequests >= int32(cb.config.HalfOpenMaxRequests) {
|
||||
cb.setState(StateClosed)
|
||||
cb.consecutiveFailures.Store(0)
|
||||
cb.halfOpenRequests.Store(0)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RecordFailure records a failed operation
|
||||
func (cb *CircuitBreaker) RecordFailure() {
|
||||
cb.totalFailures.Add(1)
|
||||
failures := cb.consecutiveFailures.Add(1)
|
||||
|
||||
cb.timeMu.Lock()
|
||||
cb.lastFailureTime = time.Now()
|
||||
cb.timeMu.Unlock()
|
||||
|
||||
state := cb.GetState()
|
||||
|
||||
switch state {
|
||||
case StateClosed:
|
||||
// Check if we should open the circuit
|
||||
if failures >= int32(cb.config.MaxFailures) {
|
||||
cb.openCircuit()
|
||||
} else if cb.config.FailureThreshold > 0 {
|
||||
// Check failure rate
|
||||
total := cb.totalRequests.Load()
|
||||
failureCount := cb.totalFailures.Load()
|
||||
if total > 10 && float64(failureCount)/float64(total) > cb.config.FailureThreshold {
|
||||
cb.openCircuit()
|
||||
}
|
||||
}
|
||||
|
||||
case StateHalfOpen:
|
||||
// Any failure in half-open state reopens the circuit
|
||||
cb.openCircuit()
|
||||
}
|
||||
}
|
||||
|
||||
// openCircuit transitions to open state
|
||||
func (cb *CircuitBreaker) openCircuit() {
|
||||
cb.setState(StateOpen)
|
||||
cb.halfOpenRequests.Store(0)
|
||||
|
||||
cb.timeMu.Lock()
|
||||
cb.nextRetryTime = time.Now().Add(cb.config.Timeout)
|
||||
cb.timeMu.Unlock()
|
||||
}
|
||||
|
||||
// GetState returns the current state
|
||||
func (cb *CircuitBreaker) GetState() State {
|
||||
return State(cb.state.Load())
|
||||
}
|
||||
|
||||
// setState changes the circuit breaker state
|
||||
func (cb *CircuitBreaker) setState(newState State) {
|
||||
oldState := State(cb.state.Swap(int32(newState)))
|
||||
|
||||
if oldState != newState {
|
||||
cb.stateTransitions.Add(1)
|
||||
|
||||
cb.stateMu.Lock()
|
||||
cb.lastStateChange = time.Now()
|
||||
cb.stateMu.Unlock()
|
||||
|
||||
if cb.config.OnStateChange != nil {
|
||||
cb.config.OnStateChange(oldState, newState)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reset resets the circuit breaker to closed state
|
||||
func (cb *CircuitBreaker) Reset() {
|
||||
cb.setState(StateClosed)
|
||||
cb.consecutiveFailures.Store(0)
|
||||
cb.totalRequests.Store(0)
|
||||
cb.totalFailures.Store(0)
|
||||
cb.halfOpenRequests.Store(0)
|
||||
cb.rejectedRequests.Store(0)
|
||||
cb.stateTransitions.Store(0)
|
||||
|
||||
now := time.Now()
|
||||
cb.timeMu.Lock()
|
||||
cb.lastFailureTime = now
|
||||
cb.lastSuccessTime = now
|
||||
cb.nextRetryTime = now
|
||||
cb.timeMu.Unlock()
|
||||
|
||||
cb.stateMu.Lock()
|
||||
cb.lastStateChange = now
|
||||
cb.stateMu.Unlock()
|
||||
}
|
||||
|
||||
// Stats returns circuit breaker statistics
|
||||
func (cb *CircuitBreaker) Stats() CircuitBreakerStats {
|
||||
cb.timeMu.RLock()
|
||||
lastFailure := cb.lastFailureTime
|
||||
lastSuccess := cb.lastSuccessTime
|
||||
nextRetry := cb.nextRetryTime
|
||||
cb.timeMu.RUnlock()
|
||||
|
||||
cb.stateMu.RLock()
|
||||
lastChange := cb.lastStateChange
|
||||
cb.stateMu.RUnlock()
|
||||
|
||||
totalReq := cb.totalRequests.Load()
|
||||
totalFail := cb.totalFailures.Load()
|
||||
successRate := float64(0)
|
||||
if totalReq > 0 {
|
||||
successRate = float64(totalReq-totalFail) / float64(totalReq)
|
||||
}
|
||||
|
||||
return CircuitBreakerStats{
|
||||
State: cb.GetState(),
|
||||
ConsecutiveFailures: cb.consecutiveFailures.Load(),
|
||||
TotalRequests: totalReq,
|
||||
TotalFailures: totalFail,
|
||||
SuccessRate: successRate,
|
||||
RejectedRequests: cb.rejectedRequests.Load(),
|
||||
StateTransitions: cb.stateTransitions.Load(),
|
||||
LastFailureTime: lastFailure,
|
||||
LastSuccessTime: lastSuccess,
|
||||
LastStateChange: lastChange,
|
||||
NextRetryTime: nextRetry,
|
||||
}
|
||||
}
|
||||
|
||||
// CircuitBreakerStats holds statistics for the circuit breaker
|
||||
type CircuitBreakerStats struct {
|
||||
State State
|
||||
ConsecutiveFailures int32
|
||||
TotalRequests int64
|
||||
TotalFailures int64
|
||||
SuccessRate float64
|
||||
RejectedRequests int64
|
||||
StateTransitions int64
|
||||
LastFailureTime time.Time
|
||||
LastSuccessTime time.Time
|
||||
LastStateChange time.Time
|
||||
NextRetryTime time.Time
|
||||
}
|
||||
|
||||
// IsHealthy returns true if the circuit breaker is in a healthy state
|
||||
func (cb *CircuitBreaker) IsHealthy() bool {
|
||||
return cb.GetState() != StateOpen
|
||||
}
|
||||
@@ -0,0 +1,141 @@
|
||||
// Package resilience provides resilience patterns for cache backends.
|
||||
package resilience
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/cache/backends"
|
||||
)
|
||||
|
||||
// CircuitBreakerBackend wraps a cache backend with circuit breaker protection
|
||||
type CircuitBreakerBackend struct {
|
||||
backend backends.CacheBackend
|
||||
cb *CircuitBreaker
|
||||
}
|
||||
|
||||
// NewCircuitBreakerBackend creates a new circuit breaker wrapped backend
|
||||
func NewCircuitBreakerBackend(b backends.CacheBackend, config *CircuitBreakerConfig) backends.CacheBackend {
|
||||
if config == nil {
|
||||
config = DefaultCircuitBreakerConfig()
|
||||
}
|
||||
|
||||
return &CircuitBreakerBackend{
|
||||
backend: b,
|
||||
cb: NewCircuitBreaker(config),
|
||||
}
|
||||
}
|
||||
|
||||
// Set stores a value with circuit breaker protection
|
||||
func (c *CircuitBreakerBackend) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
|
||||
if !c.cb.AllowRequest() {
|
||||
return backends.ErrCircuitOpen
|
||||
}
|
||||
|
||||
err := c.backend.Set(ctx, key, value, ttl)
|
||||
if err == nil {
|
||||
c.cb.RecordSuccess()
|
||||
} else {
|
||||
c.cb.RecordFailure()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Get retrieves a value with circuit breaker protection
|
||||
func (c *CircuitBreakerBackend) Get(ctx context.Context, key string) ([]byte, time.Duration, bool, error) {
|
||||
if !c.cb.AllowRequest() {
|
||||
return nil, 0, false, backends.ErrCircuitOpen
|
||||
}
|
||||
|
||||
value, ttl, exists, err := c.backend.Get(ctx, key)
|
||||
if err == nil {
|
||||
c.cb.RecordSuccess()
|
||||
} else {
|
||||
c.cb.RecordFailure()
|
||||
}
|
||||
return value, ttl, exists, err
|
||||
}
|
||||
|
||||
// Delete removes a key with circuit breaker protection
|
||||
func (c *CircuitBreakerBackend) Delete(ctx context.Context, key string) (bool, error) {
|
||||
if !c.cb.AllowRequest() {
|
||||
return false, backends.ErrCircuitOpen
|
||||
}
|
||||
|
||||
deleted, err := c.backend.Delete(ctx, key)
|
||||
if err == nil {
|
||||
c.cb.RecordSuccess()
|
||||
} else {
|
||||
c.cb.RecordFailure()
|
||||
}
|
||||
return deleted, err
|
||||
}
|
||||
|
||||
// Exists checks if a key exists with circuit breaker protection
|
||||
func (c *CircuitBreakerBackend) Exists(ctx context.Context, key string) (bool, error) {
|
||||
if !c.cb.AllowRequest() {
|
||||
return false, backends.ErrCircuitOpen
|
||||
}
|
||||
|
||||
exists, err := c.backend.Exists(ctx, key)
|
||||
if err == nil {
|
||||
c.cb.RecordSuccess()
|
||||
} else {
|
||||
c.cb.RecordFailure()
|
||||
}
|
||||
return exists, err
|
||||
}
|
||||
|
||||
// Clear removes all keys with circuit breaker protection
|
||||
func (c *CircuitBreakerBackend) Clear(ctx context.Context) error {
|
||||
if !c.cb.AllowRequest() {
|
||||
return backends.ErrCircuitOpen
|
||||
}
|
||||
|
||||
err := c.backend.Clear(ctx)
|
||||
if err == nil {
|
||||
c.cb.RecordSuccess()
|
||||
} else {
|
||||
c.cb.RecordFailure()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// GetStats returns statistics including circuit breaker state
|
||||
func (c *CircuitBreakerBackend) GetStats() map[string]interface{} {
|
||||
stats := c.backend.GetStats()
|
||||
if stats == nil {
|
||||
stats = make(map[string]interface{})
|
||||
}
|
||||
|
||||
cbStats := c.cb.Stats()
|
||||
stats["circuit_breaker"] = map[string]interface{}{
|
||||
"state": cbStats.State.String(),
|
||||
"consecutive_failures": cbStats.ConsecutiveFailures,
|
||||
"total_requests": cbStats.TotalRequests,
|
||||
"total_failures": cbStats.TotalFailures,
|
||||
"success_rate": cbStats.SuccessRate,
|
||||
}
|
||||
|
||||
return stats
|
||||
}
|
||||
|
||||
// Ping checks backend health with circuit breaker protection
|
||||
func (c *CircuitBreakerBackend) Ping(ctx context.Context) error {
|
||||
if !c.cb.AllowRequest() {
|
||||
return backends.ErrCircuitOpen
|
||||
}
|
||||
|
||||
err := c.backend.Ping(ctx)
|
||||
if err == nil {
|
||||
c.cb.RecordSuccess()
|
||||
} else {
|
||||
c.cb.RecordFailure()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Close shuts down the backend
|
||||
func (c *CircuitBreakerBackend) Close() error {
|
||||
return c.backend.Close()
|
||||
}
|
||||
@@ -0,0 +1,561 @@
|
||||
//go:build !yaegi
|
||||
|
||||
package resilience
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/cache/backends"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// mockBackend is a simple mock implementation for testing
|
||||
type mockBackend struct {
|
||||
data map[string]mockEntry
|
||||
mu sync.RWMutex
|
||||
failSet bool
|
||||
failGet bool
|
||||
failDelete bool
|
||||
failExists bool
|
||||
failClear bool
|
||||
failPing bool
|
||||
callCount int
|
||||
}
|
||||
|
||||
type mockEntry struct {
|
||||
value []byte
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
func newMockBackend() *mockBackend {
|
||||
return &mockBackend{
|
||||
data: make(map[string]mockEntry),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockBackend) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.callCount++
|
||||
|
||||
if m.failSet {
|
||||
return errors.New("mock set error")
|
||||
}
|
||||
|
||||
expiresAt := time.Now().Add(ttl)
|
||||
if ttl == 0 {
|
||||
expiresAt = time.Now().Add(24 * time.Hour)
|
||||
}
|
||||
|
||||
m.data[key] = mockEntry{
|
||||
value: value,
|
||||
expiresAt: expiresAt,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockBackend) Get(ctx context.Context, key string) ([]byte, time.Duration, bool, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
m.callCount++
|
||||
|
||||
if m.failGet {
|
||||
return nil, 0, false, errors.New("mock get error")
|
||||
}
|
||||
|
||||
entry, exists := m.data[key]
|
||||
if !exists {
|
||||
return nil, 0, false, nil
|
||||
}
|
||||
|
||||
if time.Now().After(entry.expiresAt) {
|
||||
return nil, 0, false, nil
|
||||
}
|
||||
|
||||
ttl := time.Until(entry.expiresAt)
|
||||
return entry.value, ttl, true, nil
|
||||
}
|
||||
|
||||
func (m *mockBackend) Delete(ctx context.Context, key string) (bool, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.callCount++
|
||||
|
||||
if m.failDelete {
|
||||
return false, errors.New("mock delete error")
|
||||
}
|
||||
|
||||
_, existed := m.data[key]
|
||||
delete(m.data, key)
|
||||
return existed, nil
|
||||
}
|
||||
|
||||
func (m *mockBackend) Exists(ctx context.Context, key string) (bool, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
m.callCount++
|
||||
|
||||
if m.failExists {
|
||||
return false, errors.New("mock exists error")
|
||||
}
|
||||
|
||||
entry, exists := m.data[key]
|
||||
if !exists {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if time.Now().After(entry.expiresAt) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (m *mockBackend) Clear(ctx context.Context) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.callCount++
|
||||
|
||||
if m.failClear {
|
||||
return errors.New("mock clear error")
|
||||
}
|
||||
|
||||
m.data = make(map[string]mockEntry)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockBackend) GetStats() map[string]interface{} {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
return map[string]interface{}{
|
||||
"hits": int64(0),
|
||||
"misses": int64(0),
|
||||
"call_count": m.callCount,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockBackend) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockBackend) Ping(ctx context.Context) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.callCount++
|
||||
|
||||
if m.failPing {
|
||||
return errors.New("mock ping error")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Constructor Tests
|
||||
|
||||
func TestNewCircuitBreakerBackend_WithDefaultConfig(t *testing.T) {
|
||||
mockBE := newMockBackend()
|
||||
|
||||
cb := NewCircuitBreakerBackend(mockBE, nil)
|
||||
require.NotNil(t, cb)
|
||||
|
||||
// Verify it implements the interface (compile-time check)
|
||||
var _ backends.CacheBackend = cb
|
||||
}
|
||||
|
||||
func TestNewCircuitBreakerBackend_WithCustomConfig(t *testing.T) {
|
||||
mockBE := newMockBackend()
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 3,
|
||||
FailureThreshold: 0.5,
|
||||
Timeout: 5 * time.Second,
|
||||
HalfOpenMaxRequests: 2,
|
||||
ResetTimeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
cb := NewCircuitBreakerBackend(mockBE, config)
|
||||
require.NotNil(t, cb)
|
||||
}
|
||||
|
||||
// Set Operation Tests
|
||||
|
||||
func TestCircuitBreakerBackend_Set_Success(t *testing.T) {
|
||||
mockBE := newMockBackend()
|
||||
cb := NewCircuitBreakerBackend(mockBE, nil)
|
||||
|
||||
ctx := context.Background()
|
||||
err := cb.Set(ctx, "key1", []byte("value1"), 1*time.Minute)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, mockBE.callCount)
|
||||
|
||||
// Verify value was stored
|
||||
value, _, exists, _ := mockBE.Get(ctx, "key1")
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, []byte("value1"), value)
|
||||
}
|
||||
|
||||
func TestCircuitBreakerBackend_Set_Failure(t *testing.T) {
|
||||
mockBE := newMockBackend()
|
||||
mockBE.failSet = true
|
||||
|
||||
cb := NewCircuitBreakerBackend(mockBE, nil)
|
||||
|
||||
ctx := context.Background()
|
||||
err := cb.Set(ctx, "key1", []byte("value1"), 1*time.Minute)
|
||||
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestCircuitBreakerBackend_Set_CircuitOpen(t *testing.T) {
|
||||
mockBE := newMockBackend()
|
||||
mockBE.failSet = true
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 2,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
}
|
||||
cb := NewCircuitBreakerBackend(mockBE, config)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Trigger failures to open circuit
|
||||
for i := 0; i < 5; i++ {
|
||||
cb.Set(ctx, "key", []byte("value"), 1*time.Minute)
|
||||
}
|
||||
|
||||
// Circuit should be open now
|
||||
err := cb.Set(ctx, "key2", []byte("value2"), 1*time.Minute)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, backends.ErrCircuitOpen, err)
|
||||
}
|
||||
|
||||
// Get Operation Tests
|
||||
|
||||
func TestCircuitBreakerBackend_Get_Success(t *testing.T) {
|
||||
mockBE := newMockBackend()
|
||||
cb := NewCircuitBreakerBackend(mockBE, nil)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// First set a value
|
||||
mockBE.Set(ctx, "key1", []byte("value1"), 1*time.Minute)
|
||||
|
||||
// Now get it through circuit breaker
|
||||
value, _, exists, err := cb.Get(ctx, "key1")
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, []byte("value1"), value)
|
||||
}
|
||||
|
||||
func TestCircuitBreakerBackend_Get_Failure(t *testing.T) {
|
||||
mockBE := newMockBackend()
|
||||
mockBE.failGet = true
|
||||
|
||||
cb := NewCircuitBreakerBackend(mockBE, nil)
|
||||
|
||||
ctx := context.Background()
|
||||
_, _, _, err := cb.Get(ctx, "key1")
|
||||
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestCircuitBreakerBackend_Get_CircuitOpen(t *testing.T) {
|
||||
mockBE := newMockBackend()
|
||||
mockBE.failGet = true
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 2,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
}
|
||||
cb := NewCircuitBreakerBackend(mockBE, config)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Trigger failures
|
||||
for i := 0; i < 5; i++ {
|
||||
cb.Get(ctx, "key")
|
||||
}
|
||||
|
||||
// Circuit should be open
|
||||
_, _, _, err := cb.Get(ctx, "key2")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, backends.ErrCircuitOpen, err)
|
||||
}
|
||||
|
||||
// Delete Operation Tests
|
||||
|
||||
func TestCircuitBreakerBackend_Delete_Success(t *testing.T) {
|
||||
mockBE := newMockBackend()
|
||||
cb := NewCircuitBreakerBackend(mockBE, nil)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Set a value first
|
||||
mockBE.Set(ctx, "key1", []byte("value1"), 1*time.Minute)
|
||||
|
||||
// Delete through circuit breaker
|
||||
deleted, err := cb.Delete(ctx, "key1")
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, deleted)
|
||||
|
||||
// Verify it's deleted
|
||||
exists, _ := mockBE.Exists(ctx, "key1")
|
||||
assert.False(t, exists)
|
||||
}
|
||||
|
||||
func TestCircuitBreakerBackend_Delete_CircuitOpen(t *testing.T) {
|
||||
mockBE := newMockBackend()
|
||||
mockBE.failDelete = true
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 2,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
}
|
||||
cb := NewCircuitBreakerBackend(mockBE, config)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Trigger failures
|
||||
for i := 0; i < 5; i++ {
|
||||
cb.Delete(ctx, "key")
|
||||
}
|
||||
|
||||
// Circuit should be open
|
||||
_, err := cb.Delete(ctx, "key2")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, backends.ErrCircuitOpen, err)
|
||||
}
|
||||
|
||||
// Exists Operation Tests
|
||||
|
||||
func TestCircuitBreakerBackend_Exists_Success(t *testing.T) {
|
||||
mockBE := newMockBackend()
|
||||
cb := NewCircuitBreakerBackend(mockBE, nil)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Set a value first
|
||||
mockBE.Set(ctx, "key1", []byte("value1"), 1*time.Minute)
|
||||
|
||||
// Check existence through circuit breaker
|
||||
exists, err := cb.Exists(ctx, "key1")
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
}
|
||||
|
||||
func TestCircuitBreakerBackend_Exists_CircuitOpen(t *testing.T) {
|
||||
mockBE := newMockBackend()
|
||||
mockBE.failExists = true
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 2,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
}
|
||||
cb := NewCircuitBreakerBackend(mockBE, config)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Trigger failures
|
||||
for i := 0; i < 5; i++ {
|
||||
cb.Exists(ctx, "key")
|
||||
}
|
||||
|
||||
// Circuit should be open
|
||||
_, err := cb.Exists(ctx, "key2")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, backends.ErrCircuitOpen, err)
|
||||
}
|
||||
|
||||
// Clear Operation Tests
|
||||
|
||||
func TestCircuitBreakerBackend_Clear_Success(t *testing.T) {
|
||||
mockBE := newMockBackend()
|
||||
cb := NewCircuitBreakerBackend(mockBE, nil)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Set some values
|
||||
mockBE.Set(ctx, "key1", []byte("value1"), 1*time.Minute)
|
||||
mockBE.Set(ctx, "key2", []byte("value2"), 1*time.Minute)
|
||||
|
||||
// Clear through circuit breaker
|
||||
err := cb.Clear(ctx)
|
||||
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify cleared
|
||||
exists1, _ := mockBE.Exists(ctx, "key1")
|
||||
exists2, _ := mockBE.Exists(ctx, "key2")
|
||||
assert.False(t, exists1)
|
||||
assert.False(t, exists2)
|
||||
}
|
||||
|
||||
func TestCircuitBreakerBackend_Clear_CircuitOpen(t *testing.T) {
|
||||
mockBE := newMockBackend()
|
||||
mockBE.failClear = true
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 2,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
}
|
||||
cb := NewCircuitBreakerBackend(mockBE, config)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Trigger failures
|
||||
for i := 0; i < 5; i++ {
|
||||
cb.Clear(ctx)
|
||||
}
|
||||
|
||||
// Circuit should be open
|
||||
err := cb.Clear(ctx)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, backends.ErrCircuitOpen, err)
|
||||
}
|
||||
|
||||
// GetStats Tests
|
||||
|
||||
func TestCircuitBreakerBackend_GetStats(t *testing.T) {
|
||||
mockBE := newMockBackend()
|
||||
cb := NewCircuitBreakerBackend(mockBE, nil)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Perform some operations
|
||||
cb.Set(ctx, "key1", []byte("value1"), 1*time.Minute)
|
||||
cb.Get(ctx, "key1")
|
||||
|
||||
stats := cb.GetStats()
|
||||
|
||||
require.NotNil(t, stats)
|
||||
|
||||
// Should have circuit breaker stats
|
||||
assert.Contains(t, stats, "circuit_breaker")
|
||||
|
||||
cbStats, ok := stats["circuit_breaker"].(map[string]interface{})
|
||||
require.True(t, ok)
|
||||
|
||||
// Verify circuit breaker stats fields
|
||||
assert.Contains(t, cbStats, "state")
|
||||
assert.Contains(t, cbStats, "consecutive_failures")
|
||||
assert.Contains(t, cbStats, "total_requests")
|
||||
assert.Contains(t, cbStats, "total_failures")
|
||||
assert.Contains(t, cbStats, "success_rate")
|
||||
}
|
||||
|
||||
func TestCircuitBreakerBackend_GetStats_NilBackendStats(t *testing.T) {
|
||||
// Create a mock backend that returns nil stats
|
||||
mockBE := &mockBackendNilStats{}
|
||||
cb := NewCircuitBreakerBackend(mockBE, nil)
|
||||
|
||||
stats := cb.GetStats()
|
||||
|
||||
require.NotNil(t, stats)
|
||||
assert.Contains(t, stats, "circuit_breaker")
|
||||
}
|
||||
|
||||
// mockBackendNilStats returns nil from GetStats
|
||||
type mockBackendNilStats struct {
|
||||
mockBackend
|
||||
}
|
||||
|
||||
func (m *mockBackendNilStats) GetStats() map[string]interface{} {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ping Tests
|
||||
|
||||
func TestCircuitBreakerBackend_Ping_Success(t *testing.T) {
|
||||
mockBE := newMockBackend()
|
||||
cb := NewCircuitBreakerBackend(mockBE, nil)
|
||||
|
||||
ctx := context.Background()
|
||||
err := cb.Ping(ctx)
|
||||
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestCircuitBreakerBackend_Ping_CircuitOpen(t *testing.T) {
|
||||
mockBE := newMockBackend()
|
||||
mockBE.failPing = true
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 2,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
}
|
||||
cb := NewCircuitBreakerBackend(mockBE, config)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Trigger failures
|
||||
for i := 0; i < 5; i++ {
|
||||
cb.Ping(ctx)
|
||||
}
|
||||
|
||||
// Circuit should be open
|
||||
err := cb.Ping(ctx)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, backends.ErrCircuitOpen, err)
|
||||
}
|
||||
|
||||
// Close Tests
|
||||
|
||||
func TestCircuitBreakerBackend_Close(t *testing.T) {
|
||||
mockBE := newMockBackend()
|
||||
cb := NewCircuitBreakerBackend(mockBE, nil)
|
||||
|
||||
err := cb.Close()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// Circuit Recovery Test
|
||||
|
||||
func TestCircuitBreakerBackend_CircuitRecovery(t *testing.T) {
|
||||
mockBE := newMockBackend()
|
||||
mockBE.failSet = true
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 2,
|
||||
Timeout: 200 * time.Millisecond,
|
||||
HalfOpenMaxRequests: 1,
|
||||
}
|
||||
cb := NewCircuitBreakerBackend(mockBE, config)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Trigger failures to open circuit
|
||||
for i := 0; i < 5; i++ {
|
||||
cb.Set(ctx, "key", []byte("value"), 1*time.Minute)
|
||||
}
|
||||
|
||||
// Verify circuit is open
|
||||
err := cb.Set(ctx, "key2", []byte("value2"), 1*time.Minute)
|
||||
assert.Equal(t, backends.ErrCircuitOpen, err)
|
||||
|
||||
// Wait for timeout
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
|
||||
// Fix the backend
|
||||
mockBE.mu.Lock()
|
||||
mockBE.failSet = false
|
||||
mockBE.mu.Unlock()
|
||||
|
||||
// Circuit should be in half-open state, allow a test request
|
||||
err = cb.Set(ctx, "key3", []byte("value3"), 1*time.Minute)
|
||||
|
||||
// After success threshold is met, circuit should close
|
||||
if err == nil {
|
||||
// Circuit recovered
|
||||
err2 := cb.Set(ctx, "key4", []byte("value4"), 1*time.Minute)
|
||||
assert.NoError(t, err2, "Circuit should be closed after recovery")
|
||||
}
|
||||
}
|
||||
+553
@@ -0,0 +1,553 @@
|
||||
package resilience
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// TestCircuitBreaker_StateTransitions tests state machine transitions
|
||||
func TestCircuitBreaker_StateTransitions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 3,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
HalfOpenMaxRequests: 2,
|
||||
}
|
||||
cb := NewCircuitBreaker(config)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Initial state is closed", func(t *testing.T) {
|
||||
assert.Equal(t, StateClosed, cb.GetState())
|
||||
})
|
||||
|
||||
t.Run("Closed to Open after max failures", func(t *testing.T) {
|
||||
cb.Reset()
|
||||
|
||||
// Simulate failures
|
||||
for i := 0; i < 3; i++ {
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("test error")
|
||||
})
|
||||
}
|
||||
|
||||
assert.Equal(t, StateOpen, cb.GetState())
|
||||
})
|
||||
|
||||
t.Run("Open to HalfOpen after timeout", func(t *testing.T) {
|
||||
// Open the circuit
|
||||
cb.Reset()
|
||||
for i := 0; i < 3; i++ {
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("test error")
|
||||
})
|
||||
}
|
||||
assert.Equal(t, StateOpen, cb.GetState())
|
||||
|
||||
// Wait for timeout
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// Should allow request and transition to half-open
|
||||
err := cb.Execute(ctx, func() error {
|
||||
return nil
|
||||
})
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, StateHalfOpen, cb.GetState())
|
||||
})
|
||||
|
||||
t.Run("HalfOpen to Closed after successful requests", func(t *testing.T) {
|
||||
// Open circuit then wait for half-open
|
||||
cb.Reset()
|
||||
for i := 0; i < 3; i++ {
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("test error")
|
||||
})
|
||||
}
|
||||
assert.Equal(t, StateOpen, cb.GetState())
|
||||
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// First request transitions to half-open and succeeds
|
||||
err := cb.Execute(ctx, func() error {
|
||||
return nil
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
// Should be in half-open after first request
|
||||
state := cb.GetState()
|
||||
assert.True(t, state == StateHalfOpen || state == StateClosed,
|
||||
"After first successful request, should be half-open or potentially closed")
|
||||
|
||||
if state == StateHalfOpen {
|
||||
// Need more successful requests to close
|
||||
// The exact number depends on implementation but should be within HalfOpenMaxRequests
|
||||
for i := 0; i < config.HalfOpenMaxRequests; i++ {
|
||||
cb.Execute(ctx, func() error {
|
||||
return nil
|
||||
})
|
||||
}
|
||||
// After multiple successful requests, should eventually close
|
||||
finalState := cb.GetState()
|
||||
assert.True(t, finalState == StateClosed || finalState == StateHalfOpen,
|
||||
"After successful requests, circuit should transition towards closed")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("HalfOpen to Open on failure", func(t *testing.T) {
|
||||
// Open circuit then wait for half-open
|
||||
cb.Reset()
|
||||
for i := 0; i < 3; i++ {
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("test error")
|
||||
})
|
||||
}
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// First call transitions to half-open, second failure reopens
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("test error")
|
||||
})
|
||||
|
||||
assert.Equal(t, StateOpen, cb.GetState())
|
||||
})
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_OpenCircuitBlocks tests that open circuit blocks requests
|
||||
func TestCircuitBreaker_OpenCircuitBlocks(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 2,
|
||||
Timeout: 1 * time.Second,
|
||||
HalfOpenMaxRequests: 1,
|
||||
}
|
||||
cb := NewCircuitBreaker(config)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Trigger failures to open circuit
|
||||
for i := 0; i < 2; i++ {
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("test error")
|
||||
})
|
||||
}
|
||||
|
||||
assert.Equal(t, StateOpen, cb.GetState())
|
||||
|
||||
// Requests should be blocked
|
||||
err := cb.Execute(ctx, func() error {
|
||||
t.Fatal("Should not execute function when circuit is open")
|
||||
return nil
|
||||
})
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, ErrCircuitOpen, err)
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_HalfOpenMaxRequests tests max requests in half-open state
|
||||
func TestCircuitBreaker_HalfOpenMaxRequests(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 3,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
HalfOpenMaxRequests: 2,
|
||||
}
|
||||
cb := NewCircuitBreaker(config)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Open circuit then wait for half-open
|
||||
for i := 0; i < 3; i++ {
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("test error")
|
||||
})
|
||||
}
|
||||
assert.Equal(t, StateOpen, cb.GetState())
|
||||
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// After timeout, circuit should allow transition to half-open
|
||||
// Execute HalfOpenMaxRequests successful requests
|
||||
successCount := 0
|
||||
for i := 0; i < config.HalfOpenMaxRequests; i++ {
|
||||
err := cb.Execute(ctx, func() error {
|
||||
successCount++
|
||||
return nil
|
||||
})
|
||||
// Should allow up to HalfOpenMaxRequests
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// Verify we executed the expected number
|
||||
assert.Equal(t, config.HalfOpenMaxRequests, successCount)
|
||||
|
||||
// After successful requests, circuit behavior depends on implementation
|
||||
// It could close (allowing more requests) or stay half-open (blocking)
|
||||
// The important thing is that we allowed exactly HalfOpenMaxRequests
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_SuccessResetsFailures tests failure counter reset
|
||||
func TestCircuitBreaker_SuccessResetsFailures(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 3,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
HalfOpenMaxRequests: 1,
|
||||
}
|
||||
cb := NewCircuitBreaker(config)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Have some failures (but less than max)
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("error")
|
||||
})
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("error")
|
||||
})
|
||||
|
||||
assert.Equal(t, StateClosed, cb.GetState())
|
||||
stats := cb.Stats()
|
||||
assert.Equal(t, int32(2), stats.ConsecutiveFailures)
|
||||
|
||||
// One success should reset failures
|
||||
cb.Execute(ctx, func() error {
|
||||
return nil
|
||||
})
|
||||
|
||||
assert.Equal(t, StateClosed, cb.GetState())
|
||||
stats = cb.Stats()
|
||||
assert.Equal(t, int32(0), stats.ConsecutiveFailures)
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_ConcurrentAccess tests thread safety
|
||||
func TestCircuitBreaker_ConcurrentAccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 10,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
HalfOpenMaxRequests: 5,
|
||||
}
|
||||
cb := NewCircuitBreaker(config)
|
||||
|
||||
ctx := context.Background()
|
||||
var wg sync.WaitGroup
|
||||
goroutines := 20
|
||||
iterations := 50
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < iterations; j++ {
|
||||
// Mix of successes and failures
|
||||
cb.Execute(ctx, func() error {
|
||||
if (id+j)%3 == 0 {
|
||||
return errors.New("test error")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// Random state checks
|
||||
_ = cb.GetState()
|
||||
_ = cb.Stats()
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Should complete without panics
|
||||
stats := cb.Stats()
|
||||
assert.NotNil(t, stats)
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_Stats tests statistics tracking
|
||||
func TestCircuitBreaker_Stats(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 5,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
HalfOpenMaxRequests: 2,
|
||||
}
|
||||
cb := NewCircuitBreaker(config)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Execute some requests
|
||||
cb.Execute(ctx, func() error { return nil }) // Success
|
||||
cb.Execute(ctx, func() error { return errors.New("error") }) // Failure
|
||||
cb.Execute(ctx, func() error { return errors.New("error") }) // Failure
|
||||
|
||||
stats := cb.Stats()
|
||||
|
||||
assert.Equal(t, StateClosed, stats.State)
|
||||
assert.Equal(t, int64(3), stats.TotalRequests)
|
||||
assert.Equal(t, int64(2), stats.TotalFailures)
|
||||
assert.Equal(t, int32(2), stats.ConsecutiveFailures)
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_Reset tests circuit reset
|
||||
func TestCircuitBreaker_Reset(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 2,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
HalfOpenMaxRequests: 1,
|
||||
}
|
||||
cb := NewCircuitBreaker(config)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Open the circuit
|
||||
for i := 0; i < 2; i++ {
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("error")
|
||||
})
|
||||
}
|
||||
|
||||
assert.Equal(t, StateOpen, cb.GetState())
|
||||
|
||||
// Reset
|
||||
cb.Reset()
|
||||
|
||||
assert.Equal(t, StateClosed, cb.GetState())
|
||||
stats := cb.Stats()
|
||||
assert.Equal(t, int32(0), stats.ConsecutiveFailures)
|
||||
assert.Equal(t, int64(0), stats.TotalRequests)
|
||||
assert.Equal(t, int64(0), stats.TotalFailures)
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_StateChangeCallback tests state change notifications
|
||||
func TestCircuitBreaker_StateChangeCallback(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var transitions []string
|
||||
var mu sync.Mutex
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 2,
|
||||
Timeout: 50 * time.Millisecond,
|
||||
HalfOpenMaxRequests: 1,
|
||||
OnStateChange: func(from, to State) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
transitions = append(transitions, from.String()+"->"+to.String())
|
||||
},
|
||||
}
|
||||
cb := NewCircuitBreaker(config)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Trigger state transitions
|
||||
// Closed -> Open
|
||||
for i := 0; i < 2; i++ {
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("error")
|
||||
})
|
||||
}
|
||||
|
||||
// Should be open now
|
||||
assert.Equal(t, StateOpen, cb.GetState())
|
||||
|
||||
// Wait for timeout to allow half-open transition
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Open -> HalfOpen on first request after timeout
|
||||
err := cb.Execute(ctx, func() error {
|
||||
return nil
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Execute more successful requests to trigger HalfOpen -> Closed
|
||||
for i := 0; i < config.HalfOpenMaxRequests-1; i++ {
|
||||
cb.Execute(ctx, func() error {
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
assert.Contains(t, transitions, "closed->open")
|
||||
assert.Contains(t, transitions, "open->half-open")
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_IsHealthy tests health check
|
||||
func TestCircuitBreaker_IsHealthy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 2,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
HalfOpenMaxRequests: 1,
|
||||
}
|
||||
cb := NewCircuitBreaker(config)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Initially healthy
|
||||
assert.True(t, cb.IsHealthy())
|
||||
|
||||
// Open circuit
|
||||
for i := 0; i < 2; i++ {
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("error")
|
||||
})
|
||||
}
|
||||
|
||||
assert.Equal(t, StateOpen, cb.GetState())
|
||||
assert.False(t, cb.IsHealthy(), "Should not be healthy when open")
|
||||
|
||||
// Wait for timeout and allow successful request
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
cb.Execute(ctx, func() error {
|
||||
return nil
|
||||
})
|
||||
|
||||
// Should be healthy after recovery
|
||||
assert.True(t, cb.IsHealthy(), "Should be healthy after recovery")
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_RapidFailures tests rapid consecutive failures
|
||||
func TestCircuitBreaker_RapidFailures(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 5,
|
||||
Timeout: 200 * time.Millisecond,
|
||||
HalfOpenMaxRequests: 1,
|
||||
}
|
||||
cb := NewCircuitBreaker(config)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Rapid failures
|
||||
for i := 0; i < 10; i++ {
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("rapid error")
|
||||
})
|
||||
}
|
||||
|
||||
assert.Equal(t, StateOpen, cb.GetState())
|
||||
|
||||
stats := cb.Stats()
|
||||
assert.GreaterOrEqual(t, stats.TotalFailures, int64(5))
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_TimeoutAccuracy tests timeout precision
|
||||
func TestCircuitBreaker_TimeoutAccuracy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
timeout := 100 * time.Millisecond
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 1,
|
||||
Timeout: timeout,
|
||||
HalfOpenMaxRequests: 1,
|
||||
}
|
||||
cb := NewCircuitBreaker(config)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Open circuit
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("error")
|
||||
})
|
||||
|
||||
assert.Equal(t, StateOpen, cb.GetState())
|
||||
|
||||
// Wait just before timeout
|
||||
time.Sleep(timeout - 20*time.Millisecond)
|
||||
assert.False(t, cb.IsHealthy())
|
||||
|
||||
// Wait until after timeout
|
||||
time.Sleep(40 * time.Millisecond)
|
||||
// After timeout, AllowRequest should return true for transition to half-open
|
||||
assert.True(t, cb.AllowRequest())
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_DefaultConfig tests default configuration
|
||||
func TestCircuitBreaker_DefaultConfig(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cb := NewCircuitBreaker(nil) // Should use defaults
|
||||
|
||||
assert.NotNil(t, cb)
|
||||
assert.Equal(t, StateClosed, cb.GetState())
|
||||
|
||||
// Verify defaults by triggering circuit breaker behavior
|
||||
ctx := context.Background()
|
||||
|
||||
// Test that it takes 5 failures to open (default MaxFailures)
|
||||
for i := 0; i < 4; i++ {
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("error")
|
||||
})
|
||||
}
|
||||
assert.Equal(t, StateClosed, cb.GetState(), "Should still be closed after 4 failures")
|
||||
|
||||
// 5th failure should open it
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("error")
|
||||
})
|
||||
assert.Equal(t, StateOpen, cb.GetState(), "Should be open after 5 failures (default threshold)")
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_StateString tests state string representation
|
||||
func TestCircuitBreaker_StateString(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert.Equal(t, "closed", StateClosed.String())
|
||||
assert.Equal(t, "open", StateOpen.String())
|
||||
assert.Equal(t, "half-open", StateHalfOpen.String())
|
||||
assert.Equal(t, "unknown", State(999).String())
|
||||
}
|
||||
|
||||
// Benchmark circuit breaker performance
|
||||
func BenchmarkCircuitBreaker_Execute(b *testing.B) {
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 100,
|
||||
Timeout: 1 * time.Second,
|
||||
HalfOpenMaxRequests: 10,
|
||||
}
|
||||
cb := NewCircuitBreaker(config)
|
||||
ctx := context.Background()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
cb.Execute(ctx, func() error {
|
||||
return nil
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkCircuitBreaker_ExecuteWithFailures(b *testing.B) {
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 1000,
|
||||
Timeout: 1 * time.Second,
|
||||
HalfOpenMaxRequests: 10,
|
||||
}
|
||||
cb := NewCircuitBreaker(config)
|
||||
ctx := context.Background()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
cb.Execute(ctx, func() error {
|
||||
if i%10 == 0 {
|
||||
return errors.New("error")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
}
|
||||
+375
@@ -0,0 +1,375 @@
|
||||
// Package resilience provides resilience patterns for cache backends.
|
||||
package resilience
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// HealthStatus represents the health status of a backend
|
||||
type HealthStatus int32
|
||||
|
||||
const (
|
||||
// HealthUnknown indicates unknown health status
|
||||
HealthUnknown HealthStatus = iota
|
||||
|
||||
// HealthHealthy indicates the backend is healthy
|
||||
HealthHealthy
|
||||
|
||||
// HealthDegraded indicates the backend is degraded but operational
|
||||
HealthDegraded
|
||||
|
||||
// HealthUnhealthy indicates the backend is unhealthy
|
||||
HealthUnhealthy
|
||||
)
|
||||
|
||||
// String returns the string representation of the health status
|
||||
func (h HealthStatus) String() string {
|
||||
switch h {
|
||||
case HealthHealthy:
|
||||
return "healthy"
|
||||
case HealthDegraded:
|
||||
return "degraded"
|
||||
case HealthUnhealthy:
|
||||
return "unhealthy"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// HealthCheckConfig holds configuration for the health checker
|
||||
type HealthCheckConfig struct {
|
||||
// CheckInterval is how often to check health
|
||||
CheckInterval time.Duration
|
||||
|
||||
// Timeout is the timeout for each health check
|
||||
Timeout time.Duration
|
||||
|
||||
// HealthyThreshold is the number of consecutive successes to become healthy
|
||||
HealthyThreshold int
|
||||
|
||||
// UnhealthyThreshold is the number of consecutive failures to become unhealthy
|
||||
UnhealthyThreshold int
|
||||
|
||||
// DegradedThreshold is the latency threshold in ms to mark as degraded
|
||||
DegradedThreshold time.Duration
|
||||
|
||||
// OnStatusChange is called when health status changes
|
||||
OnStatusChange func(from, to HealthStatus)
|
||||
|
||||
// CheckFunc is the function to check health
|
||||
CheckFunc func(ctx context.Context) error
|
||||
}
|
||||
|
||||
// DefaultHealthCheckConfig returns default configuration
|
||||
func DefaultHealthCheckConfig() *HealthCheckConfig {
|
||||
return &HealthCheckConfig{
|
||||
CheckInterval: 30 * time.Second,
|
||||
Timeout: 5 * time.Second,
|
||||
HealthyThreshold: 3,
|
||||
UnhealthyThreshold: 3,
|
||||
DegradedThreshold: 100 * time.Millisecond,
|
||||
}
|
||||
}
|
||||
|
||||
// HealthChecker monitors the health of a backend
|
||||
type HealthChecker struct {
|
||||
config *HealthCheckConfig
|
||||
|
||||
// Status tracking
|
||||
status atomic.Int32
|
||||
consecutiveSuccesses atomic.Int32
|
||||
consecutiveFailures atomic.Int32
|
||||
|
||||
// Timing
|
||||
lastCheckTime time.Time
|
||||
lastSuccessTime time.Time
|
||||
lastFailureTime time.Time
|
||||
averageLatency atomic.Int64
|
||||
timeMu sync.RWMutex
|
||||
|
||||
// Metrics
|
||||
totalChecks atomic.Int64
|
||||
totalSuccesses atomic.Int64
|
||||
totalFailures atomic.Int64
|
||||
statusChanges atomic.Int64
|
||||
|
||||
// Lifecycle
|
||||
ticker *time.Ticker
|
||||
stopChan chan struct{}
|
||||
stopped atomic.Bool
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// NewHealthChecker creates a new health checker
|
||||
func NewHealthChecker(config *HealthCheckConfig) *HealthChecker {
|
||||
if config == nil {
|
||||
config = DefaultHealthCheckConfig()
|
||||
}
|
||||
|
||||
hc := &HealthChecker{
|
||||
config: config,
|
||||
stopChan: make(chan struct{}),
|
||||
}
|
||||
hc.status.Store(int32(HealthUnknown))
|
||||
|
||||
return hc
|
||||
}
|
||||
|
||||
// Start begins health checking
|
||||
func (hc *HealthChecker) Start() {
|
||||
if hc.stopped.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
hc.ticker = time.NewTicker(hc.config.CheckInterval)
|
||||
hc.wg.Add(1)
|
||||
go hc.checkLoop()
|
||||
}
|
||||
|
||||
// Stop stops health checking
|
||||
func (hc *HealthChecker) Stop() {
|
||||
if hc.stopped.Swap(true) {
|
||||
return // Already stopped
|
||||
}
|
||||
|
||||
close(hc.stopChan)
|
||||
if hc.ticker != nil {
|
||||
hc.ticker.Stop()
|
||||
}
|
||||
hc.wg.Wait()
|
||||
}
|
||||
|
||||
// checkLoop runs periodic health checks
|
||||
func (hc *HealthChecker) checkLoop() {
|
||||
defer hc.wg.Done()
|
||||
|
||||
// Initial check - log error but continue
|
||||
if err := hc.Check(context.Background()); err != nil {
|
||||
// Error is already tracked in Check() method, no need to log again
|
||||
_ = err
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-hc.stopChan:
|
||||
return
|
||||
case <-hc.ticker.C:
|
||||
ctx, cancel := context.WithTimeout(context.Background(), hc.config.Timeout)
|
||||
if err := hc.Check(ctx); err != nil {
|
||||
// Error is already tracked in Check() method, no need to log again
|
||||
_ = err
|
||||
}
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check performs a health check
|
||||
func (hc *HealthChecker) Check(ctx context.Context) error {
|
||||
if hc.config.CheckFunc == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
hc.totalChecks.Add(1)
|
||||
start := time.Now()
|
||||
|
||||
// Create timeout context if not already set
|
||||
if _, hasDeadline := ctx.Deadline(); !hasDeadline {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, hc.config.Timeout)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
// Perform health check
|
||||
err := hc.config.CheckFunc(ctx)
|
||||
latency := time.Since(start)
|
||||
|
||||
hc.timeMu.Lock()
|
||||
hc.lastCheckTime = time.Now()
|
||||
hc.timeMu.Unlock()
|
||||
|
||||
// Update average latency
|
||||
hc.updateAverageLatency(latency)
|
||||
|
||||
if err != nil {
|
||||
hc.recordFailure()
|
||||
} else {
|
||||
hc.recordSuccess(latency)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// recordSuccess records a successful health check
|
||||
func (hc *HealthChecker) recordSuccess(latency time.Duration) {
|
||||
hc.totalSuccesses.Add(1)
|
||||
successes := hc.consecutiveSuccesses.Add(1)
|
||||
hc.consecutiveFailures.Store(0)
|
||||
|
||||
hc.timeMu.Lock()
|
||||
hc.lastSuccessTime = time.Now()
|
||||
hc.timeMu.Unlock()
|
||||
|
||||
currentStatus := hc.GetStatus()
|
||||
newStatus := currentStatus
|
||||
|
||||
// Check if we should become healthy
|
||||
if successes >= int32(hc.config.HealthyThreshold) {
|
||||
if latency > hc.config.DegradedThreshold {
|
||||
newStatus = HealthDegraded
|
||||
} else {
|
||||
newStatus = HealthHealthy
|
||||
}
|
||||
}
|
||||
|
||||
if newStatus != currentStatus {
|
||||
hc.setStatus(newStatus)
|
||||
}
|
||||
}
|
||||
|
||||
// recordFailure records a failed health check
|
||||
func (hc *HealthChecker) recordFailure() {
|
||||
hc.totalFailures.Add(1)
|
||||
failures := hc.consecutiveFailures.Add(1)
|
||||
hc.consecutiveSuccesses.Store(0)
|
||||
|
||||
hc.timeMu.Lock()
|
||||
hc.lastFailureTime = time.Now()
|
||||
hc.timeMu.Unlock()
|
||||
|
||||
// Check if we should become unhealthy
|
||||
if failures >= int32(hc.config.UnhealthyThreshold) {
|
||||
hc.setStatus(HealthUnhealthy)
|
||||
}
|
||||
}
|
||||
|
||||
// updateAverageLatency updates the rolling average latency
|
||||
func (hc *HealthChecker) updateAverageLatency(latency time.Duration) {
|
||||
// Simple exponential moving average
|
||||
currentAvg := time.Duration(hc.averageLatency.Load())
|
||||
if currentAvg == 0 {
|
||||
hc.averageLatency.Store(int64(latency))
|
||||
} else {
|
||||
// Weight: 0.2 for new value, 0.8 for old average
|
||||
newAvg := (currentAvg*4 + latency) / 5
|
||||
hc.averageLatency.Store(int64(newAvg))
|
||||
}
|
||||
}
|
||||
|
||||
// GetStatus returns the current health status
|
||||
func (hc *HealthChecker) GetStatus() HealthStatus {
|
||||
return HealthStatus(hc.status.Load())
|
||||
}
|
||||
|
||||
// setStatus changes the health status
|
||||
func (hc *HealthChecker) setStatus(newStatus HealthStatus) {
|
||||
oldStatus := HealthStatus(hc.status.Swap(int32(newStatus)))
|
||||
|
||||
if oldStatus != newStatus {
|
||||
hc.statusChanges.Add(1)
|
||||
if hc.config.OnStatusChange != nil {
|
||||
hc.config.OnStatusChange(oldStatus, newStatus)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// IsHealthy returns true if the backend is healthy or degraded
|
||||
func (hc *HealthChecker) IsHealthy() bool {
|
||||
status := hc.GetStatus()
|
||||
return status == HealthHealthy || status == HealthDegraded
|
||||
}
|
||||
|
||||
// LastCheckTime returns the time of the last health check
|
||||
func (hc *HealthChecker) LastCheckTime() time.Time {
|
||||
hc.timeMu.RLock()
|
||||
defer hc.timeMu.RUnlock()
|
||||
return hc.lastCheckTime
|
||||
}
|
||||
|
||||
// HealthScore returns a health score between 0.0 (unhealthy) and 1.0 (healthy)
|
||||
func (hc *HealthChecker) HealthScore() float64 {
|
||||
status := hc.GetStatus()
|
||||
switch status {
|
||||
case HealthHealthy:
|
||||
return 1.0
|
||||
case HealthDegraded:
|
||||
return 0.7
|
||||
case HealthUnhealthy:
|
||||
return 0.0
|
||||
default:
|
||||
return 0.5
|
||||
}
|
||||
}
|
||||
|
||||
// Stats returns health checker statistics
|
||||
func (hc *HealthChecker) Stats() HealthCheckerStats {
|
||||
hc.timeMu.RLock()
|
||||
lastCheck := hc.lastCheckTime
|
||||
lastSuccess := hc.lastSuccessTime
|
||||
lastFailure := hc.lastFailureTime
|
||||
hc.timeMu.RUnlock()
|
||||
|
||||
totalChecks := hc.totalChecks.Load()
|
||||
totalSuccesses := hc.totalSuccesses.Load()
|
||||
totalFailures := hc.totalFailures.Load()
|
||||
|
||||
successRate := float64(0)
|
||||
if totalChecks > 0 {
|
||||
successRate = float64(totalSuccesses) / float64(totalChecks)
|
||||
}
|
||||
|
||||
return HealthCheckerStats{
|
||||
Status: hc.GetStatus(),
|
||||
ConsecutiveSuccesses: hc.consecutiveSuccesses.Load(),
|
||||
ConsecutiveFailures: hc.consecutiveFailures.Load(),
|
||||
TotalChecks: totalChecks,
|
||||
TotalSuccesses: totalSuccesses,
|
||||
TotalFailures: totalFailures,
|
||||
SuccessRate: successRate,
|
||||
AverageLatency: time.Duration(hc.averageLatency.Load()),
|
||||
StatusChanges: hc.statusChanges.Load(),
|
||||
LastCheckTime: lastCheck,
|
||||
LastSuccessTime: lastSuccess,
|
||||
LastFailureTime: lastFailure,
|
||||
HealthScore: hc.HealthScore(),
|
||||
}
|
||||
}
|
||||
|
||||
// HealthCheckerStats holds statistics for the health checker
|
||||
type HealthCheckerStats struct {
|
||||
Status HealthStatus
|
||||
ConsecutiveSuccesses int32
|
||||
ConsecutiveFailures int32
|
||||
TotalChecks int64
|
||||
TotalSuccesses int64
|
||||
TotalFailures int64
|
||||
SuccessRate float64
|
||||
AverageLatency time.Duration
|
||||
StatusChanges int64
|
||||
LastCheckTime time.Time
|
||||
LastSuccessTime time.Time
|
||||
LastFailureTime time.Time
|
||||
HealthScore float64
|
||||
}
|
||||
|
||||
// Reset resets the health checker statistics
|
||||
func (hc *HealthChecker) Reset() {
|
||||
hc.status.Store(int32(HealthUnknown))
|
||||
hc.consecutiveSuccesses.Store(0)
|
||||
hc.consecutiveFailures.Store(0)
|
||||
hc.totalChecks.Store(0)
|
||||
hc.totalSuccesses.Store(0)
|
||||
hc.totalFailures.Store(0)
|
||||
hc.statusChanges.Store(0)
|
||||
hc.averageLatency.Store(0)
|
||||
|
||||
now := time.Now()
|
||||
hc.timeMu.Lock()
|
||||
hc.lastCheckTime = now
|
||||
hc.lastSuccessTime = now
|
||||
hc.lastFailureTime = now
|
||||
hc.timeMu.Unlock()
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user