mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-07 22:53:58 +00:00
Compare commits
21 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 413e4a1b7d | |||
| 69e0d98c67 | |||
| 6d893df12b | |||
| 6efb78b7a8 | |||
| d0b920c4f0 | |||
| c474bbafd6 | |||
| 9126c74723 | |||
| a750c4f5b9 | |||
| 56051779ee | |||
| 3f126d50f3 | |||
| 91f0fc9ab8 | |||
| 66b9ed0861 | |||
| e64fc7f730 | |||
| 5fcbd54955 | |||
| e70cd1907c | |||
| e45b06c86d | |||
| ae59a5e88a | |||
| 79e9b164f9 | |||
| 93888e56d1 | |||
| eff9bd7bd2 | |||
| bde1db1c3b |
@@ -0,0 +1,38 @@
|
||||
# Code Owners for traefik-oidc
|
||||
# These owners will be automatically requested for review when someone opens a PR
|
||||
|
||||
# Default owner for everything in the repo
|
||||
* @lukaszraczylo
|
||||
|
||||
# Core authentication and middleware
|
||||
/middleware/ @lukaszraczylo
|
||||
/auth/ @lukaszraczylo
|
||||
/handlers/ @lukaszraczylo
|
||||
|
||||
# OIDC providers
|
||||
/internal/providers/ @lukaszraczylo
|
||||
|
||||
# Session management and security
|
||||
/session/ @lukaszraczylo
|
||||
/internal/security/ @lukaszraczylo
|
||||
/security/ @lukaszraczylo
|
||||
|
||||
# Token management
|
||||
/internal/token/ @lukaszraczylo
|
||||
|
||||
# Configuration
|
||||
/config/ @lukaszraczylo
|
||||
/.traefik.yml @lukaszraczylo
|
||||
|
||||
# GitHub Actions and CI/CD
|
||||
/.github/ @lukaszraczylo
|
||||
/.github/workflows/ @lukaszraczylo
|
||||
/.golangci.yml @lukaszraczylo
|
||||
|
||||
# Documentation
|
||||
/docs/ @lukaszraczylo
|
||||
README.md @lukaszraczylo
|
||||
|
||||
# Dependencies
|
||||
go.mod @lukaszraczylo
|
||||
go.sum @lukaszraczylo
|
||||
@@ -0,0 +1,123 @@
|
||||
## Description
|
||||
|
||||
<!-- Provide a brief description of the changes in this PR -->
|
||||
|
||||
## Type of Change
|
||||
|
||||
<!-- Mark the relevant option with an "x" -->
|
||||
|
||||
- [ ] Bug fix (non-breaking change which fixes an issue)
|
||||
- [ ] New feature (non-breaking change which adds functionality)
|
||||
- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
|
||||
- [ ] Documentation update
|
||||
- [ ] Performance improvement
|
||||
- [ ] Code refactoring
|
||||
- [ ] Security fix
|
||||
- [ ] Provider-specific fix/enhancement
|
||||
|
||||
## Related Issues
|
||||
|
||||
<!-- Link to related issues using #issue_number -->
|
||||
|
||||
Fixes #
|
||||
Related to #
|
||||
|
||||
## Changes Made
|
||||
|
||||
<!-- List the main changes made in this PR -->
|
||||
|
||||
-
|
||||
-
|
||||
-
|
||||
|
||||
## Provider Impact
|
||||
|
||||
<!-- If this affects specific OIDC providers, list them here -->
|
||||
|
||||
- [ ] Google
|
||||
- [ ] Azure AD
|
||||
- [ ] Auth0
|
||||
- [ ] Okta
|
||||
- [ ] Keycloak
|
||||
- [ ] AWS Cognito
|
||||
- [ ] GitLab
|
||||
- [ ] GitHub
|
||||
- [ ] Generic OIDC
|
||||
- [ ] All providers
|
||||
|
||||
## Testing Performed
|
||||
|
||||
<!-- Describe the tests you ran to verify your changes -->
|
||||
|
||||
- [ ] Unit tests pass locally
|
||||
- [ ] Integration tests pass locally
|
||||
- [ ] Race detector shows no issues
|
||||
- [ ] Memory leak tests pass
|
||||
- [ ] Manual testing performed
|
||||
|
||||
### Test Configuration
|
||||
|
||||
<!-- Provide details about your test configuration if applicable -->
|
||||
|
||||
**Provider tested:**
|
||||
**Go version:**
|
||||
**Traefik version:**
|
||||
|
||||
## Security Considerations
|
||||
|
||||
<!-- Describe any security implications of these changes -->
|
||||
|
||||
- [ ] This PR does not introduce security vulnerabilities
|
||||
- [ ] Security scanning has been performed
|
||||
- [ ] Credentials/secrets are properly handled
|
||||
- [ ] Input validation is implemented
|
||||
|
||||
## Performance Impact
|
||||
|
||||
<!-- Describe any performance implications -->
|
||||
|
||||
- [ ] No performance impact expected
|
||||
- [ ] Performance improved (describe how)
|
||||
- [ ] Performance may be affected (describe why and mitigation)
|
||||
|
||||
## Breaking Changes
|
||||
|
||||
<!-- If this is a breaking change, describe the impact and migration path -->
|
||||
|
||||
**Breaking changes:**
|
||||
|
||||
|
||||
**Migration guide:**
|
||||
|
||||
|
||||
## Checklist
|
||||
|
||||
<!-- Ensure all items are checked before requesting review -->
|
||||
|
||||
- [ ] My code follows the project's code style
|
||||
- [ ] I have performed a self-review of my code
|
||||
- [ ] I have commented my code, particularly in hard-to-understand areas
|
||||
- [ ] I have made corresponding changes to the documentation
|
||||
- [ ] My changes generate no new warnings
|
||||
- [ ] I have added tests that prove my fix is effective or that my feature works
|
||||
- [ ] New and existing unit tests pass locally with my changes
|
||||
- [ ] Any dependent changes have been merged and published
|
||||
|
||||
## Additional Context
|
||||
|
||||
<!-- Add any other context, screenshots, or information about the PR here -->
|
||||
|
||||
## Screenshots (if applicable)
|
||||
|
||||
<!-- Add screenshots to help explain your changes -->
|
||||
|
||||
---
|
||||
|
||||
**For Reviewers:**
|
||||
|
||||
Please verify:
|
||||
- [ ] Code quality and style
|
||||
- [ ] Test coverage is adequate
|
||||
- [ ] Security implications reviewed
|
||||
- [ ] Documentation is updated
|
||||
- [ ] No performance regressions
|
||||
@@ -0,0 +1,52 @@
|
||||
version: 2
|
||||
updates:
|
||||
# Maintain dependencies for GitHub Actions
|
||||
- package-ecosystem: "github-actions"
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
day: "monday"
|
||||
time: "09:00"
|
||||
open-pull-requests-limit: 5
|
||||
commit-message:
|
||||
prefix: "chore(deps)"
|
||||
include: "scope"
|
||||
labels:
|
||||
- "dependencies"
|
||||
- "github-actions"
|
||||
reviewers:
|
||||
- "lukaszraczylo"
|
||||
|
||||
# Maintain Go module dependencies
|
||||
- package-ecosystem: "gomod"
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
day: "monday"
|
||||
time: "09:00"
|
||||
open-pull-requests-limit: 10
|
||||
commit-message:
|
||||
prefix: "chore(deps)"
|
||||
include: "scope"
|
||||
labels:
|
||||
- "dependencies"
|
||||
- "go"
|
||||
reviewers:
|
||||
- "lukaszraczylo"
|
||||
# Group patch updates together
|
||||
groups:
|
||||
patch-updates:
|
||||
patterns:
|
||||
- "*"
|
||||
update-types:
|
||||
- "patch"
|
||||
minor-updates:
|
||||
patterns:
|
||||
- "*"
|
||||
update-types:
|
||||
- "minor"
|
||||
# Ignore certain dependencies if needed
|
||||
ignore:
|
||||
# Example: ignore specific versions
|
||||
# - dependency-name: "github.com/example/package"
|
||||
# versions: ["1.x", "2.x"]
|
||||
@@ -0,0 +1,9 @@
|
||||
# Ensure consistent line endings
|
||||
* text=auto eol=lf
|
||||
|
||||
# GitHub Actions files should use LF
|
||||
*.yml text eol=lf
|
||||
*.yaml text eol=lf
|
||||
|
||||
# Shell scripts should use LF
|
||||
*.sh text eol=lf
|
||||
@@ -0,0 +1,225 @@
|
||||
# GitHub Actions Workflows
|
||||
|
||||
This directory contains CI/CD workflows for the Traefik OIDC middleware.
|
||||
|
||||
## Workflows
|
||||
|
||||
### PR Validation (`pr-validation.yml`)
|
||||
|
||||
A comprehensive validation workflow that runs **all checks in parallel** for maximum speed and thorough testing.
|
||||
|
||||
**Triggered on:**
|
||||
- Pull requests to `main` branch
|
||||
- Pushes to `main` branch
|
||||
|
||||
**Parallel Jobs (20+ concurrent checks):**
|
||||
|
||||
#### Code Quality
|
||||
- **Quick Checks** - Format, go vet, go mod verify
|
||||
- **golangci-lint** - Comprehensive linting
|
||||
- **Staticcheck** - Static analysis
|
||||
|
||||
#### Security
|
||||
- **Gosec** - Security vulnerability scanning
|
||||
- **Govulncheck** - Go vulnerability database check
|
||||
- **CodeQL** - GitHub's code analysis
|
||||
|
||||
#### Testing
|
||||
- **Race Detector** - Concurrent access bug detection
|
||||
- **Coverage** - Test coverage with 75% threshold
|
||||
- **Memory Leaks** - Goroutine and memory leak detection
|
||||
- **Integration Tests** - Full integration test suite
|
||||
- **Regression Tests** - Prevent previously fixed bugs
|
||||
- **Security Edge Cases** - Security-specific scenarios
|
||||
- **Session Tests** - Session management validation
|
||||
- **Token Tests** - Token validation scenarios
|
||||
- **CSRF Tests** - CSRF protection validation
|
||||
|
||||
#### Provider Testing (Matrix)
|
||||
Tests run in parallel for each OIDC provider:
|
||||
- Google
|
||||
- Azure AD
|
||||
- Auth0
|
||||
- Okta
|
||||
- Keycloak
|
||||
- AWS Cognito
|
||||
- GitLab
|
||||
- GitHub
|
||||
- Generic OIDC
|
||||
|
||||
#### Performance & Compatibility
|
||||
- **Benchmarks** - Performance regression detection
|
||||
- **Build Matrix** - linux/darwin × amd64/arm64
|
||||
- **Go Versions** - Go 1.23 and 1.24 compatibility
|
||||
|
||||
#### Final Validation
|
||||
- **All Checks Passed** - Ensures all jobs succeeded
|
||||
|
||||
## Workflow Features
|
||||
|
||||
### 🚀 Parallel Execution
|
||||
All independent checks run simultaneously for fastest feedback (~5-10 minutes for full suite).
|
||||
|
||||
### 📊 Coverage Reporting
|
||||
- Automatic PR comments with coverage statistics
|
||||
- Per-package coverage breakdown
|
||||
- 75% coverage threshold enforcement
|
||||
|
||||
### 🔒 Security First
|
||||
- Multiple security scanners (gosec, govulncheck, CodeQL)
|
||||
- SARIF report uploads for GitHub Security tab
|
||||
- Security edge case testing
|
||||
|
||||
### 🎯 Comprehensive Testing
|
||||
- Race condition detection
|
||||
- Memory leak detection
|
||||
- Provider-specific testing
|
||||
- Integration and regression tests
|
||||
|
||||
### 📈 Performance Tracking
|
||||
- Benchmark results stored as artifacts
|
||||
- Performance regression detection
|
||||
|
||||
### ✅ Quality Gates
|
||||
All checks must pass before PR can be merged:
|
||||
- Code formatting and style
|
||||
- Security vulnerabilities
|
||||
- Test coverage threshold
|
||||
- Race conditions
|
||||
- Memory leaks
|
||||
- Build success on all platforms
|
||||
|
||||
## Local Development
|
||||
|
||||
### Run checks locally before pushing:
|
||||
|
||||
```bash
|
||||
# Format code
|
||||
gofmt -s -w .
|
||||
|
||||
# Run linter
|
||||
golangci-lint run
|
||||
|
||||
# Run tests with race detector
|
||||
go test -race -timeout=15m -count=1 ./...
|
||||
|
||||
# Check coverage
|
||||
go test -coverprofile=coverage.out ./...
|
||||
go tool cover -func=coverage.out
|
||||
|
||||
# Run specific test suites
|
||||
go test -v -run='.*Leak.*' ./... # Memory leak tests
|
||||
go test -v -run='.*Integration.*' ./... # Integration tests
|
||||
go test -v -run='.*Regression.*' ./... # Regression tests
|
||||
|
||||
# Run benchmarks
|
||||
go test -bench=. -benchmem ./...
|
||||
|
||||
# Security scan
|
||||
gosec ./...
|
||||
govulncheck ./...
|
||||
```
|
||||
|
||||
### Required Tools
|
||||
|
||||
Install these tools for local development:
|
||||
|
||||
```bash
|
||||
# golangci-lint
|
||||
go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
|
||||
|
||||
# staticcheck
|
||||
go install honnef.co/go/tools/cmd/staticcheck@latest
|
||||
|
||||
# gosec
|
||||
go install github.com/securego/gosec/v2/cmd/gosec@latest
|
||||
|
||||
# govulncheck
|
||||
go install golang.org/x/vuln/cmd/govulncheck@latest
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Workflow Fails
|
||||
|
||||
1. **Check job status** - Click on failed job for details
|
||||
2. **Review logs** - Expand failed steps to see error messages
|
||||
3. **Run locally** - Reproduce issue with local commands above
|
||||
4. **Check coverage** - Ensure test coverage meets 75% threshold
|
||||
|
||||
### Coverage Below Threshold
|
||||
|
||||
Add tests to increase coverage:
|
||||
```bash
|
||||
# See which lines aren't covered
|
||||
go test -coverprofile=coverage.out ./...
|
||||
go tool cover -html=coverage.out
|
||||
```
|
||||
|
||||
### Race Condition Detected
|
||||
|
||||
Run with race detector locally:
|
||||
```bash
|
||||
go test -race -v ./...
|
||||
```
|
||||
|
||||
### Provider Test Failure
|
||||
|
||||
Test specific provider:
|
||||
```bash
|
||||
go test -v -run='.*Azure.*' ./internal/providers/...
|
||||
```
|
||||
|
||||
## Performance Optimization
|
||||
|
||||
The workflow is optimized for speed:
|
||||
|
||||
- **Parallel execution** - All independent jobs run simultaneously
|
||||
- **Go caching** - Dependencies cached between runs
|
||||
- **Strategic ordering** - Quick checks run first for fast feedback
|
||||
- **Fail-fast disabled** - Continue running all tests even if some fail
|
||||
|
||||
## Workflow Monitoring
|
||||
|
||||
### GitHub Actions Dashboard
|
||||
Monitor workflow runs at: `https://github.com/{owner}/{repo}/actions`
|
||||
|
||||
### Status Badges
|
||||
Add to README.md:
|
||||
```markdown
|
||||

|
||||
```
|
||||
|
||||
### Notifications
|
||||
Configure in repository settings:
|
||||
- Settings → Notifications
|
||||
- Choose email or Slack notifications for workflow failures
|
||||
|
||||
## Maintenance
|
||||
|
||||
### Update Go Version
|
||||
Edit in workflow file:
|
||||
```yaml
|
||||
go-version: '1.24' # Update this
|
||||
```
|
||||
|
||||
### Adjust Coverage Threshold
|
||||
Edit in workflow file:
|
||||
```yaml
|
||||
THRESHOLD=75 # Adjust this value
|
||||
```
|
||||
|
||||
### Add New Provider
|
||||
Add to provider matrix:
|
||||
```yaml
|
||||
matrix:
|
||||
provider:
|
||||
- new_provider # Add here
|
||||
```
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- [GitHub Actions Documentation](https://docs.github.com/en/actions)
|
||||
- [golangci-lint Configuration](../.golangci.yml)
|
||||
- [Dependabot Configuration](../dependabot.yml)
|
||||
- [PR Template](../PULL_REQUEST_TEMPLATE.md)
|
||||
@@ -0,0 +1,23 @@
|
||||
name: Pull Request
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
push:
|
||||
branches:
|
||||
- "**"
|
||||
- "!main"
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
security-events: write
|
||||
|
||||
jobs:
|
||||
pr-checks:
|
||||
uses: lukaszraczylo/shared-actions/.github/workflows/go-pr.yaml@main
|
||||
with:
|
||||
go-version: "1.24.11"
|
||||
coverage-threshold: 70
|
||||
secrets: inherit
|
||||
@@ -0,0 +1,23 @@
|
||||
name: Release
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "**.go"
|
||||
- "go.mod"
|
||||
- "go.sum"
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: write
|
||||
packages: write
|
||||
|
||||
jobs:
|
||||
release:
|
||||
uses: lukaszraczylo/shared-actions/.github/workflows/go-release.yaml@main
|
||||
with:
|
||||
go-version: "1.24.11"
|
||||
secrets: inherit
|
||||
+2
-1
@@ -1,2 +1,3 @@
|
||||
docker/
|
||||
.claude/
|
||||
.claude/*.out
|
||||
*.test
|
||||
|
||||
+192
@@ -0,0 +1,192 @@
|
||||
version: "2"
|
||||
run:
|
||||
go: "1.24"
|
||||
modules-download-mode: readonly
|
||||
tests: true
|
||||
linters:
|
||||
enable:
|
||||
- bodyclose
|
||||
- dupl
|
||||
- goconst
|
||||
- gocritic
|
||||
- gocyclo
|
||||
- goprintffuncname
|
||||
- gosec
|
||||
- misspell
|
||||
- noctx
|
||||
- nolintlint
|
||||
- prealloc
|
||||
- revive
|
||||
- rowserrcheck
|
||||
- sqlclosecheck
|
||||
- unconvert
|
||||
- unparam
|
||||
- whitespace
|
||||
disable:
|
||||
- exhaustive
|
||||
- funlen
|
||||
- gocognit
|
||||
- lll
|
||||
- mnd
|
||||
- testpackage
|
||||
- wsl
|
||||
settings:
|
||||
dupl:
|
||||
threshold: 200 # Allow intentional duplication in provider patterns and token management
|
||||
errcheck:
|
||||
check-type-assertions: true
|
||||
check-blank: false # Allow explicit blank assignments (_ = ...) to ignore errors
|
||||
exclude-functions:
|
||||
- (io.Closer).Close
|
||||
- (*database/sql.Rows).Close
|
||||
- (*database/sql.Stmt).Close
|
||||
- (io.Writer).Write
|
||||
- (*net/http.ResponseWriter).Write
|
||||
- fmt.Fprintf
|
||||
- fmt.Fprint
|
||||
- fmt.Fprintln
|
||||
goconst:
|
||||
min-len: 3
|
||||
min-occurrences: 10 # Increased to reduce noise for standard OAuth2/OIDC strings
|
||||
ignore-tests: true
|
||||
gocritic:
|
||||
# Using default enabled checks in v2
|
||||
enabled-checks:
|
||||
- appendCombine
|
||||
- boolExprSimplify
|
||||
- builtinShadow
|
||||
- commentedOutCode
|
||||
- emptyFallthrough
|
||||
- equalFold
|
||||
- hexLiteral
|
||||
- indexAlloc
|
||||
- initClause
|
||||
- methodExprCall
|
||||
- nestingReduce
|
||||
- rangeExprCopy
|
||||
- rangeValCopy
|
||||
- stringXbytes
|
||||
- typeAssertChain
|
||||
- typeUnparen
|
||||
- unlabelStmt
|
||||
- yodaStyleExpr
|
||||
gocyclo:
|
||||
min-complexity: 30 # OAuth/OIDC flows are inherently complex; set higher for Yaegi compatibility
|
||||
gosec:
|
||||
excludes:
|
||||
- G104
|
||||
- G404
|
||||
severity: medium
|
||||
confidence: medium
|
||||
govet:
|
||||
disable:
|
||||
- fieldalignment
|
||||
- shadow
|
||||
enable-all: true
|
||||
misspell:
|
||||
locale: US
|
||||
ignore-rules:
|
||||
- traefik
|
||||
- oidc
|
||||
- keycloak
|
||||
nolintlint:
|
||||
require-explanation: true
|
||||
require-specific: true
|
||||
allow-unused: false
|
||||
prealloc:
|
||||
simple: true
|
||||
range-loops: true
|
||||
for-loops: false
|
||||
revive:
|
||||
rules:
|
||||
- name: blank-imports
|
||||
- name: context-as-argument
|
||||
- name: context-keys-type
|
||||
- name: dot-imports
|
||||
- name: error-return
|
||||
- name: error-strings
|
||||
- name: error-naming
|
||||
- name: exported
|
||||
- name: if-return
|
||||
- name: increment-decrement
|
||||
- name: var-naming
|
||||
- name: var-declaration
|
||||
- name: package-comments
|
||||
- name: range
|
||||
- name: receiver-naming
|
||||
- name: time-naming
|
||||
- name: unexported-return
|
||||
- name: indent-error-flow
|
||||
- name: errorf
|
||||
- name: empty-block
|
||||
- name: superfluous-else
|
||||
- name: unused-parameter
|
||||
- name: unreachable-code
|
||||
- name: redefines-builtin-id
|
||||
unparam:
|
||||
check-exported: false
|
||||
staticcheck:
|
||||
checks:
|
||||
- all
|
||||
- -QF1001 # De Morgan's law - style preference, may affect Yaegi
|
||||
- -QF1003 # Tagged switch - style preference, may affect Yaegi
|
||||
- -QF1007 # Merge conditional assignment - style preference
|
||||
- -QF1008 # Remove embedded field - may break Yaegi compatibility
|
||||
- -QF1012 # Use fmt.Fprintf - style preference
|
||||
- -ST1003 # Package name format - allowed for test packages
|
||||
exclusions:
|
||||
generated: lax
|
||||
rules:
|
||||
- linters:
|
||||
- bodyclose
|
||||
- dupl
|
||||
- errcheck
|
||||
- goconst
|
||||
- gocyclo
|
||||
- gosec
|
||||
- noctx
|
||||
- prealloc
|
||||
- unparam
|
||||
path: _test\.go
|
||||
- linters:
|
||||
- dupl
|
||||
- gocyclo
|
||||
path: test.*\.go
|
||||
- linters:
|
||||
- gocritic
|
||||
- unused
|
||||
path: mocks.*\.go
|
||||
- linters:
|
||||
- gosec
|
||||
text: 'G404:'
|
||||
- linters:
|
||||
- all
|
||||
path: vendor/
|
||||
- linters:
|
||||
- goconst
|
||||
path: (.+)_test\.go
|
||||
- linters:
|
||||
- dupl
|
||||
path: internal/providers/(auth0|keycloak|okta|google|azure|github|gitlab|cognito|generic)\.go
|
||||
- linters:
|
||||
- dupl
|
||||
path: session\.go
|
||||
- linters:
|
||||
- dupl
|
||||
path: session_chunk_manager\.go
|
||||
text: "(extractJWTExpiration|extractJWTIssuedAt)"
|
||||
paths:
|
||||
- third_party$
|
||||
- builtin$
|
||||
- examples$
|
||||
issues:
|
||||
max-issues-per-linter: 0
|
||||
max-same-issues: 0
|
||||
uniq-by-line: true
|
||||
formatters:
|
||||
exclusions:
|
||||
generated: lax
|
||||
paths:
|
||||
- third_party$
|
||||
- builtin$
|
||||
- examples$
|
||||
@@ -0,0 +1,60 @@
|
||||
version: 2
|
||||
|
||||
# Traefik plugins are source-only - no binary builds
|
||||
# Traefik loads plugins via Yaegi interpreter at runtime
|
||||
builds:
|
||||
- skip: true
|
||||
|
||||
# Create source archive for GitHub releases
|
||||
archives:
|
||||
- formats: [tar.gz]
|
||||
name_template: "{{ .ProjectName }}_v{{ .Version }}_source"
|
||||
files:
|
||||
- "*.go"
|
||||
- "**/*.go"
|
||||
- go.mod
|
||||
- go.sum
|
||||
- .traefik.yml
|
||||
- LICENSE*
|
||||
- README*
|
||||
# Exclude test files and vendor from release archive
|
||||
- "!**/*_test.go"
|
||||
- "!vendor/**"
|
||||
- "!docker/**"
|
||||
- "!integration/**"
|
||||
- "!regression/**"
|
||||
- "!examples/**"
|
||||
- "!docs/**"
|
||||
|
||||
checksum:
|
||||
name_template: "{{ .ProjectName }}_v{{ .Version }}_checksums.txt"
|
||||
algorithm: sha256
|
||||
|
||||
changelog:
|
||||
sort: asc
|
||||
filters:
|
||||
exclude:
|
||||
- "^docs:"
|
||||
- "^test:"
|
||||
- "^Merge"
|
||||
- "^WIP"
|
||||
- "^chore:"
|
||||
|
||||
release:
|
||||
github:
|
||||
owner: lukaszraczylo
|
||||
name: traefikoidc
|
||||
name_template: "v{{ .Version }}"
|
||||
draft: false
|
||||
prerelease: auto
|
||||
|
||||
signs:
|
||||
- cmd: cosign
|
||||
signature: "${artifact}.sigstore.json"
|
||||
args:
|
||||
- sign-blob
|
||||
- "--bundle=${signature}"
|
||||
- "${artifact}"
|
||||
- "--yes"
|
||||
artifacts: checksum
|
||||
output: true
|
||||
+774
-43
@@ -31,6 +31,7 @@ summary: |
|
||||
- Flexible configuration with multiple deployment scenarios
|
||||
- Memory-efficient operation with automatic cleanup
|
||||
- Extensive logging and debugging capabilities
|
||||
- Redis cache support for multi-replica deployments with automatic failover
|
||||
It supports various authentication scenarios including:
|
||||
|
||||
- Basic authentication with customizable callback and logout URLs
|
||||
@@ -73,7 +74,16 @@ testData:
|
||||
- admin
|
||||
- developer
|
||||
|
||||
forceHTTPS: false # Forces the use of HTTPS for all URLs (default: true for security)
|
||||
# Custom claim names for Auth0 and other providers with namespaced claims
|
||||
roleClaimName: roles # JWT claim name for extracting user roles (default: "roles")
|
||||
groupClaimName: groups # JWT claim name for extracting user groups (default: "groups")
|
||||
userIdentifierClaim: email # JWT claim for user identification (default: "email", alternatives: "sub", "oid", "upn", "preferred_username")
|
||||
|
||||
# ⚠️ CRITICAL for TLS termination scenarios (AWS ALB, Cloud Load Balancers, etc.)
|
||||
# When NOT specified in config: defaults to FALSE (Go zero value)
|
||||
# When running behind load balancer that terminates TLS: MUST set to TRUE
|
||||
# See: https://github.com/lukaszraczylo/traefikoidc/issues/82
|
||||
forceHTTPS: true # Forces HTTPS scheme for redirect URIs (default when not specified: false)
|
||||
logLevel: debug # Sets logging verbosity: debug, info, error (default: info)
|
||||
rateLimit: 100 # Maximum number of requests per second (default: 100, minimum: 10)
|
||||
|
||||
@@ -84,25 +94,36 @@ testData:
|
||||
- /metrics
|
||||
|
||||
headers: # Custom headers to set with templated values from claims and tokens
|
||||
# NOTE: If you encounter "can't evaluate field AccessToken in type bool" errors,
|
||||
# you may need to escape the templates. See the headers section in configuration below.
|
||||
# NOTE: Use double curly braces to escape template expressions in YAML
|
||||
# See the headers section in configuration below for details
|
||||
- name: "X-User-Email"
|
||||
value: "{{.Claims.email}}"
|
||||
value: "{{{{.Claims.email}}}}"
|
||||
- name: "X-User-ID"
|
||||
value: "{{.Claims.sub}}"
|
||||
value: "{{{{.Claims.sub}}}}"
|
||||
- name: "Authorization"
|
||||
value: "Bearer {{.AccessToken}}"
|
||||
value: "Bearer {{{{.AccessToken}}}}"
|
||||
- name: "X-User-Roles"
|
||||
value: "{{range $i, $e := .Claims.roles}}{{if $i}},{{end}}{{$e}}{{end}}"
|
||||
value: "{{{{range $i, $e := .Claims.roles}}}}{{{{if $i}}}},{{{{end}}}}{{{{$e}}}}{{{{end}}}}"
|
||||
|
||||
# Advanced parameters (usually discovered automatically from provider metadata)
|
||||
revocationURL: https://accounts.google.com/revoke # Endpoint for revoking tokens
|
||||
oidcEndSessionURL: https://accounts.google.com/logout # Provider's end session endpoint
|
||||
enablePKCE: false # Enables PKCE (Proof Key for Code Exchange) for additional security
|
||||
cookieDomain: "" # Explicit domain for session cookies (e.g., ".example.com" for multi-subdomain setups)
|
||||
cookiePrefix: "" # Custom prefix for cookie names (e.g., "_oidc_myapp_" for session isolation between middleware instances)
|
||||
sessionMaxAge: 86400 # Maximum session age in seconds (default: 86400 = 24 hours, 0 = use default)
|
||||
overrideScopes: false # When true, replaces default scopes instead of appending (default: false)
|
||||
refreshGracePeriodSeconds: 60 # Seconds before token expiry to attempt proactive refresh (default: 60)
|
||||
|
||||
|
||||
# Auth0 / Custom API Audience Configuration
|
||||
audience: "" # Custom audience for access token validation (default: clientID)
|
||||
strictAudienceValidation: false # Reject sessions with audience mismatch (prevents token confusion attacks)
|
||||
allowOpaqueTokens: false # Enable opaque (non-JWT) access token support via RFC 7662 introspection
|
||||
requireTokenIntrospection: false # Force introspection for opaque tokens (requires introspection endpoint)
|
||||
disableReplayDetection: false # Disable JTI replay detection for multi-replica deployments (default: false)
|
||||
allowPrivateIPAddresses: false # Allow private IP addresses in provider URLs for internal networks (default: false)
|
||||
minimalHeaders: false # Reduce forwarded headers to prevent 431 errors (default: false)
|
||||
|
||||
# Security Headers Configuration (enabled by default with 'default' profile)
|
||||
securityHeaders:
|
||||
enabled: true
|
||||
@@ -115,11 +136,53 @@ testData:
|
||||
- "https://*.example.com"
|
||||
corsAllowCredentials: true
|
||||
|
||||
# Cross-origin policies
|
||||
permissionsPolicy: "geolocation=(), camera=(), microphone=()"
|
||||
crossOriginEmbedderPolicy: "require-corp"
|
||||
crossOriginOpenerPolicy: "same-origin"
|
||||
crossOriginResourcePolicy: "same-origin"
|
||||
|
||||
# Custom headers
|
||||
customHeaders:
|
||||
X-Custom-Header: "production"
|
||||
X-API-Version: "v1"
|
||||
|
||||
# Example with Redis cache for multi-replica deployments
|
||||
testDataWithRedis:
|
||||
# Required OIDC parameters (same as standard configuration)
|
||||
providerURL: https://auth.example.com
|
||||
clientID: your-client-id
|
||||
clientSecret: your-client-secret
|
||||
callbackURL: /oauth2/callback
|
||||
sessionEncryptionKey: your-64-character-encryption-key-at-least-32-bytes
|
||||
|
||||
# Standard optional parameters
|
||||
logLevel: info
|
||||
allowedUserDomains:
|
||||
- company.com
|
||||
|
||||
# Redis cache configuration for multi-replica support
|
||||
redis:
|
||||
enabled: true # Enable Redis caching
|
||||
address: "redis:6379" # Redis server address
|
||||
password: "redis-password" # Redis authentication password
|
||||
db: 0 # Redis database number (0-15)
|
||||
keyPrefix: "traefikoidc:" # Prefix for all Redis keys
|
||||
cacheMode: "hybrid" # Cache mode: redis, hybrid, or memory
|
||||
poolSize: 20 # Maximum number of connections
|
||||
connectTimeout: 5 # Connection timeout in seconds
|
||||
readTimeout: 3 # Read operation timeout
|
||||
writeTimeout: 3 # Write operation timeout
|
||||
enableTLS: false # Use TLS for Redis connection
|
||||
tlsSkipVerify: false # Skip TLS certificate verification
|
||||
hybridL1Size: 500 # L1 cache size for hybrid mode
|
||||
hybridL1MemoryMB: 10 # L1 memory limit for hybrid mode
|
||||
enableCircuitBreaker: true # Enable circuit breaker
|
||||
circuitBreakerThreshold: 5 # Failures before opening circuit
|
||||
circuitBreakerTimeout: 60 # Timeout before retry (seconds)
|
||||
enableHealthCheck: true # Enable periodic health checks
|
||||
healthCheckInterval: 30 # Health check interval (seconds)
|
||||
|
||||
# --- Common Configuration Examples ---
|
||||
#
|
||||
# 🔒 HIGH-SECURITY CONFIGURATION
|
||||
@@ -169,11 +232,11 @@ testData:
|
||||
# corsAllowedOrigins: ["https://app.example.com"]
|
||||
# corsAllowedMethods: ["GET", "POST", "PUT", "DELETE", "OPTIONS"]
|
||||
# corsAllowedHeaders: ["Authorization", "Content-Type", "X-API-Key"]
|
||||
# headers: # Custom headers with OIDC claims
|
||||
# headers: # Custom headers with OIDC claims (use double curly braces)
|
||||
# - name: "X-User-Email"
|
||||
# value: "{{.Claims.email}}"
|
||||
# value: "{{{{.Claims.email}}}}"
|
||||
# - name: "X-User-ID"
|
||||
# value: "{{.Claims.sub}}"
|
||||
# value: "{{{{.Claims.sub}}}}"
|
||||
|
||||
# --- Provider Specific Configuration Examples ---
|
||||
#
|
||||
@@ -206,6 +269,8 @@ testData:
|
||||
# allowedRolesAndGroups: # Corresponds to 'Token Claim Name' in Keycloak mappers
|
||||
# - admin
|
||||
# - editor
|
||||
# # For internal Keycloak deployments with private IPs (Docker/Kubernetes internal):
|
||||
# # allowPrivateIPAddresses: true # Enable for private IP addresses like 192.168.x.x, 10.x.x.x
|
||||
# # Ensure Keycloak client mappers add 'email', 'roles', 'groups' etc. to the ID Token.
|
||||
# # See README.md "Provider Configuration Recommendations" for Keycloak.
|
||||
|
||||
@@ -227,6 +292,26 @@ testData:
|
||||
# - "AppRoleName"
|
||||
# # See README.md "Provider Configuration Recommendations" for Azure AD.
|
||||
|
||||
# --- Azure AD Users Without Email Example (Issue #95) ---
|
||||
# testDataAzureADNoEmail:
|
||||
# providerURL: https://login.microsoftonline.com/your-tenant-id/v2.0
|
||||
# clientID: your-azure-ad-client-id
|
||||
# clientSecret: your-azure-ad-client-secret
|
||||
# callbackURL: /oauth2/callback
|
||||
# sessionEncryptionKey: "a-very-secure-key-at-least-32-bytes-long-for-azure"
|
||||
# # Use 'sub' claim instead of 'email' for user identification
|
||||
# userIdentifierClaim: sub # or "oid", "upn", "preferred_username"
|
||||
# overrideScopes: true # Remove email scope if not needed
|
||||
# scopes:
|
||||
# - openid
|
||||
# - profile
|
||||
# - groups # For group-based access control
|
||||
# # When using non-email identifiers, allowedUsers matches against the claim value
|
||||
# allowedUsers:
|
||||
# - "abc12345-6789-0abc-def0-123456789abc" # Azure AD user object ID (sub or oid claim)
|
||||
# # NOTE: allowedUserDomains is ignored when userIdentifierClaim is not "email"
|
||||
# # See: https://github.com/lukaszraczylo/traefikoidc/issues/95
|
||||
|
||||
# --- Google Workspace / Google Cloud Identity Example ---
|
||||
# testDataGoogle:
|
||||
# providerURL: https://accounts.google.com # Standard Google OIDC endpoint
|
||||
@@ -312,6 +397,20 @@ testData:
|
||||
# clientSecret: your-auth0-client-secret # Store securely
|
||||
# callbackURL: /oauth2/callback
|
||||
# sessionEncryptionKey: "a-very-secure-key-at-least-32-bytes-long-for-auth0"
|
||||
#
|
||||
# # Auth0 Audience Configuration (for custom APIs)
|
||||
# # Scenario 1 (Recommended): Custom API with JWT access tokens
|
||||
# audience: "https://my-api.example.com" # Your API identifier from Auth0 dashboard
|
||||
# strictAudienceValidation: true # Enforce proper audience validation for security
|
||||
#
|
||||
# # Scenario 2 (Backward Compatible): Default audience (uses client_id)
|
||||
# # audience: "" # Leave empty or omit to use client_id as audience
|
||||
# # strictAudienceValidation: false # Allows fallback to ID token validation (logs warnings)
|
||||
#
|
||||
# # Scenario 3: Opaque (non-JWT) access tokens
|
||||
# # allowOpaqueTokens: true # Enable opaque token support
|
||||
# # requireTokenIntrospection: true # Require RFC 7662 token introspection
|
||||
#
|
||||
# scopes: # Defaults ["openid", "profile", "email"]. Add custom scopes if your Auth0 Rules/Actions require them.
|
||||
# - read:custom_data # Example custom scope
|
||||
# allowedRolesAndGroups: # Based on claims added via Auth0 Rules or Actions (e.g. namespaced claims)
|
||||
@@ -319,7 +418,7 @@ testData:
|
||||
# - editor
|
||||
# # Use Auth0 Rules or Actions to add custom claims (roles, permissions) to the ID Token.
|
||||
# # Ensure postLogoutRedirectURI is in Auth0 app's "Allowed Logout URLs".
|
||||
# # See README.md "Provider Configuration Recommendations" for Auth0.
|
||||
# # For detailed Auth0 audience configuration, see AUTH0_AUDIENCE_GUIDE.md
|
||||
|
||||
# --- Generic OIDC Provider Example ---
|
||||
# testDataGenericOIDC:
|
||||
@@ -448,9 +547,24 @@ configuration:
|
||||
forceHTTPS:
|
||||
type: boolean
|
||||
description: |
|
||||
Forces the use of HTTPS for all URLs.
|
||||
This is recommended for security in production environments.
|
||||
Default: true
|
||||
Forces HTTPS scheme for redirect URIs regardless of request headers or TLS state.
|
||||
|
||||
⚠️ CRITICAL CONFIGURATION for TLS Termination Scenarios:
|
||||
|
||||
When running Traefik behind a load balancer that terminates TLS (AWS ALB,
|
||||
Google Cloud Load Balancer, Azure Application Gateway, etc.), you MUST set
|
||||
this to true. Without it, redirect URIs will use http:// instead of https://,
|
||||
causing OAuth callback failures.
|
||||
|
||||
How it works:
|
||||
- When true: Always uses https:// for redirect URIs (highest priority)
|
||||
- When false: Detects scheme from X-Forwarded-Proto header or TLS state
|
||||
- When NOT specified: Defaults to false (Go zero value for bool)
|
||||
|
||||
Default: false (when not specified in configuration)
|
||||
Recommended: true (for production environments and TLS termination scenarios)
|
||||
|
||||
See: https://github.com/lukaszraczylo/traefikoidc/issues/82
|
||||
required: false
|
||||
|
||||
rateLimit:
|
||||
@@ -516,6 +630,38 @@ configuration:
|
||||
items:
|
||||
type: string
|
||||
|
||||
userIdentifierClaim:
|
||||
type: string
|
||||
description: |
|
||||
Specifies the JWT claim to use as the user identifier for authentication and authorization.
|
||||
|
||||
This allows authentication for users without email addresses, such as Azure AD service
|
||||
accounts or organizational accounts that don't have email attributes configured.
|
||||
|
||||
When set to a non-email claim (e.g., "sub", "oid", "upn"):
|
||||
- AllowedUsers will match against this claim value instead of email
|
||||
- AllowedUserDomains validation is skipped (domains only apply to email addresses)
|
||||
- The session stores this identifier as the user's identity
|
||||
- If the configured claim is missing, falls back to "sub" (required by OIDC spec)
|
||||
|
||||
Common values by provider:
|
||||
- Default: "email" (standard email-based identification)
|
||||
- Azure AD: "sub", "oid" (object ID), "upn" (User Principal Name), "preferred_username"
|
||||
- Generic OIDC: "sub" (always present per OIDC specification)
|
||||
- Keycloak: "sub", "preferred_username"
|
||||
|
||||
Example for Azure AD users without email:
|
||||
```yaml
|
||||
userIdentifierClaim: sub
|
||||
allowedUsers:
|
||||
- "abc123-user-object-id"
|
||||
- "xyz789-another-user-id"
|
||||
```
|
||||
|
||||
Default: "email"
|
||||
See: https://github.com/lukaszraczylo/traefikoidc/issues/95
|
||||
required: false
|
||||
|
||||
revocationURL:
|
||||
type: string
|
||||
description: |
|
||||
@@ -549,28 +695,101 @@ configuration:
|
||||
cookieDomain:
|
||||
type: string
|
||||
description: |
|
||||
Explicit domain for session cookies. This is important for multi-subdomain setups
|
||||
Explicit domain for session cookies. This is important for multi-subdomain setups
|
||||
and reverse proxy deployments to ensure consistent cookie handling.
|
||||
|
||||
|
||||
When set, all session cookies will use this domain. When not set, the domain
|
||||
is auto-detected from the request headers (X-Forwarded-Host or Host).
|
||||
|
||||
|
||||
Use a leading dot for subdomain-wide cookies (e.g., ".example.com" allows
|
||||
cookies to be shared between app.example.com, api.example.com, etc.).
|
||||
|
||||
|
||||
Use a specific domain for host-only cookies (e.g., "app.example.com" restricts
|
||||
cookies to that exact domain).
|
||||
|
||||
|
||||
This setting is crucial to prevent authentication issues like "CSRF token missing
|
||||
in session" errors that can occur when cookies are created with inconsistent domains.
|
||||
|
||||
|
||||
Examples:
|
||||
- ".example.com" - Allows all subdomains to share cookies
|
||||
- "app.example.com" - Restricts cookies to this specific host
|
||||
|
||||
|
||||
Default: "" (auto-detected from request headers)
|
||||
required: false
|
||||
|
||||
cookiePrefix:
|
||||
type: string
|
||||
description: |
|
||||
Custom prefix for session cookie names. This is essential for running multiple
|
||||
middleware instances with different authorization requirements on the same domain.
|
||||
|
||||
By default, all middleware instances use the same cookie names (_oidc_raczylo_m,
|
||||
_oidc_raczylo_a, etc.), which means they share session state. When you have
|
||||
multiple instances with different access restrictions (e.g., one for general users
|
||||
and one for admins), this session sharing can lead to authorization bypass issues.
|
||||
|
||||
Setting a unique cookiePrefix for each middleware instance ensures complete
|
||||
session isolation, preventing users authenticated via one middleware from
|
||||
automatically gaining access to routes protected by a different middleware.
|
||||
|
||||
The prefix is prepended to all session cookie names:
|
||||
- Main session cookie: {prefix}m
|
||||
- Access token cookie: {prefix}a
|
||||
- Refresh token cookie: {prefix}r
|
||||
- ID token cookie: {prefix}id
|
||||
|
||||
Examples:
|
||||
- "_oidc_userauth_" - For general user authentication middleware
|
||||
- "_oidc_adminauth_" - For admin-only authentication middleware
|
||||
- "_oidc_api_" - For API-specific authentication middleware
|
||||
|
||||
Security Note: Use different cookie prefixes AND different sessionEncryptionKey
|
||||
values for each middleware instance to ensure complete isolation.
|
||||
|
||||
Default: "_oidc_raczylo_" (standard prefix for backward compatibility)
|
||||
|
||||
See: https://github.com/lukaszraczylo/traefikoidc/issues/87
|
||||
required: false
|
||||
|
||||
sessionMaxAge:
|
||||
type: integer
|
||||
description: |
|
||||
Maximum session age in seconds before requiring re-authentication.
|
||||
|
||||
This setting controls how long a user's authentication session remains valid
|
||||
before they must authenticate again through the OIDC provider. The session
|
||||
age is tracked from the initial authentication time (created_at).
|
||||
|
||||
When a session exceeds this age:
|
||||
- The session is cleared and invalidated
|
||||
- The user is redirected to re-authenticate
|
||||
- All session cookies are removed
|
||||
|
||||
Use Cases:
|
||||
- High-security applications: Use shorter durations (e.g., 3600 = 1 hour)
|
||||
- Standard applications: Default 24 hours balances security and UX
|
||||
- Long-lived sessions: Extend for applications accessed infrequently
|
||||
(e.g., 604800 = 7 days, 2592000 = 30 days)
|
||||
|
||||
Security Considerations:
|
||||
- Shorter sessions provide better security but require more frequent logins
|
||||
- Longer sessions improve user experience but increase security risk
|
||||
- Consider your application's security requirements and user access patterns
|
||||
- This is independent of token refresh - tokens can be refreshed during the session
|
||||
|
||||
Common Values:
|
||||
- 3600 (1 hour) - High security applications
|
||||
- 28800 (8 hours) - Working day session
|
||||
- 86400 (24 hours) - Default, balances security and convenience
|
||||
- 604800 (7 days) - Weekly session for less frequently accessed apps
|
||||
- 2592000 (30 days) - Monthly session for infrequently used applications
|
||||
|
||||
Default: 86400 (24 hours)
|
||||
Minimum: 0 (uses default of 24 hours)
|
||||
|
||||
See: https://github.com/lukaszraczylo/traefikoidc/issues/91
|
||||
required: false
|
||||
|
||||
overrideScopes:
|
||||
type: boolean
|
||||
description: |
|
||||
@@ -588,16 +807,220 @@ configuration:
|
||||
type: integer
|
||||
description: |
|
||||
The number of seconds before a token expires to attempt proactive refresh.
|
||||
|
||||
|
||||
When a request is made and the access token will expire within this grace period,
|
||||
the middleware will attempt to refresh the token proactively. This helps prevent
|
||||
authentication interruptions for active users.
|
||||
|
||||
|
||||
Setting this to 0 disables proactive refresh (tokens are only refreshed after expiry).
|
||||
|
||||
|
||||
Default: 60 (1 minute before expiry)
|
||||
required: false
|
||||
|
||||
audience:
|
||||
type: string
|
||||
description: |
|
||||
Custom audience value for access token validation.
|
||||
|
||||
This configures the expected audience claim in access tokens. Per OAuth 2.0 and OIDC
|
||||
specifications:
|
||||
- ID tokens always have aud=client_id (per OIDC Core 1.0)
|
||||
- Access tokens can have custom audiences (e.g., API identifiers)
|
||||
|
||||
Auth0 Scenarios:
|
||||
1. Custom API audience (recommended): Set to your API identifier from Auth0
|
||||
Example: "https://my-api.example.com"
|
||||
Result: Access tokens will contain this audience
|
||||
|
||||
2. Default audience: Leave empty or omit (uses client_id)
|
||||
Result: Access tokens may not contain client_id, triggering warnings
|
||||
|
||||
3. Opaque tokens: Use with allowOpaqueTokens=true for non-JWT tokens
|
||||
|
||||
When configured and different from client_id, the middleware automatically adds
|
||||
the audience parameter to the authorize endpoint request.
|
||||
|
||||
Default: "" (uses client_id as audience)
|
||||
See: AUTH0_AUDIENCE_GUIDE.md for detailed configuration
|
||||
required: false
|
||||
|
||||
strictAudienceValidation:
|
||||
type: boolean
|
||||
description: |
|
||||
Enforce strict audience validation for access tokens.
|
||||
|
||||
When enabled, sessions are rejected if access token validation fails due to
|
||||
audience mismatch. This prevents falling back to ID token validation, addressing
|
||||
potential token confusion attacks where tokens intended for different APIs could
|
||||
be used to grant access.
|
||||
|
||||
Auth0 Scenario 2 Protection:
|
||||
- When true: Rejects sessions with mismatched access token audience
|
||||
- When false: Logs security warnings but allows fallback to ID token (backward compatible)
|
||||
|
||||
Security Recommendation:
|
||||
- Production environments: Set to true for maximum security
|
||||
- Development/testing: Can use false with monitoring of security warnings
|
||||
|
||||
This setting addresses security concerns where access tokens without proper
|
||||
audience claims could bypass API-specific authorization checks.
|
||||
|
||||
Default: false (backward compatible)
|
||||
See: https://github.com/lukaszraczylo/traefikoidc/issues/74
|
||||
required: false
|
||||
|
||||
allowOpaqueTokens:
|
||||
type: boolean
|
||||
description: |
|
||||
Enable acceptance of opaque (non-JWT) access tokens.
|
||||
|
||||
When enabled, the middleware accepts access tokens that are not in JWT format
|
||||
(3-part base64 structure). Opaque tokens are validated using OAuth 2.0 Token
|
||||
Introspection (RFC 7662) if the provider exposes an introspection endpoint.
|
||||
|
||||
Auth0 Scenario 3:
|
||||
Some Auth0 configurations issue opaque access tokens when no default API is
|
||||
configured. This setting allows those tokens to be validated.
|
||||
|
||||
Requirements:
|
||||
- Provider must support introspection_endpoint in OIDC discovery
|
||||
- Client must have appropriate introspection permissions
|
||||
|
||||
Validation Process:
|
||||
1. Detects opaque token (not 3-part JWT structure)
|
||||
2. Calls provider's introspection endpoint with client credentials
|
||||
3. Validates response (active status, expiration, audience if present)
|
||||
4. Caches result for 5 minutes or token expiry (whichever is shorter)
|
||||
5. Falls back to ID token validation if introspection unavailable
|
||||
(unless requireTokenIntrospection=true)
|
||||
|
||||
Default: false (only JWT access tokens accepted)
|
||||
See: AUTH0_AUDIENCE_GUIDE.md for configuration examples
|
||||
required: false
|
||||
|
||||
requireTokenIntrospection:
|
||||
type: boolean
|
||||
description: |
|
||||
Require token introspection for all opaque access tokens.
|
||||
|
||||
When enabled with allowOpaqueTokens=true, opaque tokens are rejected if:
|
||||
- Introspection endpoint is not available from provider metadata
|
||||
- Introspection request fails
|
||||
- Introspection response indicates token is not active
|
||||
|
||||
Security Levels:
|
||||
- requireTokenIntrospection=true + allowOpaqueTokens=true:
|
||||
Maximum security - rejects opaque tokens without successful introspection
|
||||
|
||||
- requireTokenIntrospection=false + allowOpaqueTokens=true:
|
||||
Backward compatible - falls back to ID token validation if introspection fails
|
||||
|
||||
- requireTokenIntrospection=true + allowOpaqueTokens=false:
|
||||
No effect - opaque tokens are already rejected
|
||||
|
||||
Recommended Configuration:
|
||||
When accepting opaque tokens, always set this to true for maximum security:
|
||||
```yaml
|
||||
allowOpaqueTokens: true
|
||||
requireTokenIntrospection: true
|
||||
```
|
||||
|
||||
Default: false (allows fallback to ID token)
|
||||
See: RFC 7662 OAuth 2.0 Token Introspection specification
|
||||
required: false
|
||||
|
||||
disableReplayDetection:
|
||||
type: boolean
|
||||
description: |
|
||||
Disable JTI-based replay attack detection for multi-replica deployments.
|
||||
|
||||
When running multiple Traefik replicas, each instance maintains its own in-memory
|
||||
JTI (JWT Token ID) cache. This causes false positives when the same valid token
|
||||
hits different replicas:
|
||||
- Request → Replica A → JTI added to cache → OK
|
||||
- Request → Replica B → JTI not in Replica B's cache → OK
|
||||
- Request → Replica A again → JTI found → FALSE POSITIVE "replay detected"
|
||||
|
||||
Security Impact:
|
||||
When disabled, the following validations remain active:
|
||||
- RSA/ECDSA signature verification
|
||||
- Token expiration (exp claim)
|
||||
- Issuer validation (iss claim)
|
||||
- Audience validation (aud claim)
|
||||
- Not-before validation (nbf claim)
|
||||
- Issued-at validation (iat claim)
|
||||
|
||||
Only the JTI replay check is skipped.
|
||||
|
||||
Recommendations:
|
||||
- Single-instance deployment: false (default, enables replay protection)
|
||||
- Multi-replica deployment: true (prevents false positives)
|
||||
- Production with shared cache: false (use Redis/Memcached for shared JTI cache)
|
||||
|
||||
Default: false (replay detection enabled)
|
||||
required: false
|
||||
|
||||
allowPrivateIPAddresses:
|
||||
type: boolean
|
||||
description: |
|
||||
Allow private IP addresses in OIDC provider URLs for internal network deployments.
|
||||
|
||||
By default, the plugin blocks URLs containing private IP address ranges
|
||||
(10.x.x.x, 172.16-31.x.x, 192.168.x.x) to prevent SSRF attacks and ensure
|
||||
OIDC providers are publicly accessible.
|
||||
|
||||
Enable this option when:
|
||||
- Your OIDC provider (e.g., Keycloak) runs on an internal network with private IPs
|
||||
- You don't have DNS resolution available for internal services
|
||||
- Your entire stack runs in a Docker network or Kubernetes cluster with private addressing
|
||||
|
||||
When enabled, the plugin will accept provider URLs like:
|
||||
- https://192.168.1.100:8443/auth/realms/your-realm
|
||||
- https://10.0.0.50:8080/realms/master
|
||||
- https://172.16.0.10/auth
|
||||
|
||||
Security Warning:
|
||||
Enabling this option reduces SSRF protection. Only use in trusted network
|
||||
environments where the OIDC provider is known and controlled. Loopback
|
||||
addresses (127.0.0.1, localhost, ::1) remain blocked even with this option enabled.
|
||||
|
||||
Default: false (private IPs are blocked for security)
|
||||
See: https://github.com/lukaszraczylo/traefikoidc/issues/97
|
||||
required: false
|
||||
|
||||
minimalHeaders:
|
||||
type: boolean
|
||||
description: |
|
||||
Reduce forwarded headers to prevent "431 Request Header Fields Too Large" errors.
|
||||
|
||||
When enabled, the middleware only forwards the X-Forwarded-User header and skips
|
||||
the larger authentication headers that can cause downstream services to reject
|
||||
requests due to header size limits (typically 8KB).
|
||||
|
||||
Headers when disabled (default):
|
||||
- X-Forwarded-User: User's email address (always set)
|
||||
- X-Auth-Request-Redirect: Original request URI
|
||||
- X-Auth-Request-User: User's email address
|
||||
- X-Auth-Request-Token: Full ID token (can be very large with many claims)
|
||||
- X-User-Groups: Comma-separated user groups (if configured)
|
||||
- X-User-Roles: Comma-separated user roles (if configured)
|
||||
|
||||
Headers when enabled:
|
||||
- X-Forwarded-User: User's email address (always set)
|
||||
- X-User-Groups: Comma-separated user groups (if configured, still forwarded)
|
||||
- X-User-Roles: Comma-separated user roles (if configured, still forwarded)
|
||||
- Custom templated headers (still processed)
|
||||
|
||||
Use this option when:
|
||||
- Downstream services return "431 Request Header Fields Too Large" errors
|
||||
- Your ID tokens are large (many claims, long group lists)
|
||||
- You don't need the full ID token forwarded to backend services
|
||||
- You want to reduce request overhead
|
||||
|
||||
Default: false (all headers forwarded for backward compatibility)
|
||||
See: https://github.com/lukaszraczylo/traefikoidc/issues/64
|
||||
required: false
|
||||
|
||||
headers:
|
||||
type: array
|
||||
description: |
|
||||
@@ -614,29 +1037,23 @@ configuration:
|
||||
IMPORTANT: Template Escaping
|
||||
If you encounter the error "can't evaluate field AccessToken in type bool" when
|
||||
starting Traefik, this means Traefik is trying to evaluate the template expressions
|
||||
before passing them to the plugin. To fix this, you need to escape the templates
|
||||
using one of these methods:
|
||||
before passing them to the plugin.
|
||||
|
||||
1. Use YAML literal style (recommended):
|
||||
headers:
|
||||
- name: "Authorization"
|
||||
value: |
|
||||
Bearer {{.AccessToken}}
|
||||
SOLUTION: You must escape the template expressions using double curly braces:
|
||||
|
||||
2. Use single quotes:
|
||||
headers:
|
||||
- name: "Authorization"
|
||||
value: 'Bearer {{.AccessToken}}'
|
||||
headers:
|
||||
- name: "Authorization"
|
||||
value: "Bearer {{{{.AccessToken}}}}"
|
||||
|
||||
3. For inline double quotes, escape the braces:
|
||||
headers:
|
||||
- name: "Authorization"
|
||||
value: "Bearer {{"{{.AccessToken}}"}}"
|
||||
This is the only reliable method that works consistently. Here's why:
|
||||
- The YAML parser converts {{{{ → {{ and }}}} → }}
|
||||
- Result: Bearer {{.AccessToken}} reaches the Go template engine correctly
|
||||
- Other methods (YAML literal style, single quotes) do NOT work reliably
|
||||
|
||||
Examples:
|
||||
- name: "X-User-Email", value: "{{.Claims.email}}"
|
||||
- name: "Authorization", value: "Bearer {{.AccessToken}}"
|
||||
- name: "X-User-Roles", value: "{{range $i, $e := .Claims.roles}}{{if $i}},{{end}}{{$e}}{{end}}"
|
||||
- name: "X-User-Email", value: "{{{{.Claims.email}}}}"
|
||||
- name: "Authorization", value: "Bearer {{{{.AccessToken}}}}"
|
||||
- name: "X-User-Roles", value: "{{{{range $i, $e := .Claims.roles}}}}{{{{if $i}}}},{{{{end}}}}{{{{$e}}}}{{{{end}}}}"
|
||||
required: false
|
||||
items:
|
||||
type: object
|
||||
@@ -899,3 +1316,317 @@ configuration:
|
||||
Remove the X-Powered-By header to hide technology stack information.
|
||||
Default: true
|
||||
required: false
|
||||
|
||||
permissionsPolicy:
|
||||
type: string
|
||||
description: |
|
||||
Permissions-Policy header to control browser feature permissions.
|
||||
This header allows you to control which features and APIs can be used.
|
||||
|
||||
Examples:
|
||||
- "geolocation=(), camera=(), microphone=()" (deny all)
|
||||
- "geolocation=(self), camera=()" (allow geolocation for same origin only)
|
||||
|
||||
Common directives: accelerometer, camera, geolocation, gyroscope,
|
||||
magnetometer, microphone, payment, usb
|
||||
required: false
|
||||
|
||||
crossOriginEmbedderPolicy:
|
||||
type: string
|
||||
description: |
|
||||
Cross-Origin-Embedder-Policy (COEP) header to prevent untrusted
|
||||
resources from being loaded.
|
||||
|
||||
Options:
|
||||
- "require-corp": Resources must explicitly grant permission
|
||||
- "credentialless": Load without credentials for cross-origin resources
|
||||
- "unsafe-none": No restrictions (default)
|
||||
|
||||
Required for certain browser features like SharedArrayBuffer.
|
||||
required: false
|
||||
|
||||
crossOriginOpenerPolicy:
|
||||
type: string
|
||||
description: |
|
||||
Cross-Origin-Opener-Policy (COOP) header to isolate browsing context
|
||||
from cross-origin windows.
|
||||
|
||||
Options:
|
||||
- "same-origin": Isolate from cross-origin documents
|
||||
- "same-origin-allow-popups": Allow popups that don't set COOP
|
||||
- "unsafe-none": No isolation (default)
|
||||
|
||||
Helps prevent cross-origin attacks and Spectre-like vulnerabilities.
|
||||
required: false
|
||||
|
||||
crossOriginResourcePolicy:
|
||||
type: string
|
||||
description: |
|
||||
Cross-Origin-Resource-Policy (CORP) header to control which origins
|
||||
can load this resource.
|
||||
|
||||
Options:
|
||||
- "same-origin": Only same-origin requests can load the resource
|
||||
- "same-site": Only same-site requests can load the resource
|
||||
- "cross-origin": Any origin can load the resource (default)
|
||||
|
||||
Prevents your resources from being embedded on other sites.
|
||||
required: false
|
||||
|
||||
redis:
|
||||
type: object
|
||||
description: |
|
||||
Optional Redis cache configuration for multi-replica deployments.
|
||||
|
||||
When running multiple Traefik instances, Redis provides shared caching to:
|
||||
- Prevent JTI replay detection false positives across replicas
|
||||
- Share token verification results between instances
|
||||
- Maintain consistent session state across the cluster
|
||||
- Improve performance by reducing redundant OIDC provider calls
|
||||
|
||||
Features:
|
||||
- Automatic failover to memory-only mode when Redis is unavailable
|
||||
- Circuit breaker pattern for resilience against Redis failures
|
||||
- Health checking with automatic recovery
|
||||
- Multiple cache modes: redis-only, hybrid (L1 memory + L2 Redis), memory-only
|
||||
- Configurable timeouts and connection pooling
|
||||
- TLS support for secure Redis connections
|
||||
|
||||
The middleware gracefully handles Redis failures by falling back to in-memory
|
||||
caching, ensuring your authentication flow continues even during Redis outages.
|
||||
|
||||
Example configuration:
|
||||
```yaml
|
||||
redis:
|
||||
enabled: true
|
||||
address: "redis:6379"
|
||||
cacheMode: "hybrid"
|
||||
enableCircuitBreaker: true
|
||||
```
|
||||
required: false
|
||||
properties:
|
||||
enabled:
|
||||
type: boolean
|
||||
description: |
|
||||
Enable Redis caching for distributed session and token management.
|
||||
When enabled, the middleware will attempt to connect to Redis and use it
|
||||
for shared state across multiple Traefik instances.
|
||||
|
||||
Default: false
|
||||
required: false
|
||||
|
||||
address:
|
||||
type: string
|
||||
description: |
|
||||
Redis server address in host:port format.
|
||||
|
||||
Examples:
|
||||
- "redis:6379" (Docker/Kubernetes service)
|
||||
- "localhost:6379" (local Redis)
|
||||
- "redis.example.com:6380" (custom host/port)
|
||||
- "redis-cluster.default.svc.cluster.local:6379" (Kubernetes)
|
||||
|
||||
Required when Redis is enabled.
|
||||
required: false
|
||||
|
||||
password:
|
||||
type: string
|
||||
description: |
|
||||
Password for Redis authentication.
|
||||
Leave empty if Redis doesn't require authentication.
|
||||
|
||||
For Kubernetes deployments, you can use secret references:
|
||||
urn:k8s:secret:namespace:secret-name:key
|
||||
|
||||
Default: "" (no authentication)
|
||||
required: false
|
||||
|
||||
db:
|
||||
type: integer
|
||||
description: |
|
||||
Redis database number to use (0-15).
|
||||
Different databases can be used to isolate data between environments.
|
||||
|
||||
Default: 0
|
||||
required: false
|
||||
|
||||
keyPrefix:
|
||||
type: string
|
||||
description: |
|
||||
Prefix for all Redis keys created by this middleware.
|
||||
Useful for:
|
||||
- Avoiding key collisions with other applications
|
||||
- Identifying keys for monitoring/debugging
|
||||
- Supporting multiple environments in the same Redis instance
|
||||
|
||||
Default: "traefikoidc:"
|
||||
required: false
|
||||
|
||||
cacheMode:
|
||||
type: string
|
||||
description: |
|
||||
Determines the caching strategy:
|
||||
|
||||
- "redis": Redis-only caching. All cache operations go directly to Redis.
|
||||
Best for: Consistent state across all replicas, minimal memory usage.
|
||||
|
||||
- "hybrid": Two-tier caching with in-memory L1 and Redis L2.
|
||||
Best for: High performance with shared state, reduced Redis load.
|
||||
L1 provides fast local cache, L2 provides shared state.
|
||||
|
||||
- "memory": Memory-only caching (Redis disabled even if configured).
|
||||
Best for: Single instance deployments, development/testing.
|
||||
|
||||
Default: "redis" (when Redis is enabled)
|
||||
required: false
|
||||
enum:
|
||||
- redis
|
||||
- hybrid
|
||||
- memory
|
||||
|
||||
poolSize:
|
||||
type: integer
|
||||
description: |
|
||||
Maximum number of socket connections to Redis.
|
||||
Higher values allow more concurrent operations but consume more resources.
|
||||
|
||||
Recommendations:
|
||||
- Small deployments: 10-20
|
||||
- Medium deployments: 20-50
|
||||
- Large deployments: 50-100
|
||||
|
||||
Default: 10
|
||||
required: false
|
||||
|
||||
connectTimeout:
|
||||
type: integer
|
||||
description: |
|
||||
Timeout in seconds for establishing new connections to Redis.
|
||||
Should be higher than network latency but low enough to fail fast.
|
||||
|
||||
Default: 5 seconds
|
||||
required: false
|
||||
|
||||
readTimeout:
|
||||
type: integer
|
||||
description: |
|
||||
Timeout in seconds for Redis read operations.
|
||||
Includes the time to send the command, wait for Redis to process it,
|
||||
and receive the response.
|
||||
|
||||
Default: 3 seconds
|
||||
required: false
|
||||
|
||||
writeTimeout:
|
||||
type: integer
|
||||
description: |
|
||||
Timeout in seconds for Redis write operations.
|
||||
Should account for network latency and Redis persistence settings.
|
||||
|
||||
Default: 3 seconds
|
||||
required: false
|
||||
|
||||
enableTLS:
|
||||
type: boolean
|
||||
description: |
|
||||
Enable TLS encryption for Redis connections.
|
||||
Required when connecting to Redis instances that enforce TLS,
|
||||
such as AWS ElastiCache with encryption in transit.
|
||||
|
||||
Default: false
|
||||
required: false
|
||||
|
||||
tlsSkipVerify:
|
||||
type: boolean
|
||||
description: |
|
||||
Skip TLS certificate verification for Redis connections.
|
||||
|
||||
⚠️ WARNING: Only use in development environments.
|
||||
This option bypasses certificate validation and should never be used
|
||||
in production as it's vulnerable to man-in-the-middle attacks.
|
||||
|
||||
Default: false
|
||||
required: false
|
||||
|
||||
hybridL1Size:
|
||||
type: integer
|
||||
description: |
|
||||
Maximum number of items in the L1 (in-memory) cache for hybrid mode.
|
||||
Controls how many cache entries are kept in local memory before eviction.
|
||||
|
||||
Only applies when cacheMode is "hybrid".
|
||||
|
||||
Default: 500
|
||||
required: false
|
||||
|
||||
hybridL1MemoryMB:
|
||||
type: integer
|
||||
description: |
|
||||
Maximum memory in megabytes for L1 cache in hybrid mode.
|
||||
The cache will start evicting items when this limit is approached.
|
||||
|
||||
Only applies when cacheMode is "hybrid".
|
||||
|
||||
Default: 10 MB
|
||||
required: false
|
||||
|
||||
enableCircuitBreaker:
|
||||
type: boolean
|
||||
description: |
|
||||
Enable circuit breaker pattern for Redis connection failures.
|
||||
|
||||
When enabled, the middleware will:
|
||||
1. Track Redis operation failures
|
||||
2. Open the circuit after threshold failures (stop trying Redis)
|
||||
3. Fall back to in-memory caching
|
||||
4. Periodically attempt to reconnect (half-open state)
|
||||
5. Resume Redis operations when connection recovers
|
||||
|
||||
This prevents cascading failures and improves resilience.
|
||||
|
||||
Default: true
|
||||
required: false
|
||||
|
||||
circuitBreakerThreshold:
|
||||
type: integer
|
||||
description: |
|
||||
Number of consecutive Redis failures before opening the circuit.
|
||||
Lower values make the system more sensitive to Redis issues,
|
||||
higher values tolerate more failures before switching to fallback.
|
||||
|
||||
Default: 5
|
||||
required: false
|
||||
|
||||
circuitBreakerTimeout:
|
||||
type: integer
|
||||
description: |
|
||||
Time in seconds to wait before attempting to close the circuit.
|
||||
After this timeout, the circuit breaker will allow one test request
|
||||
to Redis. If successful, normal operations resume.
|
||||
|
||||
Default: 60 seconds
|
||||
required: false
|
||||
|
||||
enableHealthCheck:
|
||||
type: boolean
|
||||
description: |
|
||||
Enable periodic health checks for Redis connection.
|
||||
|
||||
Health checks:
|
||||
- Run in the background at regular intervals
|
||||
- Detect Redis availability without affecting request processing
|
||||
- Automatically reconnect when Redis becomes available
|
||||
- Update circuit breaker state based on health status
|
||||
|
||||
Default: true
|
||||
required: false
|
||||
|
||||
healthCheckInterval:
|
||||
type: integer
|
||||
description: |
|
||||
Interval in seconds between Redis health checks.
|
||||
Lower values detect issues faster but increase Redis load.
|
||||
Higher values reduce overhead but delay failure detection.
|
||||
|
||||
Default: 30 seconds
|
||||
required: false
|
||||
|
||||
@@ -8,6 +8,8 @@ The Traefik OIDC middleware provides a complete OIDC authentication solution wit
|
||||
|
||||
- **Universal provider support**: Works with 9+ OIDC providers including Google, Azure AD, Auth0, Okta, Keycloak, AWS Cognito, GitLab, and more
|
||||
- **Automatic provider detection**: Automatically detects and configures provider-specific settings
|
||||
- **Dynamic Client Registration (RFC 7591)**: Automatic client registration with OIDC providers without manual pre-registration
|
||||
- **Automatic scope filtering**: Intelligently filters OAuth scopes based on provider capabilities declared in OIDC discovery documents, preventing authentication failures with unsupported scopes
|
||||
- **Security headers**: Comprehensive security headers with CORS, CSP, HSTS, and custom profiles
|
||||
- **Domain restrictions**: Limit access to specific email domains or individual users
|
||||
- **Role-based access control**: Restrict access based on roles and groups from OIDC claims
|
||||
@@ -75,11 +77,24 @@ experimental:
|
||||
plugins:
|
||||
traefikoidc:
|
||||
moduleName: github.com/lukaszraczylo/traefikoidc
|
||||
version: v0.2.1 # Use the latest version
|
||||
version: v0.7.10 # Use the latest version
|
||||
```
|
||||
|
||||
2. Configure the middleware in your dynamic configuration (see examples below).
|
||||
|
||||
### Verifying Release Signatures
|
||||
|
||||
All release checksums are signed with [cosign](https://github.com/sigstore/cosign) using keyless signing. To verify:
|
||||
|
||||
```bash
|
||||
# Download the checksum file and its sigstore bundle from the release
|
||||
cosign verify-blob \
|
||||
--certificate-identity-regexp "https://github.com/lukaszraczylo/traefikoidc/.*" \
|
||||
--certificate-oidc-issuer "https://token.actions.githubusercontent.com" \
|
||||
--bundle "traefikoidc_v<version>_checksums.txt.sigstore.json" \
|
||||
traefikoidc_v<version>_checksums.txt
|
||||
```
|
||||
|
||||
### Local Development with Docker Compose
|
||||
|
||||
For local development or testing, you can use the provided Docker Compose setup:
|
||||
@@ -114,19 +129,46 @@ The middleware supports the following configuration options:
|
||||
| `scopes` | OAuth 2.0 scopes to use for authentication | `["openid", "profile", "email"]` (always included by default) | `["roles", "custom_scope"]` (appended to defaults) |
|
||||
| `overrideScopes` | When true, replaces default scopes with provided scopes instead of appending | `false` | `true` (use only the scopes explicitly provided) |
|
||||
| `logLevel` | Sets the logging verbosity | `info` | `debug`, `info`, `error` |
|
||||
| `forceHTTPS` | Forces the use of HTTPS for all URLs | `true` | `true`, `false` |
|
||||
| `forceHTTPS` | Forces HTTPS scheme for redirect URIs (**REQUIRED** for TLS termination at load balancer like AWS ALB) | `false` (when not specified) | `true`, `false` |
|
||||
| `rateLimit` | Sets the maximum number of requests per second | `100` | `500` |
|
||||
| `excludedURLs` | Lists paths that bypass authentication | none | `["/health", "/metrics", "/public"]` |
|
||||
| `allowedUserDomains` | Restricts access to specific email domains | none | `["company.com", "subsidiary.com"]` |
|
||||
| `allowedUsers` | A list of specific email addresses that are allowed access | none | `["user1@example.com", "user2@another.org"]` |
|
||||
| `allowedRolesAndGroups` | Restricts access to users with specific roles or groups | none | `["admin", "developer"]` |
|
||||
| `roleClaimName` | JWT claim name for extracting user roles (supports namespaced claims for Auth0) | `"roles"` | `"https://myapp.com/roles"`, `"user_roles"` |
|
||||
| `groupClaimName` | JWT claim name for extracting user groups (supports namespaced claims for Auth0) | `"groups"` | `"https://myapp.com/groups"`, `"user_groups"` |
|
||||
| `userIdentifierClaim` | JWT claim to use as user identifier (for users without email, e.g., Azure AD service accounts) | `"email"` | `"sub"`, `"oid"`, `"upn"`, `"preferred_username"` |
|
||||
| `revocationURL` | The endpoint for revoking tokens | auto-discovered | `https://accounts.google.com/revoke` |
|
||||
| `oidcEndSessionURL` | The provider's end session endpoint | auto-discovered | `https://accounts.google.com/logout` |
|
||||
| `enablePKCE` | Enables PKCE (Proof Key for Code Exchange) for authorization code flow | `false` | `true`, `false` |
|
||||
| `refreshGracePeriodSeconds` | Seconds before token expiry to attempt proactive refresh | `60` | `120` |
|
||||
| `cookieDomain` | Explicit domain for session cookies (important for multi-subdomain setups) | auto-detected | `.example.com`, `app.example.com` |
|
||||
| `cookiePrefix` | Custom prefix for session cookie names (for isolating multiple middleware instances) | `_oidc_raczylo_` | `_oidc_userauth_`, `_oidc_admin_` |
|
||||
| `sessionMaxAge` | Maximum session age in seconds before requiring re-authentication | `86400` (24 hours) | `3600` (1 hour), `604800` (7 days) |
|
||||
| `audience` | Custom audience for access token validation (for Auth0 custom APIs, etc.) | `clientID` | `https://my-api.example.com` |
|
||||
| `strictAudienceValidation` | Reject sessions with access token audience mismatch (prevents token confusion attacks) | `false` | `true` |
|
||||
| `allowOpaqueTokens` | Enable opaque (non-JWT) access token support via RFC 7662 introspection | `false` | `true` |
|
||||
| `requireTokenIntrospection` | Require introspection for opaque tokens (force validation, no fallback) | `false` | `true` |
|
||||
| `headers` | Custom HTTP headers with templates that can access OIDC claims and tokens | none | See "Templated Headers" section |
|
||||
| `securityHeaders` | Configure security headers including CSP, HSTS, CORS, and custom headers | enabled with default profile | See "Security Headers Configuration" section |
|
||||
| `disableReplayDetection` | Disable JTI-based replay attack detection for multi-replica deployments | `false` | `true` |
|
||||
| `allowPrivateIPAddresses` | Allow private IP addresses in provider URLs (for internal networks with Keycloak, etc.) | `false` | `true` |
|
||||
| `minimalHeaders` | Reduce forwarded headers to prevent "431 Request Header Fields Too Large" errors | `false` | `true` |
|
||||
| `redis` | Redis cache configuration for distributed deployments | disabled | See "Redis Cache" section |
|
||||
|
||||
> **⚠️ IMPORTANT - TLS Termination at Load Balancer:**
|
||||
>
|
||||
> If you're running Traefik behind a load balancer (AWS ALB, Google Cloud Load Balancer, Azure Application Gateway, etc.) that terminates TLS:
|
||||
> - **You MUST set `forceHTTPS: true`** in your configuration
|
||||
> - Without this setting, redirect URIs will use `http://` instead of `https://`, causing OAuth callback failures
|
||||
> - This is especially critical for AWS ALB which may overwrite the `X-Forwarded-Proto` header
|
||||
>
|
||||
> **Default behavior:**
|
||||
> - When `forceHTTPS` is **not specified** in your config → defaults to `false` (Go zero value)
|
||||
> - When `forceHTTPS: true` is explicitly set → always uses `https://` for redirect URIs
|
||||
> - When `forceHTTPS: false` is explicitly set → scheme detection based on headers/TLS
|
||||
>
|
||||
> See [GitHub Issue #82](https://github.com/lukaszraczylo/traefikoidc/issues/82) for details.
|
||||
|
||||
## Scope Configuration
|
||||
|
||||
@@ -201,6 +243,103 @@ scopes: []
|
||||
|
||||
The default append behavior ensures essential OIDC scopes are always present, while the override mode gives you complete control over the exact scopes requested from the provider.
|
||||
|
||||
## Auth0 Audience Validation & Security
|
||||
|
||||
The middleware provides comprehensive support for Auth0 audience validation to prevent token confusion attacks. Auth0 can issue tokens in three different scenarios, each requiring specific configuration.
|
||||
|
||||
### Understanding Token Audiences
|
||||
|
||||
Per OAuth 2.0 and OIDC specifications:
|
||||
- **ID Tokens**: MUST have `aud = client_id` (OIDC Core 1.0 spec)
|
||||
- **Access Tokens**: Can have custom audiences (e.g., API identifiers)
|
||||
|
||||
Proper audience validation prevents **token confusion attacks** where a token intended for one API is used to access another API.
|
||||
|
||||
### Auth0 Scenarios
|
||||
|
||||
#### Scenario 1: Custom API Audience ✅ (RECOMMENDED)
|
||||
|
||||
**Configuration:**
|
||||
```yaml
|
||||
audience: "https://my-api.example.com" # Your API identifier from Auth0
|
||||
strictAudienceValidation: true # Enforce strict validation
|
||||
```
|
||||
|
||||
**Result**: Fully secure, OIDC compliant with proper access token audience validation.
|
||||
|
||||
#### Scenario 2: Default Audience ⚠️ (USE WITH CAUTION)
|
||||
|
||||
**Configuration:**
|
||||
```yaml
|
||||
# audience not specified (defaults to client_id)
|
||||
strictAudienceValidation: true # Recommended: reject mismatched tokens
|
||||
```
|
||||
|
||||
**Behavior**: Access tokens may not contain client_id in audience, triggering security warnings. Set `strictAudienceValidation: true` to reject such sessions.
|
||||
|
||||
#### Scenario 3: Opaque Access Tokens ✅ (SUPPORTED)
|
||||
|
||||
**Configuration:**
|
||||
```yaml
|
||||
allowOpaqueTokens: true # Enable opaque token support
|
||||
requireTokenIntrospection: true # Require introspection (recommended)
|
||||
```
|
||||
|
||||
**Result**: Secure with OAuth 2.0 Token Introspection (RFC 7662).
|
||||
|
||||
### Security Configuration Options
|
||||
|
||||
| Option | Purpose | Recommended Value |
|
||||
|--------|---------|-------------------|
|
||||
| `audience` | Expected audience for access tokens | Your API identifier or leave empty |
|
||||
| `strictAudienceValidation` | Reject sessions with audience mismatch | `true` for production |
|
||||
| `allowOpaqueTokens` | Accept non-JWT access tokens | `true` if provider issues opaque tokens |
|
||||
| `requireTokenIntrospection` | Force introspection for opaque tokens | `true` when `allowOpaqueTokens=true` |
|
||||
|
||||
### Complete Auth0 Configuration Examples
|
||||
|
||||
**Production Configuration (Scenario 1):**
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-auth0-secure
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://your-auth0-domain.auth0.com
|
||||
clientID: your-auth0-client-id
|
||||
clientSecret: your-auth0-client-secret
|
||||
sessionEncryptionKey: your-secure-encryption-key-min-32-chars
|
||||
callbackURL: /oauth2/callback
|
||||
audience: "https://my-api.example.com"
|
||||
strictAudienceValidation: true
|
||||
allowedRolesAndGroups:
|
||||
- "https://your-app.com/roles:admin"
|
||||
- editor
|
||||
```
|
||||
|
||||
**Opaque Token Configuration (Scenario 3):**
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-auth0-opaque
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://your-auth0-domain.auth0.com
|
||||
clientID: your-auth0-client-id
|
||||
clientSecret: your-auth0-client-secret
|
||||
sessionEncryptionKey: your-secure-encryption-key-min-32-chars
|
||||
callbackURL: /oauth2/callback
|
||||
allowOpaqueTokens: true
|
||||
requireTokenIntrospection: true
|
||||
strictAudienceValidation: true
|
||||
```
|
||||
|
||||
For detailed Auth0 configuration including all three scenarios, troubleshooting, and security best practices, see **[AUTH0_AUDIENCE_GUIDE.md](docs/AUTH0_AUDIENCE_GUIDE.md)**.
|
||||
|
||||
## Security Headers Configuration
|
||||
|
||||
The middleware includes comprehensive security headers support to protect your applications against common web vulnerabilities. Security headers are applied to all authenticated responses.
|
||||
@@ -319,6 +458,10 @@ securityHeaders:
|
||||
| `customHeaders` | Additional custom headers | `{}` | `{"X-Custom": "value"}` |
|
||||
| `disableServerHeader` | Remove Server header | `true` | `true`, `false` |
|
||||
| `disablePoweredByHeader` | Remove X-Powered-By header | `true` | `true`, `false` |
|
||||
| `permissionsPolicy` | Permissions-Policy header | `` | `"geolocation=(), camera=(), microphone=()"` |
|
||||
| `crossOriginEmbedderPolicy` | Cross-Origin-Embedder-Policy header | `` | `"require-corp"`, `"credentialless"`, `"unsafe-none"` |
|
||||
| `crossOriginOpenerPolicy` | Cross-Origin-Opener-Policy header | `` | `"same-origin"`, `"same-origin-allow-popups"`, `"unsafe-none"` |
|
||||
| `crossOriginResourcePolicy` | Cross-Origin-Resource-Policy header | `` | `"same-origin"`, `"same-site"`, `"cross-origin"` |
|
||||
|
||||
### CORS Wildcard Support
|
||||
|
||||
@@ -390,6 +533,316 @@ securityHeaders:
|
||||
corsAllowedOrigins: ["http://localhost:*"]
|
||||
```
|
||||
|
||||
### Multi-Replica Deployment Configuration
|
||||
|
||||
When running multiple Traefik replicas with the OIDC plugin, you may encounter false positive replay detection errors. Each replica maintains its own in-memory JTI (JWT Token ID) cache, causing legitimate token reuse to be flagged as replay attacks.
|
||||
|
||||
**Problem**: When the same valid token hits different replicas:
|
||||
- Request → Replica A → JTI added to Replica A's cache ✓
|
||||
- Request → Replica B → JTI NOT in Replica B's cache ✓
|
||||
- Request → Replica A → ❌ **FALSE POSITIVE**: "token replay detected"
|
||||
|
||||
**Solution 1 (Simple)**: Disable replay detection for distributed deployments:
|
||||
|
||||
```yaml
|
||||
disableReplayDetection: true # Disable JTI replay detection for multi-replica setups
|
||||
```
|
||||
|
||||
**Solution 2 (Recommended)**: Use Redis cache backend for shared state (see [Redis Cache](#redis-cache-optional) section)
|
||||
|
||||
**Security Note**: When `disableReplayDetection: true`:
|
||||
- ✅ Token signatures still validated
|
||||
- ✅ Expiration still checked
|
||||
- ✅ All other claims still verified
|
||||
- ❌ JTI replay check **skipped**
|
||||
|
||||
**Example Configuration**:
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-multi-replica
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://accounts.google.com
|
||||
clientID: your-client-id
|
||||
clientSecret: your-client-secret
|
||||
sessionEncryptionKey: your-secure-encryption-key-min-32-chars
|
||||
callbackURL: /oauth2/callback
|
||||
disableReplayDetection: true # Required for multi-replica deployments without Redis
|
||||
```
|
||||
|
||||
**Recommendation**: For single-instance deployments, leave this setting at `false` (default) to maintain replay attack protection. For multi-replica deployments, use the Redis cache backend for proper replay detection across all instances.
|
||||
|
||||
## Redis Cache (Optional)
|
||||
|
||||
The plugin supports optional Redis caching for multi-replica deployments. This solves issues with JTI replay detection and session management when running multiple Traefik instances behind a load balancer.
|
||||
|
||||
> **✨ Yaegi Compatible**: Redis support is implemented using a pure-Go RESP protocol client that works seamlessly with Traefik's Yaegi interpreter (no `unsafe` package). Full Redis functionality is available for both dynamic plugin loading and pre-compiled deployments.
|
||||
|
||||
### Why Use Redis Cache?
|
||||
|
||||
When running multiple Traefik replicas, each instance maintains its own in-memory cache for:
|
||||
- JTI (JWT Token ID) replay detection
|
||||
- Session data
|
||||
- Token metadata
|
||||
|
||||
Without a shared cache, you may experience:
|
||||
- False positive replay detection errors
|
||||
- Session inconsistencies between replicas
|
||||
- Users needing to re-authenticate when hitting different instances
|
||||
|
||||
### Basic Configuration
|
||||
|
||||
Redis is configured through Traefik's dynamic configuration (YAML, labels, etc.):
|
||||
|
||||
```yaml
|
||||
# Enable Redis cache in your middleware configuration
|
||||
redis:
|
||||
enabled: true
|
||||
address: "localhost:6379"
|
||||
password: "your-password" # Optional
|
||||
db: 0
|
||||
keyPrefix: "traefikoidc:"
|
||||
```
|
||||
|
||||
### Configuration Priority
|
||||
|
||||
The plugin uses the following priority for Redis configuration:
|
||||
|
||||
1. **Traefik Dynamic Configuration** (PRIMARY) - Configure via YAML files or Docker/Kubernetes labels
|
||||
2. **Environment Variables** (FALLBACK) - Used only when not set in Traefik config
|
||||
|
||||
This approach allows you to manage all settings through Traefik's configuration system while maintaining backward compatibility with environment variables.
|
||||
|
||||
### Configuration Options
|
||||
|
||||
| Parameter | Description | Default | Example |
|
||||
|-----------|-------------|---------|---------|
|
||||
| `enabled` | Enable Redis caching | `false` | `true` |
|
||||
| `address` | Redis server address | - | `redis:6379` |
|
||||
| `password` | Redis password | - | `YOUR_PASSWORD` |
|
||||
| `db` | Database number | `0` | `1` |
|
||||
| `keyPrefix` | Key prefix for namespacing | `traefikoidc:` | `myapp:` |
|
||||
| `cacheMode` | Cache mode: `redis`, `hybrid`, `memory` | `redis` | `hybrid` |
|
||||
| `poolSize` | Connection pool size | `10` | `20` |
|
||||
| `connectTimeout` | Connection timeout (seconds) | `5` | `10` |
|
||||
| `readTimeout` | Read timeout (seconds) | `3` | `5` |
|
||||
| `writeTimeout` | Write timeout (seconds) | `3` | `5` |
|
||||
| `enableTLS` | Enable TLS | `false` | `true` |
|
||||
| `tlsSkipVerify` | Skip TLS verification | `false` | `true` |
|
||||
| `enableCircuitBreaker` | Circuit breaker for failures | `true` | `true` |
|
||||
| `circuitBreakerThreshold` | Failures before circuit opens | `5` | `10` |
|
||||
| `circuitBreakerTimeout` | Circuit reset timeout (seconds) | `60` | `30` |
|
||||
| `enableHealthCheck` | Periodic health checks | `true` | `true` |
|
||||
| `healthCheckInterval` | Health check interval (seconds) | `30` | `60` |
|
||||
|
||||
### Environment Variables (Fallback)
|
||||
|
||||
If not configured through Traefik, these environment variables can be used as fallback:
|
||||
|
||||
- `REDIS_ENABLED` - Enable Redis cache
|
||||
- `REDIS_ADDRESS` - Redis server address
|
||||
- `REDIS_PASSWORD` - Redis password
|
||||
- `REDIS_DB` - Database number
|
||||
- `REDIS_KEY_PREFIX` - Key prefix
|
||||
- `REDIS_CACHE_MODE` - Cache mode
|
||||
- `REDIS_POOL_SIZE` - Connection pool size
|
||||
- `REDIS_CONNECT_TIMEOUT` - Connection timeout
|
||||
- `REDIS_READ_TIMEOUT` - Read timeout
|
||||
- `REDIS_WRITE_TIMEOUT` - Write timeout
|
||||
- `REDIS_ENABLE_TLS` - Enable TLS
|
||||
- `REDIS_TLS_SKIP_VERIFY` - Skip TLS verification
|
||||
|
||||
### Cache Modes
|
||||
|
||||
The plugin supports three cache modes:
|
||||
|
||||
- **memory** (default): In-memory cache only, suitable for single-instance deployments
|
||||
- **redis**: Redis-only cache, all data stored in Redis
|
||||
- **hybrid**: Two-tier caching with local memory cache + Redis backend for optimal performance
|
||||
|
||||
### Example Configurations
|
||||
|
||||
#### Docker Compose with Redis
|
||||
|
||||
```yaml
|
||||
services:
|
||||
redis:
|
||||
image: redis:alpine
|
||||
command: redis-server --requirepass yourpassword
|
||||
|
||||
traefik:
|
||||
image: traefik:v3.2
|
||||
# ... rest of your Traefik configuration
|
||||
labels:
|
||||
# Configure the OIDC middleware with Redis
|
||||
- "traefik.http.middlewares.oidc.plugin.traefikoidc.clientID=your-client-id"
|
||||
- "traefik.http.middlewares.oidc.plugin.traefikoidc.clientSecret=your-secret"
|
||||
- "traefik.http.middlewares.oidc.plugin.traefikoidc.providerURL=https://auth.example.com"
|
||||
- "traefik.http.middlewares.oidc.plugin.traefikoidc.callbackURL=/oauth2/callback"
|
||||
- "traefik.http.middlewares.oidc.plugin.traefikoidc.sessionEncryptionKey=your-64-char-key"
|
||||
# Redis configuration via labels
|
||||
- "traefik.http.middlewares.oidc.plugin.traefikoidc.redis.enabled=true"
|
||||
- "traefik.http.middlewares.oidc.plugin.traefikoidc.redis.address=redis:6379"
|
||||
- "traefik.http.middlewares.oidc.plugin.traefikoidc.redis.password=yourpassword"
|
||||
- "traefik.http.middlewares.oidc.plugin.traefikoidc.redis.cacheMode=hybrid"
|
||||
```
|
||||
|
||||
#### Kubernetes with Redis
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-with-redis
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://accounts.google.com
|
||||
clientID: your-client-id
|
||||
clientSecret: your-client-secret
|
||||
sessionEncryptionKey: your-encryption-key
|
||||
callbackURL: /oauth2/callback
|
||||
redis:
|
||||
enabled: true
|
||||
address: "redis-service.redis-namespace:6379"
|
||||
password: "urn:k8s:secret:redis-secret:password"
|
||||
db: 0
|
||||
keyPrefix: "traefikoidc"
|
||||
cacheMode: "hybrid"
|
||||
```
|
||||
|
||||
### Advanced Redis Configuration
|
||||
|
||||
See [Redis Cache Documentation](docs/REDIS_CACHE.md) for:
|
||||
- Detailed architecture overview
|
||||
- High availability setup with Redis Sentinel
|
||||
- Redis Cluster configuration
|
||||
- Performance tuning guidelines
|
||||
- Monitoring and observability
|
||||
- Troubleshooting guide
|
||||
- Migration from memory-only cache
|
||||
|
||||
## Dynamic Client Registration (RFC 7591)
|
||||
|
||||
The middleware supports **OIDC Dynamic Client Registration** (RFC 7591), allowing automatic client registration with OIDC providers without manual pre-registration. This is useful for:
|
||||
|
||||
- **Multi-tenant deployments**: Automatically register clients per tenant
|
||||
- **Development environments**: Quick setup without manual OAuth app creation
|
||||
- **Self-service integrations**: Allow applications to self-register
|
||||
|
||||
### How It Works
|
||||
|
||||
1. When enabled, the middleware discovers the `registration_endpoint` from the provider's `.well-known/openid-configuration`
|
||||
2. If no `clientID` is configured, it automatically registers a new client with the provider
|
||||
3. The registered `client_id` and `client_secret` are cached and optionally persisted to a file
|
||||
4. Subsequent requests use the registered credentials
|
||||
|
||||
### Configuration
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-dynamic-registration
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://your-oidc-provider.com
|
||||
# clientID and clientSecret are NOT required when using DCR
|
||||
sessionEncryptionKey: your-secure-encryption-key-min-32-chars
|
||||
callbackURL: /oauth2/callback
|
||||
|
||||
dynamicClientRegistration:
|
||||
enabled: true
|
||||
|
||||
# Optional: Initial access token for protected registration endpoints
|
||||
initialAccessToken: "your-initial-access-token"
|
||||
|
||||
# Optional: Override the registration endpoint (auto-discovered by default)
|
||||
registrationEndpoint: "https://your-provider.com/register"
|
||||
|
||||
# Optional: Persist credentials to file for reuse across restarts
|
||||
persistCredentials: true
|
||||
credentialsFile: "/tmp/oidc-client-credentials.json"
|
||||
|
||||
# Client metadata for registration
|
||||
clientMetadata:
|
||||
redirect_uris:
|
||||
- "https://your-app.com/oauth2/callback"
|
||||
client_name: "My Application"
|
||||
application_type: "web"
|
||||
grant_types:
|
||||
- "authorization_code"
|
||||
- "refresh_token"
|
||||
response_types:
|
||||
- "code"
|
||||
token_endpoint_auth_method: "client_secret_basic"
|
||||
contacts:
|
||||
- "admin@your-app.com"
|
||||
```
|
||||
|
||||
### DCR Configuration Parameters
|
||||
|
||||
| Parameter | Description | Required | Default |
|
||||
|-----------|-------------|----------|---------|
|
||||
| `enabled` | Enable dynamic client registration | Yes | `false` |
|
||||
| `initialAccessToken` | Bearer token for protected registration endpoints | No | - |
|
||||
| `registrationEndpoint` | Override auto-discovered registration endpoint | No | From discovery |
|
||||
| `persistCredentials` | Save registered credentials to file | No | `false` |
|
||||
| `credentialsFile` | Path to store/load credentials | No | `/tmp/oidc-client-credentials.json` |
|
||||
| `clientMetadata.redirect_uris` | **REQUIRED** - Redirect URIs for OAuth flow | Yes | - |
|
||||
| `clientMetadata.client_name` | Human-readable client name | No | - |
|
||||
| `clientMetadata.application_type` | `web` or `native` | No | `web` |
|
||||
| `clientMetadata.grant_types` | OAuth grant types | No | `["authorization_code", "refresh_token"]` |
|
||||
| `clientMetadata.response_types` | OAuth response types | No | `["code"]` |
|
||||
| `clientMetadata.token_endpoint_auth_method` | Authentication method | No | `client_secret_basic` |
|
||||
| `clientMetadata.contacts` | Contact email addresses | No | - |
|
||||
| `clientMetadata.logo_uri` | URL to client logo | No | - |
|
||||
| `clientMetadata.client_uri` | URL to client homepage | No | - |
|
||||
| `clientMetadata.policy_uri` | URL to privacy policy | No | - |
|
||||
| `clientMetadata.tos_uri` | URL to terms of service | No | - |
|
||||
| `clientMetadata.scope` | Space-separated scopes | No | - |
|
||||
|
||||
### Provider Support
|
||||
|
||||
DCR support varies by provider:
|
||||
|
||||
| Provider | DCR Support | Notes |
|
||||
|----------|-------------|-------|
|
||||
| Keycloak | ✅ Full | Enable in realm settings |
|
||||
| Auth0 | ✅ Full | Requires Management API token |
|
||||
| Okta | ✅ Full | Enable Dynamic Client Registration |
|
||||
| Azure AD | ⚠️ Limited | App Registration API instead |
|
||||
| Google | ❌ No | Manual registration required |
|
||||
| AWS Cognito | ❌ No | Manual registration required |
|
||||
|
||||
### Security Considerations
|
||||
|
||||
1. **HTTPS Required**: Registration endpoints must use HTTPS (except localhost for development)
|
||||
2. **Initial Access Token**: Recommended for production to prevent unauthorized registrations
|
||||
3. **Credential Persistence**: If enabled, ensure the credentials file has appropriate permissions (0600)
|
||||
4. **Secret Expiration**: Monitor `client_secret_expires_at` and handle rotation if needed
|
||||
|
||||
### Example: Keycloak with DCR
|
||||
|
||||
```yaml
|
||||
dynamicClientRegistration:
|
||||
enabled: true
|
||||
clientMetadata:
|
||||
redirect_uris:
|
||||
- "https://myapp.example.com/oauth2/callback"
|
||||
client_name: "My App - Production"
|
||||
application_type: "web"
|
||||
grant_types:
|
||||
- "authorization_code"
|
||||
- "refresh_token"
|
||||
```
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### Basic Configuration
|
||||
@@ -568,6 +1021,87 @@ spec:
|
||||
|
||||
**Important**: The `cookieDomain` parameter is crucial when running behind a reverse proxy or when your application serves multiple subdomains. Without it, cookies may be created with inconsistent domains, leading to authentication issues like "CSRF token missing in session" errors.
|
||||
|
||||
### With Multiple Middleware Instances (Session Isolation)
|
||||
|
||||
When running multiple middleware instances with different authorization requirements (e.g., one for general users and one for admins), you must use different `cookiePrefix` values to prevent session sharing between instances:
|
||||
|
||||
```yaml
|
||||
# Middleware for general user authentication
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-userauth
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://auth.example.com
|
||||
clientID: your-client-id
|
||||
clientSecret: your-client-secret
|
||||
sessionEncryptionKey: user-key-at-least-32-bytes-long
|
||||
callbackURL: /oauth2/callback
|
||||
cookiePrefix: "_oidc_userauth_" # Unique prefix for this instance
|
||||
---
|
||||
# Middleware for admin authentication with stricter requirements
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-adminauth
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://auth.example.com
|
||||
clientID: your-client-id
|
||||
clientSecret: your-client-secret
|
||||
sessionEncryptionKey: admin-key-at-least-32-bytes-long # Different encryption key
|
||||
callbackURL: /oauth2/admin/callback # Different callback URL
|
||||
cookiePrefix: "_oidc_adminauth_" # Different prefix for isolation
|
||||
allowedUsers: # Restricted to specific admin users
|
||||
- admin@example.com
|
||||
- superadmin@example.com
|
||||
```
|
||||
|
||||
**Security Note**: When running multiple instances, ensure you use:
|
||||
1. **Different `cookiePrefix`** values to prevent cookie name collisions
|
||||
2. **Different `sessionEncryptionKey`** values for complete session isolation
|
||||
3. **Different `callbackURL`** paths to avoid routing conflicts
|
||||
|
||||
This configuration prevents authorization bypass issues where a user authenticated via the general middleware could access admin-protected routes. See [issue #87](https://github.com/lukaszraczylo/traefikoidc/issues/87) for more details.
|
||||
|
||||
### With Extended Session Duration
|
||||
|
||||
For applications that users access infrequently (weekly or monthly), you can extend the session duration beyond the default 24 hours to reduce authentication friction:
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-long-session
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://auth.example.com
|
||||
clientID: your-client-id
|
||||
clientSecret: your-client-secret
|
||||
sessionEncryptionKey: your-key-at-least-32-bytes-long
|
||||
callbackURL: /oauth2/callback
|
||||
sessionMaxAge: 604800 # 7 days (in seconds)
|
||||
# Other common values:
|
||||
# 259200 - 3 days
|
||||
# 604800 - 7 days
|
||||
# 1209600 - 14 days
|
||||
# 2592000 - 30 days
|
||||
```
|
||||
|
||||
**Security Note**: Longer session durations improve user experience but increase security risk. Consider your application's security requirements:
|
||||
- **High-security apps**: Use shorter sessions (3600 = 1 hour)
|
||||
- **Standard apps**: Default 24 hours balances security and UX
|
||||
- **Low-frequency access apps**: Extend to 7-30 days for better UX
|
||||
|
||||
See [issue #91](https://github.com/lukaszraczylo/traefikoidc/issues/91) for more details.
|
||||
|
||||
### With Custom Logging and Rate Limiting
|
||||
|
||||
```yaml
|
||||
@@ -723,6 +1257,45 @@ spec:
|
||||
- "AppRoleName" # Application role names
|
||||
```
|
||||
|
||||
### Azure AD Configuration (Users Without Email)
|
||||
|
||||
For Azure AD users without email addresses (service accounts, organizational accounts without mail attributes):
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-azure-no-email
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://login.microsoftonline.com/your-tenant-id/v2.0
|
||||
clientID: your-azure-ad-client-id
|
||||
clientSecret: your-azure-ad-client-secret
|
||||
sessionEncryptionKey: your-secure-encryption-key-min-32-chars
|
||||
callbackURL: /oauth2/callback
|
||||
logoutURL: /oauth2/logout
|
||||
|
||||
# Use 'sub' instead of 'email' for user identification
|
||||
userIdentifierClaim: sub # Can also use: "oid", "upn", "preferred_username"
|
||||
|
||||
overrideScopes: true # Optional: Don't request email scope if not needed
|
||||
scopes:
|
||||
- openid
|
||||
- profile
|
||||
- groups
|
||||
|
||||
# When using non-email identifiers, allowedUsers matches against the claim value
|
||||
allowedUsers:
|
||||
- "abc12345-6789-0abc-def0-123456789abc" # Azure AD user object ID
|
||||
- "def67890-1234-5678-90ab-cdef12345678"
|
||||
|
||||
# NOTE: allowedUserDomains is ignored when userIdentifierClaim is not "email"
|
||||
```
|
||||
|
||||
> **Note**: When `userIdentifierClaim` is set to a non-email claim (like `sub`, `oid`, or `upn`), the `allowedUserDomains` configuration is ignored since domain-based validation only applies to email addresses. Use `allowedUsers` with the actual claim values instead.
|
||||
|
||||
### Auth0 Configuration
|
||||
|
||||
```yaml
|
||||
@@ -740,14 +1313,26 @@ spec:
|
||||
sessionEncryptionKey: your-secure-encryption-key-min-32-chars
|
||||
callbackURL: /oauth2/callback
|
||||
logoutURL: /oauth2/logout
|
||||
|
||||
# Audience configuration for custom APIs
|
||||
audience: "https://my-api.example.com" # Your API identifier from Auth0
|
||||
strictAudienceValidation: true # Enforce proper audience validation
|
||||
|
||||
scopes:
|
||||
- read:custom_data # Custom scopes as needed
|
||||
|
||||
# Custom claim names for Auth0 namespaced claims
|
||||
roleClaimName: "https://your-app.com/roles" # Auth0 requires namespaced custom claims
|
||||
groupClaimName: "https://your-app.com/groups" # Must match claims added in Auth0 Actions
|
||||
|
||||
allowedRolesAndGroups:
|
||||
- "https://your-app.com/roles:admin" # Namespaced claims from Actions
|
||||
- admin # Will match "admin" in https://your-app.com/roles claim
|
||||
- editor
|
||||
postLogoutRedirectURI: /logged-out-page # Must be in Auth0 Allowed Logout URLs
|
||||
```
|
||||
|
||||
**Note**: For detailed Auth0 audience configuration including opaque tokens and all security scenarios, see [AUTH0_AUDIENCE_GUIDE.md](docs/AUTH0_AUDIENCE_GUIDE.md).
|
||||
|
||||
### Okta Configuration
|
||||
|
||||
```yaml
|
||||
@@ -797,8 +1382,12 @@ spec:
|
||||
- admin
|
||||
- editor
|
||||
# Ensure Keycloak client mappers add necessary claims to ID Token
|
||||
# For internal Keycloak deployments with private IPs (e.g., Docker network):
|
||||
# allowPrivateIPAddresses: true
|
||||
```
|
||||
|
||||
> **Internal Network Deployment**: If your Keycloak runs on an internal network with private IP addresses (e.g., `192.168.x.x`, `10.x.x.x`, `172.16-31.x.x`) and you don't have DNS resolution available, set `allowPrivateIPAddresses: true` to allow the plugin to connect to your Keycloak instance. See [Issue #97](https://github.com/lukaszraczylo/traefikoidc/issues/97) for details.
|
||||
|
||||
### AWS Cognito Configuration
|
||||
|
||||
```yaml
|
||||
@@ -920,7 +1509,7 @@ services:
|
||||
image: traefik:v3.2.1
|
||||
command:
|
||||
- "--experimental.plugins.traefikoidc.modulename=github.com/lukaszraczylo/traefikoidc"
|
||||
- "--experimental.plugins.traefikoidc.version=v0.2.1"
|
||||
- "--experimental.plugins.traefikoidc.version=v0.7.10"
|
||||
volumes:
|
||||
- /var/run/docker.sock:/var/run/docker.sock
|
||||
- ./traefik-config/traefik.yml:/etc/traefik/traefik.yml
|
||||
@@ -1027,58 +1616,6 @@ http:
|
||||
{{range $i, $e := .Claims.roles}}{{if $i}},{{end}}{{$e}}{{end}}
|
||||
```
|
||||
|
||||
## Advanced Configuration
|
||||
|
||||
### Session Management
|
||||
|
||||
The middleware uses encrypted cookies to manage user sessions. The `sessionEncryptionKey` must be at least 32 bytes long and should be kept secret.
|
||||
|
||||
### PKCE Support
|
||||
|
||||
The middleware supports PKCE (Proof Key for Code Exchange), which is an extension to the authorization code flow to prevent authorization code interception attacks. When enabled via the `enablePKCE` option, the middleware will generate a code verifier for each authentication request and derive a code challenge from it. The code verifier is stored in the user's session and sent during the token exchange process.
|
||||
|
||||
PKCE is recommended when:
|
||||
- Your OIDC provider supports it (most modern providers do)
|
||||
- You need an additional layer of security for the authorization code flow
|
||||
- You're concerned about potential authorization code interception attacks
|
||||
|
||||
Note that not all OIDC providers support PKCE, so check your provider's documentation before enabling this feature.
|
||||
|
||||
### Session Duration and Token Refresh
|
||||
|
||||
This middleware aims to provide long-lived user sessions, typically up to 24 hours, by utilizing OIDC refresh tokens.
|
||||
|
||||
**How it works:**
|
||||
- When a user authenticates, the middleware requests an access token and, if available, a refresh token from the OIDC provider.
|
||||
- The access token usually has a short lifespan (e.g., 1 hour).
|
||||
- Before the access token expires (controlled by `refreshGracePeriodSeconds`), the middleware uses the refresh token to obtain a new access token from the provider without requiring the user to log in again.
|
||||
- This process repeats, allowing the session to remain valid for as long as the refresh token is valid (often 24 hours or more, depending on the provider).
|
||||
|
||||
**Provider-Specific Considerations (e.g., Google):**
|
||||
- Some providers, like Google, issue short-lived access tokens (e.g., 1 hour) and require specific configurations for long-term sessions.
|
||||
- To enable session extension beyond the initial token expiry with Google and similar providers, the middleware automatically includes the `offline_access` scope in the authentication request. This scope is necessary to obtain a refresh token.
|
||||
- For Google specifically, the middleware also adds the `prompt=consent` parameter to the initial authorization request. This ensures Google issues a refresh token, which is crucial for extending the session.
|
||||
- If a refresh attempt fails (e.g., the refresh token is revoked or expired), the user will be required to re-authenticate. The middleware includes enhanced error handling and logging for these scenarios.
|
||||
- Ensure your OIDC provider is configured to issue refresh tokens and allows their use for extending sessions. Check your provider's documentation for details on refresh token validity periods.
|
||||
|
||||
### Google OAuth Compatibility Fix
|
||||
|
||||
The middleware includes a specific fix for Google's OAuth implementation, which differs from the standard OIDC specification in how it handles refresh tokens:
|
||||
|
||||
- **Issue**: Google does not support the standard `offline_access` scope for requesting refresh tokens and instead requires special parameters.
|
||||
|
||||
- **Automatic Solution**: The middleware detects Google as the provider based on the issuer URL and:
|
||||
- Uses `access_type=offline` query parameter instead of the `offline_access` scope
|
||||
- Adds `prompt=consent` to ensure refresh tokens are consistently issued
|
||||
- Properly handles token refresh with Google's implementation
|
||||
|
||||
You do not need any special configuration to use Google OAuth - just set `providerURL` to `https://accounts.google.com` and the middleware will automatically apply the proper parameters.
|
||||
|
||||
For detailed information on the Google OAuth fix, see the [dedicated documentation](docs/google-oauth-fix.md).
|
||||
|
||||
### Token Caching and Blacklisting
|
||||
|
||||
The middleware automatically caches validated tokens to improve performance and maintains a blacklist of revoked tokens.
|
||||
### Templated Headers
|
||||
|
||||
The middleware supports setting custom HTTP headers with values templated from OIDC claims and tokens. This allows you to pass authentication information to downstream services in a flexible, customized format.
|
||||
@@ -1151,12 +1688,39 @@ headers:
|
||||
|
||||
When a user is authenticated, the middleware sets the following headers for downstream services:
|
||||
|
||||
- `X-Forwarded-User`: The user's email address
|
||||
- `X-Forwarded-User`: The user's email address (always set)
|
||||
- `X-User-Groups`: Comma-separated list of user groups (if available)
|
||||
- `X-User-Roles`: Comma-separated list of user roles (if available)
|
||||
- `X-Auth-Request-Redirect`: The original request URI
|
||||
- `X-Auth-Request-User`: The user's email address
|
||||
- `X-Auth-Request-Token`: The user's access token
|
||||
- `X-Auth-Request-Token`: The user's ID token (can be large)
|
||||
|
||||
#### Minimal Headers Mode
|
||||
|
||||
If your downstream services return **"431 Request Header Fields Too Large"** errors, you can enable minimal headers mode to reduce header overhead:
|
||||
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
my-auth:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
minimalHeaders: true
|
||||
# ... other config
|
||||
```
|
||||
|
||||
When `minimalHeaders: true` is set:
|
||||
- **Only forwards**: `X-Forwarded-User`
|
||||
- **Skips**: `X-Auth-Request-Token` (the full ID token - often the largest header), `X-Auth-Request-User`, `X-Auth-Request-Redirect`
|
||||
- **Still forwards**: `X-User-Groups` and `X-User-Roles` (if configured)
|
||||
- **Still processes**: Custom templated headers
|
||||
|
||||
This is particularly useful when:
|
||||
- Your ID tokens are large (many claims, long group lists)
|
||||
- Downstream services have limited header buffer sizes (default 8KB in many servers)
|
||||
- You don't need the full token forwarded to backend services
|
||||
|
||||
See [GitHub Issue #64](https://github.com/lukaszraczylo/traefikoidc/issues/64) for details.
|
||||
|
||||
### Security Headers
|
||||
|
||||
@@ -1280,32 +1844,6 @@ GitLab supports OIDC for both GitLab.com and self-hosted instances.
|
||||
* **Scopes**: Use `user:email`, `read:user` for basic profile access
|
||||
* **Detection**: Auto-detected from `github.com` in issuer URL
|
||||
|
||||
### Azure AD (Microsoft Entra ID)
|
||||
|
||||
Azure AD generally works well with standard OIDC configurations.
|
||||
|
||||
* **ID Token Claims**: Azure AD typically includes standard claims like `email`, `name`, `preferred_username`, and `oid` (Object ID) in the ID Token by default when `openid profile email` scopes are requested.
|
||||
* **Group Claims**: To include group claims in the ID Token, you need to configure this in the Azure AD application registration:
|
||||
* Go to your App Registration -> Token configuration -> Add groups claim.
|
||||
* You can choose which types of groups (Security groups, Directory roles, All groups) to include.
|
||||
* Be aware of the "overage" issue: If a user is a member of too many groups, Azure AD will send a link to fetch groups instead of embedding them. This plugin currently expects group claims to be directly in the ID token. For users with many groups, consider alternative role/permission management strategies.
|
||||
* The claim name for groups is typically `groups`.
|
||||
* **Optional Claims**: You can add other optional claims via the "Token configuration" section of your App Registration. Ensure these are configured for the ID token.
|
||||
* **Endpoints**: The `providerURL` should be `https://login.microsoftonline.com/{your-tenant-id}/v2.0`. The plugin will auto-discover the necessary endpoints.
|
||||
* **Optimization**: Ensure your application manifest in Azure AD is configured for the desired token version (v1.0 or v2.0). This plugin works with v2.0 endpoints.
|
||||
|
||||
### Google Workspace / Google Cloud Identity
|
||||
|
||||
Google's OIDC implementation is well-supported.
|
||||
|
||||
* **Optimal Configuration**: The plugin automatically handles Google-specific requirements, such as using `access_type=offline` and `prompt=consent` to ensure refresh tokens are issued for long-lived sessions. You do not need to add `offline_access` to scopes.
|
||||
* **ID Token Claims**: Google includes standard claims like `email`, `sub`, `name`, `given_name`, `family_name`, `picture` in the ID Token by default with `openid profile email` scopes.
|
||||
* **Hosted Domain (hd claim)**: If you are using Google Workspace and want to restrict access to users within your organization's domain, Google includes an `hd` (hosted domain) claim in the ID Token. You can use this with the `allowedUserDomains` setting or for custom header logic.
|
||||
* **Best Practices**:
|
||||
* Use the `providerURL`: `https://accounts.google.com`.
|
||||
* Ensure your OAuth consent screen in Google Cloud Console is configured correctly and published. For production, it should be "External" and in "Production" status. "Testing" status limits refresh token lifetime.
|
||||
* Refer to the [Google OAuth Compatibility Fix](#google-oauth-compatibility-fix) section for more details on how the plugin handles Google's specifics.
|
||||
|
||||
### Auth0
|
||||
|
||||
Auth0 is generally OIDC compliant and works well.
|
||||
@@ -1410,6 +1948,15 @@ logLevel: debug
|
||||
- No refresh tokens (re-authentication required on expiry)
|
||||
- Use only for GitHub API access, not user authentication
|
||||
|
||||
15. **Environment variable names containing "API" cause plugin failure** ([Issue #98](https://github.com/lukaszraczylo/traefikoidc/issues/98)):
|
||||
- When using environment variable syntax like `${OIDC_ENCRYPTION_SECRET_API}` in Traefik configuration, the plugin fails with "invalid handler type: \<nil\>" error
|
||||
- This is a **Traefik-side issue**, not a plugin bug. Traefik uses reserved environment variables starting with `TRAEFIK_API_*` for its internal API configuration, and the "API" substring in user-defined variable names may interfere with Traefik's environment variable processing
|
||||
- **Workaround**: Avoid using "API" as a substring in environment variable names. Use alternatives like:
|
||||
- `${OIDC_ENCRYPTION_SECRET_SVC}` instead of `${OIDC_ENCRYPTION_SECRET_API}`
|
||||
- `${OIDC_ENCRYPTION_SECRET_SERVICE}`
|
||||
- `${OIDC_ENCRYPTION_SECRET_BACKEND}`
|
||||
- Any name that doesn't contain the literal substring "API"
|
||||
|
||||
### Provider Warnings and Recommendations
|
||||
|
||||
The middleware includes built-in warnings for provider-specific limitations. Check your logs for important notices about:
|
||||
|
||||
@@ -1,308 +0,0 @@
|
||||
# Test Execution Guide
|
||||
|
||||
This guide explains how to run tests efficiently with the new test categorization and optimization system.
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Fast Development Testing (Default - Target: < 30 seconds)
|
||||
```bash
|
||||
# Run quick smoke tests only
|
||||
go test ./...
|
||||
|
||||
# Or explicitly run in short mode
|
||||
go test ./... -short
|
||||
```
|
||||
|
||||
### Extended Testing (Target: 2-5 minutes)
|
||||
```bash
|
||||
# Enable extended tests with more iterations and concurrency
|
||||
RUN_EXTENDED_TESTS=1 go test ./...
|
||||
|
||||
# Or use the flag equivalent (if using test runner that supports it)
|
||||
go test ./... -extended
|
||||
```
|
||||
|
||||
### Long-Running Performance Tests (Target: 5-15 minutes)
|
||||
```bash
|
||||
# Enable comprehensive performance and stress tests
|
||||
RUN_LONG_TESTS=1 go test ./...
|
||||
```
|
||||
|
||||
### Full Stress Testing (Target: 10-30 minutes)
|
||||
```bash
|
||||
# Enable all stress tests with maximum parameters
|
||||
RUN_STRESS_TESTS=1 go test ./...
|
||||
```
|
||||
|
||||
## Test Categories
|
||||
|
||||
### 1. Quick Tests (Default)
|
||||
- **Purpose**: Fast feedback during development
|
||||
- **Duration**: < 30 seconds total
|
||||
- **Features**:
|
||||
- Basic functionality verification
|
||||
- Limited iterations (1-3)
|
||||
- Small data sets
|
||||
- Minimal concurrency
|
||||
- Essential memory leak checks
|
||||
|
||||
**Configuration**:
|
||||
- Max Iterations: 3
|
||||
- Max Concurrency: 5
|
||||
- Memory Threshold: 2.0 MB
|
||||
- Cache Size: 50
|
||||
- Timeout: 10 seconds
|
||||
|
||||
### 2. Extended Tests
|
||||
- **Purpose**: Comprehensive testing before commits
|
||||
- **Duration**: 2-5 minutes
|
||||
- **Features**:
|
||||
- Increased test coverage
|
||||
- More iterations (5-10)
|
||||
- Medium concurrency tests
|
||||
- Enhanced memory leak detection
|
||||
|
||||
**Configuration**:
|
||||
- Max Iterations: 10
|
||||
- Max Concurrency: 20
|
||||
- Memory Threshold: 10.0 MB
|
||||
- Cache Size: 200
|
||||
- Timeout: 30 seconds
|
||||
|
||||
### 3. Long Tests
|
||||
- **Purpose**: Performance validation and stress testing
|
||||
- **Duration**: 5-15 minutes
|
||||
- **Features**:
|
||||
- High iteration counts (50-100)
|
||||
- High concurrency scenarios
|
||||
- Large data sets
|
||||
- Comprehensive memory testing
|
||||
|
||||
**Configuration**:
|
||||
- Max Iterations: 100
|
||||
- Max Concurrency: 50
|
||||
- Memory Threshold: 50.0 MB
|
||||
- Cache Size: 1000
|
||||
- Timeout: 60 seconds
|
||||
|
||||
### 4. Stress Tests
|
||||
- **Purpose**: Maximum load testing and edge case validation
|
||||
- **Duration**: 10-30 minutes
|
||||
- **Features**:
|
||||
- Extreme iteration counts (100-500)
|
||||
- Maximum concurrency (100+)
|
||||
- Large memory allocations
|
||||
- Edge case combinations
|
||||
|
||||
**Configuration**:
|
||||
- Max Iterations: 500
|
||||
- Max Concurrency: 100
|
||||
- Memory Threshold: 100.0 MB
|
||||
- Cache Size: 2000
|
||||
- Timeout: 120 seconds
|
||||
|
||||
## Environment Variables
|
||||
|
||||
### Test Execution Control
|
||||
```bash
|
||||
# Enable specific test types
|
||||
export RUN_EXTENDED_TESTS=1 # Enable extended tests
|
||||
export RUN_LONG_TESTS=1 # Enable long-running tests
|
||||
export RUN_STRESS_TESTS=1 # Enable stress tests
|
||||
|
||||
# Disable specific features
|
||||
export DISABLE_LEAK_DETECTION=1 # Skip memory leak detection
|
||||
```
|
||||
|
||||
### Parameter Customization
|
||||
```bash
|
||||
# Customize concurrency limits
|
||||
export TEST_MAX_CONCURRENCY=10 # Override max concurrent operations
|
||||
|
||||
# Customize iteration limits
|
||||
export TEST_MAX_ITERATIONS=50 # Override max test iterations
|
||||
|
||||
# Customize memory thresholds
|
||||
export TEST_MEMORY_THRESHOLD_MB=25.5 # Override memory growth limit (in MB)
|
||||
```
|
||||
|
||||
## Test-Specific Behavior
|
||||
|
||||
### Memory Leak Tests
|
||||
- **Quick Mode**: 1-3 iterations, small data sets, strict memory limits
|
||||
- **Extended Mode**: 5-10 iterations, medium data sets, relaxed limits
|
||||
- **Long Mode**: 50-100 iterations, large data sets, performance focus
|
||||
- **Stress Mode**: 100-500 iterations, maximum data sets, stress focus
|
||||
|
||||
### Concurrency Tests
|
||||
- **Quick Mode**: 2-5 concurrent operations, basic race detection
|
||||
- **Extended Mode**: 10-20 concurrent operations, moderate stress
|
||||
- **Long Mode**: 20-50 concurrent operations, high contention
|
||||
- **Stress Mode**: 50-100+ concurrent operations, maximum stress
|
||||
|
||||
### Cache Tests
|
||||
- **Quick Mode**: Small caches (50 items), basic operations
|
||||
- **Extended Mode**: Medium caches (200 items), varied operations
|
||||
- **Long Mode**: Large caches (1000 items), performance testing
|
||||
- **Stress Mode**: Very large caches (2000+ items), stress testing
|
||||
|
||||
## Integration with CI/CD
|
||||
|
||||
### GitHub Actions Example
|
||||
```yaml
|
||||
# Quick tests for every push/PR
|
||||
- name: Quick Tests
|
||||
run: go test ./... -short
|
||||
|
||||
# Extended tests for main branch
|
||||
- name: Extended Tests
|
||||
if: github.ref == 'refs/heads/main'
|
||||
run: RUN_EXTENDED_TESTS=1 go test ./...
|
||||
|
||||
# Nightly comprehensive testing
|
||||
- name: Nightly Stress Tests
|
||||
if: github.event_name == 'schedule'
|
||||
run: RUN_STRESS_TESTS=1 go test ./...
|
||||
```
|
||||
|
||||
### Local Development Workflow
|
||||
```bash
|
||||
# During active development
|
||||
go test ./... -short
|
||||
|
||||
# Before committing
|
||||
RUN_EXTENDED_TESTS=1 go test ./...
|
||||
|
||||
# Before major releases
|
||||
RUN_LONG_TESTS=1 go test ./...
|
||||
|
||||
# Performance validation
|
||||
RUN_STRESS_TESTS=1 go test ./...
|
||||
```
|
||||
|
||||
## Performance Optimization Features
|
||||
|
||||
### Dynamic Test Scaling
|
||||
The test system automatically adjusts parameters based on:
|
||||
- Test mode (quick/extended/long/stress)
|
||||
- Available resources
|
||||
- Environment variables
|
||||
- Previous test performance
|
||||
|
||||
### Memory Management
|
||||
- **Garbage Collection**: Forced GC between test iterations
|
||||
- **Memory Monitoring**: Real-time memory growth tracking
|
||||
- **Leak Detection**: Goroutine and memory leak prevention
|
||||
- **Resource Cleanup**: Automatic cleanup of test resources
|
||||
|
||||
### Timeout Management
|
||||
- **Adaptive Timeouts**: Timeouts scale with test complexity
|
||||
- **Graceful Degradation**: Tests adapt to slower environments
|
||||
- **Early Termination**: Failed tests terminate quickly
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Tests Taking Too Long
|
||||
```bash
|
||||
# Check if running in extended mode accidentally
|
||||
echo $RUN_EXTENDED_TESTS $RUN_LONG_TESTS
|
||||
|
||||
# Force quick mode
|
||||
unset RUN_EXTENDED_TESTS RUN_LONG_TESTS RUN_STRESS_TESTS
|
||||
go test ./... -short
|
||||
```
|
||||
|
||||
### Memory Issues
|
||||
```bash
|
||||
# Reduce memory limits for constrained environments
|
||||
export TEST_MEMORY_THRESHOLD_MB=5.0
|
||||
export TEST_MAX_CONCURRENCY=2
|
||||
go test ./...
|
||||
```
|
||||
|
||||
### Concurrency Issues
|
||||
```bash
|
||||
# Reduce concurrency for slower systems
|
||||
export TEST_MAX_CONCURRENCY=5
|
||||
export TEST_MAX_ITERATIONS=10
|
||||
go test ./...
|
||||
```
|
||||
|
||||
### Skip Specific Test Types
|
||||
```bash
|
||||
# Skip memory leak detection if problematic
|
||||
export DISABLE_LEAK_DETECTION=1
|
||||
go test ./...
|
||||
```
|
||||
|
||||
## Benchmarking
|
||||
|
||||
### Running Benchmarks
|
||||
```bash
|
||||
# Quick benchmarks
|
||||
go test -bench=. -short
|
||||
|
||||
# Extended benchmarks
|
||||
RUN_EXTENDED_TESTS=1 go test -bench=.
|
||||
|
||||
# Memory profiling
|
||||
go test -bench=. -memprofile=mem.prof
|
||||
go tool pprof mem.prof
|
||||
```
|
||||
|
||||
### Benchmark Categories
|
||||
- **Basic Operations**: Set/Get performance
|
||||
- **Concurrency**: Multi-threaded performance
|
||||
- **Memory**: Allocation and cleanup performance
|
||||
- **Cache**: Eviction and cleanup performance
|
||||
|
||||
## Best Practices
|
||||
|
||||
### For Developers
|
||||
1. Always run quick tests during development (`go test ./... -short`)
|
||||
2. Run extended tests before committing (`RUN_EXTENDED_TESTS=1 go test ./...`)
|
||||
3. Use appropriate test categories for your use case
|
||||
4. Monitor test execution time and adjust if needed
|
||||
|
||||
### For CI/CD
|
||||
1. Use quick tests for fast feedback on PRs
|
||||
2. Use extended tests for main branch validation
|
||||
3. Use long tests for release validation
|
||||
4. Use stress tests for nightly/weekly validation
|
||||
|
||||
### For Performance Testing
|
||||
1. Use consistent environment variables
|
||||
2. Run tests multiple times for statistical significance
|
||||
3. Monitor both execution time and resource usage
|
||||
4. Use profiling tools for detailed analysis
|
||||
|
||||
## Examples
|
||||
|
||||
### Daily Development
|
||||
```bash
|
||||
# Fast tests while coding
|
||||
go test ./... -short
|
||||
|
||||
# Before git commit
|
||||
RUN_EXTENDED_TESTS=1 go test ./...
|
||||
```
|
||||
|
||||
### Release Testing
|
||||
```bash
|
||||
# Comprehensive validation
|
||||
RUN_LONG_TESTS=1 go test ./...
|
||||
|
||||
# Stress testing
|
||||
RUN_STRESS_TESTS=1 go test ./...
|
||||
```
|
||||
|
||||
### Custom Configuration
|
||||
```bash
|
||||
# Custom limits for specific environment
|
||||
export TEST_MAX_CONCURRENCY=8
|
||||
export TEST_MAX_ITERATIONS=25
|
||||
export TEST_MEMORY_THRESHOLD_MB=15.0
|
||||
RUN_EXTENDED_TESTS=1 go test ./...
|
||||
```
|
||||
|
||||
This test system provides flexible, scalable test execution that adapts to your development workflow and infrastructure constraints while maintaining comprehensive test coverage.
|
||||
+1518
File diff suppressed because it is too large
Load Diff
@@ -1,360 +0,0 @@
|
||||
// Package auth provides authentication-related functionality for the OIDC middleware.
|
||||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// AuthHandler provides core authentication functionality for OIDC flows
|
||||
type AuthHandler struct {
|
||||
logger Logger
|
||||
enablePKCE bool
|
||||
isGoogleProv func() bool
|
||||
isAzureProv func() bool
|
||||
clientID string
|
||||
authURL string
|
||||
issuerURL string
|
||||
scopes []string
|
||||
overrideScopes bool
|
||||
}
|
||||
|
||||
// Logger interface for dependency injection
|
||||
type Logger interface {
|
||||
Debugf(format string, args ...interface{})
|
||||
Errorf(format string, args ...interface{})
|
||||
}
|
||||
|
||||
// NewAuthHandler creates a new AuthHandler instance
|
||||
func NewAuthHandler(logger Logger, enablePKCE bool, isGoogleProv, isAzureProv func() bool,
|
||||
clientID, authURL, issuerURL string, scopes []string, overrideScopes bool) *AuthHandler {
|
||||
return &AuthHandler{
|
||||
logger: logger,
|
||||
enablePKCE: enablePKCE,
|
||||
isGoogleProv: isGoogleProv,
|
||||
isAzureProv: isAzureProv,
|
||||
clientID: clientID,
|
||||
authURL: authURL,
|
||||
issuerURL: issuerURL,
|
||||
scopes: scopes,
|
||||
overrideScopes: overrideScopes,
|
||||
}
|
||||
}
|
||||
|
||||
// InitiateAuthentication initiates the OIDC authentication flow.
|
||||
// It generates CSRF tokens, nonce, PKCE parameters (if enabled), clears the session,
|
||||
// stores authentication state, and redirects the user to the OIDC provider.
|
||||
func (h *AuthHandler) InitiateAuthentication(rw http.ResponseWriter, req *http.Request,
|
||||
session SessionData, redirectURL string,
|
||||
generateNonce, generateCodeVerifier, deriveCodeChallenge func() (string, error)) {
|
||||
|
||||
h.logger.Debugf("Initiating new OIDC authentication flow for request: %s", req.URL.RequestURI())
|
||||
|
||||
const maxRedirects = 5
|
||||
redirectCount := session.GetRedirectCount()
|
||||
if redirectCount >= maxRedirects {
|
||||
h.logger.Errorf("Maximum redirect limit (%d) exceeded, possible redirect loop detected", maxRedirects)
|
||||
session.ResetRedirectCount()
|
||||
http.Error(rw, "Authentication failed: Too many redirects", http.StatusLoopDetected)
|
||||
return
|
||||
}
|
||||
|
||||
session.IncrementRedirectCount()
|
||||
|
||||
csrfToken := uuid.NewString()
|
||||
nonce, err := generateNonce()
|
||||
if err != nil {
|
||||
h.logger.Errorf("Failed to generate nonce: %v", err)
|
||||
http.Error(rw, "Failed to generate nonce", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Generate PKCE code verifier and challenge if PKCE is enabled
|
||||
var codeVerifier, codeChallenge string
|
||||
if h.enablePKCE {
|
||||
codeVerifier, err = generateCodeVerifier()
|
||||
if err != nil {
|
||||
h.logger.Errorf("Failed to generate code verifier: %v", err)
|
||||
http.Error(rw, "Failed to generate code verifier", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
codeChallenge, err = deriveCodeChallenge()
|
||||
if err != nil {
|
||||
h.logger.Errorf("Failed to generate code challenge: %v", err)
|
||||
http.Error(rw, "Failed to generate code challenge", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
h.logger.Debugf("PKCE enabled, generated code challenge")
|
||||
}
|
||||
|
||||
session.SetAuthenticated(false)
|
||||
session.SetEmail("")
|
||||
session.SetAccessToken("")
|
||||
session.SetRefreshToken("")
|
||||
session.SetIDToken("")
|
||||
session.SetNonce("")
|
||||
session.SetCodeVerifier("")
|
||||
|
||||
session.SetCSRF(csrfToken)
|
||||
session.SetNonce(nonce)
|
||||
if h.enablePKCE {
|
||||
session.SetCodeVerifier(codeVerifier)
|
||||
}
|
||||
session.SetIncomingPath(req.URL.RequestURI())
|
||||
h.logger.Debugf("Storing incoming path: %s", req.URL.RequestURI())
|
||||
|
||||
session.MarkDirty()
|
||||
|
||||
if err := session.Save(req, rw); err != nil {
|
||||
h.logger.Errorf("Failed to save session before redirecting to provider: %v", err)
|
||||
http.Error(rw, "Failed to save session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Debugf("Session saved before redirect. CSRF: %s, Nonce: %s",
|
||||
csrfToken, nonce)
|
||||
|
||||
authURL := h.BuildAuthURL(redirectURL, csrfToken, nonce, codeChallenge)
|
||||
h.logger.Debugf("Redirecting user to OIDC provider: %s", authURL)
|
||||
|
||||
http.Redirect(rw, req, authURL, http.StatusFound)
|
||||
}
|
||||
|
||||
// BuildAuthURL constructs the OIDC provider authorization URL.
|
||||
// It builds the URL with all necessary parameters including client_id, scopes,
|
||||
// PKCE parameters, and provider-specific parameters for Google and Azure.
|
||||
func (h *AuthHandler) BuildAuthURL(redirectURL, state, nonce, codeChallenge string) string {
|
||||
params := url.Values{}
|
||||
params.Set("client_id", h.clientID)
|
||||
params.Set("response_type", "code")
|
||||
params.Set("redirect_uri", redirectURL)
|
||||
params.Set("state", state)
|
||||
params.Set("nonce", nonce)
|
||||
|
||||
if h.enablePKCE && codeChallenge != "" {
|
||||
params.Set("code_challenge", codeChallenge)
|
||||
params.Set("code_challenge_method", "S256")
|
||||
}
|
||||
|
||||
scopes := make([]string, len(h.scopes))
|
||||
copy(scopes, h.scopes)
|
||||
|
||||
if h.isGoogleProv() {
|
||||
params.Set("access_type", "offline")
|
||||
h.logger.Debugf("Google OIDC provider detected, added access_type=offline for refresh tokens")
|
||||
|
||||
params.Set("prompt", "consent")
|
||||
h.logger.Debugf("Google OIDC provider detected, added prompt=consent to ensure refresh tokens")
|
||||
} else if h.isAzureProv() {
|
||||
params.Set("response_mode", "query")
|
||||
h.logger.Debugf("Azure AD provider detected, added response_mode=query")
|
||||
|
||||
hasOfflineAccess := false
|
||||
|
||||
for _, scope := range scopes {
|
||||
if scope == "offline_access" {
|
||||
hasOfflineAccess = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !h.overrideScopes || (h.overrideScopes && len(h.scopes) == 0) {
|
||||
if !hasOfflineAccess {
|
||||
scopes = append(scopes, "offline_access")
|
||||
h.logger.Debugf("Azure AD provider: Added offline_access scope (overrideScopes: %t, user scopes count: %d)", h.overrideScopes, len(h.scopes))
|
||||
}
|
||||
} else {
|
||||
h.logger.Debugf("Azure AD provider: User is overriding scopes (count: %d), offline_access not automatically added.", len(h.scopes))
|
||||
}
|
||||
} else {
|
||||
if !h.overrideScopes || (h.overrideScopes && len(h.scopes) == 0) {
|
||||
hasOfflineAccess := false
|
||||
for _, scope := range scopes {
|
||||
if scope == "offline_access" {
|
||||
hasOfflineAccess = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasOfflineAccess {
|
||||
scopes = append(scopes, "offline_access")
|
||||
h.logger.Debugf("Standard provider: Added offline_access scope (overrideScopes: %t, user scopes count: %d)", h.overrideScopes, len(h.scopes))
|
||||
}
|
||||
} else {
|
||||
h.logger.Debugf("Standard provider: User is overriding scopes (count: %d), offline_access not automatically added.", len(h.scopes))
|
||||
}
|
||||
}
|
||||
|
||||
if len(scopes) > 0 {
|
||||
finalScopeString := strings.Join(scopes, " ")
|
||||
params.Set("scope", finalScopeString)
|
||||
h.logger.Debugf("AuthHandler.BuildAuthURL: Final scope string being sent to OIDC provider: %s", finalScopeString)
|
||||
}
|
||||
|
||||
return h.buildURLWithParams(h.authURL, params)
|
||||
}
|
||||
|
||||
// buildURLWithParams constructs a URL by combining a base URL with query parameters.
|
||||
// It handles both relative and absolute URLs, validates URL security,
|
||||
// and properly encodes query parameters.
|
||||
func (h *AuthHandler) buildURLWithParams(baseURL string, params url.Values) string {
|
||||
if baseURL != "" {
|
||||
if strings.HasPrefix(baseURL, "http://") || strings.HasPrefix(baseURL, "https://") {
|
||||
if err := h.validateURL(baseURL); err != nil {
|
||||
h.logger.Errorf("URL validation failed for %s: %v", baseURL, err)
|
||||
return ""
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
|
||||
issuerURLParsed, err := url.Parse(h.issuerURL)
|
||||
if err != nil {
|
||||
h.logger.Errorf("Could not parse issuerURL: %s. Error: %v", h.issuerURL, err)
|
||||
return ""
|
||||
}
|
||||
|
||||
baseURLParsed, err := url.Parse(baseURL)
|
||||
if err != nil {
|
||||
h.logger.Errorf("Could not parse baseURL: %s. Error: %v", baseURL, err)
|
||||
return ""
|
||||
}
|
||||
|
||||
resolvedURL := issuerURLParsed.ResolveReference(baseURLParsed)
|
||||
|
||||
if err := h.validateURL(resolvedURL.String()); err != nil {
|
||||
h.logger.Errorf("Resolved URL validation failed for %s: %v", resolvedURL.String(), err)
|
||||
return ""
|
||||
}
|
||||
|
||||
resolvedURL.RawQuery = params.Encode()
|
||||
return resolvedURL.String()
|
||||
}
|
||||
|
||||
u, err := url.Parse(baseURL)
|
||||
if err != nil {
|
||||
h.logger.Errorf("Could not parse absolute baseURL: %s. Error: %v", baseURL, err)
|
||||
return ""
|
||||
}
|
||||
|
||||
if err := h.validateParsedURL(u); err != nil {
|
||||
h.logger.Errorf("Parsed URL validation failed for %s: %v", baseURL, err)
|
||||
return ""
|
||||
}
|
||||
|
||||
u.RawQuery = params.Encode()
|
||||
return u.String()
|
||||
}
|
||||
|
||||
// validateURL performs security validation on URLs to prevent SSRF attacks.
|
||||
// It checks for allowed schemes, validates hosts, and prevents access to private networks.
|
||||
func (h *AuthHandler) validateURL(urlStr string) error {
|
||||
if urlStr == "" {
|
||||
return fmt.Errorf("empty URL")
|
||||
}
|
||||
|
||||
u, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid URL format: %w", err)
|
||||
}
|
||||
|
||||
return h.validateParsedURL(u)
|
||||
}
|
||||
|
||||
// validateParsedURL validates a parsed URL structure for security.
|
||||
// It checks schemes, hosts, and paths to prevent malicious URLs.
|
||||
func (h *AuthHandler) validateParsedURL(u *url.URL) error {
|
||||
allowedSchemes := map[string]bool{
|
||||
"https": true,
|
||||
"http": true,
|
||||
}
|
||||
|
||||
if !allowedSchemes[u.Scheme] {
|
||||
return fmt.Errorf("disallowed URL scheme: %s", u.Scheme)
|
||||
}
|
||||
|
||||
if u.Scheme == "http" {
|
||||
h.logger.Debugf("Warning: Using HTTP scheme for URL: %s", u.String())
|
||||
}
|
||||
|
||||
if u.Host == "" {
|
||||
return fmt.Errorf("missing host in URL")
|
||||
}
|
||||
|
||||
if err := h.validateHost(u.Host); err != nil {
|
||||
return fmt.Errorf("invalid host: %w", err)
|
||||
}
|
||||
|
||||
if strings.Contains(u.Path, "..") {
|
||||
return fmt.Errorf("path traversal detected in URL path")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateHost validates a hostname for security and reachability.
|
||||
// It prevents access to private networks and localhost addresses.
|
||||
func (h *AuthHandler) validateHost(host string) error {
|
||||
if host == "" {
|
||||
return fmt.Errorf("empty host")
|
||||
}
|
||||
|
||||
// Strip port if present
|
||||
if strings.Contains(host, ":") {
|
||||
var err error
|
||||
host, _, err = net.SplitHostPort(host)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid host:port format: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Check for localhost variations
|
||||
localhostVariations := []string{
|
||||
"localhost", "127.0.0.1", "::1", "0.0.0.0",
|
||||
}
|
||||
for _, localhost := range localhostVariations {
|
||||
if strings.EqualFold(host, localhost) {
|
||||
return fmt.Errorf("localhost access not allowed: %s", host)
|
||||
}
|
||||
}
|
||||
|
||||
// Try to parse as IP address
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
if ip.IsLoopback() {
|
||||
return fmt.Errorf("loopback IP not allowed: %s", host)
|
||||
}
|
||||
if ip.IsPrivate() {
|
||||
return fmt.Errorf("private IP not allowed: %s", host)
|
||||
}
|
||||
if ip.IsLinkLocalUnicast() {
|
||||
return fmt.Errorf("link-local IP not allowed: %s", host)
|
||||
}
|
||||
if ip.IsMulticast() {
|
||||
return fmt.Errorf("multicast IP not allowed: %s", host)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SessionData interface for dependency injection
|
||||
type SessionData interface {
|
||||
GetRedirectCount() int
|
||||
ResetRedirectCount()
|
||||
IncrementRedirectCount()
|
||||
SetAuthenticated(bool)
|
||||
SetEmail(string)
|
||||
SetAccessToken(string)
|
||||
SetRefreshToken(string)
|
||||
SetIDToken(string)
|
||||
SetNonce(string)
|
||||
SetCodeVerifier(string)
|
||||
SetCSRF(string)
|
||||
SetIncomingPath(string)
|
||||
MarkDirty()
|
||||
Save(req *http.Request, rw http.ResponseWriter) error
|
||||
}
|
||||
@@ -1,599 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// Test mocks
|
||||
type mockLogger struct {
|
||||
debugMessages []string
|
||||
errorMessages []string
|
||||
}
|
||||
|
||||
func (l *mockLogger) Debugf(format string, args ...interface{}) {
|
||||
l.debugMessages = append(l.debugMessages, format)
|
||||
}
|
||||
|
||||
func (l *mockLogger) Errorf(format string, args ...interface{}) {
|
||||
l.errorMessages = append(l.errorMessages, format)
|
||||
}
|
||||
|
||||
type mockSessionData struct {
|
||||
authenticated bool
|
||||
email string
|
||||
accessToken string
|
||||
refreshToken string
|
||||
idToken string
|
||||
csrf string
|
||||
nonce string
|
||||
codeVerifier string
|
||||
incomingPath string
|
||||
redirectCount int
|
||||
saveError error
|
||||
dirty bool
|
||||
}
|
||||
|
||||
func (s *mockSessionData) GetRedirectCount() int { return s.redirectCount }
|
||||
func (s *mockSessionData) ResetRedirectCount() { s.redirectCount = 0 }
|
||||
func (s *mockSessionData) IncrementRedirectCount() { s.redirectCount++ }
|
||||
func (s *mockSessionData) SetAuthenticated(auth bool) { s.authenticated = auth }
|
||||
func (s *mockSessionData) SetEmail(email string) { s.email = email }
|
||||
func (s *mockSessionData) SetAccessToken(token string) { s.accessToken = token }
|
||||
func (s *mockSessionData) SetRefreshToken(token string) { s.refreshToken = token }
|
||||
func (s *mockSessionData) SetIDToken(token string) { s.idToken = token }
|
||||
func (s *mockSessionData) SetNonce(nonce string) { s.nonce = nonce }
|
||||
func (s *mockSessionData) SetCodeVerifier(verifier string) { s.codeVerifier = verifier }
|
||||
func (s *mockSessionData) SetCSRF(csrf string) { s.csrf = csrf }
|
||||
func (s *mockSessionData) SetIncomingPath(path string) { s.incomingPath = path }
|
||||
func (s *mockSessionData) MarkDirty() { s.dirty = true }
|
||||
|
||||
func (s *mockSessionData) Save(req *http.Request, rw http.ResponseWriter) error {
|
||||
return s.saveError
|
||||
}
|
||||
|
||||
// TestAuthHandler_NewAuthHandler tests the constructor
|
||||
func TestAuthHandler_NewAuthHandler(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
isGoogleProv := func() bool { return false }
|
||||
isAzureProv := func() bool { return true }
|
||||
scopes := []string{"openid", "profile", "email"}
|
||||
|
||||
handler := NewAuthHandler(logger, true, isGoogleProv, isAzureProv,
|
||||
"test-client-id", "https://example.com/auth", "https://example.com",
|
||||
scopes, false)
|
||||
|
||||
if handler == nil {
|
||||
t.Fatal("Expected handler to be created, got nil")
|
||||
}
|
||||
|
||||
if handler.logger != logger {
|
||||
t.Error("Logger not set correctly")
|
||||
}
|
||||
|
||||
if !handler.enablePKCE {
|
||||
t.Error("PKCE should be enabled")
|
||||
}
|
||||
|
||||
if handler.clientID != "test-client-id" {
|
||||
t.Errorf("Expected clientID 'test-client-id', got '%s'", handler.clientID)
|
||||
}
|
||||
|
||||
if handler.authURL != "https://example.com/auth" {
|
||||
t.Errorf("Expected authURL 'https://example.com/auth', got '%s'", handler.authURL)
|
||||
}
|
||||
|
||||
if handler.issuerURL != "https://example.com" {
|
||||
t.Errorf("Expected issuerURL 'https://example.com', got '%s'", handler.issuerURL)
|
||||
}
|
||||
|
||||
if len(handler.scopes) != 3 {
|
||||
t.Errorf("Expected 3 scopes, got %d", len(handler.scopes))
|
||||
}
|
||||
|
||||
if handler.overrideScopes {
|
||||
t.Error("overrideScopes should be false")
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthHandler_InitiateAuthentication_MaxRedirects tests redirect limit enforcement
|
||||
func TestAuthHandler_InitiateAuthentication_MaxRedirects(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false)
|
||||
|
||||
session := &mockSessionData{redirectCount: 5} // At the limit
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
generateNonce := func() (string, error) { return "test-nonce", nil }
|
||||
generateCodeVerifier := func() (string, error) { return "", nil }
|
||||
deriveCodeChallenge := func() (string, error) { return "", nil }
|
||||
|
||||
handler.InitiateAuthentication(rw, req, session, "https://example.com/callback",
|
||||
generateNonce, generateCodeVerifier, deriveCodeChallenge)
|
||||
|
||||
if rw.Code != http.StatusLoopDetected {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusLoopDetected, rw.Code)
|
||||
}
|
||||
|
||||
body := rw.Body.String()
|
||||
if !strings.Contains(body, "Too many redirects") {
|
||||
t.Errorf("Expected 'Too many redirects' in response body, got '%s'", body)
|
||||
}
|
||||
|
||||
if session.redirectCount != 0 {
|
||||
t.Errorf("Expected redirect count to be reset, got %d", session.redirectCount)
|
||||
}
|
||||
|
||||
if len(logger.errorMessages) == 0 {
|
||||
t.Error("Expected error to be logged")
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthHandler_InitiateAuthentication_NonceGenerationError tests nonce generation failure
|
||||
func TestAuthHandler_InitiateAuthentication_NonceGenerationError(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false)
|
||||
|
||||
session := &mockSessionData{}
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
generateNonce := func() (string, error) { return "", &testError{"nonce generation failed"} }
|
||||
generateCodeVerifier := func() (string, error) { return "", nil }
|
||||
deriveCodeChallenge := func() (string, error) { return "", nil }
|
||||
|
||||
handler.InitiateAuthentication(rw, req, session, "https://example.com/callback",
|
||||
generateNonce, generateCodeVerifier, deriveCodeChallenge)
|
||||
|
||||
if rw.Code != http.StatusInternalServerError {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, rw.Code)
|
||||
}
|
||||
|
||||
body := rw.Body.String()
|
||||
if !strings.Contains(body, "Failed to generate nonce") {
|
||||
t.Errorf("Expected 'Failed to generate nonce' in response body, got '%s'", body)
|
||||
}
|
||||
|
||||
if len(logger.errorMessages) == 0 {
|
||||
t.Error("Expected error to be logged")
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthHandler_InitiateAuthentication_PKCECodeVerifierError tests PKCE code verifier generation failure
|
||||
func TestAuthHandler_InitiateAuthentication_PKCECodeVerifierError(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, true, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false)
|
||||
|
||||
session := &mockSessionData{}
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
generateNonce := func() (string, error) { return "test-nonce", nil }
|
||||
generateCodeVerifier := func() (string, error) { return "", &testError{"code verifier generation failed"} }
|
||||
deriveCodeChallenge := func() (string, error) { return "", nil }
|
||||
|
||||
handler.InitiateAuthentication(rw, req, session, "https://example.com/callback",
|
||||
generateNonce, generateCodeVerifier, deriveCodeChallenge)
|
||||
|
||||
if rw.Code != http.StatusInternalServerError {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, rw.Code)
|
||||
}
|
||||
|
||||
body := rw.Body.String()
|
||||
if !strings.Contains(body, "Failed to generate code verifier") {
|
||||
t.Errorf("Expected 'Failed to generate code verifier' in response body, got '%s'", body)
|
||||
}
|
||||
|
||||
if len(logger.errorMessages) == 0 {
|
||||
t.Error("Expected error to be logged")
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthHandler_InitiateAuthentication_PKCECodeChallengeError tests PKCE code challenge derivation failure
|
||||
func TestAuthHandler_InitiateAuthentication_PKCECodeChallengeError(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, true, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false)
|
||||
|
||||
session := &mockSessionData{}
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
generateNonce := func() (string, error) { return "test-nonce", nil }
|
||||
generateCodeVerifier := func() (string, error) { return "test-verifier", nil }
|
||||
deriveCodeChallenge := func() (string, error) { return "", &testError{"code challenge derivation failed"} }
|
||||
|
||||
handler.InitiateAuthentication(rw, req, session, "https://example.com/callback",
|
||||
generateNonce, generateCodeVerifier, deriveCodeChallenge)
|
||||
|
||||
if rw.Code != http.StatusInternalServerError {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, rw.Code)
|
||||
}
|
||||
|
||||
body := rw.Body.String()
|
||||
if !strings.Contains(body, "Failed to generate code challenge") {
|
||||
t.Errorf("Expected 'Failed to generate code challenge' in response body, got '%s'", body)
|
||||
}
|
||||
|
||||
if len(logger.errorMessages) == 0 {
|
||||
t.Error("Expected error to be logged")
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthHandler_InitiateAuthentication_SessionSaveError tests session save failure
|
||||
func TestAuthHandler_InitiateAuthentication_SessionSaveError(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false)
|
||||
|
||||
session := &mockSessionData{saveError: &testError{"save failed"}}
|
||||
req := httptest.NewRequest("GET", "/test?param=value", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
generateNonce := func() (string, error) { return "test-nonce", nil }
|
||||
generateCodeVerifier := func() (string, error) { return "", nil }
|
||||
deriveCodeChallenge := func() (string, error) { return "", nil }
|
||||
|
||||
handler.InitiateAuthentication(rw, req, session, "https://example.com/callback",
|
||||
generateNonce, generateCodeVerifier, deriveCodeChallenge)
|
||||
|
||||
if rw.Code != http.StatusInternalServerError {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, rw.Code)
|
||||
}
|
||||
|
||||
body := rw.Body.String()
|
||||
if !strings.Contains(body, "Failed to save session") {
|
||||
t.Errorf("Expected 'Failed to save session' in response body, got '%s'", body)
|
||||
}
|
||||
|
||||
if len(logger.errorMessages) == 0 {
|
||||
t.Error("Expected error to be logged")
|
||||
}
|
||||
|
||||
// Verify session was prepared correctly before the save failure
|
||||
if session.incomingPath != "/test?param=value" {
|
||||
t.Errorf("Expected incoming path '/test?param=value', got '%s'", session.incomingPath)
|
||||
}
|
||||
|
||||
if session.nonce != "test-nonce" {
|
||||
t.Errorf("Expected nonce 'test-nonce', got '%s'", session.nonce)
|
||||
}
|
||||
|
||||
if session.redirectCount != 1 {
|
||||
t.Errorf("Expected redirect count 1, got %d", session.redirectCount)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthHandler_InitiateAuthentication_Success tests successful authentication initiation
|
||||
func TestAuthHandler_InitiateAuthentication_Success(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, true, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{"openid", "email"}, false)
|
||||
|
||||
session := &mockSessionData{}
|
||||
req := httptest.NewRequest("GET", "/protected/resource", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
generateNonce := func() (string, error) { return "generated-nonce", nil }
|
||||
generateCodeVerifier := func() (string, error) { return "generated-verifier", nil }
|
||||
deriveCodeChallenge := func() (string, error) { return "generated-challenge", nil }
|
||||
|
||||
handler.InitiateAuthentication(rw, req, session, "https://example.com/callback",
|
||||
generateNonce, generateCodeVerifier, deriveCodeChallenge)
|
||||
|
||||
// Should redirect
|
||||
if rw.Code != http.StatusFound {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusFound, rw.Code)
|
||||
}
|
||||
|
||||
location := rw.Header().Get("Location")
|
||||
if location == "" {
|
||||
t.Error("Expected Location header to be set")
|
||||
}
|
||||
|
||||
// Parse the redirect URL to verify parameters
|
||||
parsedURL, err := url.Parse(location)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse redirect URL: %v", err)
|
||||
}
|
||||
|
||||
query := parsedURL.Query()
|
||||
|
||||
// Verify required parameters
|
||||
if query.Get("client_id") != "test-client" {
|
||||
t.Errorf("Expected client_id 'test-client', got '%s'", query.Get("client_id"))
|
||||
}
|
||||
|
||||
if query.Get("response_type") != "code" {
|
||||
t.Errorf("Expected response_type 'code', got '%s'", query.Get("response_type"))
|
||||
}
|
||||
|
||||
if query.Get("redirect_uri") != "https://example.com/callback" {
|
||||
t.Errorf("Expected redirect_uri 'https://example.com/callback', got '%s'", query.Get("redirect_uri"))
|
||||
}
|
||||
|
||||
if query.Get("nonce") != "generated-nonce" {
|
||||
t.Errorf("Expected nonce 'generated-nonce', got '%s'", query.Get("nonce"))
|
||||
}
|
||||
|
||||
// Verify PKCE parameters
|
||||
if query.Get("code_challenge") != "generated-challenge" {
|
||||
t.Errorf("Expected code_challenge 'generated-challenge', got '%s'", query.Get("code_challenge"))
|
||||
}
|
||||
|
||||
if query.Get("code_challenge_method") != "S256" {
|
||||
t.Errorf("Expected code_challenge_method 'S256', got '%s'", query.Get("code_challenge_method"))
|
||||
}
|
||||
|
||||
// Verify scope
|
||||
scope := query.Get("scope")
|
||||
if !strings.Contains(scope, "openid") || !strings.Contains(scope, "email") {
|
||||
t.Errorf("Expected scope to contain 'openid' and 'email', got '%s'", scope)
|
||||
}
|
||||
|
||||
// Verify session was updated correctly
|
||||
if !session.dirty {
|
||||
t.Error("Expected session to be marked dirty")
|
||||
}
|
||||
|
||||
if session.incomingPath != "/protected/resource" {
|
||||
t.Errorf("Expected incoming path '/protected/resource', got '%s'", session.incomingPath)
|
||||
}
|
||||
|
||||
if session.nonce != "generated-nonce" {
|
||||
t.Errorf("Expected session nonce 'generated-nonce', got '%s'", session.nonce)
|
||||
}
|
||||
|
||||
if session.codeVerifier != "generated-verifier" {
|
||||
t.Errorf("Expected session code verifier 'generated-verifier', got '%s'", session.codeVerifier)
|
||||
}
|
||||
|
||||
// Verify session data was cleared
|
||||
if session.authenticated {
|
||||
t.Error("Expected session to not be authenticated")
|
||||
}
|
||||
|
||||
if session.email != "" {
|
||||
t.Errorf("Expected email to be cleared, got '%s'", session.email)
|
||||
}
|
||||
|
||||
if session.accessToken != "" {
|
||||
t.Errorf("Expected access token to be cleared, got '%s'", session.accessToken)
|
||||
}
|
||||
|
||||
if session.idToken != "" {
|
||||
t.Errorf("Expected ID token to be cleared, got '%s'", session.idToken)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthHandler_BuildAuthURL_GoogleProvider tests Google-specific URL building
|
||||
func TestAuthHandler_BuildAuthURL_GoogleProvider(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return true }, func() bool { return false },
|
||||
"google-client", "https://accounts.google.com/oauth2/auth", "https://accounts.google.com",
|
||||
[]string{"openid", "profile", "email"}, false)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
|
||||
|
||||
parsedURL, err := url.Parse(authURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse auth URL: %v", err)
|
||||
}
|
||||
|
||||
query := parsedURL.Query()
|
||||
|
||||
// Google-specific parameters
|
||||
if query.Get("access_type") != "offline" {
|
||||
t.Errorf("Expected access_type 'offline' for Google, got '%s'", query.Get("access_type"))
|
||||
}
|
||||
|
||||
if query.Get("prompt") != "consent" {
|
||||
t.Errorf("Expected prompt 'consent' for Google, got '%s'", query.Get("prompt"))
|
||||
}
|
||||
|
||||
// Standard parameters should still be present
|
||||
if query.Get("client_id") != "google-client" {
|
||||
t.Errorf("Expected client_id 'google-client', got '%s'", query.Get("client_id"))
|
||||
}
|
||||
|
||||
if query.Get("state") != "test-state" {
|
||||
t.Errorf("Expected state 'test-state', got '%s'", query.Get("state"))
|
||||
}
|
||||
|
||||
if query.Get("nonce") != "test-nonce" {
|
||||
t.Errorf("Expected nonce 'test-nonce', got '%s'", query.Get("nonce"))
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthHandler_BuildAuthURL_AzureProvider tests Azure-specific URL building
|
||||
func TestAuthHandler_BuildAuthURL_AzureProvider(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return true },
|
||||
"azure-client", "https://login.microsoftonline.com/tenant/oauth2/v2.0/authorize",
|
||||
"https://login.microsoftonline.com/tenant/v2.0",
|
||||
[]string{"openid", "profile", "email"}, false)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
|
||||
|
||||
parsedURL, err := url.Parse(authURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse auth URL: %v", err)
|
||||
}
|
||||
|
||||
query := parsedURL.Query()
|
||||
|
||||
// Azure-specific parameters
|
||||
if query.Get("response_mode") != "query" {
|
||||
t.Errorf("Expected response_mode 'query' for Azure, got '%s'", query.Get("response_mode"))
|
||||
}
|
||||
|
||||
// Azure should add offline_access scope automatically
|
||||
scope := query.Get("scope")
|
||||
if !strings.Contains(scope, "offline_access") {
|
||||
t.Errorf("Expected scope to contain 'offline_access' for Azure, got '%s'", scope)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthHandler_BuildAuthURL_PKCEEnabled tests PKCE parameter inclusion
|
||||
func TestAuthHandler_BuildAuthURL_PKCEEnabled(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, true, func() bool { return false }, func() bool { return false },
|
||||
"pkce-client", "https://example.com/auth", "https://example.com",
|
||||
[]string{"openid"}, false)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "test-challenge")
|
||||
|
||||
parsedURL, err := url.Parse(authURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse auth URL: %v", err)
|
||||
}
|
||||
|
||||
query := parsedURL.Query()
|
||||
|
||||
if query.Get("code_challenge") != "test-challenge" {
|
||||
t.Errorf("Expected code_challenge 'test-challenge', got '%s'", query.Get("code_challenge"))
|
||||
}
|
||||
|
||||
if query.Get("code_challenge_method") != "S256" {
|
||||
t.Errorf("Expected code_challenge_method 'S256', got '%s'", query.Get("code_challenge_method"))
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthHandler_BuildAuthURL_PKCEDisabled tests when PKCE is disabled
|
||||
func TestAuthHandler_BuildAuthURL_PKCEDisabled(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"no-pkce-client", "https://example.com/auth", "https://example.com",
|
||||
[]string{"openid"}, false)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "test-challenge")
|
||||
|
||||
parsedURL, err := url.Parse(authURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse auth URL: %v", err)
|
||||
}
|
||||
|
||||
query := parsedURL.Query()
|
||||
|
||||
// PKCE parameters should not be included
|
||||
if query.Get("code_challenge") != "" {
|
||||
t.Errorf("Expected no code_challenge when PKCE disabled, got '%s'", query.Get("code_challenge"))
|
||||
}
|
||||
|
||||
if query.Get("code_challenge_method") != "" {
|
||||
t.Errorf("Expected no code_challenge_method when PKCE disabled, got '%s'", query.Get("code_challenge_method"))
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthHandler_BuildAuthURL_ScopeHandling tests various scope configurations
|
||||
func TestAuthHandler_BuildAuthURL_ScopeHandling(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
scopes []string
|
||||
overrideScopes bool
|
||||
isAzure bool
|
||||
expectedScopes []string
|
||||
}{
|
||||
{
|
||||
name: "Basic scopes",
|
||||
scopes: []string{"openid", "profile", "email"},
|
||||
overrideScopes: false,
|
||||
isAzure: false,
|
||||
expectedScopes: []string{"openid", "profile", "email", "offline_access"},
|
||||
},
|
||||
{
|
||||
name: "Azure with offline_access already present",
|
||||
scopes: []string{"openid", "profile", "offline_access"},
|
||||
overrideScopes: false,
|
||||
isAzure: true,
|
||||
expectedScopes: []string{"openid", "profile", "offline_access"},
|
||||
},
|
||||
{
|
||||
name: "Azure auto-add offline_access",
|
||||
scopes: []string{"openid", "profile"},
|
||||
overrideScopes: false,
|
||||
isAzure: true,
|
||||
expectedScopes: []string{"openid", "profile", "offline_access"},
|
||||
},
|
||||
{
|
||||
name: "Override scopes with empty array",
|
||||
scopes: []string{},
|
||||
overrideScopes: true,
|
||||
isAzure: true,
|
||||
expectedScopes: []string{"offline_access"},
|
||||
},
|
||||
{
|
||||
name: "Override scopes prevents auto-add",
|
||||
scopes: []string{"openid", "custom_scope"},
|
||||
overrideScopes: true,
|
||||
isAzure: true,
|
||||
expectedScopes: []string{"openid", "custom_scope"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return tt.isAzure },
|
||||
"test-client", "https://example.com/auth", "https://example.com",
|
||||
tt.scopes, tt.overrideScopes)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
|
||||
|
||||
parsedURL, err := url.Parse(authURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse auth URL: %v", err)
|
||||
}
|
||||
|
||||
actualScope := parsedURL.Query().Get("scope")
|
||||
actualScopes := strings.Split(actualScope, " ")
|
||||
|
||||
// Check each expected scope is present
|
||||
for _, expectedScope := range tt.expectedScopes {
|
||||
found := false
|
||||
for _, actualScope := range actualScopes {
|
||||
if actualScope == expectedScope {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("Expected scope '%s' not found in '%s'", expectedScope, actualScope)
|
||||
}
|
||||
}
|
||||
|
||||
// Check no unexpected scopes are present
|
||||
for _, actualScope := range actualScopes {
|
||||
if actualScope == "" {
|
||||
continue // Skip empty strings from split
|
||||
}
|
||||
found := false
|
||||
for _, expectedScope := range tt.expectedScopes {
|
||||
if actualScope == expectedScope {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("Unexpected scope '%s' found in '%s'", actualScope, parsedURL.Query().Get("scope"))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test helper type for errors
|
||||
type testError struct {
|
||||
message string
|
||||
}
|
||||
|
||||
func (e *testError) Error() string {
|
||||
return e.message
|
||||
}
|
||||
@@ -1,562 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestAuthHandler_validateURL tests URL validation functionality
|
||||
func TestAuthHandler_validateURL(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
url string
|
||||
wantErr bool
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "Valid HTTPS URL",
|
||||
url: "https://example.com/auth",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Valid HTTP URL",
|
||||
url: "http://example.com/auth",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Empty URL",
|
||||
url: "",
|
||||
wantErr: true,
|
||||
errMsg: "empty URL",
|
||||
},
|
||||
{
|
||||
name: "Invalid URL format",
|
||||
url: "not-a-url",
|
||||
wantErr: true,
|
||||
errMsg: "disallowed URL scheme",
|
||||
},
|
||||
{
|
||||
name: "Disallowed scheme - javascript",
|
||||
url: "javascript:alert('xss')",
|
||||
wantErr: true,
|
||||
errMsg: "disallowed URL scheme",
|
||||
},
|
||||
{
|
||||
name: "Disallowed scheme - data",
|
||||
url: "data:text/html,<script>alert('xss')</script>",
|
||||
wantErr: true,
|
||||
errMsg: "disallowed URL scheme",
|
||||
},
|
||||
{
|
||||
name: "Disallowed scheme - file",
|
||||
url: "file:///etc/passwd",
|
||||
wantErr: true,
|
||||
errMsg: "disallowed URL scheme",
|
||||
},
|
||||
{
|
||||
name: "Disallowed scheme - ftp",
|
||||
url: "ftp://example.com/file",
|
||||
wantErr: true,
|
||||
errMsg: "disallowed URL scheme",
|
||||
},
|
||||
{
|
||||
name: "Missing host",
|
||||
url: "https:///path",
|
||||
wantErr: true,
|
||||
errMsg: "missing host",
|
||||
},
|
||||
{
|
||||
name: "Path traversal attempt",
|
||||
url: "https://example.com/../../../etc/passwd",
|
||||
wantErr: true,
|
||||
errMsg: "path traversal detected",
|
||||
},
|
||||
{
|
||||
name: "Path traversal in middle",
|
||||
url: "https://example.com/path/../sensitive/file",
|
||||
wantErr: true,
|
||||
errMsg: "path traversal detected",
|
||||
},
|
||||
{
|
||||
name: "Localhost attempt",
|
||||
url: "https://localhost/auth",
|
||||
wantErr: true,
|
||||
errMsg: "localhost access not allowed",
|
||||
},
|
||||
{
|
||||
name: "127.0.0.1 attempt",
|
||||
url: "https://127.0.0.1/auth",
|
||||
wantErr: true,
|
||||
errMsg: "localhost access not allowed",
|
||||
},
|
||||
{
|
||||
name: "IPv6 localhost attempt",
|
||||
url: "https://[::1]/auth",
|
||||
wantErr: true,
|
||||
errMsg: "invalid host:port format",
|
||||
},
|
||||
{
|
||||
name: "0.0.0.0 attempt",
|
||||
url: "https://0.0.0.0/auth",
|
||||
wantErr: true,
|
||||
errMsg: "localhost access not allowed",
|
||||
},
|
||||
{
|
||||
name: "Private IP - 192.168.x.x",
|
||||
url: "https://192.168.1.1/auth",
|
||||
wantErr: true,
|
||||
errMsg: "private IP not allowed",
|
||||
},
|
||||
{
|
||||
name: "Private IP - 10.x.x.x",
|
||||
url: "https://10.0.0.1/auth",
|
||||
wantErr: true,
|
||||
errMsg: "private IP not allowed",
|
||||
},
|
||||
{
|
||||
name: "Private IP - 172.16.x.x",
|
||||
url: "https://172.16.0.1/auth",
|
||||
wantErr: true,
|
||||
errMsg: "private IP not allowed",
|
||||
},
|
||||
{
|
||||
name: "Link-local IP",
|
||||
url: "https://169.254.1.1/auth",
|
||||
wantErr: true,
|
||||
errMsg: "link-local IP not allowed",
|
||||
},
|
||||
{
|
||||
name: "Multicast IP",
|
||||
url: "https://224.0.0.1/auth",
|
||||
wantErr: true,
|
||||
errMsg: "multicast IP not allowed",
|
||||
},
|
||||
{
|
||||
name: "Valid public IP",
|
||||
url: "https://8.8.8.8/auth",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Valid domain with port",
|
||||
url: "https://example.com:8443/auth",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "localhost with case variation",
|
||||
url: "https://LOCALHOST/auth",
|
||||
wantErr: true,
|
||||
errMsg: "localhost access not allowed",
|
||||
},
|
||||
{
|
||||
name: "Invalid host:port format",
|
||||
url: "https://example.com:notanumber/auth",
|
||||
wantErr: true,
|
||||
errMsg: "invalid URL format",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := handler.validateURL(tt.url)
|
||||
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Errorf("validateURL() expected error but got none")
|
||||
return
|
||||
}
|
||||
if !strings.Contains(err.Error(), tt.errMsg) {
|
||||
t.Errorf("validateURL() error = %v, expected error containing %v", err, tt.errMsg)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("validateURL() unexpected error = %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthHandler_validateHost tests host validation specifically
|
||||
func TestAuthHandler_validateHost(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
host string
|
||||
wantErr bool
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "Valid hostname",
|
||||
host: "example.com",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Valid hostname with subdomain",
|
||||
host: "api.example.com",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Valid hostname with port",
|
||||
host: "example.com:8080",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Empty host",
|
||||
host: "",
|
||||
wantErr: true,
|
||||
errMsg: "empty host",
|
||||
},
|
||||
{
|
||||
name: "localhost",
|
||||
host: "localhost",
|
||||
wantErr: true,
|
||||
errMsg: "localhost access not allowed",
|
||||
},
|
||||
{
|
||||
name: "LOCALHOST (case insensitive)",
|
||||
host: "LOCALHOST",
|
||||
wantErr: true,
|
||||
errMsg: "localhost access not allowed",
|
||||
},
|
||||
{
|
||||
name: "localhost with port",
|
||||
host: "localhost:8080",
|
||||
wantErr: true,
|
||||
errMsg: "localhost access not allowed",
|
||||
},
|
||||
{
|
||||
name: "127.0.0.1",
|
||||
host: "127.0.0.1",
|
||||
wantErr: true,
|
||||
errMsg: "localhost access not allowed",
|
||||
},
|
||||
{
|
||||
name: "127.0.0.1 with port",
|
||||
host: "127.0.0.1:8080",
|
||||
wantErr: true,
|
||||
errMsg: "localhost access not allowed",
|
||||
},
|
||||
{
|
||||
name: "IPv6 localhost",
|
||||
host: "::1",
|
||||
wantErr: true,
|
||||
errMsg: "invalid host:port format",
|
||||
},
|
||||
{
|
||||
name: "0.0.0.0",
|
||||
host: "0.0.0.0",
|
||||
wantErr: true,
|
||||
errMsg: "localhost access not allowed",
|
||||
},
|
||||
{
|
||||
name: "Private IP 192.168.1.1",
|
||||
host: "192.168.1.1",
|
||||
wantErr: true,
|
||||
errMsg: "private IP not allowed",
|
||||
},
|
||||
{
|
||||
name: "Private IP 10.0.0.1",
|
||||
host: "10.0.0.1",
|
||||
wantErr: true,
|
||||
errMsg: "private IP not allowed",
|
||||
},
|
||||
{
|
||||
name: "Private IP 172.16.0.1",
|
||||
host: "172.16.0.1",
|
||||
wantErr: true,
|
||||
errMsg: "private IP not allowed",
|
||||
},
|
||||
{
|
||||
name: "Public IP 8.8.8.8",
|
||||
host: "8.8.8.8",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Link-local IP",
|
||||
host: "169.254.1.1",
|
||||
wantErr: true,
|
||||
errMsg: "link-local IP not allowed",
|
||||
},
|
||||
{
|
||||
name: "Multicast IP",
|
||||
host: "224.0.0.1",
|
||||
wantErr: true,
|
||||
errMsg: "multicast IP not allowed",
|
||||
},
|
||||
{
|
||||
name: "Invalid host:port format",
|
||||
host: "example.com::",
|
||||
wantErr: true,
|
||||
errMsg: "invalid host:port format",
|
||||
},
|
||||
{
|
||||
name: "Valid international domain",
|
||||
host: "example.org",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Valid ccTLD",
|
||||
host: "example.co.uk",
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := handler.validateHost(tt.host)
|
||||
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Errorf("validateHost() expected error but got none")
|
||||
return
|
||||
}
|
||||
if !strings.Contains(err.Error(), tt.errMsg) {
|
||||
t.Errorf("validateHost() error = %v, expected error containing %v", err, tt.errMsg)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("validateHost() unexpected error = %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthHandler_buildURLWithParams tests URL building with parameters
|
||||
func TestAuthHandler_buildURLWithParams(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
baseURL string
|
||||
params url.Values
|
||||
expected string
|
||||
expectEmpty bool
|
||||
}{
|
||||
{
|
||||
name: "Absolute HTTPS URL",
|
||||
baseURL: "https://provider.com/auth",
|
||||
params: url.Values{
|
||||
"client_id": []string{"test-client"},
|
||||
"response_type": []string{"code"},
|
||||
},
|
||||
expected: "https://provider.com/auth?client_id=test-client&response_type=code",
|
||||
},
|
||||
{
|
||||
name: "Absolute HTTP URL",
|
||||
baseURL: "http://provider.com/auth",
|
||||
params: url.Values{
|
||||
"state": []string{"test-state"},
|
||||
},
|
||||
expected: "http://provider.com/auth?state=test-state",
|
||||
},
|
||||
{
|
||||
name: "Relative URL resolved against issuer",
|
||||
baseURL: "/oauth2/authorize",
|
||||
params: url.Values{
|
||||
"scope": []string{"openid"},
|
||||
},
|
||||
expected: "https://example.com/oauth2/authorize?scope=openid",
|
||||
},
|
||||
{
|
||||
name: "Root relative URL",
|
||||
baseURL: "/auth",
|
||||
params: url.Values{
|
||||
"nonce": []string{"test-nonce"},
|
||||
},
|
||||
expected: "https://example.com/auth?nonce=test-nonce",
|
||||
},
|
||||
{
|
||||
name: "Invalid absolute URL",
|
||||
baseURL: "https://localhost/auth",
|
||||
params: url.Values{},
|
||||
expectEmpty: true, // Should return empty string due to validation failure
|
||||
},
|
||||
{
|
||||
name: "Invalid relative URL when resolved",
|
||||
baseURL: "/auth",
|
||||
params: url.Values{},
|
||||
expected: "", // Should be empty because issuer validation would be tested separately
|
||||
expectEmpty: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := handler.buildURLWithParams(tt.baseURL, tt.params)
|
||||
|
||||
if tt.expectEmpty {
|
||||
if result != "" {
|
||||
t.Errorf("buildURLWithParams() expected empty string, got %v", result)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// For relative URLs, we expect them to be resolved against the issuer URL
|
||||
if !strings.HasPrefix(tt.baseURL, "http") {
|
||||
// Verify it starts with the issuer URL
|
||||
if !strings.HasPrefix(result, handler.issuerURL) {
|
||||
t.Errorf("buildURLWithParams() relative URL not resolved against issuer URL. Got %v", result)
|
||||
}
|
||||
}
|
||||
|
||||
// Parse the result to verify parameters
|
||||
parsedURL, err := url.Parse(result)
|
||||
if err != nil {
|
||||
t.Fatalf("buildURLWithParams() produced invalid URL: %v", err)
|
||||
}
|
||||
|
||||
// Verify all expected parameters are present
|
||||
resultParams := parsedURL.Query()
|
||||
for key, expectedValues := range tt.params {
|
||||
actualValues := resultParams[key]
|
||||
if len(actualValues) != len(expectedValues) {
|
||||
t.Errorf("Parameter %s: expected %d values, got %d", key, len(expectedValues), len(actualValues))
|
||||
continue
|
||||
}
|
||||
for i, expectedValue := range expectedValues {
|
||||
if actualValues[i] != expectedValue {
|
||||
t.Errorf("Parameter %s[%d]: expected %v, got %v", key, i, expectedValue, actualValues[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthHandler_buildURLWithParams_ParameterEncoding tests proper parameter encoding
|
||||
func TestAuthHandler_buildURLWithParams_ParameterEncoding(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false)
|
||||
|
||||
// Test special characters that need encoding
|
||||
params := url.Values{
|
||||
"redirect_uri": []string{"https://example.com/callback?test=value&other=data"},
|
||||
"state": []string{"state with spaces and & special chars"},
|
||||
"scope": []string{"openid profile email"},
|
||||
"special": []string{"value+with+plus&ersand=equals"},
|
||||
}
|
||||
|
||||
result := handler.buildURLWithParams("https://provider.com/auth", params)
|
||||
|
||||
parsedURL, err := url.Parse(result)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse result URL: %v", err)
|
||||
}
|
||||
|
||||
// Verify parameters are correctly encoded/decoded
|
||||
resultParams := parsedURL.Query()
|
||||
|
||||
expectedParams := map[string]string{
|
||||
"redirect_uri": "https://example.com/callback?test=value&other=data",
|
||||
"state": "state with spaces and & special chars",
|
||||
"scope": "openid profile email",
|
||||
"special": "value+with+plus&ersand=equals",
|
||||
}
|
||||
|
||||
for key, expectedValue := range expectedParams {
|
||||
actualValue := resultParams.Get(key)
|
||||
if actualValue != expectedValue {
|
||||
t.Errorf("Parameter %s: expected %v, got %v", key, expectedValue, actualValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthHandler_validateParsedURL tests validateParsedURL method
|
||||
func TestAuthHandler_validateParsedURL(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
url string
|
||||
wantErr bool
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "Valid HTTPS URL",
|
||||
url: "https://example.com/path",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Valid HTTP URL with warning",
|
||||
url: "http://example.com/path",
|
||||
wantErr: false, // Should not error but should log warning
|
||||
},
|
||||
{
|
||||
name: "Invalid scheme",
|
||||
url: "javascript:alert('xss')",
|
||||
wantErr: true,
|
||||
errMsg: "disallowed URL scheme",
|
||||
},
|
||||
{
|
||||
name: "Missing host",
|
||||
url: "https:///path",
|
||||
wantErr: true,
|
||||
errMsg: "missing host",
|
||||
},
|
||||
{
|
||||
name: "Path traversal",
|
||||
url: "https://example.com/path/../../../etc",
|
||||
wantErr: true,
|
||||
errMsg: "path traversal detected",
|
||||
},
|
||||
{
|
||||
name: "Invalid host (private IP)",
|
||||
url: "https://192.168.1.1/path",
|
||||
wantErr: true,
|
||||
errMsg: "invalid host",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
parsedURL, err := url.Parse(tt.url)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse test URL: %v", err)
|
||||
}
|
||||
|
||||
err = handler.validateParsedURL(parsedURL)
|
||||
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Errorf("validateParsedURL() expected error but got none")
|
||||
return
|
||||
}
|
||||
if !strings.Contains(err.Error(), tt.errMsg) {
|
||||
t.Errorf("validateParsedURL() error = %v, expected error containing %v", err, tt.errMsg)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("validateParsedURL() unexpected error = %v", err)
|
||||
}
|
||||
|
||||
// Check for HTTP warning in debug logs
|
||||
if parsedURL.Scheme == "http" && len(logger.debugMessages) > 0 {
|
||||
found := false
|
||||
for _, msg := range logger.debugMessages {
|
||||
if strings.Contains(msg, "Warning: Using HTTP scheme") {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("Expected HTTP scheme warning in debug logs")
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
+21
-15
@@ -8,10 +8,6 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// ============================================================================
|
||||
// AUTHENTICATION FLOW
|
||||
// ============================================================================
|
||||
|
||||
// validateRedirectCount checks if redirect limit is exceeded and handles the error
|
||||
func (t *TraefikOidc) validateRedirectCount(session *SessionData, rw http.ResponseWriter, req *http.Request) error {
|
||||
const maxRedirects = 5
|
||||
@@ -47,7 +43,7 @@ func (t *TraefikOidc) generatePKCEParameters() (string, string, error) {
|
||||
// prepareSessionForAuthentication clears existing session data and sets new authentication state
|
||||
func (t *TraefikOidc) prepareSessionForAuthentication(session *SessionData, csrfToken, nonce, codeVerifier, incomingPath string) {
|
||||
// Clear all existing session data
|
||||
session.SetAuthenticated(false)
|
||||
_ = session.SetAuthenticated(false) // Safe to ignore: clearing authentication state on new flow
|
||||
session.SetEmail("")
|
||||
session.SetAccessToken("")
|
||||
session.SetRefreshToken("")
|
||||
@@ -223,15 +219,25 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
return
|
||||
}
|
||||
|
||||
email, _ := claims["email"].(string)
|
||||
if email == "" {
|
||||
t.logger.Errorf("Email claim missing or empty in token during callback")
|
||||
t.sendErrorResponse(rw, req, "Authentication failed: Email missing in token", http.StatusInternalServerError)
|
||||
return
|
||||
// Extract user identifier from the configured claim (defaults to "email" for backward compatibility)
|
||||
userIdentifier, _ := claims[t.userIdentifierClaim].(string)
|
||||
if userIdentifier == "" {
|
||||
// Try "sub" as fallback since it's required by OIDC spec
|
||||
if t.userIdentifierClaim != "sub" {
|
||||
userIdentifier, _ = claims["sub"].(string)
|
||||
}
|
||||
if userIdentifier == "" {
|
||||
t.logger.Errorf("User identifier claim '%s' missing or empty in token during callback", t.userIdentifierClaim)
|
||||
t.sendErrorResponse(rw, req, "Authentication failed: User identifier missing in token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
t.logger.Debugf("Configured claim '%s' not found, using 'sub' claim as fallback", t.userIdentifierClaim)
|
||||
}
|
||||
if !t.isAllowedDomain(email) {
|
||||
t.logger.Errorf("Disallowed email domain during callback: %s", email)
|
||||
t.sendErrorResponse(rw, req, "Authentication failed: Email domain not allowed", http.StatusForbidden)
|
||||
|
||||
// Validate user authorization
|
||||
if !t.isAllowedUser(userIdentifier) {
|
||||
t.logger.Errorf("User not authorized during callback: %s", userIdentifier)
|
||||
t.sendErrorResponse(rw, req, "Authentication failed: User not authorized", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -240,7 +246,7 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
t.sendErrorResponse(rw, req, "Failed to update session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
session.SetEmail(email)
|
||||
session.SetEmail(userIdentifier) // SetEmail stores the user identifier (email or other claim)
|
||||
session.SetIDToken(tokenResponse.IDToken)
|
||||
session.SetAccessToken(tokenResponse.AccessToken)
|
||||
session.SetRefreshToken(tokenResponse.RefreshToken)
|
||||
@@ -276,7 +282,7 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
// - redirectURL: The callback URL to be used in the new authentication flow.
|
||||
func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
|
||||
t.logger.Debug("Handling expired token: Clearing session and initiating re-authentication.")
|
||||
session.SetAuthenticated(false)
|
||||
_ = session.SetAuthenticated(false) // Safe to ignore: clearing authentication on expired token
|
||||
session.SetIDToken("")
|
||||
session.SetAccessToken("")
|
||||
session.SetRefreshToken("")
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,101 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestGeneratePKCEParameters tests the generatePKCEParameters method
|
||||
func TestGeneratePKCEParameters(t *testing.T) {
|
||||
t.Run("PKCE enabled - successful generation", func(t *testing.T) {
|
||||
// Create a TraefikOidc instance with PKCE enabled
|
||||
plugin := &TraefikOidc{
|
||||
enablePKCE: true,
|
||||
logger: NewLogger("debug"),
|
||||
}
|
||||
|
||||
verifier, challenge, err := plugin.generatePKCEParameters()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, verifier, "code verifier should not be empty when PKCE is enabled")
|
||||
assert.NotEmpty(t, challenge, "code challenge should not be empty when PKCE is enabled")
|
||||
|
||||
// Verify the challenge is derived from the verifier
|
||||
expectedChallenge := deriveCodeChallenge(verifier)
|
||||
assert.Equal(t, expectedChallenge, challenge, "challenge should match derived challenge from verifier")
|
||||
})
|
||||
|
||||
t.Run("PKCE disabled - returns empty strings", func(t *testing.T) {
|
||||
// Create a TraefikOidc instance with PKCE disabled
|
||||
plugin := &TraefikOidc{
|
||||
enablePKCE: false,
|
||||
logger: NewLogger("debug"),
|
||||
}
|
||||
|
||||
verifier, challenge, err := plugin.generatePKCEParameters()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, verifier, "code verifier should be empty when PKCE is disabled")
|
||||
assert.Empty(t, challenge, "code challenge should be empty when PKCE is disabled")
|
||||
})
|
||||
|
||||
t.Run("PKCE enabled - generates different values each time", func(t *testing.T) {
|
||||
plugin := &TraefikOidc{
|
||||
enablePKCE: true,
|
||||
logger: NewLogger("debug"),
|
||||
}
|
||||
|
||||
verifier1, challenge1, err1 := plugin.generatePKCEParameters()
|
||||
require.NoError(t, err1)
|
||||
|
||||
verifier2, challenge2, err2 := plugin.generatePKCEParameters()
|
||||
require.NoError(t, err2)
|
||||
|
||||
assert.NotEqual(t, verifier1, verifier2, "verifiers should be different")
|
||||
assert.NotEqual(t, challenge1, challenge2, "challenges should be different")
|
||||
})
|
||||
|
||||
t.Run("PKCE enabled - verifier and challenge relationship", func(t *testing.T) {
|
||||
plugin := &TraefikOidc{
|
||||
enablePKCE: true,
|
||||
logger: NewLogger("debug"),
|
||||
}
|
||||
|
||||
verifier, challenge, err := plugin.generatePKCEParameters()
|
||||
require.NoError(t, err)
|
||||
|
||||
// The challenge should always be derivable from the verifier
|
||||
recalculatedChallenge := deriveCodeChallenge(verifier)
|
||||
assert.Equal(t, challenge, recalculatedChallenge,
|
||||
"challenge should always match the SHA256 hash of verifier")
|
||||
})
|
||||
|
||||
t.Run("PKCE enabled - verifier meets RFC 7636 requirements", func(t *testing.T) {
|
||||
plugin := &TraefikOidc{
|
||||
enablePKCE: true,
|
||||
logger: NewLogger("debug"),
|
||||
}
|
||||
|
||||
verifier, _, err := plugin.generatePKCEParameters()
|
||||
require.NoError(t, err)
|
||||
|
||||
// RFC 7636 requires verifier to be 43-128 characters
|
||||
assert.GreaterOrEqual(t, len(verifier), 43, "verifier should be at least 43 characters")
|
||||
assert.LessOrEqual(t, len(verifier), 128, "verifier should be at most 128 characters")
|
||||
})
|
||||
|
||||
t.Run("PKCE enabled - challenge meets RFC 7636 requirements", func(t *testing.T) {
|
||||
plugin := &TraefikOidc{
|
||||
enablePKCE: true,
|
||||
logger: NewLogger("debug"),
|
||||
}
|
||||
|
||||
_, challenge, err := plugin.generatePKCEParameters()
|
||||
require.NoError(t, err)
|
||||
|
||||
// SHA256 hash base64 encoded should be 43 characters
|
||||
assert.Equal(t, 43, len(challenge), "S256 challenge should be exactly 43 characters")
|
||||
})
|
||||
}
|
||||
+19
-16
@@ -222,17 +222,16 @@ func (bt *BackgroundTask) run() {
|
||||
// TaskCircuitBreaker implements circuit breaker pattern for background task creation
|
||||
// It limits concurrent task execution and tracks failures to prevent system overload
|
||||
type TaskCircuitBreaker struct {
|
||||
state int32 // CircuitBreakerState
|
||||
failureCount int32
|
||||
lastFailureTime int64 // Unix timestamp
|
||||
failureThreshold int32
|
||||
timeout time.Duration
|
||||
logger *Logger
|
||||
// Concurrency limiting
|
||||
concurrentTasks int32 // Current number of running tasks
|
||||
maxConcurrent int32 // Maximum concurrent tasks allowed
|
||||
activeTasks map[string]struct{} // Track active task names
|
||||
tasksMu sync.RWMutex // Separate mutex for task tracking
|
||||
activeTasks map[string]struct{}
|
||||
lastFailureTime int64
|
||||
timeout time.Duration
|
||||
tasksMu sync.RWMutex
|
||||
state int32
|
||||
failureCount int32
|
||||
failureThreshold int32
|
||||
concurrentTasks int32
|
||||
maxConcurrent int32
|
||||
}
|
||||
|
||||
// NewTaskCircuitBreaker creates a new circuit breaker for background tasks
|
||||
@@ -266,18 +265,21 @@ func (cb *TaskCircuitBreaker) CanCreateTask(taskName string) error {
|
||||
max := atomic.LoadInt32(&cb.maxConcurrent)
|
||||
|
||||
// For cleanup tasks, be more restrictive (singleton-like behavior)
|
||||
// However, allow distinct realm-specific tasks (e.g., singleton-metadata-refresh-abc123 vs singleton-metadata-refresh-def456)
|
||||
if strings.Contains(taskName, "cleanup") || strings.Contains(taskName, "singleton") {
|
||||
cb.tasksMu.RLock()
|
||||
hasCleanupTask := false
|
||||
hasSameTask := false
|
||||
for activeTask := range cb.activeTasks {
|
||||
if strings.Contains(activeTask, "cleanup") || strings.Contains(activeTask, "singleton") {
|
||||
hasCleanupTask = true
|
||||
// Only block if the EXACT same task is already running
|
||||
// This allows realm-specific tasks like singleton-metadata-refresh-{hash} to run concurrently
|
||||
if activeTask == taskName {
|
||||
hasSameTask = true
|
||||
break
|
||||
}
|
||||
}
|
||||
cb.tasksMu.RUnlock()
|
||||
|
||||
if hasCleanupTask {
|
||||
if hasSameTask {
|
||||
return fmt.Errorf("cleanup/singleton task already running: %s", taskName)
|
||||
}
|
||||
}
|
||||
@@ -377,9 +379,9 @@ func (cb *TaskCircuitBreaker) OnTaskFailure(taskName string, err error) {
|
||||
// TaskRegistry maintains a registry of all active background tasks to prevent duplicates
|
||||
type TaskRegistry struct {
|
||||
tasks map[string]*BackgroundTask
|
||||
mu sync.RWMutex
|
||||
cb *TaskCircuitBreaker
|
||||
logger *Logger
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// GlobalTaskRegistry is the singleton instance for managing all background tasks
|
||||
@@ -538,7 +540,7 @@ func (tr *TaskRegistry) CreateSingletonTask(name string, interval time.Duration,
|
||||
|
||||
// Start the task if not already running
|
||||
if !rm.IsTaskRunning(name) {
|
||||
rm.StartBackgroundTask(name)
|
||||
_ = rm.StartBackgroundTask(name) // Safe to ignore: task registration succeeded, start is best-effort
|
||||
}
|
||||
|
||||
// Get the task from resource manager's internal registry
|
||||
@@ -787,6 +789,7 @@ func (mm *TaskMemoryMonitor) ForceGC() (before, after TaskMemoryStats, err error
|
||||
}
|
||||
|
||||
if mm.logger != nil {
|
||||
// #nosec G115 -- heap allocation bytes fit in int64 for practical purposes
|
||||
freed := int64(before.HeapAlloc) - int64(after.HeapAlloc)
|
||||
mm.logger.Infof("Forced GC: freed %d bytes (%.2f MB)", freed, float64(freed)/(1024*1024))
|
||||
}
|
||||
|
||||
+8
-7
@@ -58,6 +58,7 @@ func TestAzureOIDCRegression(t *testing.T) {
|
||||
tokenURL: "https://login.microsoftonline.com/tenant-id/oauth2/v2.0/token",
|
||||
jwksURL: "https://login.microsoftonline.com/tenant-id/discovery/v2.0/keys",
|
||||
clientID: "test-client-id",
|
||||
audience: "test-client-id",
|
||||
clientSecret: "test-client-secret",
|
||||
scopes: []string{"openid", "profile", "email"},
|
||||
refreshGracePeriod: 60 * time.Second,
|
||||
@@ -78,7 +79,7 @@ func TestAzureOIDCRegression(t *testing.T) {
|
||||
tOidc := &mockTraefikOidc{TraefikOidc: baseOidc}
|
||||
|
||||
// Initialize session manager
|
||||
sessionManager, _ := NewSessionManager("test-encryption-key-32-bytes-long", false, "", mockLogger)
|
||||
sessionManager, _ := NewSessionManager("test-encryption-key-32-bytes-long", false, "", "", 0, mockLogger)
|
||||
tOidc.sessionManager = sessionManager
|
||||
|
||||
// Mock the JWT verification to avoid JWKS lookup issues
|
||||
@@ -329,12 +330,12 @@ func TestValidateGoogleTokens(t *testing.T) {
|
||||
ts.tOidc.refreshGracePeriod = 60 * time.Second
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupSession func() *SessionData
|
||||
name string
|
||||
description string
|
||||
expectedAuth bool
|
||||
expectedRefresh bool
|
||||
expectedExpired bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "ValidGoogleTokens",
|
||||
@@ -475,13 +476,13 @@ func TestIsUserAuthenticated(t *testing.T) {
|
||||
ts.tOidc.refreshGracePeriod = 60 * time.Second
|
||||
|
||||
tests := []struct {
|
||||
setupSession func() *SessionData
|
||||
name string
|
||||
providerType string
|
||||
setupSession func() *SessionData
|
||||
description string
|
||||
expectedAuth bool
|
||||
expectedRefresh bool
|
||||
expectedExpired bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "AzureProvider",
|
||||
@@ -659,12 +660,12 @@ func TestValidateAzureTokensEdgeCases(t *testing.T) {
|
||||
ts.tOidc.refreshGracePeriod = 60 * time.Second
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupSession func() *SessionData
|
||||
name string
|
||||
description string
|
||||
expectedAuth bool
|
||||
expectedRefresh bool
|
||||
expectedExpired bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "UnauthenticatedWithRefreshToken",
|
||||
|
||||
@@ -0,0 +1,536 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// TestMemoryMonitorComprehensive tests memory monitor edge cases
|
||||
func TestMemoryMonitorComprehensive(t *testing.T) {
|
||||
t.Run("TriggerGC calls runtime GC", func(t *testing.T) {
|
||||
thresholds := DefaultMemoryAlertThresholds()
|
||||
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
|
||||
|
||||
// Should not panic
|
||||
assert.NotPanics(t, func() {
|
||||
monitor.TriggerGC()
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("GetMemoryPressure returns pressure level", func(t *testing.T) {
|
||||
thresholds := DefaultMemoryAlertThresholds()
|
||||
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
|
||||
|
||||
// Initially should return None (no stats yet)
|
||||
pressure := monitor.GetMemoryPressure()
|
||||
assert.Equal(t, MemoryPressureNone, pressure)
|
||||
|
||||
// Collect stats to populate lastStats
|
||||
monitor.GetCurrentStats()
|
||||
|
||||
// Now should return a valid pressure level
|
||||
pressure = monitor.GetMemoryPressure()
|
||||
assert.NotNil(t, pressure)
|
||||
})
|
||||
|
||||
t.Run("StartMonitoring can be called", func(t *testing.T) {
|
||||
ResetGlobalMemoryMonitor()
|
||||
ResetGlobalTaskRegistry()
|
||||
defer ResetGlobalMemoryMonitor()
|
||||
defer ResetGlobalTaskRegistry()
|
||||
|
||||
thresholds := DefaultMemoryAlertThresholds()
|
||||
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
|
||||
|
||||
// Start monitoring should not panic
|
||||
assert.NotPanics(t, func() {
|
||||
ctx := context.Background()
|
||||
monitor.StartMonitoring(ctx, 100*time.Millisecond)
|
||||
time.Sleep(GetTestDuration(50 * time.Millisecond))
|
||||
})
|
||||
|
||||
// Clean up
|
||||
monitor.StopMonitoring()
|
||||
})
|
||||
|
||||
t.Run("StopMonitoring can be called safely", func(t *testing.T) {
|
||||
ResetGlobalMemoryMonitor()
|
||||
defer ResetGlobalMemoryMonitor()
|
||||
|
||||
thresholds := DefaultMemoryAlertThresholds()
|
||||
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
|
||||
|
||||
// StopMonitoring should not panic even if not started
|
||||
assert.NotPanics(t, func() {
|
||||
monitor.StopMonitoring()
|
||||
})
|
||||
|
||||
// Can be called multiple times safely
|
||||
assert.NotPanics(t, func() {
|
||||
monitor.StopMonitoring()
|
||||
monitor.StopMonitoring()
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("ResetGlobalMemoryMonitor resets singleton", func(t *testing.T) {
|
||||
ResetGlobalMemoryMonitor()
|
||||
defer ResetGlobalMemoryMonitor()
|
||||
|
||||
// Get initial instance
|
||||
GetGlobalMemoryMonitor()
|
||||
|
||||
// Reset
|
||||
ResetGlobalMemoryMonitor()
|
||||
|
||||
// Should be able to get a new instance
|
||||
monitor := GetGlobalMemoryMonitor()
|
||||
assert.NotNil(t, monitor)
|
||||
|
||||
// Clean up
|
||||
monitor.StopMonitoring()
|
||||
ResetGlobalMemoryMonitor()
|
||||
})
|
||||
|
||||
t.Run("String method returns pressure name", func(t *testing.T) {
|
||||
pressures := []struct {
|
||||
name string
|
||||
level MemoryPressureLevel
|
||||
}{
|
||||
{level: MemoryPressureNone, name: "None"},
|
||||
{level: MemoryPressureLow, name: "Low"},
|
||||
{level: MemoryPressureModerate, name: "Moderate"},
|
||||
{level: MemoryPressureHigh, name: "High"},
|
||||
{level: MemoryPressureCritical, name: "Critical"},
|
||||
{level: MemoryPressureLevel(999), name: "Unknown"},
|
||||
}
|
||||
|
||||
for _, p := range pressures {
|
||||
assert.Equal(t, p.name, p.level.String(), "pressure level %d should return %s", p.level, p.name)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetCurrentStats collects statistics", func(t *testing.T) {
|
||||
thresholds := DefaultMemoryAlertThresholds()
|
||||
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
|
||||
|
||||
stats := monitor.GetCurrentStats()
|
||||
assert.NotNil(t, stats)
|
||||
assert.Greater(t, stats.HeapAllocBytes, uint64(0))
|
||||
assert.Greater(t, stats.NumGoroutines, 0)
|
||||
assert.NotZero(t, stats.Timestamp)
|
||||
})
|
||||
}
|
||||
|
||||
// TestBackgroundTaskRegistry tests background task registry edge cases
|
||||
func TestBackgroundTaskRegistry(t *testing.T) {
|
||||
t.Run("GetGlobalTaskRegistry returns singleton", func(t *testing.T) {
|
||||
registry1 := GetGlobalTaskRegistry()
|
||||
registry2 := GetGlobalTaskRegistry()
|
||||
|
||||
assert.Equal(t, registry1, registry2, "should return same instance")
|
||||
})
|
||||
|
||||
t.Run("RegisterTask adds task to registry", func(t *testing.T) {
|
||||
ResetGlobalTaskRegistry()
|
||||
registry := GetGlobalTaskRegistry()
|
||||
defer ResetGlobalTaskRegistry()
|
||||
|
||||
taskName := "test-register-task"
|
||||
task := NewBackgroundTask(
|
||||
taskName,
|
||||
100*time.Millisecond,
|
||||
func() {},
|
||||
newNoOpLogger(),
|
||||
)
|
||||
|
||||
err := registry.RegisterTask(taskName, task)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify task was registered
|
||||
_, exists := registry.GetTask(taskName)
|
||||
assert.True(t, exists, "task should be registered")
|
||||
|
||||
// Clean up
|
||||
task.Stop()
|
||||
})
|
||||
|
||||
t.Run("CreateSingletonTask is idempotent", func(t *testing.T) {
|
||||
ResetGlobalTaskRegistry()
|
||||
registry := GetGlobalTaskRegistry()
|
||||
defer ResetGlobalTaskRegistry()
|
||||
|
||||
taskName := "test-singleton-idempotent"
|
||||
callCount := 0
|
||||
var mu sync.Mutex
|
||||
|
||||
taskFunc := func() {
|
||||
mu.Lock()
|
||||
callCount++
|
||||
mu.Unlock()
|
||||
}
|
||||
|
||||
// First creation should succeed
|
||||
task1, err1 := registry.CreateSingletonTask(
|
||||
taskName,
|
||||
100*time.Millisecond,
|
||||
taskFunc,
|
||||
newNoOpLogger(),
|
||||
nil,
|
||||
)
|
||||
|
||||
assert.NoError(t, err1)
|
||||
assert.NotNil(t, task1)
|
||||
|
||||
// Second creation should also succeed (idempotent)
|
||||
// Returns same task without error
|
||||
task2, err2 := registry.CreateSingletonTask(
|
||||
taskName,
|
||||
100*time.Millisecond,
|
||||
taskFunc,
|
||||
newNoOpLogger(),
|
||||
nil,
|
||||
)
|
||||
|
||||
assert.NoError(t, err2, "CreateSingletonTask should be idempotent")
|
||||
assert.NotNil(t, task2)
|
||||
|
||||
// Clean up
|
||||
if task1 != nil {
|
||||
task1.Stop()
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetTaskCount returns active task count", func(t *testing.T) {
|
||||
ResetGlobalTaskRegistry()
|
||||
registry := GetGlobalTaskRegistry()
|
||||
defer ResetGlobalTaskRegistry()
|
||||
|
||||
// Initially should be 0 or small number
|
||||
initialCount := registry.GetTaskCount()
|
||||
|
||||
// Create a task
|
||||
task := NewBackgroundTask(
|
||||
"count-test-task",
|
||||
100*time.Millisecond,
|
||||
func() {},
|
||||
newNoOpLogger(),
|
||||
)
|
||||
|
||||
err := registry.RegisterTask("count-test-task", task)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Count should increase
|
||||
newCount := registry.GetTaskCount()
|
||||
assert.Equal(t, initialCount+1, newCount)
|
||||
|
||||
// Clean up
|
||||
task.Stop()
|
||||
})
|
||||
|
||||
t.Run("StopAllTasks stops all tasks", func(t *testing.T) {
|
||||
ResetGlobalTaskRegistry()
|
||||
registry := GetGlobalTaskRegistry()
|
||||
defer ResetGlobalTaskRegistry()
|
||||
|
||||
// Create multiple tasks
|
||||
for i := 0; i < 3; i++ {
|
||||
taskName := "multi-task-" + string(rune(i+'0'))
|
||||
task := NewBackgroundTask(
|
||||
taskName,
|
||||
100*time.Millisecond,
|
||||
func() {},
|
||||
newNoOpLogger(),
|
||||
)
|
||||
registry.RegisterTask(taskName, task)
|
||||
}
|
||||
|
||||
// Verify tasks were created
|
||||
assert.GreaterOrEqual(t, registry.GetTaskCount(), 3)
|
||||
|
||||
// Stop all tasks
|
||||
registry.StopAllTasks()
|
||||
|
||||
// Verify all tasks are removed
|
||||
taskCount := registry.GetTaskCount()
|
||||
assert.Equal(t, 0, taskCount, "all tasks should be stopped")
|
||||
})
|
||||
|
||||
t.Run("ResetGlobalTaskRegistry clears registry", func(t *testing.T) {
|
||||
ResetGlobalTaskRegistry()
|
||||
registry := GetGlobalTaskRegistry()
|
||||
|
||||
// Create a task
|
||||
task := NewBackgroundTask(
|
||||
"reset-test-task",
|
||||
100*time.Millisecond,
|
||||
func() {},
|
||||
newNoOpLogger(),
|
||||
)
|
||||
registry.RegisterTask("reset-test-task", task)
|
||||
|
||||
// Reset
|
||||
ResetGlobalTaskRegistry()
|
||||
|
||||
// Get new registry
|
||||
newRegistry := GetGlobalTaskRegistry()
|
||||
assert.Equal(t, 0, newRegistry.GetTaskCount(), "new registry should be empty")
|
||||
})
|
||||
}
|
||||
|
||||
// TestBackgroundTaskLifecycle tests background task lifecycle
|
||||
func TestBackgroundTaskLifecycle(t *testing.T) {
|
||||
t.Run("Start begins task execution", func(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping background task test in short mode")
|
||||
}
|
||||
|
||||
ResetGlobalTaskRegistry()
|
||||
defer ResetGlobalTaskRegistry()
|
||||
|
||||
executed := false
|
||||
var mu sync.Mutex
|
||||
|
||||
task := NewBackgroundTask(
|
||||
"lifecycle-test",
|
||||
50*time.Millisecond,
|
||||
func() {
|
||||
mu.Lock()
|
||||
executed = true
|
||||
mu.Unlock()
|
||||
},
|
||||
newNoOpLogger(),
|
||||
)
|
||||
|
||||
// Start task
|
||||
task.Start()
|
||||
|
||||
// Wait for execution
|
||||
time.Sleep(GetTestDuration(100 * time.Millisecond))
|
||||
|
||||
// Stop task
|
||||
task.Stop()
|
||||
|
||||
// Verify it executed
|
||||
mu.Lock()
|
||||
wasExecuted := executed
|
||||
mu.Unlock()
|
||||
|
||||
assert.True(t, wasExecuted, "task should have executed")
|
||||
})
|
||||
|
||||
t.Run("Stop halts task execution", func(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping background task test in short mode")
|
||||
}
|
||||
|
||||
ResetGlobalTaskRegistry()
|
||||
defer ResetGlobalTaskRegistry()
|
||||
|
||||
execCount := 0
|
||||
var mu sync.Mutex
|
||||
|
||||
task := NewBackgroundTask(
|
||||
"stop-test",
|
||||
30*time.Millisecond,
|
||||
func() {
|
||||
mu.Lock()
|
||||
execCount++
|
||||
mu.Unlock()
|
||||
},
|
||||
newNoOpLogger(),
|
||||
)
|
||||
|
||||
// Start task
|
||||
task.Start()
|
||||
|
||||
// Let it run a few times
|
||||
time.Sleep(GetTestDuration(100 * time.Millisecond))
|
||||
|
||||
// Stop task
|
||||
task.Stop()
|
||||
|
||||
// Record count
|
||||
mu.Lock()
|
||||
countAfterStop := execCount
|
||||
mu.Unlock()
|
||||
|
||||
// Wait more
|
||||
time.Sleep(GetTestDuration(100 * time.Millisecond))
|
||||
|
||||
// Count should not increase
|
||||
mu.Lock()
|
||||
finalCount := execCount
|
||||
mu.Unlock()
|
||||
|
||||
assert.Equal(t, countAfterStop, finalCount, "task should not execute after stop")
|
||||
})
|
||||
|
||||
t.Run("Multiple Start calls are safe", func(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping background task test in short mode")
|
||||
}
|
||||
|
||||
ResetGlobalTaskRegistry()
|
||||
defer ResetGlobalTaskRegistry()
|
||||
|
||||
execCount := 0
|
||||
var mu sync.Mutex
|
||||
|
||||
task := NewBackgroundTask(
|
||||
"multi-start-test",
|
||||
100*time.Millisecond,
|
||||
func() {
|
||||
mu.Lock()
|
||||
execCount++
|
||||
mu.Unlock()
|
||||
},
|
||||
newNoOpLogger(),
|
||||
)
|
||||
|
||||
// Multiple starts should be safe
|
||||
task.Start()
|
||||
task.Start()
|
||||
task.Start()
|
||||
|
||||
// Wait a bit
|
||||
time.Sleep(GetTestDuration(50 * time.Millisecond))
|
||||
|
||||
// Stop task
|
||||
task.Stop()
|
||||
|
||||
// Should have executed, but only one goroutine
|
||||
mu.Lock()
|
||||
count := execCount
|
||||
mu.Unlock()
|
||||
|
||||
assert.GreaterOrEqual(t, count, 0, "task should have executed at least once")
|
||||
})
|
||||
|
||||
t.Run("Multiple Stop calls are safe", func(t *testing.T) {
|
||||
ResetGlobalTaskRegistry()
|
||||
defer ResetGlobalTaskRegistry()
|
||||
|
||||
task := NewBackgroundTask(
|
||||
"multi-stop-test",
|
||||
100*time.Millisecond,
|
||||
func() {},
|
||||
newNoOpLogger(),
|
||||
)
|
||||
|
||||
// Start and stop
|
||||
task.Start()
|
||||
time.Sleep(GetTestDuration(20 * time.Millisecond))
|
||||
|
||||
// Multiple stops should be safe
|
||||
assert.NotPanics(t, func() {
|
||||
task.Stop()
|
||||
task.Stop()
|
||||
task.Stop()
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// TestMemoryMonitorIntegration tests memory monitor integration
|
||||
func TestMemoryMonitorIntegration(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping memory monitor integration test in short mode")
|
||||
}
|
||||
|
||||
t.Run("monitoring updates stats", func(t *testing.T) {
|
||||
ResetGlobalMemoryMonitor()
|
||||
ResetGlobalTaskRegistry()
|
||||
defer ResetGlobalMemoryMonitor()
|
||||
defer ResetGlobalTaskRegistry()
|
||||
|
||||
thresholds := DefaultMemoryAlertThresholds()
|
||||
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
|
||||
defer monitor.StopMonitoring()
|
||||
|
||||
// Start monitoring
|
||||
ctx := context.Background()
|
||||
monitor.StartMonitoring(ctx, 50*time.Millisecond)
|
||||
|
||||
// Wait for at least one check
|
||||
time.Sleep(GetTestDuration(150 * time.Millisecond))
|
||||
|
||||
// Get pressure (should be a valid pressure level)
|
||||
pressure := monitor.GetMemoryPressure()
|
||||
assert.Contains(t, []MemoryPressureLevel{
|
||||
MemoryPressureNone,
|
||||
MemoryPressureLow,
|
||||
MemoryPressureModerate,
|
||||
MemoryPressureHigh,
|
||||
MemoryPressureCritical,
|
||||
}, pressure, "pressure should be a valid level")
|
||||
|
||||
// Stop monitoring
|
||||
monitor.StopMonitoring()
|
||||
})
|
||||
|
||||
t.Run("global memory monitor singleton", func(t *testing.T) {
|
||||
ResetGlobalMemoryMonitor()
|
||||
defer ResetGlobalMemoryMonitor()
|
||||
|
||||
monitor1 := GetGlobalMemoryMonitor()
|
||||
monitor2 := GetGlobalMemoryMonitor()
|
||||
|
||||
assert.Equal(t, monitor1, monitor2, "should return same instance")
|
||||
})
|
||||
}
|
||||
|
||||
// TestMemoryStatsCollection tests memory statistics collection
|
||||
func TestMemoryStatsCollection(t *testing.T) {
|
||||
t.Run("GetCurrentStats returns valid data", func(t *testing.T) {
|
||||
thresholds := DefaultMemoryAlertThresholds()
|
||||
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
|
||||
|
||||
stats := monitor.GetCurrentStats()
|
||||
|
||||
assert.NotNil(t, stats)
|
||||
assert.Greater(t, stats.HeapAllocBytes, uint64(0))
|
||||
assert.Greater(t, stats.HeapSysBytes, uint64(0))
|
||||
assert.Greater(t, stats.NumGoroutines, 0)
|
||||
assert.False(t, stats.Timestamp.IsZero())
|
||||
})
|
||||
|
||||
t.Run("Stats include memory pressure", func(t *testing.T) {
|
||||
thresholds := DefaultMemoryAlertThresholds()
|
||||
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
|
||||
|
||||
stats := monitor.GetCurrentStats()
|
||||
|
||||
// Should calculate and include pressure level
|
||||
assert.NotNil(t, stats.MemoryPressure)
|
||||
assert.Contains(t, []MemoryPressureLevel{
|
||||
MemoryPressureNone,
|
||||
MemoryPressureLow,
|
||||
MemoryPressureModerate,
|
||||
MemoryPressureHigh,
|
||||
MemoryPressureCritical,
|
||||
}, stats.MemoryPressure)
|
||||
})
|
||||
|
||||
t.Run("TriggerGC reduces memory", func(t *testing.T) {
|
||||
thresholds := DefaultMemoryAlertThresholds()
|
||||
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
|
||||
|
||||
// Allocate some memory
|
||||
_ = make([]byte, 1024*1024) // 1MB
|
||||
|
||||
// Get stats before GC
|
||||
beforeStats := monitor.GetCurrentStats()
|
||||
|
||||
// Trigger GC
|
||||
monitor.TriggerGC()
|
||||
|
||||
// Get stats after GC
|
||||
afterStats := monitor.GetCurrentStats()
|
||||
|
||||
// After GC should have different stats
|
||||
assert.NotEqual(t, beforeStats.LastGCTime, afterStats.LastGCTime)
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,241 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// =============================================================================
|
||||
// UNIVERSAL CACHE BENCHMARKS
|
||||
// =============================================================================
|
||||
|
||||
func BenchmarkCacheSet(b *testing.B) {
|
||||
cache := NewUniversalCache(createTestCacheConfig())
|
||||
defer cache.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := 0
|
||||
for pb.Next() {
|
||||
cache.Set(fmt.Sprintf("key%d", i), fmt.Sprintf("value%d", i), 1*time.Hour)
|
||||
i++
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkCacheGet(b *testing.B) {
|
||||
cache := NewUniversalCache(createTestCacheConfig())
|
||||
defer cache.Close()
|
||||
|
||||
for i := 0; i < 1000; i++ {
|
||||
cache.Set(fmt.Sprintf("key%d", i), fmt.Sprintf("value%d", i), 1*time.Hour)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := 0
|
||||
for pb.Next() {
|
||||
cache.Get(fmt.Sprintf("key%d", i%1000))
|
||||
i++
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkCacheSetGet(b *testing.B) {
|
||||
cache := NewUniversalCache(createTestCacheConfig())
|
||||
defer cache.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := 0
|
||||
for pb.Next() {
|
||||
key := fmt.Sprintf("key%d", i)
|
||||
cache.Set(key, fmt.Sprintf("value%d", i), 1*time.Hour)
|
||||
cache.Get(key)
|
||||
i++
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkCacheLRUEviction(b *testing.B) {
|
||||
config := createTestCacheConfig()
|
||||
config.MaxSize = 100
|
||||
cache := NewUniversalCache(config)
|
||||
defer cache.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
cache.Set(fmt.Sprintf("key%d", i), fmt.Sprintf("value%d", i), 1*time.Hour)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkCacheConcurrent(b *testing.B) {
|
||||
cache := NewUniversalCache(createTestCacheConfig())
|
||||
defer cache.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := 0
|
||||
for pb.Next() {
|
||||
switch i % 3 {
|
||||
case 0:
|
||||
cache.Set(fmt.Sprintf("key%d", i), fmt.Sprintf("value%d", i), 1*time.Hour)
|
||||
case 1:
|
||||
cache.Get(fmt.Sprintf("key%d", i))
|
||||
case 2:
|
||||
cache.Delete(fmt.Sprintf("key%d", i))
|
||||
}
|
||||
i++
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// CACHE MANAGER BENCHMARKS
|
||||
// =============================================================================
|
||||
|
||||
func BenchmarkCacheInterfaceWrapper_Set(b *testing.B) {
|
||||
t := &testing.T{}
|
||||
cm := getTestCacheManager(t)
|
||||
cache := cm.GetSharedTokenBlacklist()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
cache.Set("benchmark-key", "benchmark-value", time.Hour)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkCacheInterfaceWrapper_Get(b *testing.B) {
|
||||
t := &testing.T{}
|
||||
cm := getTestCacheManager(t)
|
||||
cache := cm.GetSharedTokenBlacklist()
|
||||
|
||||
cache.Set("benchmark-key", "benchmark-value", time.Hour)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
cache.Get("benchmark-key")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkCacheInterfaceWrapper_Delete(b *testing.B) {
|
||||
t := &testing.T{}
|
||||
cm := getTestCacheManager(t)
|
||||
cache := cm.GetSharedTokenBlacklist()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
b.StopTimer()
|
||||
key := fmt.Sprintf("benchmark-key-%d", i)
|
||||
cache.Set(key, "value", time.Hour)
|
||||
b.StartTimer()
|
||||
|
||||
cache.Delete(key)
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// CACHE COMPATIBILITY BENCHMARKS
|
||||
// =============================================================================
|
||||
|
||||
func BenchmarkNewBoundedCache(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
NewBoundedCache(1000)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkNewOptimizedCache(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
NewOptimizedCache()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkLRUStrategy_EstimateSize(b *testing.B) {
|
||||
strategy := NewLRUStrategy(1000)
|
||||
item := "test-item"
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
strategy.EstimateSize(item)
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// SHARDED CACHE BENCHMARKS
|
||||
// =============================================================================
|
||||
|
||||
func BenchmarkShardedCache(b *testing.B) {
|
||||
b.Run("Set", func(b *testing.B) {
|
||||
cache := NewShardedCache(64, 100000)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
cache.Set(fmt.Sprintf("key-%d", i), i, 5*time.Minute)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("Get", func(b *testing.B) {
|
||||
cache := NewShardedCache(64, 100000)
|
||||
for i := 0; i < 10000; i++ {
|
||||
cache.Set(fmt.Sprintf("key-%d", i), i, 5*time.Minute)
|
||||
}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
cache.Get(fmt.Sprintf("key-%d", i%10000))
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("ParallelSetGet", func(b *testing.B) {
|
||||
cache := NewShardedCache(64, 100000)
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := 0
|
||||
for pb.Next() {
|
||||
key := fmt.Sprintf("key-%d", i)
|
||||
cache.Set(key, i, 5*time.Minute)
|
||||
cache.Get(key)
|
||||
i++
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkShardedVsGlobalMutex compares sharded cache with global mutex approach
|
||||
func BenchmarkShardedVsGlobalMutex(b *testing.B) {
|
||||
b.Run("ShardedCache64", func(b *testing.B) {
|
||||
cache := NewShardedCache(64, 100000)
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := 0
|
||||
for pb.Next() {
|
||||
key := fmt.Sprintf("jti-%d", i%10000)
|
||||
if !cache.Exists(key) {
|
||||
cache.Set(key, true, 5*time.Minute)
|
||||
}
|
||||
i++
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
b.Run("GlobalMutexCache", func(b *testing.B) {
|
||||
var mu sync.RWMutex
|
||||
data := make(map[string]bool)
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := 0
|
||||
for pb.Next() {
|
||||
key := fmt.Sprintf("jti-%d", i%10000)
|
||||
|
||||
mu.RLock()
|
||||
_, exists := data[key]
|
||||
mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
mu.Lock()
|
||||
data[key] = true
|
||||
mu.Unlock()
|
||||
}
|
||||
i++
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
+3
-3
@@ -155,9 +155,9 @@ type CacheStrategy interface {
|
||||
|
||||
// CacheEntry for backward compatibility
|
||||
type CacheEntry struct {
|
||||
Key string
|
||||
Value interface{}
|
||||
ExpiresAt time.Time
|
||||
Value interface{}
|
||||
Key string
|
||||
}
|
||||
|
||||
// Cache is an alias for backward compatibility
|
||||
@@ -175,10 +175,10 @@ func NewOptimizedCacheWithConfig(config OptimizedCacheConfig) *CacheInterfaceWra
|
||||
|
||||
// ListNode for backward compatibility
|
||||
type ListNode struct {
|
||||
Key string
|
||||
Value interface{}
|
||||
Next *ListNode
|
||||
Prev *ListNode
|
||||
Key string
|
||||
}
|
||||
|
||||
// NewFixedMetadataCache creates a metadata cache with fixed configuration
|
||||
|
||||
@@ -1,369 +0,0 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestNewBoundedCache tests creation of bounded cache
|
||||
func TestNewBoundedCache(t *testing.T) {
|
||||
maxSize := 500
|
||||
cache := NewBoundedCache(maxSize)
|
||||
|
||||
if cache == nil {
|
||||
t.Fatal("Expected cache to be created, got nil")
|
||||
}
|
||||
|
||||
// Verify we can use basic operations
|
||||
cache.Set("test-key", "test-value", time.Hour)
|
||||
value, found := cache.Get("test-key")
|
||||
if !found {
|
||||
t.Error("Expected key to be found in cache")
|
||||
}
|
||||
if value != "test-value" {
|
||||
t.Errorf("Expected 'test-value', got %v", value)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDefaultUnifiedCacheConfig tests default configuration
|
||||
func TestDefaultUnifiedCacheConfig(t *testing.T) {
|
||||
config := DefaultUnifiedCacheConfig()
|
||||
|
||||
if config.Type != CacheTypeGeneral {
|
||||
t.Errorf("Expected CacheTypeGeneral, got %v", config.Type)
|
||||
}
|
||||
|
||||
if config.MaxSize != 500 {
|
||||
t.Errorf("Expected MaxSize 500, got %d", config.MaxSize)
|
||||
}
|
||||
|
||||
if config.MaxMemoryBytes != 64*1024*1024 {
|
||||
t.Errorf("Expected MaxMemoryBytes 64MB, got %d", config.MaxMemoryBytes)
|
||||
}
|
||||
|
||||
if config.CleanupInterval != 2*time.Minute {
|
||||
t.Errorf("Expected CleanupInterval 2 minutes, got %v", config.CleanupInterval)
|
||||
}
|
||||
|
||||
if config.Logger == nil {
|
||||
t.Error("Expected Logger to be set")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewUnifiedCache tests unified cache creation
|
||||
func TestNewUnifiedCache(t *testing.T) {
|
||||
config := DefaultUnifiedCacheConfig()
|
||||
cache := NewUnifiedCache(config)
|
||||
|
||||
if cache == nil {
|
||||
t.Fatal("Expected cache to be created, got nil")
|
||||
}
|
||||
|
||||
if cache.UniversalCache == nil {
|
||||
t.Error("Expected UniversalCache to be set")
|
||||
}
|
||||
|
||||
// Test basic operations
|
||||
cache.Set("test-key", "test-value", time.Hour)
|
||||
value, found := cache.Get("test-key")
|
||||
if !found {
|
||||
t.Error("Expected key to be found in cache")
|
||||
}
|
||||
if value != "test-value" {
|
||||
t.Errorf("Expected 'test-value', got %v", value)
|
||||
}
|
||||
}
|
||||
|
||||
// TestUnifiedCache_SetMaxSize tests SetMaxSize method
|
||||
func TestUnifiedCache_SetMaxSize(t *testing.T) {
|
||||
config := DefaultUnifiedCacheConfig()
|
||||
cache := NewUnifiedCache(config)
|
||||
|
||||
// Test setting max size
|
||||
newSize := 1000
|
||||
cache.SetMaxSize(newSize)
|
||||
|
||||
// We can't easily verify the size was set without exposing internal fields,
|
||||
// but we can ensure the method doesn't panic
|
||||
}
|
||||
|
||||
// TestNewCacheAdapter tests cache adapter creation
|
||||
func TestNewCacheAdapter(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cache interface{}
|
||||
expectNil bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "UniversalCache",
|
||||
cache: NewUniversalCache(DefaultUnifiedCacheConfig()),
|
||||
expectNil: false,
|
||||
description: "Should create adapter for UniversalCache",
|
||||
},
|
||||
{
|
||||
name: "UnifiedCache",
|
||||
cache: NewUnifiedCache(DefaultUnifiedCacheConfig()),
|
||||
expectNil: false,
|
||||
description: "Should create adapter for UnifiedCache",
|
||||
},
|
||||
{
|
||||
name: "Invalid cache type",
|
||||
cache: "not-a-cache",
|
||||
expectNil: true,
|
||||
description: "Should return nil for invalid cache type",
|
||||
},
|
||||
{
|
||||
name: "Nil cache",
|
||||
cache: nil,
|
||||
expectNil: true,
|
||||
description: "Should return nil for nil cache",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
adapter := NewCacheAdapter(tt.cache)
|
||||
|
||||
if tt.expectNil {
|
||||
if adapter != nil {
|
||||
t.Errorf("Expected nil adapter, got %v", adapter)
|
||||
}
|
||||
} else {
|
||||
if adapter == nil {
|
||||
t.Error("Expected non-nil adapter")
|
||||
}
|
||||
// Test basic operations
|
||||
adapter.Set("test", "value", time.Hour)
|
||||
value, found := adapter.Get("test")
|
||||
if !found {
|
||||
t.Error("Expected key to be found")
|
||||
}
|
||||
if value != "value" {
|
||||
t.Errorf("Expected 'value', got %v", value)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewOptimizedCache tests optimized cache creation
|
||||
func TestNewOptimizedCache(t *testing.T) {
|
||||
cache := NewOptimizedCache()
|
||||
|
||||
if cache == nil {
|
||||
t.Fatal("Expected cache to be created, got nil")
|
||||
}
|
||||
|
||||
// Verify it works with basic operations
|
||||
cache.Set("test-key", "test-value", time.Hour)
|
||||
value, found := cache.Get("test-key")
|
||||
if !found {
|
||||
t.Error("Expected key to be found in cache")
|
||||
}
|
||||
if value != "test-value" {
|
||||
t.Errorf("Expected 'test-value', got %v", value)
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewLRUStrategy tests LRU strategy creation
|
||||
func TestNewLRUStrategy(t *testing.T) {
|
||||
maxSize := 100
|
||||
strategy := NewLRUStrategy(maxSize)
|
||||
|
||||
if strategy == nil {
|
||||
t.Fatal("Expected strategy to be created, got nil")
|
||||
}
|
||||
|
||||
lruStrategy, ok := strategy.(*LRUStrategy)
|
||||
if !ok {
|
||||
t.Fatal("Expected LRUStrategy type")
|
||||
}
|
||||
|
||||
if lruStrategy.maxSize != maxSize {
|
||||
t.Errorf("Expected maxSize %d, got %d", maxSize, lruStrategy.maxSize)
|
||||
}
|
||||
|
||||
if lruStrategy.order == nil {
|
||||
t.Error("Expected order list to be initialized")
|
||||
}
|
||||
|
||||
if lruStrategy.elements == nil {
|
||||
t.Error("Expected elements map to be initialized")
|
||||
}
|
||||
}
|
||||
|
||||
// TestLRUStrategy_Name tests strategy name
|
||||
func TestLRUStrategy_Name(t *testing.T) {
|
||||
strategy := NewLRUStrategy(100)
|
||||
|
||||
name := strategy.Name()
|
||||
if name != "LRU" {
|
||||
t.Errorf("Expected 'LRU', got %s", name)
|
||||
}
|
||||
}
|
||||
|
||||
// TestLRUStrategy_ShouldEvict tests eviction logic
|
||||
func TestLRUStrategy_ShouldEvict(t *testing.T) {
|
||||
strategy := NewLRUStrategy(100)
|
||||
|
||||
// LRU strategy always returns false for ShouldEvict
|
||||
result := strategy.ShouldEvict("test-item", time.Now())
|
||||
if result != false {
|
||||
t.Error("Expected ShouldEvict to return false")
|
||||
}
|
||||
}
|
||||
|
||||
// TestLRUStrategy_OnAccess tests access callback
|
||||
func TestLRUStrategy_OnAccess(t *testing.T) {
|
||||
strategy := NewLRUStrategy(100)
|
||||
|
||||
// OnAccess should not panic
|
||||
strategy.OnAccess("test-key", "test-value")
|
||||
}
|
||||
|
||||
// TestLRUStrategy_OnRemove tests removal callback
|
||||
func TestLRUStrategy_OnRemove(t *testing.T) {
|
||||
strategy := NewLRUStrategy(100)
|
||||
|
||||
// OnRemove should not panic
|
||||
strategy.OnRemove("test-key")
|
||||
}
|
||||
|
||||
// TestLRUStrategy_EstimateSize tests size estimation
|
||||
func TestLRUStrategy_EstimateSize(t *testing.T) {
|
||||
strategy := NewLRUStrategy(100)
|
||||
|
||||
size := strategy.EstimateSize("test-item")
|
||||
if size != 64 {
|
||||
t.Errorf("Expected size 64, got %d", size)
|
||||
}
|
||||
}
|
||||
|
||||
// TestLRUStrategy_GetEvictionCandidate tests eviction candidate retrieval
|
||||
func TestLRUStrategy_GetEvictionCandidate(t *testing.T) {
|
||||
strategy := NewLRUStrategy(100)
|
||||
|
||||
key, found := strategy.GetEvictionCandidate()
|
||||
if found {
|
||||
t.Error("Expected no eviction candidate to be found")
|
||||
}
|
||||
if key != "" {
|
||||
t.Errorf("Expected empty key, got %s", key)
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewOptimizedCacheWithConfig tests optimized cache with custom config
|
||||
func TestNewOptimizedCacheWithConfig(t *testing.T) {
|
||||
config := UniversalCacheConfig{
|
||||
Type: CacheTypeGeneral,
|
||||
MaxSize: 1000,
|
||||
MaxMemoryBytes: 128 * 1024 * 1024,
|
||||
EnableMetrics: true,
|
||||
Logger: GetSingletonNoOpLogger(),
|
||||
}
|
||||
|
||||
cache := NewOptimizedCacheWithConfig(config)
|
||||
|
||||
if cache == nil {
|
||||
t.Fatal("Expected cache to be created, got nil")
|
||||
}
|
||||
|
||||
// Verify it works with basic operations
|
||||
cache.Set("test-key", "test-value", time.Hour)
|
||||
value, found := cache.Get("test-key")
|
||||
if !found {
|
||||
t.Error("Expected key to be found in cache")
|
||||
}
|
||||
if value != "test-value" {
|
||||
t.Errorf("Expected 'test-value', got %v", value)
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewFixedMetadataCache tests fixed metadata cache creation
|
||||
func TestNewFixedMetadataCache(t *testing.T) {
|
||||
cache := NewFixedMetadataCache()
|
||||
|
||||
if cache == nil {
|
||||
t.Fatal("Expected cache to be created, got nil")
|
||||
}
|
||||
|
||||
// Verify it works with proper metadata operations
|
||||
metadata := &ProviderMetadata{
|
||||
Issuer: "https://example.com",
|
||||
AuthURL: "https://example.com/auth",
|
||||
TokenURL: "https://example.com/token",
|
||||
JWKSURL: "https://example.com/jwks",
|
||||
}
|
||||
|
||||
err := cache.Set("test-provider", metadata, time.Hour)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error setting metadata: %v", err)
|
||||
}
|
||||
|
||||
// Test that the cache was created (basic verification)
|
||||
// Note: We can't easily test Get without more complex setup
|
||||
}
|
||||
|
||||
// TestNewDoublyLinkedList tests doubly linked list creation
|
||||
func TestNewDoublyLinkedList(t *testing.T) {
|
||||
list := NewDoublyLinkedList()
|
||||
|
||||
if list == nil {
|
||||
t.Fatal("Expected list to be created, got nil")
|
||||
}
|
||||
|
||||
// Test it's a proper list structure
|
||||
if list.Len() != 0 {
|
||||
t.Error("Expected empty list initially")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDoublyLinkedList_PopFront tests front element removal
|
||||
func TestDoublyLinkedList_PopFront(t *testing.T) {
|
||||
list := NewDoublyLinkedList()
|
||||
|
||||
// Test popping from empty list
|
||||
element := list.PopFront()
|
||||
if element != nil {
|
||||
t.Error("Expected nil when popping from empty list")
|
||||
}
|
||||
|
||||
// Add an element and test popping
|
||||
added := list.PushBack("test-value")
|
||||
if added == nil {
|
||||
t.Fatal("Expected element to be added")
|
||||
}
|
||||
|
||||
popped := list.PopFront()
|
||||
if popped == nil {
|
||||
t.Error("Expected element to be popped")
|
||||
}
|
||||
|
||||
if list.Len() != 0 {
|
||||
t.Error("Expected list to be empty after popping")
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark tests for performance
|
||||
func BenchmarkNewBoundedCache(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
NewBoundedCache(1000)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkNewOptimizedCache(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
NewOptimizedCache()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkLRUStrategy_EstimateSize(b *testing.B) {
|
||||
strategy := NewLRUStrategy(1000)
|
||||
item := "test-item"
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
strategy.EstimateSize(item)
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
+57
-7
@@ -21,10 +21,37 @@ var (
|
||||
)
|
||||
|
||||
// GetGlobalCacheManager returns a singleton CacheManager instance
|
||||
// Deprecated: Use GetGlobalCacheManagerWithConfig instead
|
||||
func GetGlobalCacheManager(wg *sync.WaitGroup) *CacheManager {
|
||||
return GetGlobalCacheManagerWithConfig(wg, nil)
|
||||
}
|
||||
|
||||
// GetGlobalCacheManagerWithConfig returns a singleton CacheManager instance with optional Redis configuration
|
||||
func GetGlobalCacheManagerWithConfig(wg *sync.WaitGroup, config *Config) *CacheManager {
|
||||
cacheManagerInitOnce.Do(func() {
|
||||
var redisConfig *RedisConfig
|
||||
var logger *Logger
|
||||
|
||||
if config != nil {
|
||||
logger = NewLogger(config.LogLevel)
|
||||
|
||||
// Initialize Redis config if not present
|
||||
if config.Redis == nil {
|
||||
config.Redis = &RedisConfig{}
|
||||
}
|
||||
|
||||
// Apply environment variable fallbacks for fields not set in config
|
||||
// This allows env vars to be used as optional overrides
|
||||
config.Redis.ApplyEnvFallbacks()
|
||||
|
||||
// Apply defaults after env fallbacks
|
||||
config.Redis.ApplyDefaults()
|
||||
|
||||
redisConfig = config.Redis
|
||||
}
|
||||
|
||||
globalCacheManagerInstance = &CacheManager{
|
||||
manager: GetUniversalCacheManager(nil),
|
||||
manager: GetUniversalCacheManagerWithConfig(logger, redisConfig),
|
||||
}
|
||||
})
|
||||
return globalCacheManagerInstance
|
||||
@@ -34,7 +61,7 @@ func GetGlobalCacheManager(wg *sync.WaitGroup) *CacheManager {
|
||||
func (cm *CacheManager) GetSharedTokenBlacklist() CacheInterface {
|
||||
cm.mu.RLock()
|
||||
defer cm.mu.RUnlock()
|
||||
return &CacheInterfaceWrapper{cache: cm.manager.GetBlacklistCache()}
|
||||
return &CacheInterfaceWrapper{cache: cm.manager.GetBlacklistCache(), managed: true}
|
||||
}
|
||||
|
||||
// GetSharedTokenCache returns the shared token cache
|
||||
@@ -61,6 +88,22 @@ func (cm *CacheManager) GetSharedJWKCache() JWKCacheInterface {
|
||||
return &JWKCache{cache: cm.manager.GetJWKCache()}
|
||||
}
|
||||
|
||||
// GetSharedIntrospectionCache returns the shared token introspection cache
|
||||
// for caching OAuth 2.0 Token Introspection (RFC 7662) results
|
||||
func (cm *CacheManager) GetSharedIntrospectionCache() CacheInterface {
|
||||
cm.mu.RLock()
|
||||
defer cm.mu.RUnlock()
|
||||
return &CacheInterfaceWrapper{cache: cm.manager.GetIntrospectionCache(), managed: true}
|
||||
}
|
||||
|
||||
// GetSharedTokenTypeCache returns the shared token type cache
|
||||
// for caching token type detection results to improve performance
|
||||
func (cm *CacheManager) GetSharedTokenTypeCache() CacheInterface {
|
||||
cm.mu.RLock()
|
||||
defer cm.mu.RUnlock()
|
||||
return &CacheInterfaceWrapper{cache: cm.manager.GetTokenTypeCache(), managed: true}
|
||||
}
|
||||
|
||||
// Close gracefully shuts down all cache components
|
||||
func (cm *CacheManager) Close() error {
|
||||
cm.mu.Lock()
|
||||
@@ -78,12 +121,13 @@ func CleanupGlobalCacheManager() error {
|
||||
|
||||
// CacheInterfaceWrapper wraps UniversalCache to implement CacheInterface
|
||||
type CacheInterfaceWrapper struct {
|
||||
cache *UniversalCache
|
||||
cache *UniversalCache
|
||||
managed bool // If true, cache is managed globally and Close() is a no-op
|
||||
}
|
||||
|
||||
// Set stores a value
|
||||
func (c *CacheInterfaceWrapper) Set(key string, value interface{}, ttl time.Duration) {
|
||||
c.cache.Set(key, value, ttl)
|
||||
_ = c.cache.Set(key, value, ttl) // Safe to ignore: cache set failures are non-critical
|
||||
}
|
||||
|
||||
// Get retrieves a value
|
||||
@@ -106,11 +150,17 @@ func (c *CacheInterfaceWrapper) Cleanup() {
|
||||
c.cache.Cleanup()
|
||||
}
|
||||
|
||||
// Close shuts down the cache
|
||||
// Close shuts down the cache if it's not managed globally.
|
||||
// For managed caches (from UniversalCacheManager), this is a no-op to prevent log flooding
|
||||
// when multiple plugin instances are closed during Traefik configuration reloads.
|
||||
func (c *CacheInterfaceWrapper) Close() {
|
||||
// Close the underlying cache to stop goroutines
|
||||
if c.managed {
|
||||
// Cache is managed globally by UniversalCacheManager, so we don't close it here.
|
||||
return
|
||||
}
|
||||
// Standalone cache - close it properly to stop cleanup goroutines
|
||||
if c.cache != nil {
|
||||
c.cache.Close()
|
||||
_ = c.cache.Close() // Safe to ignore: closing cache is best-effort during shutdown
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,314 +0,0 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Helper function to ensure we have a working cache manager for tests
|
||||
func getTestCacheManager(t *testing.T) *CacheManager {
|
||||
cm := GetGlobalCacheManager(&sync.WaitGroup{})
|
||||
if cm == nil {
|
||||
t.Fatal("Failed to get cache manager")
|
||||
}
|
||||
if cm.manager == nil {
|
||||
t.Fatal("Cache manager has nil internal manager")
|
||||
}
|
||||
return cm
|
||||
}
|
||||
|
||||
// TestCacheManager_Close tests cache manager close functionality
|
||||
func TestCacheManager_Close(t *testing.T) {
|
||||
// Get a fresh cache manager
|
||||
wg := &sync.WaitGroup{}
|
||||
cm := GetGlobalCacheManager(wg)
|
||||
|
||||
if cm == nil {
|
||||
t.Fatal("Expected cache manager to be created")
|
||||
}
|
||||
|
||||
// Test closing the cache manager
|
||||
err := cm.Close()
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error closing cache manager: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCleanupGlobalCacheManager tests global cleanup
|
||||
func TestCleanupGlobalCacheManager(t *testing.T) {
|
||||
// Test cleanup when no instance exists (should not error)
|
||||
originalInstance := globalCacheManagerInstance
|
||||
globalCacheManagerInstance = nil
|
||||
err := CleanupGlobalCacheManager()
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error during cleanup of nil instance: %v", err)
|
||||
}
|
||||
|
||||
// Restore original instance
|
||||
globalCacheManagerInstance = originalInstance
|
||||
}
|
||||
|
||||
// TestCacheInterfaceWrapper_Delete tests delete functionality
|
||||
func TestCacheInterfaceWrapper_Delete(t *testing.T) {
|
||||
cm := getTestCacheManager(t)
|
||||
cache := cm.GetSharedTokenBlacklist()
|
||||
|
||||
// Add an item
|
||||
cache.Set("test-key", "test-value", time.Hour)
|
||||
|
||||
// Verify it exists
|
||||
value, found := cache.Get("test-key")
|
||||
if !found {
|
||||
t.Fatal("Expected key to be found after setting")
|
||||
}
|
||||
if value != "test-value" {
|
||||
t.Errorf("Expected 'test-value', got %v", value)
|
||||
}
|
||||
|
||||
// Delete it
|
||||
cache.Delete("test-key")
|
||||
|
||||
// Verify it's gone
|
||||
_, found = cache.Get("test-key")
|
||||
if found {
|
||||
t.Error("Expected key to be deleted")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCacheInterfaceWrapper_Size tests size functionality
|
||||
func TestCacheInterfaceWrapper_Size(t *testing.T) {
|
||||
cm := getTestCacheManager(t)
|
||||
cache := cm.GetSharedTokenBlacklist()
|
||||
|
||||
// Clear cache first
|
||||
cache.Clear()
|
||||
|
||||
// Check initial size
|
||||
initialSize := cache.Size()
|
||||
if initialSize != 0 {
|
||||
t.Errorf("Expected initial size 0, got %d", initialSize)
|
||||
}
|
||||
|
||||
// Add some items
|
||||
cache.Set("key1", "value1", time.Hour)
|
||||
cache.Set("key2", "value2", time.Hour)
|
||||
|
||||
// Check size increased
|
||||
newSize := cache.Size()
|
||||
if newSize != 2 {
|
||||
t.Errorf("Expected size 2, got %d", newSize)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCacheInterfaceWrapper_Clear tests clear functionality
|
||||
func TestCacheInterfaceWrapper_Clear(t *testing.T) {
|
||||
cm := getTestCacheManager(t)
|
||||
cache := cm.GetSharedTokenBlacklist()
|
||||
|
||||
// Add some items
|
||||
cache.Set("key1", "value1", time.Hour)
|
||||
cache.Set("key2", "value2", time.Hour)
|
||||
|
||||
// Verify items exist
|
||||
size := cache.Size()
|
||||
if size != 2 {
|
||||
t.Errorf("Expected 2 items before clear, got %d", size)
|
||||
}
|
||||
|
||||
// Clear all
|
||||
cache.Clear()
|
||||
|
||||
// Verify cache is empty
|
||||
size = cache.Size()
|
||||
if size != 0 {
|
||||
t.Errorf("Expected 0 items after clear, got %d", size)
|
||||
}
|
||||
|
||||
// Verify specific items are gone
|
||||
_, found := cache.Get("key1")
|
||||
if found {
|
||||
t.Error("Expected key1 to be cleared")
|
||||
}
|
||||
|
||||
_, found = cache.Get("key2")
|
||||
if found {
|
||||
t.Error("Expected key2 to be cleared")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCacheInterfaceWrapper_Close tests wrapper close functionality
|
||||
func TestCacheInterfaceWrapper_Close(t *testing.T) {
|
||||
cm := getTestCacheManager(t)
|
||||
cache := cm.GetSharedTokenBlacklist()
|
||||
|
||||
// Test close - should not panic
|
||||
wrapper, ok := cache.(*CacheInterfaceWrapper)
|
||||
if !ok {
|
||||
t.Fatal("Expected CacheInterfaceWrapper")
|
||||
}
|
||||
|
||||
wrapper.Close() // Should not panic
|
||||
|
||||
// Test close with nil cache
|
||||
nilWrapper := &CacheInterfaceWrapper{cache: nil}
|
||||
nilWrapper.Close() // Should not panic
|
||||
}
|
||||
|
||||
// TestCacheInterfaceWrapper_GetStats tests stats functionality
|
||||
func TestCacheInterfaceWrapper_GetStats(t *testing.T) {
|
||||
cm := getTestCacheManager(t)
|
||||
cache := cm.GetSharedTokenBlacklist()
|
||||
|
||||
wrapper, ok := cache.(*CacheInterfaceWrapper)
|
||||
if !ok {
|
||||
t.Fatal("Expected CacheInterfaceWrapper")
|
||||
}
|
||||
|
||||
// Get stats
|
||||
stats := wrapper.GetStats()
|
||||
if stats == nil {
|
||||
t.Error("Expected non-nil stats")
|
||||
}
|
||||
|
||||
// Stats should be accessible (len() never returns negative values)
|
||||
// Just verify it's accessible by checking it's not nil (already done above)
|
||||
}
|
||||
|
||||
// TestCacheInterfaceWrapper_Cleanup tests cleanup functionality
|
||||
func TestCacheInterfaceWrapper_Cleanup(t *testing.T) {
|
||||
cm := getTestCacheManager(t)
|
||||
cache := cm.GetSharedTokenBlacklist()
|
||||
|
||||
// Add an item that will expire quickly
|
||||
cache.Set("expire-key", "expire-value", time.Millisecond)
|
||||
|
||||
// Wait for expiration
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Trigger cleanup
|
||||
cache.Cleanup()
|
||||
|
||||
// Item should be cleaned up
|
||||
_, found := cache.Get("expire-key")
|
||||
if found {
|
||||
t.Error("Expected expired key to be cleaned up")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCacheInterfaceWrapper_SetMaxSize tests max size setting
|
||||
func TestCacheInterfaceWrapper_SetMaxSize(t *testing.T) {
|
||||
cm := getTestCacheManager(t)
|
||||
cache := cm.GetSharedTokenBlacklist()
|
||||
|
||||
// Test setting max size (should not panic)
|
||||
cache.SetMaxSize(1000)
|
||||
|
||||
// We can't easily verify the size was set without exposing internals,
|
||||
// but we can ensure the method doesn't panic
|
||||
}
|
||||
|
||||
// TestGetSharedCaches tests getting shared cache instances
|
||||
func TestGetSharedCaches(t *testing.T) {
|
||||
cm := getTestCacheManager(t)
|
||||
|
||||
// Test getting shared token blacklist
|
||||
blacklist := cm.GetSharedTokenBlacklist()
|
||||
if blacklist == nil {
|
||||
t.Error("Expected non-nil token blacklist")
|
||||
}
|
||||
|
||||
// Test getting shared token cache
|
||||
tokenCache := cm.GetSharedTokenCache()
|
||||
if tokenCache == nil {
|
||||
t.Error("Expected non-nil token cache")
|
||||
}
|
||||
|
||||
// Test getting shared metadata cache
|
||||
metadataCache := cm.GetSharedMetadataCache()
|
||||
if metadataCache == nil {
|
||||
t.Error("Expected non-nil metadata cache")
|
||||
}
|
||||
|
||||
// Test getting shared JWK cache
|
||||
jwkCache := cm.GetSharedJWKCache()
|
||||
if jwkCache == nil {
|
||||
t.Error("Expected non-nil JWK cache")
|
||||
}
|
||||
}
|
||||
|
||||
// TestConcurrentCacheAccess tests thread safety
|
||||
func TestConcurrentCacheAccess(t *testing.T) {
|
||||
cm := getTestCacheManager(t)
|
||||
cache := cm.GetSharedTokenBlacklist()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
goroutines := 10
|
||||
iterations := 10
|
||||
|
||||
// Concurrent operations
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < iterations; j++ {
|
||||
key := fmt.Sprintf("key-%d-%d", id, j)
|
||||
value := fmt.Sprintf("value-%d-%d", id, j)
|
||||
|
||||
cache.Set(key, value, time.Hour)
|
||||
|
||||
retrieved, found := cache.Get(key)
|
||||
if found && retrieved != value {
|
||||
t.Errorf("Concurrent access failed: expected %s, got %v", value, retrieved)
|
||||
}
|
||||
|
||||
cache.Delete(key)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// Benchmark tests for performance
|
||||
func BenchmarkCacheInterfaceWrapper_Set(b *testing.B) {
|
||||
t := &testing.T{}
|
||||
cm := getTestCacheManager(t)
|
||||
cache := cm.GetSharedTokenBlacklist()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
cache.Set("benchmark-key", "benchmark-value", time.Hour)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkCacheInterfaceWrapper_Get(b *testing.B) {
|
||||
t := &testing.T{}
|
||||
cm := getTestCacheManager(t)
|
||||
cache := cm.GetSharedTokenBlacklist()
|
||||
|
||||
// Pre-populate cache
|
||||
cache.Set("benchmark-key", "benchmark-value", time.Hour)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
cache.Get("benchmark-key")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkCacheInterfaceWrapper_Delete(b *testing.B) {
|
||||
t := &testing.T{}
|
||||
cm := getTestCacheManager(t)
|
||||
cache := cm.GetSharedTokenBlacklist()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
b.StopTimer()
|
||||
key := fmt.Sprintf("benchmark-key-%d", i)
|
||||
cache.Set(key, "value", time.Hour)
|
||||
b.StartTimer()
|
||||
|
||||
cache.Delete(key)
|
||||
}
|
||||
}
|
||||
+1854
File diff suppressed because it is too large
Load Diff
@@ -1,319 +0,0 @@
|
||||
// Package circuit_breaker provides circuit breaker implementation for resilience
|
||||
package circuit_breaker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// CircuitBreakerState represents the current state of a circuit breaker.
|
||||
// The circuit breaker pattern prevents cascading failures by monitoring
|
||||
// error rates and temporarily blocking requests to failing services.
|
||||
type CircuitBreakerState int
|
||||
|
||||
// Circuit breaker states following the standard pattern:
|
||||
// Closed: Normal operation, requests flow through
|
||||
// Open: Circuit is tripped, requests are blocked
|
||||
// HalfOpen: Testing state, limited requests allowed to test recovery
|
||||
const (
|
||||
// CircuitBreakerClosed allows all requests through (normal operation)
|
||||
CircuitBreakerClosed CircuitBreakerState = iota
|
||||
// CircuitBreakerOpen blocks all requests (service is failing)
|
||||
CircuitBreakerOpen
|
||||
// CircuitBreakerHalfOpen allows limited requests to test service recovery
|
||||
CircuitBreakerHalfOpen
|
||||
)
|
||||
|
||||
// String returns a string representation of the circuit breaker state
|
||||
func (s CircuitBreakerState) String() string {
|
||||
switch s {
|
||||
case CircuitBreakerClosed:
|
||||
return "closed"
|
||||
case CircuitBreakerOpen:
|
||||
return "open"
|
||||
case CircuitBreakerHalfOpen:
|
||||
return "half-open"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// Logger interface for dependency injection
|
||||
type Logger interface {
|
||||
Infof(format string, args ...interface{})
|
||||
Errorf(format string, args ...interface{})
|
||||
Debugf(format string, args ...interface{})
|
||||
}
|
||||
|
||||
// BaseRecoveryMechanism interface for common functionality
|
||||
type BaseRecoveryMechanism interface {
|
||||
RecordRequest()
|
||||
RecordSuccess()
|
||||
RecordFailure()
|
||||
GetBaseMetrics() map[string]interface{}
|
||||
LogInfo(format string, args ...interface{})
|
||||
LogError(format string, args ...interface{})
|
||||
LogDebug(format string, args ...interface{})
|
||||
}
|
||||
|
||||
// CircuitBreaker implements the circuit breaker pattern for external service calls.
|
||||
// It monitors failure rates and automatically opens the circuit when failures
|
||||
// exceed the threshold, preventing further requests until the service recovers.
|
||||
type CircuitBreaker struct {
|
||||
// baseRecovery provides common functionality
|
||||
baseRecovery BaseRecoveryMechanism
|
||||
// maxFailures is the threshold for opening the circuit
|
||||
maxFailures int
|
||||
// timeout is how long to wait before allowing requests in half-open state
|
||||
timeout time.Duration
|
||||
// resetTimeout is how long to wait before transitioning from open to half-open
|
||||
resetTimeout time.Duration
|
||||
// state tracks the current circuit breaker state
|
||||
state CircuitBreakerState
|
||||
// failures counts consecutive failures
|
||||
failures int64
|
||||
// lastFailureTime records when the last failure occurred
|
||||
lastFailureTime time.Time
|
||||
// mutex protects shared state
|
||||
mutex sync.RWMutex
|
||||
// logger for debugging and monitoring
|
||||
logger Logger
|
||||
}
|
||||
|
||||
// CircuitBreakerConfig holds configuration parameters for circuit breakers.
|
||||
// These settings control when the circuit opens and how it recovers.
|
||||
type CircuitBreakerConfig struct {
|
||||
// MaxFailures is the number of failures before opening the circuit
|
||||
MaxFailures int `json:"max_failures"`
|
||||
// Timeout is how long to wait before trying to recover (open -> half-open)
|
||||
Timeout time.Duration `json:"timeout"`
|
||||
// ResetTimeout is how long to wait before fully closing the circuit
|
||||
ResetTimeout time.Duration `json:"reset_timeout"`
|
||||
}
|
||||
|
||||
// DefaultCircuitBreakerConfig returns sensible default configuration for circuit breakers.
|
||||
// Configured for typical web service scenarios with moderate tolerance for failures.
|
||||
func DefaultCircuitBreakerConfig() CircuitBreakerConfig {
|
||||
return CircuitBreakerConfig{
|
||||
MaxFailures: 2,
|
||||
Timeout: 60 * time.Second,
|
||||
ResetTimeout: 30 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// NewCircuitBreaker creates a new circuit breaker with the specified configuration.
|
||||
// The circuit breaker starts in the closed state, allowing all requests through.
|
||||
func NewCircuitBreaker(config CircuitBreakerConfig, logger Logger, baseRecovery BaseRecoveryMechanism) *CircuitBreaker {
|
||||
return &CircuitBreaker{
|
||||
baseRecovery: baseRecovery,
|
||||
maxFailures: config.MaxFailures,
|
||||
timeout: config.Timeout,
|
||||
resetTimeout: config.ResetTimeout,
|
||||
state: CircuitBreakerClosed,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// ExecuteWithContext executes a function through the circuit breaker with context.
|
||||
// It checks if requests are allowed, executes the function, and updates the circuit state
|
||||
// based on the result. Implements the ErrorRecoveryMechanism interface.
|
||||
func (cb *CircuitBreaker) ExecuteWithContext(ctx context.Context, fn func() error) error {
|
||||
if cb.baseRecovery != nil {
|
||||
cb.baseRecovery.RecordRequest()
|
||||
}
|
||||
|
||||
if !cb.allowRequest() {
|
||||
return fmt.Errorf("circuit breaker is open")
|
||||
}
|
||||
|
||||
err := fn()
|
||||
if err != nil {
|
||||
cb.recordFailure()
|
||||
if cb.baseRecovery != nil {
|
||||
cb.baseRecovery.RecordFailure()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
cb.recordSuccess()
|
||||
if cb.baseRecovery != nil {
|
||||
cb.baseRecovery.RecordSuccess()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Execute executes a function through the circuit breaker without context.
|
||||
// This is provided for backward compatibility with existing code.
|
||||
func (cb *CircuitBreaker) Execute(fn func() error) error {
|
||||
return cb.ExecuteWithContext(context.Background(), fn)
|
||||
}
|
||||
|
||||
// allowRequest determines whether to allow a request based on the circuit state.
|
||||
// Handles state transitions from open to half-open based on timeout.
|
||||
func (cb *CircuitBreaker) allowRequest() bool {
|
||||
cb.mutex.Lock()
|
||||
defer cb.mutex.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
|
||||
switch cb.state {
|
||||
case CircuitBreakerClosed:
|
||||
return true
|
||||
|
||||
case CircuitBreakerOpen:
|
||||
if now.Sub(cb.lastFailureTime) > cb.timeout {
|
||||
cb.state = CircuitBreakerHalfOpen
|
||||
if cb.logger != nil {
|
||||
cb.logger.Infof("Circuit breaker transitioning to half-open state")
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
|
||||
case CircuitBreakerHalfOpen:
|
||||
return true
|
||||
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// recordFailure records a failure and potentially opens the circuit.
|
||||
// Updates failure count and triggers state transitions when thresholds are exceeded.
|
||||
func (cb *CircuitBreaker) recordFailure() {
|
||||
cb.mutex.Lock()
|
||||
defer cb.mutex.Unlock()
|
||||
|
||||
cb.failures++
|
||||
cb.lastFailureTime = time.Now()
|
||||
|
||||
switch cb.state {
|
||||
case CircuitBreakerClosed:
|
||||
if cb.failures >= int64(cb.maxFailures) {
|
||||
cb.state = CircuitBreakerOpen
|
||||
if cb.baseRecovery != nil {
|
||||
cb.baseRecovery.LogError("Circuit breaker opened after %d failures", cb.failures)
|
||||
}
|
||||
}
|
||||
|
||||
case CircuitBreakerHalfOpen:
|
||||
cb.state = CircuitBreakerOpen
|
||||
if cb.baseRecovery != nil {
|
||||
cb.baseRecovery.LogError("Circuit breaker returned to open state after failure in half-open")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// recordSuccess records a successful request and potentially closes the circuit.
|
||||
// Resets failure count and transitions from half-open to closed state on success.
|
||||
func (cb *CircuitBreaker) recordSuccess() {
|
||||
cb.mutex.Lock()
|
||||
defer cb.mutex.Unlock()
|
||||
|
||||
switch cb.state {
|
||||
case CircuitBreakerHalfOpen:
|
||||
cb.failures = 0
|
||||
cb.state = CircuitBreakerClosed
|
||||
if cb.baseRecovery != nil {
|
||||
cb.baseRecovery.LogInfo("Circuit breaker closed after successful request in half-open state")
|
||||
}
|
||||
|
||||
case CircuitBreakerClosed:
|
||||
cb.failures = 0
|
||||
}
|
||||
}
|
||||
|
||||
// GetState returns the current state of the circuit breaker.
|
||||
// Thread-safe method for monitoring circuit breaker status.
|
||||
func (cb *CircuitBreaker) GetState() CircuitBreakerState {
|
||||
cb.mutex.RLock()
|
||||
defer cb.mutex.RUnlock()
|
||||
return cb.state
|
||||
}
|
||||
|
||||
// Reset resets the circuit breaker to its initial closed state.
|
||||
// Clears failure count and state, effectively recovering from any open state.
|
||||
func (cb *CircuitBreaker) Reset() {
|
||||
cb.mutex.Lock()
|
||||
defer cb.mutex.Unlock()
|
||||
|
||||
cb.state = CircuitBreakerClosed
|
||||
atomic.StoreInt64(&cb.failures, 0)
|
||||
if cb.baseRecovery != nil {
|
||||
cb.baseRecovery.LogInfo("Circuit breaker has been reset")
|
||||
}
|
||||
}
|
||||
|
||||
// IsAvailable returns whether the circuit breaker is currently allowing requests.
|
||||
// This provides a quick way to check if the service is available.
|
||||
func (cb *CircuitBreaker) IsAvailable() bool {
|
||||
return cb.allowRequest()
|
||||
}
|
||||
|
||||
// GetMetrics returns comprehensive metrics about the circuit breaker.
|
||||
// Includes state information, failure counts, configuration, and base metrics.
|
||||
func (cb *CircuitBreaker) GetMetrics() map[string]interface{} {
|
||||
cb.mutex.RLock()
|
||||
state := cb.state
|
||||
failures := cb.failures
|
||||
lastFailureTime := cb.lastFailureTime
|
||||
cb.mutex.RUnlock()
|
||||
|
||||
var metrics map[string]interface{}
|
||||
if cb.baseRecovery != nil {
|
||||
metrics = cb.baseRecovery.GetBaseMetrics()
|
||||
} else {
|
||||
metrics = make(map[string]interface{})
|
||||
}
|
||||
|
||||
metrics["state"] = state.String()
|
||||
metrics["current_failures"] = failures
|
||||
metrics["max_failures"] = cb.maxFailures
|
||||
metrics["timeout"] = cb.timeout.String()
|
||||
metrics["reset_timeout"] = cb.resetTimeout.String()
|
||||
|
||||
if !lastFailureTime.IsZero() {
|
||||
metrics["last_failure_time"] = lastFailureTime
|
||||
metrics["time_since_last_failure"] = time.Since(lastFailureTime).String()
|
||||
}
|
||||
|
||||
return metrics
|
||||
}
|
||||
|
||||
// GetFailureCount returns the current failure count
|
||||
func (cb *CircuitBreaker) GetFailureCount() int64 {
|
||||
cb.mutex.RLock()
|
||||
defer cb.mutex.RUnlock()
|
||||
return cb.failures
|
||||
}
|
||||
|
||||
// GetLastFailureTime returns the time of the last failure
|
||||
func (cb *CircuitBreaker) GetLastFailureTime() time.Time {
|
||||
cb.mutex.RLock()
|
||||
defer cb.mutex.RUnlock()
|
||||
return cb.lastFailureTime
|
||||
}
|
||||
|
||||
// IsOpen returns true if the circuit breaker is in open state
|
||||
func (cb *CircuitBreaker) IsOpen() bool {
|
||||
cb.mutex.RLock()
|
||||
defer cb.mutex.RUnlock()
|
||||
return cb.state == CircuitBreakerOpen
|
||||
}
|
||||
|
||||
// IsClosed returns true if the circuit breaker is in closed state
|
||||
func (cb *CircuitBreaker) IsClosed() bool {
|
||||
cb.mutex.RLock()
|
||||
defer cb.mutex.RUnlock()
|
||||
return cb.state == CircuitBreakerClosed
|
||||
}
|
||||
|
||||
// IsHalfOpen returns true if the circuit breaker is in half-open state
|
||||
func (cb *CircuitBreaker) IsHalfOpen() bool {
|
||||
cb.mutex.RLock()
|
||||
defer cb.mutex.RUnlock()
|
||||
return cb.state == CircuitBreakerHalfOpen
|
||||
}
|
||||
@@ -1,981 +0,0 @@
|
||||
package circuit_breaker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Mock implementations for testing
|
||||
type mockLogger struct {
|
||||
infoLogs []string
|
||||
errorLogs []string
|
||||
debugLogs []string
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func (m *mockLogger) Infof(format string, args ...interface{}) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.infoLogs = append(m.infoLogs, fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
func (m *mockLogger) Errorf(format string, args ...interface{}) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.errorLogs = append(m.errorLogs, fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
func (m *mockLogger) Debugf(format string, args ...interface{}) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.debugLogs = append(m.debugLogs, fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
func (m *mockLogger) getInfoLogs() []string {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
result := make([]string, len(m.infoLogs))
|
||||
copy(result, m.infoLogs)
|
||||
return result
|
||||
}
|
||||
|
||||
//lint:ignore U1000 May be needed for future error log verification tests
|
||||
func (m *mockLogger) getErrorLogs() []string {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
result := make([]string, len(m.errorLogs))
|
||||
copy(result, m.errorLogs)
|
||||
return result
|
||||
}
|
||||
|
||||
//lint:ignore U1000 May be needed for future test isolation
|
||||
func (m *mockLogger) reset() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.infoLogs = nil
|
||||
m.errorLogs = nil
|
||||
m.debugLogs = nil
|
||||
}
|
||||
|
||||
type mockBaseRecoveryMechanism struct {
|
||||
requestCount int64
|
||||
successCount int64
|
||||
failureCount int64
|
||||
infoLogs []string
|
||||
errorLogs []string
|
||||
debugLogs []string
|
||||
baseMetrics map[string]interface{}
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func newMockBaseRecovery() *mockBaseRecoveryMechanism {
|
||||
return &mockBaseRecoveryMechanism{
|
||||
baseMetrics: make(map[string]interface{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockBaseRecoveryMechanism) RecordRequest() {
|
||||
atomic.AddInt64(&m.requestCount, 1)
|
||||
}
|
||||
|
||||
func (m *mockBaseRecoveryMechanism) RecordSuccess() {
|
||||
atomic.AddInt64(&m.successCount, 1)
|
||||
}
|
||||
|
||||
func (m *mockBaseRecoveryMechanism) RecordFailure() {
|
||||
atomic.AddInt64(&m.failureCount, 1)
|
||||
}
|
||||
|
||||
func (m *mockBaseRecoveryMechanism) GetBaseMetrics() map[string]interface{} {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
result := make(map[string]interface{})
|
||||
for k, v := range m.baseMetrics {
|
||||
result[k] = v
|
||||
}
|
||||
result["total_requests"] = atomic.LoadInt64(&m.requestCount)
|
||||
result["total_successes"] = atomic.LoadInt64(&m.successCount)
|
||||
result["total_failures"] = atomic.LoadInt64(&m.failureCount)
|
||||
return result
|
||||
}
|
||||
|
||||
func (m *mockBaseRecoveryMechanism) LogInfo(format string, args ...interface{}) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.infoLogs = append(m.infoLogs, fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
func (m *mockBaseRecoveryMechanism) LogError(format string, args ...interface{}) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.errorLogs = append(m.errorLogs, fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
func (m *mockBaseRecoveryMechanism) LogDebug(format string, args ...interface{}) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.debugLogs = append(m.debugLogs, fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
func (m *mockBaseRecoveryMechanism) getRequestCount() int64 {
|
||||
return atomic.LoadInt64(&m.requestCount)
|
||||
}
|
||||
|
||||
func (m *mockBaseRecoveryMechanism) getSuccessCount() int64 {
|
||||
return atomic.LoadInt64(&m.successCount)
|
||||
}
|
||||
|
||||
func (m *mockBaseRecoveryMechanism) getFailureCount() int64 {
|
||||
return atomic.LoadInt64(&m.failureCount)
|
||||
}
|
||||
|
||||
func (m *mockBaseRecoveryMechanism) getInfoLogs() []string {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
result := make([]string, len(m.infoLogs))
|
||||
copy(result, m.infoLogs)
|
||||
return result
|
||||
}
|
||||
|
||||
func (m *mockBaseRecoveryMechanism) getErrorLogs() []string {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
result := make([]string, len(m.errorLogs))
|
||||
copy(result, m.errorLogs)
|
||||
return result
|
||||
}
|
||||
|
||||
func TestCircuitBreakerState_String(t *testing.T) {
|
||||
tests := []struct {
|
||||
state CircuitBreakerState
|
||||
expected string
|
||||
}{
|
||||
{CircuitBreakerClosed, "closed"},
|
||||
{CircuitBreakerOpen, "open"},
|
||||
{CircuitBreakerHalfOpen, "half-open"},
|
||||
{CircuitBreakerState(999), "unknown"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.expected, func(t *testing.T) {
|
||||
result := tt.state.String()
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected %s, got %s", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultCircuitBreakerConfig(t *testing.T) {
|
||||
config := DefaultCircuitBreakerConfig()
|
||||
|
||||
if config.MaxFailures != 2 {
|
||||
t.Errorf("Expected MaxFailures to be 2, got %d", config.MaxFailures)
|
||||
}
|
||||
|
||||
if config.Timeout != 60*time.Second {
|
||||
t.Errorf("Expected Timeout to be 60s, got %v", config.Timeout)
|
||||
}
|
||||
|
||||
if config.ResetTimeout != 30*time.Second {
|
||||
t.Errorf("Expected ResetTimeout to be 30s, got %v", config.ResetTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewCircuitBreaker(t *testing.T) {
|
||||
config := CircuitBreakerConfig{
|
||||
MaxFailures: 3,
|
||||
Timeout: 30 * time.Second,
|
||||
ResetTimeout: 15 * time.Second,
|
||||
}
|
||||
logger := &mockLogger{}
|
||||
baseRecovery := newMockBaseRecovery()
|
||||
|
||||
cb := NewCircuitBreaker(config, logger, baseRecovery)
|
||||
|
||||
if cb == nil {
|
||||
t.Fatal("NewCircuitBreaker returned nil")
|
||||
}
|
||||
|
||||
if cb.maxFailures != 3 {
|
||||
t.Errorf("Expected maxFailures to be 3, got %d", cb.maxFailures)
|
||||
}
|
||||
|
||||
if cb.timeout != 30*time.Second {
|
||||
t.Errorf("Expected timeout to be 30s, got %v", cb.timeout)
|
||||
}
|
||||
|
||||
if cb.resetTimeout != 15*time.Second {
|
||||
t.Errorf("Expected resetTimeout to be 15s, got %v", cb.resetTimeout)
|
||||
}
|
||||
|
||||
if cb.state != CircuitBreakerClosed {
|
||||
t.Errorf("Expected initial state to be Closed, got %v", cb.state)
|
||||
}
|
||||
|
||||
if cb.logger != logger {
|
||||
t.Error("Expected logger to be set")
|
||||
}
|
||||
|
||||
if cb.baseRecovery != baseRecovery {
|
||||
t.Error("Expected baseRecovery to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_ExecuteWithContext_Success(t *testing.T) {
|
||||
config := CircuitBreakerConfig{
|
||||
MaxFailures: 2,
|
||||
Timeout: time.Second,
|
||||
ResetTimeout: time.Second,
|
||||
}
|
||||
logger := &mockLogger{}
|
||||
baseRecovery := newMockBaseRecovery()
|
||||
cb := NewCircuitBreaker(config, logger, baseRecovery)
|
||||
|
||||
callCount := 0
|
||||
testFunc := func() error {
|
||||
callCount++
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
err := cb.ExecuteWithContext(ctx, testFunc)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if callCount != 1 {
|
||||
t.Errorf("Expected function to be called once, got %d", callCount)
|
||||
}
|
||||
|
||||
if cb.GetState() != CircuitBreakerClosed {
|
||||
t.Errorf("Expected state to remain Closed, got %v", cb.GetState())
|
||||
}
|
||||
|
||||
if baseRecovery.getRequestCount() != 1 {
|
||||
t.Errorf("Expected 1 request recorded, got %d", baseRecovery.getRequestCount())
|
||||
}
|
||||
|
||||
if baseRecovery.getSuccessCount() != 1 {
|
||||
t.Errorf("Expected 1 success recorded, got %d", baseRecovery.getSuccessCount())
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_ExecuteWithContext_Failure(t *testing.T) {
|
||||
config := CircuitBreakerConfig{
|
||||
MaxFailures: 2,
|
||||
Timeout: time.Second,
|
||||
ResetTimeout: time.Second,
|
||||
}
|
||||
logger := &mockLogger{}
|
||||
baseRecovery := newMockBaseRecovery()
|
||||
cb := NewCircuitBreaker(config, logger, baseRecovery)
|
||||
|
||||
testError := fmt.Errorf("test error")
|
||||
testFunc := func() error {
|
||||
return testError
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
err := cb.ExecuteWithContext(ctx, testFunc)
|
||||
|
||||
if err != testError {
|
||||
t.Errorf("Expected test error, got %v", err)
|
||||
}
|
||||
|
||||
if cb.GetState() != CircuitBreakerClosed {
|
||||
t.Errorf("Expected state to remain Closed after single failure, got %v", cb.GetState())
|
||||
}
|
||||
|
||||
if baseRecovery.getRequestCount() != 1 {
|
||||
t.Errorf("Expected 1 request recorded, got %d", baseRecovery.getRequestCount())
|
||||
}
|
||||
|
||||
if baseRecovery.getFailureCount() != 1 {
|
||||
t.Errorf("Expected 1 failure recorded, got %d", baseRecovery.getFailureCount())
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_Execute(t *testing.T) {
|
||||
config := CircuitBreakerConfig{
|
||||
MaxFailures: 1,
|
||||
Timeout: time.Second,
|
||||
ResetTimeout: time.Second,
|
||||
}
|
||||
logger := &mockLogger{}
|
||||
baseRecovery := newMockBaseRecovery()
|
||||
cb := NewCircuitBreaker(config, logger, baseRecovery)
|
||||
|
||||
callCount := 0
|
||||
testFunc := func() error {
|
||||
callCount++
|
||||
return nil
|
||||
}
|
||||
|
||||
err := cb.Execute(testFunc)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if callCount != 1 {
|
||||
t.Errorf("Expected function to be called once, got %d", callCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_OpenAfterMaxFailures(t *testing.T) {
|
||||
config := CircuitBreakerConfig{
|
||||
MaxFailures: 2,
|
||||
Timeout: time.Second,
|
||||
ResetTimeout: time.Second,
|
||||
}
|
||||
logger := &mockLogger{}
|
||||
baseRecovery := newMockBaseRecovery()
|
||||
cb := NewCircuitBreaker(config, logger, baseRecovery)
|
||||
|
||||
testError := fmt.Errorf("test error")
|
||||
testFunc := func() error {
|
||||
return testError
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// First failure
|
||||
err := cb.ExecuteWithContext(ctx, testFunc)
|
||||
if err != testError {
|
||||
t.Errorf("Expected test error on first failure, got %v", err)
|
||||
}
|
||||
if cb.GetState() != CircuitBreakerClosed {
|
||||
t.Errorf("Expected state to remain Closed after first failure, got %v", cb.GetState())
|
||||
}
|
||||
|
||||
// Second failure - should open circuit
|
||||
err = cb.ExecuteWithContext(ctx, testFunc)
|
||||
if err != testError {
|
||||
t.Errorf("Expected test error on second failure, got %v", err)
|
||||
}
|
||||
if cb.GetState() != CircuitBreakerOpen {
|
||||
t.Errorf("Expected state to be Open after max failures, got %v", cb.GetState())
|
||||
}
|
||||
|
||||
// Third attempt - should be blocked
|
||||
callCount := 0
|
||||
blockedFunc := func() error {
|
||||
callCount++
|
||||
return nil
|
||||
}
|
||||
err = cb.ExecuteWithContext(ctx, blockedFunc)
|
||||
if err == nil {
|
||||
t.Error("Expected error when circuit is open")
|
||||
}
|
||||
if callCount != 0 {
|
||||
t.Errorf("Expected function not to be called when circuit is open, got %d calls", callCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_HalfOpenTransition(t *testing.T) {
|
||||
config := CircuitBreakerConfig{
|
||||
MaxFailures: 1,
|
||||
Timeout: 10 * time.Millisecond, // Very short for testing
|
||||
ResetTimeout: time.Second,
|
||||
}
|
||||
logger := &mockLogger{}
|
||||
baseRecovery := newMockBaseRecovery()
|
||||
cb := NewCircuitBreaker(config, logger, baseRecovery)
|
||||
|
||||
// Trigger circuit opening
|
||||
testError := fmt.Errorf("test error")
|
||||
err := cb.ExecuteWithContext(context.Background(), func() error {
|
||||
return testError
|
||||
})
|
||||
if err != testError {
|
||||
t.Errorf("Expected test error, got %v", err)
|
||||
}
|
||||
if cb.GetState() != CircuitBreakerOpen {
|
||||
t.Errorf("Expected state to be Open, got %v", cb.GetState())
|
||||
}
|
||||
|
||||
// Wait for timeout
|
||||
time.Sleep(15 * time.Millisecond)
|
||||
|
||||
// Next request should transition to half-open
|
||||
callCount := 0
|
||||
testFunc := func() error {
|
||||
callCount++
|
||||
return nil
|
||||
}
|
||||
|
||||
err = cb.ExecuteWithContext(context.Background(), testFunc)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error in half-open state, got %v", err)
|
||||
}
|
||||
if callCount != 1 {
|
||||
t.Errorf("Expected function to be called in half-open state, got %d calls", callCount)
|
||||
}
|
||||
if cb.GetState() != CircuitBreakerClosed {
|
||||
t.Errorf("Expected state to be Closed after successful half-open request, got %v", cb.GetState())
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_HalfOpenFailureReturnsToOpen(t *testing.T) {
|
||||
config := CircuitBreakerConfig{
|
||||
MaxFailures: 1,
|
||||
Timeout: 10 * time.Millisecond,
|
||||
ResetTimeout: time.Second,
|
||||
}
|
||||
logger := &mockLogger{}
|
||||
baseRecovery := newMockBaseRecovery()
|
||||
cb := NewCircuitBreaker(config, logger, baseRecovery)
|
||||
|
||||
// Trigger circuit opening
|
||||
testError := fmt.Errorf("test error")
|
||||
_ = cb.ExecuteWithContext(context.Background(), func() error {
|
||||
return testError
|
||||
})
|
||||
if cb.GetState() != CircuitBreakerOpen {
|
||||
t.Errorf("Expected state to be Open, got %v", cb.GetState())
|
||||
}
|
||||
|
||||
// Wait for timeout to allow half-open transition
|
||||
time.Sleep(15 * time.Millisecond)
|
||||
|
||||
// First call should transition to half-open, but we'll force it by checking allowRequest
|
||||
if !cb.allowRequest() {
|
||||
t.Error("Expected allowRequest to return true after timeout")
|
||||
}
|
||||
if cb.GetState() != CircuitBreakerHalfOpen {
|
||||
t.Errorf("Expected state to be HalfOpen, got %v", cb.GetState())
|
||||
}
|
||||
|
||||
// Failure in half-open should return to open
|
||||
err := cb.ExecuteWithContext(context.Background(), func() error {
|
||||
return testError
|
||||
})
|
||||
if err != testError {
|
||||
t.Errorf("Expected test error, got %v", err)
|
||||
}
|
||||
if cb.GetState() != CircuitBreakerOpen {
|
||||
t.Errorf("Expected state to return to Open after half-open failure, got %v", cb.GetState())
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_Reset(t *testing.T) {
|
||||
config := CircuitBreakerConfig{
|
||||
MaxFailures: 1,
|
||||
Timeout: time.Second,
|
||||
ResetTimeout: time.Second,
|
||||
}
|
||||
logger := &mockLogger{}
|
||||
baseRecovery := newMockBaseRecovery()
|
||||
cb := NewCircuitBreaker(config, logger, baseRecovery)
|
||||
|
||||
// Trigger circuit opening
|
||||
testError := fmt.Errorf("test error")
|
||||
_ = cb.ExecuteWithContext(context.Background(), func() error {
|
||||
return testError
|
||||
})
|
||||
if cb.GetState() != CircuitBreakerOpen {
|
||||
t.Errorf("Expected state to be Open, got %v", cb.GetState())
|
||||
}
|
||||
|
||||
// Reset circuit
|
||||
cb.Reset()
|
||||
|
||||
if cb.GetState() != CircuitBreakerClosed {
|
||||
t.Errorf("Expected state to be Closed after reset, got %v", cb.GetState())
|
||||
}
|
||||
|
||||
if cb.GetFailureCount() != 0 {
|
||||
t.Errorf("Expected failure count to be 0 after reset, got %d", cb.GetFailureCount())
|
||||
}
|
||||
|
||||
// Should allow requests again
|
||||
callCount := 0
|
||||
err := cb.ExecuteWithContext(context.Background(), func() error {
|
||||
callCount++
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error after reset, got %v", err)
|
||||
}
|
||||
if callCount != 1 {
|
||||
t.Errorf("Expected function to be called after reset, got %d calls", callCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_IsAvailable(t *testing.T) {
|
||||
config := CircuitBreakerConfig{
|
||||
MaxFailures: 1,
|
||||
Timeout: 10 * time.Millisecond,
|
||||
ResetTimeout: time.Second,
|
||||
}
|
||||
logger := &mockLogger{}
|
||||
baseRecovery := newMockBaseRecovery()
|
||||
cb := NewCircuitBreaker(config, logger, baseRecovery)
|
||||
|
||||
// Initially available
|
||||
if !cb.IsAvailable() {
|
||||
t.Error("Expected circuit breaker to be available initially")
|
||||
}
|
||||
|
||||
// Trigger opening
|
||||
testError := fmt.Errorf("test error")
|
||||
cb.ExecuteWithContext(context.Background(), func() error {
|
||||
return testError
|
||||
})
|
||||
|
||||
// Should not be available when open
|
||||
if cb.IsAvailable() {
|
||||
t.Error("Expected circuit breaker to be unavailable when open")
|
||||
}
|
||||
|
||||
// Wait for timeout
|
||||
time.Sleep(15 * time.Millisecond)
|
||||
|
||||
// Should be available again after timeout (half-open)
|
||||
if !cb.IsAvailable() {
|
||||
t.Error("Expected circuit breaker to be available after timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_StateCheckers(t *testing.T) {
|
||||
config := CircuitBreakerConfig{
|
||||
MaxFailures: 1,
|
||||
Timeout: 10 * time.Millisecond,
|
||||
ResetTimeout: time.Second,
|
||||
}
|
||||
logger := &mockLogger{}
|
||||
baseRecovery := newMockBaseRecovery()
|
||||
cb := NewCircuitBreaker(config, logger, baseRecovery)
|
||||
|
||||
// Initially closed
|
||||
if !cb.IsClosed() {
|
||||
t.Error("Expected circuit breaker to be closed initially")
|
||||
}
|
||||
if cb.IsOpen() {
|
||||
t.Error("Expected circuit breaker not to be open initially")
|
||||
}
|
||||
if cb.IsHalfOpen() {
|
||||
t.Error("Expected circuit breaker not to be half-open initially")
|
||||
}
|
||||
|
||||
// Trigger opening
|
||||
testError := fmt.Errorf("test error")
|
||||
cb.ExecuteWithContext(context.Background(), func() error {
|
||||
return testError
|
||||
})
|
||||
|
||||
// Should be open
|
||||
if cb.IsClosed() {
|
||||
t.Error("Expected circuit breaker not to be closed when open")
|
||||
}
|
||||
if !cb.IsOpen() {
|
||||
t.Error("Expected circuit breaker to be open")
|
||||
}
|
||||
if cb.IsHalfOpen() {
|
||||
t.Error("Expected circuit breaker not to be half-open when open")
|
||||
}
|
||||
|
||||
// Wait for timeout and trigger half-open
|
||||
time.Sleep(15 * time.Millisecond)
|
||||
cb.allowRequest() // This will transition to half-open
|
||||
|
||||
// Should be half-open
|
||||
if cb.IsClosed() {
|
||||
t.Error("Expected circuit breaker not to be closed when half-open")
|
||||
}
|
||||
if cb.IsOpen() {
|
||||
t.Error("Expected circuit breaker not to be open when half-open")
|
||||
}
|
||||
if !cb.IsHalfOpen() {
|
||||
t.Error("Expected circuit breaker to be half-open")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_GetMetrics(t *testing.T) {
|
||||
config := CircuitBreakerConfig{
|
||||
MaxFailures: 2,
|
||||
Timeout: 30 * time.Second,
|
||||
ResetTimeout: 15 * time.Second,
|
||||
}
|
||||
logger := &mockLogger{}
|
||||
baseRecovery := newMockBaseRecovery()
|
||||
baseRecovery.baseMetrics["custom_metric"] = "custom_value"
|
||||
cb := NewCircuitBreaker(config, logger, baseRecovery)
|
||||
|
||||
// Record some activity
|
||||
testError := fmt.Errorf("test error")
|
||||
cb.ExecuteWithContext(context.Background(), func() error {
|
||||
return testError
|
||||
})
|
||||
|
||||
metrics := cb.GetMetrics()
|
||||
|
||||
// Check circuit breaker specific metrics
|
||||
if metrics["state"] != "closed" {
|
||||
t.Errorf("Expected state to be 'closed', got %v", metrics["state"])
|
||||
}
|
||||
|
||||
if metrics["current_failures"] != int64(1) {
|
||||
t.Errorf("Expected current_failures to be 1, got %v", metrics["current_failures"])
|
||||
}
|
||||
|
||||
if metrics["max_failures"] != 2 {
|
||||
t.Errorf("Expected max_failures to be 2, got %v", metrics["max_failures"])
|
||||
}
|
||||
|
||||
if metrics["timeout"] != "30s" {
|
||||
t.Errorf("Expected timeout to be '30s', got %v", metrics["timeout"])
|
||||
}
|
||||
|
||||
if metrics["reset_timeout"] != "15s" {
|
||||
t.Errorf("Expected reset_timeout to be '15s', got %v", metrics["reset_timeout"])
|
||||
}
|
||||
|
||||
// Check base metrics are included
|
||||
if metrics["total_requests"] != int64(1) {
|
||||
t.Errorf("Expected total_requests to be 1, got %v", metrics["total_requests"])
|
||||
}
|
||||
|
||||
if metrics["custom_metric"] != "custom_value" {
|
||||
t.Errorf("Expected custom_metric to be 'custom_value', got %v", metrics["custom_metric"])
|
||||
}
|
||||
|
||||
// Check failure time metrics
|
||||
if _, exists := metrics["last_failure_time"]; !exists {
|
||||
t.Error("Expected last_failure_time to exist")
|
||||
}
|
||||
|
||||
if _, exists := metrics["time_since_last_failure"]; !exists {
|
||||
t.Error("Expected time_since_last_failure to exist")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_GetMetrics_NoBaseRecovery(t *testing.T) {
|
||||
config := DefaultCircuitBreakerConfig()
|
||||
logger := &mockLogger{}
|
||||
cb := NewCircuitBreaker(config, logger, nil)
|
||||
|
||||
metrics := cb.GetMetrics()
|
||||
|
||||
// Should still have circuit breaker metrics
|
||||
if metrics["state"] != "closed" {
|
||||
t.Errorf("Expected state to be 'closed', got %v", metrics["state"])
|
||||
}
|
||||
|
||||
if metrics["max_failures"] != 2 {
|
||||
t.Errorf("Expected max_failures to be 2, got %v", metrics["max_failures"])
|
||||
}
|
||||
|
||||
// Should not have base metrics
|
||||
if _, exists := metrics["total_requests"]; exists {
|
||||
t.Error("Expected total_requests not to exist without base recovery")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_GetLastFailureTime(t *testing.T) {
|
||||
config := DefaultCircuitBreakerConfig()
|
||||
logger := &mockLogger{}
|
||||
baseRecovery := newMockBaseRecovery()
|
||||
cb := NewCircuitBreaker(config, logger, baseRecovery)
|
||||
|
||||
// Initially should be zero
|
||||
if !cb.GetLastFailureTime().IsZero() {
|
||||
t.Error("Expected last failure time to be zero initially")
|
||||
}
|
||||
|
||||
// Record a failure
|
||||
before := time.Now()
|
||||
testError := fmt.Errorf("test error")
|
||||
cb.ExecuteWithContext(context.Background(), func() error {
|
||||
return testError
|
||||
})
|
||||
after := time.Now()
|
||||
|
||||
lastFailure := cb.GetLastFailureTime()
|
||||
if lastFailure.IsZero() {
|
||||
t.Error("Expected last failure time to be set after failure")
|
||||
}
|
||||
|
||||
if lastFailure.Before(before) || lastFailure.After(after) {
|
||||
t.Errorf("Expected last failure time to be between %v and %v, got %v",
|
||||
before, after, lastFailure)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_ExecuteWithoutBaseRecovery(t *testing.T) {
|
||||
config := DefaultCircuitBreakerConfig()
|
||||
logger := &mockLogger{}
|
||||
cb := NewCircuitBreaker(config, logger, nil)
|
||||
|
||||
callCount := 0
|
||||
testFunc := func() error {
|
||||
callCount++
|
||||
return nil
|
||||
}
|
||||
|
||||
err := cb.ExecuteWithContext(context.Background(), testFunc)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if callCount != 1 {
|
||||
t.Errorf("Expected function to be called once, got %d", callCount)
|
||||
}
|
||||
|
||||
// Should work fine without base recovery
|
||||
if cb.GetState() != CircuitBreakerClosed {
|
||||
t.Errorf("Expected state to be Closed, got %v", cb.GetState())
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_ConcurrentAccess(t *testing.T) {
|
||||
config := CircuitBreakerConfig{
|
||||
MaxFailures: 10, // Higher threshold for concurrent test
|
||||
Timeout: 100 * time.Millisecond,
|
||||
ResetTimeout: 50 * time.Millisecond,
|
||||
}
|
||||
logger := &mockLogger{}
|
||||
baseRecovery := newMockBaseRecovery()
|
||||
cb := NewCircuitBreaker(config, logger, baseRecovery)
|
||||
|
||||
const numGoroutines = 10
|
||||
const numOperations = 50
|
||||
|
||||
var wg sync.WaitGroup
|
||||
successCount := int64(0)
|
||||
errorCount := int64(0)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < numOperations; j++ {
|
||||
err := cb.ExecuteWithContext(context.Background(), func() error {
|
||||
// Simulate some failures
|
||||
if j%10 == 9 { // Every 10th operation fails
|
||||
return fmt.Errorf("simulated error")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
atomic.AddInt64(&errorCount, 1)
|
||||
} else {
|
||||
atomic.AddInt64(&successCount, 1)
|
||||
}
|
||||
|
||||
// Intermittently check state and metrics
|
||||
if j%5 == 0 {
|
||||
cb.GetState()
|
||||
cb.GetMetrics()
|
||||
cb.IsAvailable()
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Verify we got both successes and errors
|
||||
finalSuccessCount := atomic.LoadInt64(&successCount)
|
||||
finalErrorCount := atomic.LoadInt64(&errorCount)
|
||||
|
||||
if finalSuccessCount == 0 {
|
||||
t.Error("Expected some successful operations")
|
||||
}
|
||||
|
||||
if finalErrorCount == 0 {
|
||||
t.Error("Expected some failed operations")
|
||||
}
|
||||
|
||||
totalOperations := finalSuccessCount + finalErrorCount
|
||||
expectedMax := int64(numGoroutines * numOperations)
|
||||
|
||||
if totalOperations > expectedMax {
|
||||
t.Errorf("Expected at most %d operations, got %d", expectedMax, totalOperations)
|
||||
}
|
||||
|
||||
t.Logf("Concurrent test completed: %d successes, %d errors, final state: %v",
|
||||
finalSuccessCount, finalErrorCount, cb.GetState())
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_StateTransitionLogging(t *testing.T) {
|
||||
config := CircuitBreakerConfig{
|
||||
MaxFailures: 1,
|
||||
Timeout: 10 * time.Millisecond,
|
||||
ResetTimeout: time.Second,
|
||||
}
|
||||
logger := &mockLogger{}
|
||||
baseRecovery := newMockBaseRecovery()
|
||||
cb := NewCircuitBreaker(config, logger, baseRecovery)
|
||||
|
||||
// Trigger circuit opening
|
||||
testError := fmt.Errorf("test error")
|
||||
cb.ExecuteWithContext(context.Background(), func() error {
|
||||
return testError
|
||||
})
|
||||
|
||||
// Check that error was logged when circuit opened
|
||||
errorLogs := baseRecovery.getErrorLogs()
|
||||
if len(errorLogs) == 0 {
|
||||
t.Error("Expected error log when circuit breaker opened")
|
||||
} else {
|
||||
if !contains(errorLogs, "Circuit breaker opened after") {
|
||||
t.Errorf("Expected circuit opening log, got %v", errorLogs)
|
||||
}
|
||||
}
|
||||
|
||||
// Wait and trigger half-open
|
||||
time.Sleep(15 * time.Millisecond)
|
||||
|
||||
// Successful request should close circuit and log
|
||||
cb.ExecuteWithContext(context.Background(), func() error {
|
||||
return nil
|
||||
})
|
||||
|
||||
// Check that success was logged when circuit closed
|
||||
infoLogs := baseRecovery.getInfoLogs()
|
||||
if len(infoLogs) == 0 {
|
||||
t.Error("Expected info log when circuit breaker closed")
|
||||
} else {
|
||||
if !contains(infoLogs, "Circuit breaker closed after successful request") {
|
||||
t.Errorf("Expected circuit closing log, got %v", infoLogs)
|
||||
}
|
||||
}
|
||||
|
||||
// Reset should also be logged
|
||||
cb.Reset()
|
||||
infoLogs = baseRecovery.getInfoLogs()
|
||||
if !contains(infoLogs, "Circuit breaker has been reset") {
|
||||
t.Errorf("Expected reset log, got %v", infoLogs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_LoggerTransitionLogging(t *testing.T) {
|
||||
config := CircuitBreakerConfig{
|
||||
MaxFailures: 1,
|
||||
Timeout: 10 * time.Millisecond,
|
||||
ResetTimeout: time.Second,
|
||||
}
|
||||
logger := &mockLogger{}
|
||||
baseRecovery := newMockBaseRecovery()
|
||||
cb := NewCircuitBreaker(config, logger, baseRecovery)
|
||||
|
||||
// Wait for timeout and check half-open transition logging
|
||||
testError := fmt.Errorf("test error")
|
||||
cb.ExecuteWithContext(context.Background(), func() error {
|
||||
return testError
|
||||
})
|
||||
|
||||
// Wait for timeout
|
||||
time.Sleep(15 * time.Millisecond)
|
||||
|
||||
// Next allowRequest call should log transition to half-open
|
||||
cb.allowRequest()
|
||||
|
||||
infoLogs := logger.getInfoLogs()
|
||||
if len(infoLogs) == 0 {
|
||||
t.Error("Expected info log for half-open transition")
|
||||
} else {
|
||||
if !contains(infoLogs, "Circuit breaker transitioning to half-open state") {
|
||||
t.Errorf("Expected half-open transition log, got %v", infoLogs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to check if a slice contains a string with substring
|
||||
func contains(slice []string, substr string) bool {
|
||||
for _, s := range slice {
|
||||
if len(s) >= len(substr) && s[:len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Benchmark tests
|
||||
func BenchmarkCircuitBreaker_ExecuteWithContext_Success(b *testing.B) {
|
||||
config := DefaultCircuitBreakerConfig()
|
||||
logger := &mockLogger{}
|
||||
baseRecovery := newMockBaseRecovery()
|
||||
cb := NewCircuitBreaker(config, logger, baseRecovery)
|
||||
|
||||
testFunc := func() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
cb.ExecuteWithContext(ctx, testFunc)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkCircuitBreaker_ExecuteWithContext_Failure(b *testing.B) {
|
||||
config := CircuitBreakerConfig{
|
||||
MaxFailures: 1000, // High threshold to avoid opening during benchmark
|
||||
Timeout: time.Second,
|
||||
ResetTimeout: time.Second,
|
||||
}
|
||||
logger := &mockLogger{}
|
||||
baseRecovery := newMockBaseRecovery()
|
||||
cb := NewCircuitBreaker(config, logger, baseRecovery)
|
||||
|
||||
testError := fmt.Errorf("test error")
|
||||
testFunc := func() error {
|
||||
return testError
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
cb.ExecuteWithContext(ctx, testFunc)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkCircuitBreaker_GetState(b *testing.B) {
|
||||
config := DefaultCircuitBreakerConfig()
|
||||
logger := &mockLogger{}
|
||||
baseRecovery := newMockBaseRecovery()
|
||||
cb := NewCircuitBreaker(config, logger, baseRecovery)
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
cb.GetState()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkCircuitBreaker_GetMetrics(b *testing.B) {
|
||||
config := DefaultCircuitBreakerConfig()
|
||||
logger := &mockLogger{}
|
||||
baseRecovery := newMockBaseRecovery()
|
||||
cb := NewCircuitBreaker(config, logger, baseRecovery)
|
||||
|
||||
// Add some activity
|
||||
for i := 0; i < 100; i++ {
|
||||
if i%2 == 0 {
|
||||
cb.ExecuteWithContext(context.Background(), func() error { return nil })
|
||||
} else {
|
||||
cb.ExecuteWithContext(context.Background(), func() error { return fmt.Errorf("error") })
|
||||
}
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
cb.GetMetrics()
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,428 +0,0 @@
|
||||
// Package config provides configuration management for the OIDC middleware
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
minEncryptionKeyLength = 16
|
||||
ConstSessionTimeout = 86400
|
||||
)
|
||||
|
||||
//lint:ignore U1000 May be referenced for default exclusion patterns
|
||||
var defaultExcludedURLs = map[string]struct{}{
|
||||
"/favicon.ico": {},
|
||||
"/robots.txt": {},
|
||||
"/health": {},
|
||||
"/.well-known/": {},
|
||||
"/metrics": {},
|
||||
"/ping": {},
|
||||
"/api/": {},
|
||||
"/static/": {},
|
||||
"/assets/": {},
|
||||
"/js/": {},
|
||||
"/css/": {},
|
||||
"/images/": {},
|
||||
"/fonts/": {},
|
||||
}
|
||||
|
||||
// Settings manages configuration and initialization for the OIDC middleware
|
||||
type Settings struct {
|
||||
logger Logger
|
||||
}
|
||||
|
||||
// Logger interface for dependency injection
|
||||
type Logger interface {
|
||||
Debug(msg string)
|
||||
Debugf(format string, args ...interface{})
|
||||
Info(msg string)
|
||||
Infof(format string, args ...interface{})
|
||||
Error(msg string)
|
||||
Errorf(format string, args ...interface{})
|
||||
}
|
||||
|
||||
// Config represents the configuration for the OIDC middleware
|
||||
type Config struct {
|
||||
ProviderURL string `json:"providerUrl"`
|
||||
ClientID string `json:"clientId"`
|
||||
ClientSecret string `json:"clientSecret"`
|
||||
CallbackURL string `json:"callbackUrl"`
|
||||
LogoutURL string `json:"logoutUrl"`
|
||||
PostLogoutRedirectURI string `json:"postLogoutRedirectUri"`
|
||||
SessionEncryptionKey string `json:"sessionEncryptionKey"`
|
||||
ForceHTTPS bool `json:"forceHttps"`
|
||||
LogLevel string `json:"logLevel"`
|
||||
Scopes []string `json:"scopes"`
|
||||
OverrideScopes bool `json:"overrideScopes"`
|
||||
AllowedUsers []string `json:"allowedUsers"`
|
||||
AllowedUserDomains []string `json:"allowedUserDomains"`
|
||||
AllowedRolesAndGroups []string `json:"allowedRolesAndGroups"`
|
||||
ExcludedURLs []string `json:"excludedUrls"`
|
||||
EnablePKCE bool `json:"enablePkce"`
|
||||
RateLimit int `json:"rateLimit"`
|
||||
RefreshGracePeriodSeconds int `json:"refreshGracePeriodSeconds"`
|
||||
Headers []HeaderConfig `json:"headers"`
|
||||
HTTPClient *http.Client `json:"-"`
|
||||
CookieDomain string `json:"cookieDomain"`
|
||||
SecurityHeaders *SecurityHeadersConfig `json:"securityHeaders,omitempty"`
|
||||
}
|
||||
|
||||
// HeaderConfig represents header template configuration
|
||||
type HeaderConfig struct {
|
||||
Name string `json:"name"`
|
||||
Value string `json:"value"`
|
||||
}
|
||||
|
||||
// SecurityHeadersConfig configures security headers for the plugin
|
||||
type SecurityHeadersConfig struct {
|
||||
// Enable security headers (default: true)
|
||||
Enabled bool `json:"enabled"`
|
||||
|
||||
// Security profile: "default", "strict", "development", "api", or "custom"
|
||||
Profile string `json:"profile"`
|
||||
|
||||
// Content Security Policy
|
||||
ContentSecurityPolicy string `json:"contentSecurityPolicy,omitempty"`
|
||||
|
||||
// HSTS settings
|
||||
StrictTransportSecurity bool `json:"strictTransportSecurity"`
|
||||
StrictTransportSecurityMaxAge int `json:"strictTransportSecurityMaxAge"` // seconds
|
||||
StrictTransportSecuritySubdomains bool `json:"strictTransportSecuritySubdomains"`
|
||||
StrictTransportSecurityPreload bool `json:"strictTransportSecurityPreload"`
|
||||
|
||||
// Frame options: "DENY", "SAMEORIGIN", or "ALLOW-FROM uri"
|
||||
FrameOptions string `json:"frameOptions,omitempty"`
|
||||
|
||||
// Content type options (default: "nosniff")
|
||||
ContentTypeOptions string `json:"contentTypeOptions,omitempty"`
|
||||
|
||||
// XSS protection (default: "1; mode=block")
|
||||
XSSProtection string `json:"xssProtection,omitempty"`
|
||||
|
||||
// Referrer policy
|
||||
ReferrerPolicy string `json:"referrerPolicy,omitempty"`
|
||||
|
||||
// Permissions policy
|
||||
PermissionsPolicy string `json:"permissionsPolicy,omitempty"`
|
||||
|
||||
// Cross-origin settings
|
||||
CrossOriginEmbedderPolicy string `json:"crossOriginEmbedderPolicy,omitempty"`
|
||||
CrossOriginOpenerPolicy string `json:"crossOriginOpenerPolicy,omitempty"`
|
||||
CrossOriginResourcePolicy string `json:"crossOriginResourcePolicy,omitempty"`
|
||||
|
||||
// CORS settings
|
||||
CORSEnabled bool `json:"corsEnabled"`
|
||||
CORSAllowedOrigins []string `json:"corsAllowedOrigins,omitempty"`
|
||||
CORSAllowedMethods []string `json:"corsAllowedMethods,omitempty"`
|
||||
CORSAllowedHeaders []string `json:"corsAllowedHeaders,omitempty"`
|
||||
CORSAllowCredentials bool `json:"corsAllowCredentials"`
|
||||
CORSMaxAge int `json:"corsMaxAge"` // seconds
|
||||
|
||||
// Custom headers (in addition to standard security headers)
|
||||
CustomHeaders map[string]string `json:"customHeaders,omitempty"`
|
||||
|
||||
// Security features
|
||||
DisableServerHeader bool `json:"disableServerHeader"`
|
||||
DisablePoweredByHeader bool `json:"disablePoweredByHeader"`
|
||||
}
|
||||
|
||||
// NewSettings creates a new Settings instance
|
||||
func NewSettings(logger Logger) *Settings {
|
||||
return &Settings{
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateConfig creates a default configuration
|
||||
func CreateConfig() *Config {
|
||||
return &Config{
|
||||
LogLevel: "INFO",
|
||||
ForceHTTPS: true,
|
||||
EnablePKCE: true,
|
||||
RateLimit: 10,
|
||||
RefreshGracePeriodSeconds: 60,
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
Headers: []HeaderConfig{},
|
||||
SecurityHeaders: createDefaultSecurityConfig(),
|
||||
}
|
||||
}
|
||||
|
||||
// createDefaultSecurityConfig creates a default security headers configuration
|
||||
func createDefaultSecurityConfig() *SecurityHeadersConfig {
|
||||
return &SecurityHeadersConfig{
|
||||
Enabled: true,
|
||||
Profile: "default",
|
||||
|
||||
// Default security headers
|
||||
StrictTransportSecurity: true,
|
||||
StrictTransportSecurityMaxAge: 31536000, // 1 year
|
||||
StrictTransportSecuritySubdomains: true,
|
||||
StrictTransportSecurityPreload: true,
|
||||
|
||||
FrameOptions: "DENY",
|
||||
ContentTypeOptions: "nosniff",
|
||||
XSSProtection: "1; mode=block",
|
||||
ReferrerPolicy: "strict-origin-when-cross-origin",
|
||||
|
||||
// CORS disabled by default
|
||||
CORSEnabled: false,
|
||||
CORSAllowedMethods: []string{"GET", "POST", "OPTIONS"},
|
||||
CORSAllowedHeaders: []string{"Authorization", "Content-Type"},
|
||||
CORSAllowCredentials: false,
|
||||
CORSMaxAge: 86400, // 24 hours
|
||||
|
||||
// Security features
|
||||
DisableServerHeader: true,
|
||||
DisablePoweredByHeader: true,
|
||||
}
|
||||
}
|
||||
|
||||
// ToInternalSecurityConfig converts plugin SecurityHeadersConfig to internal security config
|
||||
func (c *SecurityHeadersConfig) ToInternalSecurityConfig() interface{} {
|
||||
if c == nil || !c.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create the internal security config structure
|
||||
config := map[string]interface{}{
|
||||
"DevelopmentMode": false,
|
||||
}
|
||||
|
||||
// Apply profile-based defaults
|
||||
switch strings.ToLower(c.Profile) {
|
||||
case "strict":
|
||||
applyStrictProfile(config)
|
||||
case "development":
|
||||
applyDevelopmentProfile(config)
|
||||
case "api":
|
||||
applyAPIProfile(config)
|
||||
case "custom":
|
||||
// No defaults, use only what's explicitly configured
|
||||
default: // "default"
|
||||
applyDefaultProfile(config)
|
||||
}
|
||||
|
||||
// Override with explicit configuration
|
||||
if c.ContentSecurityPolicy != "" {
|
||||
config["ContentSecurityPolicy"] = c.ContentSecurityPolicy
|
||||
}
|
||||
|
||||
// HSTS configuration
|
||||
if c.StrictTransportSecurity {
|
||||
config["StrictTransportSecurityMaxAge"] = c.StrictTransportSecurityMaxAge
|
||||
config["StrictTransportSecuritySubdomains"] = c.StrictTransportSecuritySubdomains
|
||||
config["StrictTransportSecurityPreload"] = c.StrictTransportSecurityPreload
|
||||
}
|
||||
|
||||
// Frame options
|
||||
if c.FrameOptions != "" {
|
||||
config["FrameOptions"] = c.FrameOptions
|
||||
}
|
||||
|
||||
// Content type and XSS protection
|
||||
if c.ContentTypeOptions != "" {
|
||||
config["ContentTypeOptions"] = c.ContentTypeOptions
|
||||
}
|
||||
if c.XSSProtection != "" {
|
||||
config["XSSProtection"] = c.XSSProtection
|
||||
}
|
||||
|
||||
// Referrer and permissions policies
|
||||
if c.ReferrerPolicy != "" {
|
||||
config["ReferrerPolicy"] = c.ReferrerPolicy
|
||||
}
|
||||
if c.PermissionsPolicy != "" {
|
||||
config["PermissionsPolicy"] = c.PermissionsPolicy
|
||||
}
|
||||
|
||||
// Cross-origin policies
|
||||
if c.CrossOriginEmbedderPolicy != "" {
|
||||
config["CrossOriginEmbedderPolicy"] = c.CrossOriginEmbedderPolicy
|
||||
}
|
||||
if c.CrossOriginOpenerPolicy != "" {
|
||||
config["CrossOriginOpenerPolicy"] = c.CrossOriginOpenerPolicy
|
||||
}
|
||||
if c.CrossOriginResourcePolicy != "" {
|
||||
config["CrossOriginResourcePolicy"] = c.CrossOriginResourcePolicy
|
||||
}
|
||||
|
||||
// CORS configuration
|
||||
config["CORSEnabled"] = c.CORSEnabled
|
||||
if len(c.CORSAllowedOrigins) > 0 {
|
||||
config["CORSAllowedOrigins"] = c.CORSAllowedOrigins
|
||||
}
|
||||
if len(c.CORSAllowedMethods) > 0 {
|
||||
config["CORSAllowedMethods"] = c.CORSAllowedMethods
|
||||
}
|
||||
if len(c.CORSAllowedHeaders) > 0 {
|
||||
config["CORSAllowedHeaders"] = c.CORSAllowedHeaders
|
||||
}
|
||||
config["CORSAllowCredentials"] = c.CORSAllowCredentials
|
||||
if c.CORSMaxAge > 0 {
|
||||
config["CORSMaxAge"] = c.CORSMaxAge
|
||||
}
|
||||
|
||||
// Custom headers
|
||||
if len(c.CustomHeaders) > 0 {
|
||||
config["CustomHeaders"] = c.CustomHeaders
|
||||
}
|
||||
|
||||
// Security features
|
||||
config["DisableServerHeader"] = c.DisableServerHeader
|
||||
config["DisablePoweredByHeader"] = c.DisablePoweredByHeader
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
// applyDefaultProfile applies default security settings
|
||||
func applyDefaultProfile(config map[string]interface{}) {
|
||||
config["ContentSecurityPolicy"] = "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'; img-src 'self' data: https:; font-src 'self' data:; connect-src 'self'; frame-ancestors 'none';"
|
||||
config["FrameOptions"] = "DENY"
|
||||
config["ContentTypeOptions"] = "nosniff"
|
||||
config["XSSProtection"] = "1; mode=block"
|
||||
config["ReferrerPolicy"] = "strict-origin-when-cross-origin"
|
||||
config["PermissionsPolicy"] = "geolocation=(), microphone=(), camera=(), payment=(), usb=()"
|
||||
config["CrossOriginEmbedderPolicy"] = "require-corp"
|
||||
config["CrossOriginOpenerPolicy"] = "same-origin"
|
||||
config["CrossOriginResourcePolicy"] = "same-origin"
|
||||
}
|
||||
|
||||
// applyStrictProfile applies strict security settings
|
||||
func applyStrictProfile(config map[string]interface{}) {
|
||||
config["ContentSecurityPolicy"] = "default-src 'none'; script-src 'self'; style-src 'self'; img-src 'self'; font-src 'self'; connect-src 'self'; frame-ancestors 'none'; base-uri 'self'; form-action 'self';"
|
||||
config["FrameOptions"] = "DENY"
|
||||
config["ContentTypeOptions"] = "nosniff"
|
||||
config["XSSProtection"] = "1; mode=block"
|
||||
config["ReferrerPolicy"] = "strict-origin-when-cross-origin"
|
||||
config["PermissionsPolicy"] = "geolocation=(), microphone=(), camera=(), payment=(), usb=(), magnetometer=(), gyroscope=(), speaker=()"
|
||||
config["CrossOriginEmbedderPolicy"] = "require-corp"
|
||||
config["CrossOriginOpenerPolicy"] = "same-origin"
|
||||
config["CrossOriginResourcePolicy"] = "same-site"
|
||||
}
|
||||
|
||||
// applyDevelopmentProfile applies development-friendly settings
|
||||
func applyDevelopmentProfile(config map[string]interface{}) {
|
||||
config["ContentSecurityPolicy"] = "default-src 'self' 'unsafe-inline' 'unsafe-eval'; img-src 'self' data: https: http:; connect-src 'self' ws: wss:;"
|
||||
config["FrameOptions"] = "SAMEORIGIN"
|
||||
config["ContentTypeOptions"] = "nosniff"
|
||||
config["XSSProtection"] = "1; mode=block"
|
||||
config["ReferrerPolicy"] = "strict-origin-when-cross-origin"
|
||||
config["CrossOriginOpenerPolicy"] = "unsafe-none"
|
||||
config["CrossOriginResourcePolicy"] = "cross-origin"
|
||||
config["DevelopmentMode"] = true
|
||||
}
|
||||
|
||||
// applyAPIProfile applies API-friendly settings
|
||||
func applyAPIProfile(config map[string]interface{}) {
|
||||
config["ContentSecurityPolicy"] = "default-src 'none'; frame-ancestors 'none';"
|
||||
config["FrameOptions"] = "DENY"
|
||||
config["ContentTypeOptions"] = "nosniff"
|
||||
config["XSSProtection"] = "1; mode=block"
|
||||
config["ReferrerPolicy"] = "strict-origin-when-cross-origin"
|
||||
config["CrossOriginResourcePolicy"] = "cross-origin"
|
||||
}
|
||||
|
||||
// GetSecurityHeadersApplier returns a function that applies security headers
|
||||
func (c *Config) GetSecurityHeadersApplier() func(http.ResponseWriter, *http.Request) {
|
||||
if c.SecurityHeaders == nil || !c.SecurityHeaders.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
// This would need to import the internal security package
|
||||
// For now, return a basic implementation
|
||||
return func(rw http.ResponseWriter, req *http.Request) {
|
||||
headers := rw.Header()
|
||||
|
||||
// Apply basic security headers based on configuration
|
||||
if c.SecurityHeaders.FrameOptions != "" {
|
||||
headers.Set("X-Frame-Options", c.SecurityHeaders.FrameOptions)
|
||||
}
|
||||
if c.SecurityHeaders.ContentTypeOptions != "" {
|
||||
headers.Set("X-Content-Type-Options", c.SecurityHeaders.ContentTypeOptions)
|
||||
}
|
||||
if c.SecurityHeaders.XSSProtection != "" {
|
||||
headers.Set("X-XSS-Protection", c.SecurityHeaders.XSSProtection)
|
||||
}
|
||||
if c.SecurityHeaders.ReferrerPolicy != "" {
|
||||
headers.Set("Referrer-Policy", c.SecurityHeaders.ReferrerPolicy)
|
||||
}
|
||||
if c.SecurityHeaders.ContentSecurityPolicy != "" {
|
||||
headers.Set("Content-Security-Policy", c.SecurityHeaders.ContentSecurityPolicy)
|
||||
}
|
||||
|
||||
// HSTS for HTTPS
|
||||
if (req.TLS != nil || req.Header.Get("X-Forwarded-Proto") == "https") && c.SecurityHeaders.StrictTransportSecurity {
|
||||
hstsValue := fmt.Sprintf("max-age=%d", c.SecurityHeaders.StrictTransportSecurityMaxAge)
|
||||
if c.SecurityHeaders.StrictTransportSecuritySubdomains {
|
||||
hstsValue += "; includeSubDomains"
|
||||
}
|
||||
if c.SecurityHeaders.StrictTransportSecurityPreload {
|
||||
hstsValue += "; preload"
|
||||
}
|
||||
headers.Set("Strict-Transport-Security", hstsValue)
|
||||
}
|
||||
|
||||
// CORS headers
|
||||
if c.SecurityHeaders.CORSEnabled {
|
||||
origin := req.Header.Get("Origin")
|
||||
if origin != "" && isOriginAllowed(origin, c.SecurityHeaders.CORSAllowedOrigins) {
|
||||
headers.Set("Access-Control-Allow-Origin", origin)
|
||||
}
|
||||
|
||||
if len(c.SecurityHeaders.CORSAllowedMethods) > 0 {
|
||||
headers.Set("Access-Control-Allow-Methods", strings.Join(c.SecurityHeaders.CORSAllowedMethods, ", "))
|
||||
}
|
||||
if len(c.SecurityHeaders.CORSAllowedHeaders) > 0 {
|
||||
headers.Set("Access-Control-Allow-Headers", strings.Join(c.SecurityHeaders.CORSAllowedHeaders, ", "))
|
||||
}
|
||||
if c.SecurityHeaders.CORSAllowCredentials {
|
||||
headers.Set("Access-Control-Allow-Credentials", "true")
|
||||
}
|
||||
if c.SecurityHeaders.CORSMaxAge > 0 {
|
||||
headers.Set("Access-Control-Max-Age", strconv.Itoa(c.SecurityHeaders.CORSMaxAge))
|
||||
}
|
||||
}
|
||||
|
||||
// Custom headers
|
||||
for name, value := range c.SecurityHeaders.CustomHeaders {
|
||||
headers.Set(name, value)
|
||||
}
|
||||
|
||||
// Remove server headers
|
||||
if c.SecurityHeaders.DisableServerHeader {
|
||||
headers.Del("Server")
|
||||
}
|
||||
if c.SecurityHeaders.DisablePoweredByHeader {
|
||||
headers.Del("X-Powered-By")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// isOriginAllowed checks if an origin is in the allowed list
|
||||
func isOriginAllowed(origin string, allowedOrigins []string) bool {
|
||||
for _, allowed := range allowedOrigins {
|
||||
if origin == allowed || allowed == "*" {
|
||||
return true
|
||||
}
|
||||
// Simple wildcard matching for subdomains
|
||||
if strings.Contains(allowed, "*") {
|
||||
if strings.HasPrefix(allowed, "https://*.") {
|
||||
domain := strings.TrimPrefix(allowed, "https://*.")
|
||||
if strings.HasSuffix(origin, "."+domain) || origin == "https://"+domain {
|
||||
return true
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(allowed, "http://*.") {
|
||||
domain := strings.TrimPrefix(allowed, "http://*.")
|
||||
if strings.HasSuffix(origin, "."+domain) || origin == "http://"+domain {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -0,0 +1,116 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
// REDACTED is the placeholder value for sensitive information
|
||||
const REDACTED = "[REDACTED]"
|
||||
|
||||
// MarshalJSON implements custom JSON marshalling to redact sensitive fields
|
||||
// Rewritten without type aliases for yaegi compatibility
|
||||
func (c Config) MarshalJSON() ([]byte, error) {
|
||||
// Build a map manually to avoid type alias issues with yaegi
|
||||
result := make(map[string]interface{})
|
||||
|
||||
// Copy public fields
|
||||
result["providerURL"] = c.ProviderURL
|
||||
result["clientID"] = c.ClientID
|
||||
result["callbackURL"] = c.CallbackURL
|
||||
result["logoutURL"] = c.LogoutURL
|
||||
result["postLogoutRedirectURI"] = c.PostLogoutRedirectURI
|
||||
result["scopes"] = c.Scopes
|
||||
result["forceHTTPS"] = c.ForceHTTPS
|
||||
result["logLevel"] = c.LogLevel
|
||||
result["rateLimit"] = c.RateLimit
|
||||
result["excludedURLs"] = c.ExcludedURLs
|
||||
result["allowedUserDomains"] = c.AllowedUserDomains
|
||||
result["allowedUsers"] = c.AllowedUsers
|
||||
result["allowedRolesAndGroups"] = c.AllowedRolesAndGroups
|
||||
|
||||
// Redact sensitive fields
|
||||
result["clientSecret"] = REDACTED
|
||||
result["sessionEncryptionKey"] = REDACTED
|
||||
|
||||
// Handle Redis config
|
||||
if c.Redis != nil {
|
||||
redisMap := make(map[string]interface{})
|
||||
redisMap["enabled"] = c.Redis.Enabled
|
||||
redisMap["address"] = c.Redis.Address
|
||||
redisMap["password"] = REDACTED
|
||||
redisMap["db"] = c.Redis.DB
|
||||
redisMap["poolSize"] = c.Redis.PoolSize
|
||||
redisMap["cacheMode"] = c.Redis.CacheMode
|
||||
result["redis"] = redisMap
|
||||
}
|
||||
|
||||
return json.Marshal(result)
|
||||
}
|
||||
|
||||
// MarshalYAML implements custom YAML marshalling to redact sensitive fields
|
||||
// Rewritten without type aliases for yaegi compatibility
|
||||
func (c Config) MarshalYAML() (interface{}, error) {
|
||||
// Build a map manually to avoid type alias issues with yaegi
|
||||
result := make(map[string]interface{})
|
||||
|
||||
// Copy public fields
|
||||
result["providerURL"] = c.ProviderURL
|
||||
result["clientID"] = c.ClientID
|
||||
result["callbackURL"] = c.CallbackURL
|
||||
result["logoutURL"] = c.LogoutURL
|
||||
result["postLogoutRedirectURI"] = c.PostLogoutRedirectURI
|
||||
result["scopes"] = c.Scopes
|
||||
result["forceHTTPS"] = c.ForceHTTPS
|
||||
result["logLevel"] = c.LogLevel
|
||||
result["rateLimit"] = c.RateLimit
|
||||
result["excludedURLs"] = c.ExcludedURLs
|
||||
result["allowedUserDomains"] = c.AllowedUserDomains
|
||||
result["allowedUsers"] = c.AllowedUsers
|
||||
result["allowedRolesAndGroups"] = c.AllowedRolesAndGroups
|
||||
|
||||
// Redact sensitive fields
|
||||
result["clientSecret"] = REDACTED
|
||||
result["sessionEncryptionKey"] = REDACTED
|
||||
|
||||
// Handle Redis config
|
||||
if c.Redis != nil {
|
||||
redisMap := make(map[string]interface{})
|
||||
redisMap["enabled"] = c.Redis.Enabled
|
||||
redisMap["address"] = c.Redis.Address
|
||||
redisMap["password"] = REDACTED
|
||||
redisMap["db"] = c.Redis.DB
|
||||
redisMap["poolSize"] = c.Redis.PoolSize
|
||||
redisMap["cacheMode"] = c.Redis.CacheMode
|
||||
result["redis"] = redisMap
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// MarshalJSON for RedisConfig to redact sensitive fields
|
||||
// Rewritten without type aliases for yaegi compatibility
|
||||
func (r RedisConfig) MarshalJSON() ([]byte, error) {
|
||||
result := make(map[string]interface{})
|
||||
result["enabled"] = r.Enabled
|
||||
result["address"] = r.Address
|
||||
result["password"] = REDACTED
|
||||
result["db"] = r.DB
|
||||
result["poolSize"] = r.PoolSize
|
||||
result["cacheMode"] = r.CacheMode
|
||||
|
||||
return json.Marshal(result)
|
||||
}
|
||||
|
||||
// MarshalYAML for RedisConfig to redact sensitive fields
|
||||
// Rewritten without type aliases for yaegi compatibility
|
||||
func (r RedisConfig) MarshalYAML() (interface{}, error) {
|
||||
result := make(map[string]interface{})
|
||||
result["enabled"] = r.Enabled
|
||||
result["address"] = r.Address
|
||||
result["password"] = REDACTED
|
||||
result["db"] = r.DB
|
||||
result["poolSize"] = r.PoolSize
|
||||
result["cacheMode"] = r.CacheMode
|
||||
|
||||
return result, nil
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -18,7 +18,7 @@ func TestCSRFTokenSessionManagement(t *testing.T) {
|
||||
// Test that CSRF tokens persist through the authentication flow
|
||||
t.Run("CSRF_Token_Persists_After_Selective_Clear", func(t *testing.T) {
|
||||
// Create a session manager
|
||||
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", NewLogger("debug"))
|
||||
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", "", 0, NewLogger("debug"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create initial request
|
||||
@@ -90,7 +90,7 @@ func TestCSRFTokenSessionManagement(t *testing.T) {
|
||||
|
||||
// Test that marking session as dirty forces save
|
||||
t.Run("Mark_Dirty_Forces_Session_Save", func(t *testing.T) {
|
||||
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", NewLogger("debug"))
|
||||
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", "", 0, NewLogger("debug"))
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com/test", nil)
|
||||
@@ -126,7 +126,7 @@ func TestCSRFTokenSessionManagement(t *testing.T) {
|
||||
|
||||
// Test Azure-specific session handling
|
||||
t.Run("Azure_Session_Cookie_Configuration", func(t *testing.T) {
|
||||
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", NewLogger("debug"))
|
||||
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", "", 0, NewLogger("debug"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Simulate Azure callback scenario
|
||||
@@ -158,7 +158,7 @@ func TestCSRFTokenSessionManagement(t *testing.T) {
|
||||
|
||||
// Test session continuity through auth flow
|
||||
t.Run("Session_Continuity_Through_Auth_Flow", func(t *testing.T) {
|
||||
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", NewLogger("debug"))
|
||||
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", "", 0, NewLogger("debug"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Step 1: Initial request
|
||||
@@ -199,7 +199,7 @@ func TestCSRFTokenSessionManagement(t *testing.T) {
|
||||
|
||||
// Test large token handling doesn't affect CSRF
|
||||
t.Run("Large_Tokens_Dont_Affect_CSRF", func(t *testing.T) {
|
||||
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", NewLogger("debug"))
|
||||
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", "", 0, NewLogger("debug"))
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com/test", nil)
|
||||
@@ -262,7 +262,7 @@ func TestAuthFlowWithoutExternalDependencies(t *testing.T) {
|
||||
|
||||
// We can't fully initialize TraefikOidc without network access,
|
||||
// but we can test the session management directly
|
||||
sessionManager, err := NewSessionManager(plugin.SessionEncryptionKey, plugin.ForceHTTPS, "", NewLogger(plugin.LogLevel))
|
||||
sessionManager, err := NewSessionManager(plugin.SessionEncryptionKey, plugin.ForceHTTPS, "", "", 0, NewLogger(plugin.LogLevel))
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("Session_Created_On_Protected_Request", func(t *testing.T) {
|
||||
@@ -291,7 +291,7 @@ func TestAuthFlowWithoutExternalDependencies(t *testing.T) {
|
||||
// TestRegressionLoginLoop specifically tests the fix for issue #53
|
||||
func TestRegressionLoginLoop(t *testing.T) {
|
||||
// This test verifies that the specific changes made to fix the login loop work correctly
|
||||
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", NewLogger("debug"))
|
||||
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", "", 0, NewLogger("debug"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Simulate the exact flow that was causing the login loop
|
||||
@@ -392,7 +392,7 @@ func TestRegressionLoginLoop(t *testing.T) {
|
||||
|
||||
// TestCSRFValidationTiming tests timing-sensitive CSRF validation scenarios
|
||||
func TestCSRFValidationTiming(t *testing.T) {
|
||||
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", NewLogger("debug"))
|
||||
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", "", 0, NewLogger("debug"))
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("Rapid_Redirect_Maintains_CSRF", func(t *testing.T) {
|
||||
|
||||
@@ -0,0 +1,364 @@
|
||||
//go:build !yaegi
|
||||
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestCustomClaimNames_DefaultBehavior tests backward compatibility with default claim names
|
||||
func TestCustomClaimNames_DefaultBehavior(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
// Explicitly set defaults to test backward compatibility
|
||||
ts.tOidc.roleClaimName = "roles"
|
||||
ts.tOidc.groupClaimName = "groups"
|
||||
|
||||
// Test that when no custom claim names are configured, it uses defaults "roles" and "groups"
|
||||
claims := map[string]interface{}{
|
||||
"groups": []interface{}{"admin", "users"},
|
||||
"roles": []interface{}{"editor", "viewer"},
|
||||
}
|
||||
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test token: %v", err)
|
||||
}
|
||||
|
||||
groups, roles, err := ts.tOidc.extractGroupsAndRoles(token)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if !stringSliceEqual(groups, []string{"admin", "users"}) {
|
||||
t.Errorf("Expected groups [admin users], got %v", groups)
|
||||
}
|
||||
|
||||
if !stringSliceEqual(roles, []string{"editor", "viewer"}) {
|
||||
t.Errorf("Expected roles [editor viewer], got %v", roles)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCustomClaimNames_Auth0Namespaced tests Auth0-style namespaced claims
|
||||
func TestCustomClaimNames_Auth0Namespaced(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
// Configure custom claim names for Auth0
|
||||
ts.tOidc.roleClaimName = "https://myapp.com/roles"
|
||||
ts.tOidc.groupClaimName = "https://myapp.com/groups"
|
||||
|
||||
// Create token with Auth0-style namespaced claims
|
||||
claims := map[string]interface{}{
|
||||
"https://myapp.com/groups": []interface{}{"admin", "users"},
|
||||
"https://myapp.com/roles": []interface{}{"editor", "viewer"},
|
||||
}
|
||||
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test token: %v", err)
|
||||
}
|
||||
|
||||
groups, roles, err := ts.tOidc.extractGroupsAndRoles(token)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if !stringSliceEqual(groups, []string{"admin", "users"}) {
|
||||
t.Errorf("Expected groups [admin users], got %v", groups)
|
||||
}
|
||||
|
||||
if !stringSliceEqual(roles, []string{"editor", "viewer"}) {
|
||||
t.Errorf("Expected roles [editor viewer], got %v", roles)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCustomClaimNames_CustomSimpleNames tests custom simple claim names
|
||||
func TestCustomClaimNames_CustomSimpleNames(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
// Configure custom simple claim names
|
||||
ts.tOidc.roleClaimName = "user_roles"
|
||||
ts.tOidc.groupClaimName = "user_groups"
|
||||
|
||||
// Create token with custom claim names
|
||||
claims := map[string]interface{}{
|
||||
"user_groups": []interface{}{"engineering", "product"},
|
||||
"user_roles": []interface{}{"developer", "manager"},
|
||||
}
|
||||
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test token: %v", err)
|
||||
}
|
||||
|
||||
groups, roles, err := ts.tOidc.extractGroupsAndRoles(token)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if !stringSliceEqual(groups, []string{"engineering", "product"}) {
|
||||
t.Errorf("Expected groups [engineering product], got %v", groups)
|
||||
}
|
||||
|
||||
if !stringSliceEqual(roles, []string{"developer", "manager"}) {
|
||||
t.Errorf("Expected roles [developer manager], got %v", roles)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCustomClaimNames_MissingClaims tests behavior when custom claims are missing
|
||||
func TestCustomClaimNames_MissingClaims(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
// Configure custom claim names
|
||||
ts.tOidc.roleClaimName = "custom_roles"
|
||||
ts.tOidc.groupClaimName = "custom_groups"
|
||||
|
||||
// Create token WITHOUT the custom claims
|
||||
claims := map[string]interface{}{
|
||||
"sub": "user123",
|
||||
"email": "user@example.com",
|
||||
}
|
||||
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test token: %v", err)
|
||||
}
|
||||
|
||||
groups, roles, err := ts.tOidc.extractGroupsAndRoles(token)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Should return empty slices, not error
|
||||
if len(groups) != 0 {
|
||||
t.Errorf("Expected empty groups, got %v", groups)
|
||||
}
|
||||
|
||||
if len(roles) != 0 {
|
||||
t.Errorf("Expected empty roles, got %v", roles)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCustomClaimNames_MalformedClaims tests error handling for malformed claims
|
||||
func TestCustomClaimNames_MalformedRoleClaim(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
// Configure custom claim names
|
||||
ts.tOidc.roleClaimName = "custom_roles"
|
||||
|
||||
// Create token with malformed role claim (not an array)
|
||||
claims := map[string]interface{}{
|
||||
"custom_roles": "this-should-be-an-array",
|
||||
}
|
||||
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test token: %v", err)
|
||||
}
|
||||
|
||||
_, _, err = ts.tOidc.extractGroupsAndRoles(token)
|
||||
if err == nil {
|
||||
t.Error("Expected error for malformed role claim, got nil")
|
||||
}
|
||||
|
||||
// Check error message contains the custom claim name
|
||||
expectedError := "custom_roles claim is not an array"
|
||||
if err.Error() != expectedError {
|
||||
t.Errorf("Expected error '%s', got '%s'", expectedError, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// TestCustomClaimNames_MalformedGroupClaim tests error handling for malformed group claims
|
||||
func TestCustomClaimNames_MalformedGroupClaim(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
// Configure custom claim names
|
||||
ts.tOidc.groupClaimName = "custom_groups"
|
||||
|
||||
// Create token with malformed group claim (not an array)
|
||||
claims := map[string]interface{}{
|
||||
"custom_groups": 12345, // Not an array
|
||||
}
|
||||
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test token: %v", err)
|
||||
}
|
||||
|
||||
_, _, err = ts.tOidc.extractGroupsAndRoles(token)
|
||||
if err == nil {
|
||||
t.Error("Expected error for malformed group claim, got nil")
|
||||
}
|
||||
|
||||
// Check error message contains the custom claim name
|
||||
expectedError := "custom_groups claim is not an array"
|
||||
if err.Error() != expectedError {
|
||||
t.Errorf("Expected error '%s', got '%s'", expectedError, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// TestCustomClaimNames_PartialConfiguration tests when only one claim name is customized
|
||||
func TestCustomClaimNames_OnlyRoleCustomized(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
// Configure only role claim name (group uses default)
|
||||
ts.tOidc.roleClaimName = "https://myapp.com/roles"
|
||||
ts.tOidc.groupClaimName = "groups" // default
|
||||
|
||||
// Create token with mixed claim names
|
||||
claims := map[string]interface{}{
|
||||
"groups": []interface{}{"admin"},
|
||||
"https://myapp.com/roles": []interface{}{"editor"},
|
||||
}
|
||||
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test token: %v", err)
|
||||
}
|
||||
|
||||
groups, roles, err := ts.tOidc.extractGroupsAndRoles(token)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if !stringSliceEqual(groups, []string{"admin"}) {
|
||||
t.Errorf("Expected groups [admin], got %v", groups)
|
||||
}
|
||||
|
||||
if !stringSliceEqual(roles, []string{"editor"}) {
|
||||
t.Errorf("Expected roles [editor], got %v", roles)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCustomClaimNames_OnlyGroupCustomized tests when only group claim name is customized
|
||||
func TestCustomClaimNames_OnlyGroupCustomized(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
// Configure only group claim name (role uses default)
|
||||
ts.tOidc.roleClaimName = "roles" // default
|
||||
ts.tOidc.groupClaimName = "https://myapp.com/groups"
|
||||
|
||||
// Create token with mixed claim names
|
||||
claims := map[string]interface{}{
|
||||
"roles": []interface{}{"viewer"},
|
||||
"https://myapp.com/groups": []interface{}{"users"},
|
||||
}
|
||||
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test token: %v", err)
|
||||
}
|
||||
|
||||
groups, roles, err := ts.tOidc.extractGroupsAndRoles(token)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if !stringSliceEqual(groups, []string{"users"}) {
|
||||
t.Errorf("Expected groups [users], got %v", groups)
|
||||
}
|
||||
|
||||
if !stringSliceEqual(roles, []string{"viewer"}) {
|
||||
t.Errorf("Expected roles [viewer], got %v", roles)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCustomClaimNames_EmptyArrays tests extraction with empty claim arrays
|
||||
func TestCustomClaimNames_EmptyArrays(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
// Configure custom claim names
|
||||
ts.tOidc.roleClaimName = "https://myapp.com/roles"
|
||||
ts.tOidc.groupClaimName = "https://myapp.com/groups"
|
||||
|
||||
// Create token with empty arrays
|
||||
claims := map[string]interface{}{
|
||||
"https://myapp.com/groups": []interface{}{},
|
||||
"https://myapp.com/roles": []interface{}{},
|
||||
}
|
||||
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test token: %v", err)
|
||||
}
|
||||
|
||||
groups, roles, err := ts.tOidc.extractGroupsAndRoles(token)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(groups) != 0 {
|
||||
t.Errorf("Expected empty groups, got %v", groups)
|
||||
}
|
||||
|
||||
if len(roles) != 0 {
|
||||
t.Errorf("Expected empty roles, got %v", roles)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCustomClaimNames_NonStringElements tests handling of non-string elements in claim arrays
|
||||
func TestCustomClaimNames_NonStringInRoleArray(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
// Configure custom claim names
|
||||
ts.tOidc.roleClaimName = "custom_roles"
|
||||
|
||||
// Create token with mixed-type array (should skip non-string elements)
|
||||
claims := map[string]interface{}{
|
||||
"custom_roles": []interface{}{"role1", 12345, "role2", true},
|
||||
}
|
||||
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test token: %v", err)
|
||||
}
|
||||
|
||||
_, roles, err := ts.tOidc.extractGroupsAndRoles(token)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Should only extract string elements
|
||||
if !stringSliceEqual(roles, []string{"role1", "role2"}) {
|
||||
t.Errorf("Expected roles [role1 role2], got %v", roles)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCustomClaimNames_NonStringInGroupArray tests handling of non-string elements in group arrays
|
||||
func TestCustomClaimNames_NonStringInGroupArray(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
// Configure custom claim names
|
||||
ts.tOidc.groupClaimName = "custom_groups"
|
||||
|
||||
// Create token with mixed-type array (should skip non-string elements)
|
||||
claims := map[string]interface{}{
|
||||
"custom_groups": []interface{}{"group1", nil, "group2", 3.14},
|
||||
}
|
||||
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test token: %v", err)
|
||||
}
|
||||
|
||||
groups, _, err := ts.tOidc.extractGroupsAndRoles(token)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Should only extract string elements
|
||||
if !stringSliceEqual(groups, []string{"group1", "group2"}) {
|
||||
t.Errorf("Expected groups [group1 group2], got %v", groups)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,424 @@
|
||||
# Auth0 Audience Validation Guide
|
||||
|
||||
## Overview
|
||||
|
||||
This guide explains how to configure audience validation for Auth0 and other OIDC providers that support custom API audiences. It covers three common Auth0 scenarios and how to configure the middleware for maximum security.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
1. [Understanding Audiences](#understanding-audiences)
|
||||
2. [The Three Auth0 Scenarios](#the-three-auth0-scenarios)
|
||||
3. [Configuration Options](#configuration-options)
|
||||
4. [Security Recommendations](#security-recommendations)
|
||||
5. [Troubleshooting](#troubleshooting)
|
||||
|
||||
---
|
||||
|
||||
## Understanding Audiences
|
||||
|
||||
### What is an Audience?
|
||||
|
||||
The **audience** (`aud`) claim in a JWT identifies the intended recipient of the token. Per OAuth 2.0 and OIDC specifications:
|
||||
|
||||
- **ID Tokens**: MUST have `aud = client_id` (per OIDC Core 1.0 spec)
|
||||
- **Access Tokens**: Can have custom audiences (e.g., API identifiers)
|
||||
|
||||
### Why Does This Matter?
|
||||
|
||||
Proper audience validation prevents **token confusion attacks** where a token intended for one API is used to access another API.
|
||||
|
||||
---
|
||||
|
||||
## The Three Auth0 Scenarios
|
||||
|
||||
### Scenario 1: Custom API Audience ✅ **RECOMMENDED**
|
||||
|
||||
**Configuration:**
|
||||
```yaml
|
||||
audience: "https://my-api.example.com" # Your API identifier from Auth0
|
||||
```
|
||||
|
||||
**What Happens:**
|
||||
1. Authorization request includes `audience` parameter
|
||||
2. Auth0 issues:
|
||||
- **ID Token**: `aud = client_id`
|
||||
- **Access Token**: `aud = ["https://issuer/userinfo", "https://my-api.example.com"]`
|
||||
3. Middleware validates:
|
||||
- ID tokens against `client_id`
|
||||
- Access tokens against custom audience
|
||||
|
||||
**Result:** ✅ Fully secure, OIDC compliant
|
||||
|
||||
---
|
||||
|
||||
### Scenario 2: Default Audience (No Custom API) ⚠️ **USE WITH CAUTION**
|
||||
|
||||
**Configuration:**
|
||||
```yaml
|
||||
# audience not specified (defaults to client_id)
|
||||
```
|
||||
|
||||
**What Happens:**
|
||||
1. Authorization request WITHOUT `audience` parameter
|
||||
2. Auth0 issues:
|
||||
- **ID Token**: `aud = client_id`
|
||||
- **Access Token**: `aud = ["https://issuer/userinfo", "default_api"]` (no `client_id`)
|
||||
3. Access token validation fails (audience mismatch)
|
||||
4. Middleware falls back to ID token validation
|
||||
|
||||
**Security Warning:**
|
||||
```
|
||||
⚠️⚠️⚠️ SECURITY WARNING: Falling back to ID token validation despite access token audience mismatch!
|
||||
⚠️ This could allow tokens intended for different APIs to grant access
|
||||
⚠️ Set strictAudienceValidation=true to enforce proper audience validation
|
||||
⚠️ See: https://github.com/lukaszraczylo/traefikoidc/issues/74
|
||||
```
|
||||
|
||||
**Recommended Fix:**
|
||||
```yaml
|
||||
strictAudienceValidation: true # Reject sessions with audience mismatch
|
||||
```
|
||||
|
||||
**Result:**
|
||||
- Default: ⚠️ Works but logs security warnings
|
||||
- With strict mode: ✅ Secure (rejects mismatched tokens)
|
||||
|
||||
---
|
||||
|
||||
### Scenario 3: Opaque Access Tokens ✅ **SUPPORTED**
|
||||
|
||||
**Configuration:**
|
||||
```yaml
|
||||
allowOpaqueTokens: true # Enable opaque token support
|
||||
requireTokenIntrospection: true # Require introspection (recommended)
|
||||
```
|
||||
|
||||
**What Happens:**
|
||||
1. Auth0 issues opaque (non-JWT) access token
|
||||
2. Middleware detects opaque token (not 3 parts separated by dots)
|
||||
3. Uses OAuth 2.0 Token Introspection (RFC 7662) to validate
|
||||
4. Falls back to ID token if introspection unavailable (unless `requireTokenIntrospection=true`)
|
||||
|
||||
**Requirements:**
|
||||
- Provider must support `introspection_endpoint` in OIDC discovery
|
||||
- Client must have introspection permissions
|
||||
|
||||
**Result:** ✅ Secure with introspection, ⚠️ risky without
|
||||
|
||||
---
|
||||
|
||||
## Configuration Options
|
||||
|
||||
### Audience Settings
|
||||
|
||||
| Option | Type | Default | Description |
|
||||
|--------|------|---------|-------------|
|
||||
| `audience` | string | `client_id` | Expected audience for access tokens |
|
||||
|
||||
**Example:**
|
||||
```yaml
|
||||
# .traefik.yml
|
||||
http:
|
||||
middlewares:
|
||||
oidc-auth:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
audience: "https://my-api.example.com"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Security Mode Settings
|
||||
|
||||
#### `strictAudienceValidation`
|
||||
|
||||
**Type:** boolean
|
||||
**Default:** `false`
|
||||
**Recommended:** `true` for production
|
||||
|
||||
**What it does:**
|
||||
- When `true`: Rejects sessions if access token audience doesn't match (prevents Scenario 2)
|
||||
- When `false`: Logs warnings but allows fallback to ID token (backward compatible)
|
||||
|
||||
**Example:**
|
||||
```yaml
|
||||
strictAudienceValidation: true
|
||||
```
|
||||
|
||||
**When to use:**
|
||||
- ✅ Always use in production environments
|
||||
- ✅ When you have custom API audiences configured in Auth0
|
||||
- ⚠️ May break existing deployments relying on Scenario 2 behavior
|
||||
|
||||
---
|
||||
|
||||
#### `allowOpaqueTokens`
|
||||
|
||||
**Type:** boolean
|
||||
**Default:** `false`
|
||||
|
||||
**What it does:**
|
||||
- When `true`: Accepts opaque (non-JWT) access tokens
|
||||
- When `false`: Only accepts JWT access tokens
|
||||
|
||||
**Example:**
|
||||
```yaml
|
||||
allowOpaqueTokens: true
|
||||
```
|
||||
|
||||
**When to use:**
|
||||
- ✅ When Auth0 issues opaque tokens (no default API configured)
|
||||
- ✅ When using Auth0 Management API tokens
|
||||
- ⚠️ Requires introspection endpoint for security
|
||||
|
||||
---
|
||||
|
||||
#### `requireTokenIntrospection`
|
||||
|
||||
**Type:** boolean
|
||||
**Default:** `false`
|
||||
**Recommended:** `true` when `allowOpaqueTokens=true`
|
||||
|
||||
**What it does:**
|
||||
- When `true`: Rejects opaque tokens if introspection fails or endpoint unavailable
|
||||
- When `false`: Falls back to ID token validation for opaque tokens
|
||||
|
||||
**Example:**
|
||||
```yaml
|
||||
allowOpaqueTokens: true
|
||||
requireTokenIntrospection: true
|
||||
```
|
||||
|
||||
**When to use:**
|
||||
- ✅ Always use when `allowOpaqueTokens=true` for maximum security
|
||||
- ⚠️ Requires provider to expose introspection endpoint
|
||||
|
||||
---
|
||||
|
||||
## Security Recommendations
|
||||
|
||||
### Recommended Configuration for Auth0
|
||||
|
||||
**For APIs with custom audiences (Scenario 1):**
|
||||
```yaml
|
||||
audience: "https://my-api.example.com"
|
||||
strictAudienceValidation: true
|
||||
allowOpaqueTokens: false
|
||||
```
|
||||
|
||||
**For default Auth0 setup (Scenario 2):**
|
||||
```yaml
|
||||
# Don't set audience (defaults to client_id)
|
||||
strictAudienceValidation: true # Enforce proper configuration
|
||||
```
|
||||
|
||||
**For opaque tokens (Scenario 3):**
|
||||
```yaml
|
||||
allowOpaqueTokens: true
|
||||
requireTokenIntrospection: true
|
||||
strictAudienceValidation: true
|
||||
```
|
||||
|
||||
### Security Best Practices
|
||||
|
||||
1. ✅ **Always set `strictAudienceValidation: true` in production**
|
||||
2. ✅ **Configure custom API audiences in Auth0 dashboard**
|
||||
3. ✅ **Use `requireTokenIntrospection: true` if accepting opaque tokens**
|
||||
4. ✅ **Monitor logs for security warnings**
|
||||
5. ❌ **Don't rely on Scenario 2 fallback behavior**
|
||||
|
||||
---
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "Access token validation failed due to audience mismatch"
|
||||
|
||||
**Symptom:**
|
||||
```
|
||||
⚠️ SCENARIO 2 DETECTED: Access token validation failed due to audience mismatch
|
||||
```
|
||||
|
||||
**Cause:** Access token audience doesn't match configured audience
|
||||
|
||||
**Solutions:**
|
||||
1. **Configure correct audience:**
|
||||
```yaml
|
||||
audience: "https://your-api-identifier" # From Auth0 API settings
|
||||
```
|
||||
|
||||
2. **Update Auth0 authorization request:**
|
||||
- Ensure `audience` parameter is included in authorize URL
|
||||
- Middleware automatically adds this when `audience != client_id`
|
||||
|
||||
3. **Accept the behavior (not recommended):**
|
||||
```yaml
|
||||
strictAudienceValidation: false # Logs warnings but allows
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### "Opaque token detected but allowOpaqueTokens=false"
|
||||
|
||||
**Symptom:**
|
||||
```
|
||||
⚠️ Opaque access token detected but allowOpaqueTokens=false
|
||||
```
|
||||
|
||||
**Cause:** Auth0 issued non-JWT access token but middleware not configured to accept them
|
||||
|
||||
**Solutions:**
|
||||
1. **Enable opaque tokens:**
|
||||
```yaml
|
||||
allowOpaqueTokens: true
|
||||
requireTokenIntrospection: true
|
||||
```
|
||||
|
||||
2. **Configure Auth0 to issue JWT access tokens:**
|
||||
- Create an API in Auth0 dashboard
|
||||
- Set API identifier as `audience` in configuration
|
||||
|
||||
---
|
||||
|
||||
### "Introspection endpoint not available"
|
||||
|
||||
**Symptom:**
|
||||
```
|
||||
⚠️ Opaque tokens enabled but no introspection endpoint available from provider
|
||||
```
|
||||
|
||||
**Cause:** Auth0 provider metadata doesn't include `introspection_endpoint`
|
||||
|
||||
**Solutions:**
|
||||
1. **Check provider discovery:**
|
||||
```bash
|
||||
curl https://YOUR_DOMAIN/.well-known/openid-configuration
|
||||
```
|
||||
Look for `introspection_endpoint`
|
||||
|
||||
2. **Disable required introspection (less secure):**
|
||||
```yaml
|
||||
allowOpaqueTokens: true
|
||||
requireTokenIntrospection: false # Falls back to ID token
|
||||
```
|
||||
|
||||
3. **Use JWT access tokens instead** (recommended)
|
||||
|
||||
---
|
||||
|
||||
### "Token introspection required but endpoint not available"
|
||||
|
||||
**Symptom:**
|
||||
```
|
||||
❌ SECURITY: Opaque token rejected (introspection required but failed)
|
||||
```
|
||||
|
||||
**Cause:** `requireTokenIntrospection=true` but provider doesn't support it
|
||||
|
||||
**Solutions:**
|
||||
1. **Disable required introspection:**
|
||||
```yaml
|
||||
requireTokenIntrospection: false
|
||||
```
|
||||
|
||||
2. **Configure Auth0 to issue JWT tokens** (better solution)
|
||||
|
||||
---
|
||||
|
||||
## Advanced Topics
|
||||
|
||||
### Token Type Detection
|
||||
|
||||
The middleware uses a sophisticated 6-step detection algorithm:
|
||||
|
||||
1. **RFC 9068 `typ` header**: `at+jwt` → Access Token
|
||||
2. **Explicit type claims**: `token_use`, `token_type`
|
||||
3. **`scope` claim**: Present → Access Token
|
||||
4. **`nonce` claim**: Present → ID Token (OIDC spec)
|
||||
5. **Audience check**: `aud == client_id` only → ID Token
|
||||
6. **Default**: Access Token
|
||||
|
||||
### OAuth 2.0 Token Introspection (RFC 7662)
|
||||
|
||||
When opaque tokens are detected:
|
||||
|
||||
1. Middleware calls provider's `introspection_endpoint`
|
||||
2. Authenticates using client credentials
|
||||
3. Receives response with `active` status and claims
|
||||
4. Caches result for 5 minutes (configurable via TTL)
|
||||
5. Validates expiration, not-before, and audience if present
|
||||
|
||||
**Cache behavior:**
|
||||
- Cache key: Token hash
|
||||
- TTL: 5 minutes or token expiry (whichever is shorter)
|
||||
- Reduces introspection requests for frequently used tokens
|
||||
|
||||
---
|
||||
|
||||
## Reference Links
|
||||
|
||||
- [GitHub Issue #74](https://github.com/lukaszraczylo/traefikoidc/issues/74) - Original Auth0 audience discussion
|
||||
- [OIDC Core 1.0 Spec](https://openid.net/specs/openid-connect-core-1_0.html) - ID Token requirements
|
||||
- [OAuth 2.0 RFC 6749](https://datatracker.ietf.org/doc/html/rfc6749) - OAuth 2.0 specification
|
||||
- [RFC 7662](https://datatracker.ietf.org/doc/html/rfc7662) - OAuth 2.0 Token Introspection
|
||||
- [RFC 9068](https://datatracker.ietf.org/doc/html/rfc9068) - JWT Access Token Profile
|
||||
- [Auth0 API Authorization](https://auth0.com/docs/secure/tokens/access-tokens) - Auth0 audience documentation
|
||||
|
||||
---
|
||||
|
||||
## Migration Guide
|
||||
|
||||
### From Previous Versions
|
||||
|
||||
**If you're upgrading from a version without these features:**
|
||||
|
||||
1. **No action required for default behavior** - backward compatible
|
||||
2. **Recommended: Enable strict mode gradually**
|
||||
```yaml
|
||||
# Step 1: Enable and monitor logs
|
||||
strictAudienceValidation: false # Default
|
||||
|
||||
# Step 2: After confirming no warnings, enable
|
||||
strictAudienceValidation: true
|
||||
```
|
||||
|
||||
3. **For opaque tokens: Enable explicitly**
|
||||
```yaml
|
||||
allowOpaqueTokens: true
|
||||
requireTokenIntrospection: true
|
||||
```
|
||||
|
||||
### Testing Your Configuration
|
||||
|
||||
1. **Check logs for warnings:**
|
||||
```bash
|
||||
# Look for Scenario 2 warnings
|
||||
grep "SCENARIO 2 DETECTED" /var/log/traefik.log
|
||||
|
||||
# Look for opaque token warnings
|
||||
grep "Opaque" /var/log/traefik.log
|
||||
```
|
||||
|
||||
2. **Test with curl:**
|
||||
```bash
|
||||
# Get token from Auth0
|
||||
ACCESS_TOKEN="your_access_token"
|
||||
|
||||
# Test request
|
||||
curl -H "Authorization: Bearer $ACCESS_TOKEN" \
|
||||
https://your-app.example.com/api
|
||||
```
|
||||
|
||||
3. **Monitor for security warnings in production logs**
|
||||
|
||||
---
|
||||
|
||||
## Support
|
||||
|
||||
For issues or questions:
|
||||
- GitHub Issues: https://github.com/lukaszraczylo/traefikoidc/issues
|
||||
- Security issues: See SECURITY.md for responsible disclosure
|
||||
|
||||
---
|
||||
|
||||
**Last Updated:** 2025-01-09
|
||||
**Version:** 0.7.8+
|
||||
@@ -0,0 +1 @@
|
||||
traefikoidc.raczylo.com
|
||||
@@ -0,0 +1,456 @@
|
||||
# Configuration Reference
|
||||
|
||||
Complete reference for all Traefik OIDC middleware configuration options.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Required Parameters](#required-parameters)
|
||||
- [Optional Parameters](#optional-parameters)
|
||||
- [Security Options](#security-options)
|
||||
- [Session Management](#session-management)
|
||||
- [Access Control](#access-control)
|
||||
- [Headers Configuration](#headers-configuration)
|
||||
- [Security Headers](#security-headers)
|
||||
- [Scope Configuration](#scope-configuration)
|
||||
- [Advanced Options](#advanced-options)
|
||||
|
||||
---
|
||||
|
||||
## Required Parameters
|
||||
|
||||
| Parameter | Type | Description | Example |
|
||||
|-----------|------|-------------|---------|
|
||||
| `providerURL` | string | Base URL of the OIDC provider | `https://accounts.google.com` |
|
||||
| `clientID` | string | OAuth 2.0 client identifier | `1234567890.apps.googleusercontent.com` |
|
||||
| `clientSecret` | string | OAuth 2.0 client secret | `your-client-secret` |
|
||||
| `sessionEncryptionKey` | string | Key for encrypting session data (min 32 bytes) | `your-32-byte-encryption-key-here` |
|
||||
| `callbackURL` | string | Path where provider redirects after authentication | `/oauth2/callback` |
|
||||
|
||||
### Basic Configuration Example
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-auth
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://accounts.google.com
|
||||
clientID: your-client-id.apps.googleusercontent.com
|
||||
clientSecret: your-client-secret
|
||||
sessionEncryptionKey: your-32-byte-encryption-key-here
|
||||
callbackURL: /oauth2/callback
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Optional Parameters
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
|-----------|------|---------|-------------|
|
||||
| `logoutURL` | string | `callbackURL + "/logout"` | Path for logout requests |
|
||||
| `postLogoutRedirectURI` | string | `/` | Redirect URL after logout |
|
||||
| `logLevel` | string | `info` | Logging verbosity (`debug`, `info`, `error`) |
|
||||
| `forceHTTPS` | bool | `false` | Force HTTPS for redirect URIs |
|
||||
| `rateLimit` | int | `100` | Maximum requests per second |
|
||||
| `excludedURLs` | []string | none | Paths that bypass authentication |
|
||||
| `revocationURL` | string | auto-discovered | Token revocation endpoint |
|
||||
| `oidcEndSessionURL` | string | auto-discovered | Provider's end session endpoint |
|
||||
| `enablePKCE` | bool | `false` | Enable PKCE for authorization code flow |
|
||||
| `minimalHeaders` | bool | `false` | Reduce forwarded headers |
|
||||
|
||||
### TLS Termination at Load Balancer
|
||||
|
||||
If running Traefik behind a load balancer (AWS ALB, Google Cloud LB, Azure App Gateway) that terminates TLS:
|
||||
|
||||
```yaml
|
||||
forceHTTPS: true # Required for correct redirect URIs
|
||||
```
|
||||
|
||||
Without this setting, redirect URIs will use `http://` instead of `https://`, causing OAuth callback failures.
|
||||
|
||||
---
|
||||
|
||||
## Security Options
|
||||
|
||||
### Audience Validation
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
|-----------|------|---------|-------------|
|
||||
| `audience` | string | `clientID` | Expected audience for access token validation |
|
||||
| `strictAudienceValidation` | bool | `false` | Reject sessions with audience mismatch |
|
||||
| `allowOpaqueTokens` | bool | `false` | Enable opaque token support via RFC 7662 |
|
||||
| `requireTokenIntrospection` | bool | `false` | Require introspection for opaque tokens |
|
||||
|
||||
#### Production Security Configuration
|
||||
|
||||
```yaml
|
||||
audience: "https://my-api.example.com"
|
||||
strictAudienceValidation: true
|
||||
```
|
||||
|
||||
#### Opaque Token Support
|
||||
|
||||
```yaml
|
||||
allowOpaqueTokens: true
|
||||
requireTokenIntrospection: true
|
||||
strictAudienceValidation: true
|
||||
```
|
||||
|
||||
### Other Security Options
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
|-----------|------|---------|-------------|
|
||||
| `disableReplayDetection` | bool | `false` | Disable JTI-based replay attack detection |
|
||||
| `allowPrivateIPAddresses` | bool | `false` | Allow private IPs in provider URLs |
|
||||
|
||||
---
|
||||
|
||||
## Session Management
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
|-----------|------|---------|-------------|
|
||||
| `sessionMaxAge` | int | `86400` (24h) | Maximum session age in seconds |
|
||||
| `refreshGracePeriodSeconds` | int | `60` | Seconds before expiry to attempt refresh |
|
||||
| `cookieDomain` | string | auto-detected | Domain for session cookies |
|
||||
| `cookiePrefix` | string | `_oidc_raczylo_` | Prefix for cookie names |
|
||||
|
||||
### Multi-Subdomain Setup
|
||||
|
||||
```yaml
|
||||
cookieDomain: .example.com # Share cookies across subdomains
|
||||
```
|
||||
|
||||
### Multiple Middleware Instances
|
||||
|
||||
When running multiple middleware instances with different authorization requirements, use unique prefixes:
|
||||
|
||||
```yaml
|
||||
# User authentication middleware
|
||||
---
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-userauth
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
cookiePrefix: "_oidc_userauth_"
|
||||
sessionEncryptionKey: user-encryption-key-min-32-bytes
|
||||
# ... other config
|
||||
---
|
||||
# Admin authentication middleware
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-adminauth
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
cookiePrefix: "_oidc_adminauth_"
|
||||
sessionEncryptionKey: admin-encryption-key-min-32-bytes
|
||||
allowedUsers:
|
||||
- admin@example.com
|
||||
# ... other config
|
||||
```
|
||||
|
||||
### Extended Session Duration
|
||||
|
||||
```yaml
|
||||
sessionMaxAge: 604800 # 7 days
|
||||
# Common values:
|
||||
# 3600 - 1 hour (high security)
|
||||
# 86400 - 1 day (default)
|
||||
# 259200 - 3 days
|
||||
# 604800 - 7 days
|
||||
# 2592000 - 30 days
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Access Control
|
||||
|
||||
### User Restrictions
|
||||
|
||||
| Parameter | Type | Description |
|
||||
|-----------|------|-------------|
|
||||
| `allowedUserDomains` | []string | Restrict to specific email domains |
|
||||
| `allowedUsers` | []string | Specific email addresses allowed |
|
||||
| `allowedRolesAndGroups` | []string | Required roles or groups |
|
||||
| `roleClaimName` | string | JWT claim for roles (default: `roles`) |
|
||||
| `groupClaimName` | string | JWT claim for groups (default: `groups`) |
|
||||
| `userIdentifierClaim` | string | Claim for user ID (default: `email`) |
|
||||
|
||||
### Domain Restriction
|
||||
|
||||
```yaml
|
||||
allowedUserDomains:
|
||||
- company.com
|
||||
- subsidiary.com
|
||||
```
|
||||
|
||||
### Specific User Access
|
||||
|
||||
```yaml
|
||||
allowedUsers:
|
||||
- user@example.com
|
||||
- contractor@external.org
|
||||
```
|
||||
|
||||
### Role-Based Access Control
|
||||
|
||||
```yaml
|
||||
allowedRolesAndGroups:
|
||||
- admin
|
||||
- developer
|
||||
roleClaimName: "https://myapp.com/roles" # For namespaced claims (Auth0)
|
||||
```
|
||||
|
||||
### Access Control Logic
|
||||
|
||||
- If only `allowedUsers` is set: Only specified emails can access
|
||||
- If only `allowedUserDomains` is set: Only specified domains can access
|
||||
- If both are set: Access granted if email is in `allowedUsers` OR domain is in `allowedUserDomains`
|
||||
- If neither is set: Any authenticated user can access
|
||||
|
||||
### Users Without Email (Azure AD)
|
||||
|
||||
For Azure AD service accounts or users without email:
|
||||
|
||||
```yaml
|
||||
userIdentifierClaim: sub # Options: sub, oid, upn, preferred_username
|
||||
allowedUsers:
|
||||
- "abc12345-6789-0abc-def0-123456789abc" # User object ID
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Headers Configuration
|
||||
|
||||
### Default Headers
|
||||
|
||||
The middleware sets these headers for downstream services:
|
||||
|
||||
| Header | Description |
|
||||
|--------|-------------|
|
||||
| `X-Forwarded-User` | User's email address |
|
||||
| `X-User-Groups` | Comma-separated user groups |
|
||||
| `X-User-Roles` | Comma-separated user roles |
|
||||
| `X-Auth-Request-Redirect` | Original request URI |
|
||||
| `X-Auth-Request-User` | User's email address |
|
||||
| `X-Auth-Request-Token` | User's ID token |
|
||||
|
||||
### Minimal Headers Mode
|
||||
|
||||
For "431 Request Header Fields Too Large" errors:
|
||||
|
||||
```yaml
|
||||
minimalHeaders: true # Only forwards X-Forwarded-User
|
||||
```
|
||||
|
||||
### Custom Templated Headers
|
||||
|
||||
```yaml
|
||||
headers:
|
||||
- name: "X-User-Email"
|
||||
value: "{{{{.Claims.email}}}}"
|
||||
- name: "X-User-ID"
|
||||
value: "{{{{.Claims.sub}}}}"
|
||||
- name: "Authorization"
|
||||
value: "Bearer {{{{.AccessToken}}}}"
|
||||
- name: "X-User-Roles"
|
||||
value: "{{{{range $i, $e := .Claims.roles}}}}{{{{if $i}}}},{{{{end}}}}{{{{$e}}}}{{{{end}}}}"
|
||||
```
|
||||
|
||||
**Template Variables:**
|
||||
- `{{.Claims.field}}` - ID token claims
|
||||
- `{{.AccessToken}}` - Raw access token
|
||||
- `{{.IdToken}}` - Raw ID token
|
||||
- `{{.RefreshToken}}` - Raw refresh token
|
||||
|
||||
**Important:** Use double curly braces (`{{{{` and `}}}}`) to escape templates in YAML.
|
||||
|
||||
---
|
||||
|
||||
## Security Headers
|
||||
|
||||
### Security Profiles
|
||||
|
||||
| Profile | Use Case | Security Level |
|
||||
|---------|----------|----------------|
|
||||
| `default` | Standard web apps | High |
|
||||
| `strict` | Maximum security | Very High |
|
||||
| `development` | Local development | Medium |
|
||||
| `api` | API endpoints | High |
|
||||
| `custom` | Custom requirements | Configurable |
|
||||
|
||||
### Basic Configuration
|
||||
|
||||
```yaml
|
||||
securityHeaders:
|
||||
enabled: true
|
||||
profile: "default"
|
||||
```
|
||||
|
||||
### API with CORS
|
||||
|
||||
```yaml
|
||||
securityHeaders:
|
||||
enabled: true
|
||||
profile: "api"
|
||||
corsEnabled: true
|
||||
corsAllowedOrigins:
|
||||
- "https://your-frontend.com"
|
||||
- "https://*.example.com"
|
||||
corsAllowCredentials: true
|
||||
```
|
||||
|
||||
### Custom Security Configuration
|
||||
|
||||
```yaml
|
||||
securityHeaders:
|
||||
enabled: true
|
||||
profile: "custom"
|
||||
|
||||
# Content Security Policy
|
||||
contentSecurityPolicy: "default-src 'self'; script-src 'self'"
|
||||
|
||||
# HSTS
|
||||
strictTransportSecurity: true
|
||||
strictTransportSecurityMaxAge: 31536000
|
||||
strictTransportSecuritySubdomains: true
|
||||
strictTransportSecurityPreload: true
|
||||
|
||||
# Frame and Content Protection
|
||||
frameOptions: "DENY"
|
||||
contentTypeOptions: "nosniff"
|
||||
xssProtection: "1; mode=block"
|
||||
referrerPolicy: "strict-origin-when-cross-origin"
|
||||
|
||||
# CORS
|
||||
corsEnabled: true
|
||||
corsAllowedOrigins: ["https://app.example.com"]
|
||||
corsAllowedMethods: ["GET", "POST", "PUT", "DELETE", "OPTIONS"]
|
||||
corsAllowedHeaders: ["Authorization", "Content-Type"]
|
||||
corsAllowCredentials: true
|
||||
corsMaxAge: 86400
|
||||
|
||||
# Custom Headers
|
||||
customHeaders:
|
||||
X-Custom-Header: "value"
|
||||
|
||||
# Server Identification
|
||||
disableServerHeader: true
|
||||
disablePoweredByHeader: true
|
||||
```
|
||||
|
||||
### CORS Origin Patterns
|
||||
|
||||
```yaml
|
||||
corsAllowedOrigins:
|
||||
- "https://example.com" # Exact match
|
||||
- "https://*.example.com" # Subdomain wildcard
|
||||
- "http://localhost:*" # Port wildcard (development)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Scope Configuration
|
||||
|
||||
### Default Behavior (Append Mode)
|
||||
|
||||
```yaml
|
||||
scopes:
|
||||
- roles
|
||||
- custom_scope
|
||||
# Result: ["openid", "profile", "email", "roles", "custom_scope"]
|
||||
```
|
||||
|
||||
### Override Mode
|
||||
|
||||
```yaml
|
||||
overrideScopes: true
|
||||
scopes:
|
||||
- openid
|
||||
- profile
|
||||
- custom_scope
|
||||
# Result: ["openid", "profile", "custom_scope"]
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Advanced Options
|
||||
|
||||
### Dynamic Client Registration (RFC 7591)
|
||||
|
||||
```yaml
|
||||
dynamicClientRegistration:
|
||||
enabled: true
|
||||
initialAccessToken: "your-token" # Optional
|
||||
persistCredentials: true
|
||||
credentialsFile: "/tmp/oidc-credentials.json"
|
||||
clientMetadata:
|
||||
redirect_uris:
|
||||
- "https://your-app.com/oauth2/callback"
|
||||
client_name: "My Application"
|
||||
application_type: "web"
|
||||
grant_types:
|
||||
- "authorization_code"
|
||||
- "refresh_token"
|
||||
```
|
||||
|
||||
### Multi-Replica Deployment
|
||||
|
||||
Without Redis, disable replay detection:
|
||||
|
||||
```yaml
|
||||
disableReplayDetection: true
|
||||
```
|
||||
|
||||
With Redis (recommended):
|
||||
|
||||
```yaml
|
||||
redis:
|
||||
enabled: true
|
||||
address: "redis:6379"
|
||||
cacheMode: "hybrid"
|
||||
```
|
||||
|
||||
See [REDIS.md](REDIS.md) for complete Redis configuration.
|
||||
|
||||
---
|
||||
|
||||
## Kubernetes Secrets
|
||||
|
||||
Reference secrets instead of hardcoding sensitive values:
|
||||
|
||||
```yaml
|
||||
providerURL: urn:k8s:secret:oidc-secret:ISSUER
|
||||
clientID: urn:k8s:secret:oidc-secret:CLIENT_ID
|
||||
clientSecret: urn:k8s:secret:oidc-secret:SECRET
|
||||
```
|
||||
|
||||
Create the secret:
|
||||
|
||||
```bash
|
||||
kubectl create secret generic oidc-secret \
|
||||
--from-literal=ISSUER=https://accounts.google.com \
|
||||
--from-literal=CLIENT_ID=your-client-id \
|
||||
--from-literal=SECRET=your-client-secret \
|
||||
-n traefik
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Environment Variable Naming
|
||||
|
||||
**Important:** Avoid using "API" as a substring in environment variable names when using `${VAR}` syntax in Traefik configuration. Traefik reserves `TRAEFIK_API_*` variables and the substring may cause conflicts.
|
||||
|
||||
```yaml
|
||||
# Bad - may cause issues
|
||||
sessionEncryptionKey: ${OIDC_SECRET_API}
|
||||
|
||||
# Good
|
||||
sessionEncryptionKey: ${OIDC_SECRET_SVC}
|
||||
```
|
||||
@@ -0,0 +1,455 @@
|
||||
# Development Guide
|
||||
|
||||
Guide for local development, testing, and contributing to the Traefik OIDC middleware.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Prerequisites](#prerequisites)
|
||||
- [Local Development Setup](#local-development-setup)
|
||||
- [Running Tests](#running-tests)
|
||||
- [Test Categories](#test-categories)
|
||||
- [CI/CD Pipeline](#cicd-pipeline)
|
||||
- [Code Quality](#code-quality)
|
||||
- [Contributing](#contributing)
|
||||
|
||||
---
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- **Go 1.23+** for plugin compilation
|
||||
- **Docker & Docker Compose** for local testing
|
||||
- **OIDC Provider** credentials (Google, Azure, etc.)
|
||||
|
||||
### Required Development Tools
|
||||
|
||||
```bash
|
||||
# golangci-lint (comprehensive linting)
|
||||
go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
|
||||
|
||||
# staticcheck (static analysis)
|
||||
go install honnef.co/go/tools/cmd/staticcheck@latest
|
||||
|
||||
# gosec (security scanning)
|
||||
go install github.com/securego/gosec/v2/cmd/gosec@latest
|
||||
|
||||
# govulncheck (vulnerability scanning)
|
||||
go install golang.org/x/vuln/cmd/govulncheck@latest
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Local Development Setup
|
||||
|
||||
### Docker Compose Environment
|
||||
|
||||
The repository includes a Docker Compose setup for testing the plugin locally.
|
||||
|
||||
#### 1. Host Configuration
|
||||
|
||||
Add to `/etc/hosts`:
|
||||
|
||||
```bash
|
||||
127.0.0.1 hello.localhost
|
||||
127.0.0.1 traefik.localhost
|
||||
```
|
||||
|
||||
#### 2. Plugin Configuration
|
||||
|
||||
The plugin is loaded using Traefik's **local plugins mode**:
|
||||
|
||||
- Plugin source: Parent directory (`../`)
|
||||
- Mount path: `/plugins-local/src/github.com/lukaszraczylo/traefikoidc`
|
||||
- Configuration: `experimental.localPlugins` in `traefik.yml`
|
||||
|
||||
#### 3. OIDC Provider Setup
|
||||
|
||||
Edit `docker/dynamic.yml` with your provider details:
|
||||
|
||||
**Google:**
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
oidc-auth:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: "https://accounts.google.com"
|
||||
clientID: "your-client-id.apps.googleusercontent.com"
|
||||
clientSecret: "your-google-client-secret"
|
||||
sessionEncryptionKey: "your-32-character-encryption-key"
|
||||
callbackURL: "/oauth2/callback"
|
||||
logoutURL: "/oauth2/logout"
|
||||
scopes:
|
||||
- "openid"
|
||||
- "email"
|
||||
- "profile"
|
||||
```
|
||||
|
||||
**Azure AD:**
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
oidc-auth:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: "https://login.microsoftonline.com/your-tenant-id/v2.0"
|
||||
clientID: "your-azure-client-id"
|
||||
clientSecret: "your-azure-client-secret"
|
||||
sessionEncryptionKey: "your-32-character-encryption-key"
|
||||
callbackURL: "/oauth2/callback"
|
||||
scopes:
|
||||
- "openid"
|
||||
- "email"
|
||||
- "profile"
|
||||
```
|
||||
|
||||
#### 4. Start Environment
|
||||
|
||||
```bash
|
||||
cd docker
|
||||
docker-compose up -d
|
||||
```
|
||||
|
||||
#### 5. Test Plugin
|
||||
|
||||
- **Protected App**: http://hello.localhost (redirects to OIDC)
|
||||
- **Traefik Dashboard**: http://traefik.localhost:8080
|
||||
|
||||
### Development Workflow
|
||||
|
||||
1. **Edit plugin code** in the project root
|
||||
2. **Build and test** (optional syntax check):
|
||||
```bash
|
||||
go mod tidy
|
||||
go build .
|
||||
go test ./...
|
||||
```
|
||||
3. **Restart Traefik** to reload plugin:
|
||||
```bash
|
||||
docker-compose restart traefik
|
||||
```
|
||||
4. **Test changes** at http://hello.localhost
|
||||
|
||||
### Debugging
|
||||
|
||||
**View plugin logs:**
|
||||
```bash
|
||||
docker-compose logs -f traefik | grep traefikoidc
|
||||
```
|
||||
|
||||
**Check plugin loading:**
|
||||
```bash
|
||||
docker-compose logs traefik | grep -i plugin
|
||||
```
|
||||
|
||||
**Verify plugin directory:**
|
||||
```bash
|
||||
docker-compose exec traefik ls -la /plugins-local/src/github.com/lukaszraczylo/traefikoidc/
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Running Tests
|
||||
|
||||
### Quick Start
|
||||
|
||||
```bash
|
||||
# Fast development testing (< 30 seconds)
|
||||
go test ./... -short
|
||||
|
||||
# Standard tests with race detector
|
||||
go test -race -timeout=15m ./...
|
||||
|
||||
# With coverage report
|
||||
go test -coverprofile=coverage.out ./...
|
||||
go tool cover -func=coverage.out
|
||||
```
|
||||
|
||||
### Test Modes
|
||||
|
||||
| Mode | Command | Duration | Use Case |
|
||||
|------|---------|----------|----------|
|
||||
| Quick | `go test ./... -short` | < 30s | During development |
|
||||
| Extended | `RUN_EXTENDED_TESTS=1 go test ./...` | 2-5 min | Before commits |
|
||||
| Long | `RUN_LONG_TESTS=1 go test ./...` | 5-15 min | Release validation |
|
||||
| Stress | `RUN_STRESS_TESTS=1 go test ./...` | 10-30 min | Performance testing |
|
||||
|
||||
### Environment Variables
|
||||
|
||||
```bash
|
||||
# Enable specific test types
|
||||
export RUN_EXTENDED_TESTS=1
|
||||
export RUN_LONG_TESTS=1
|
||||
export RUN_STRESS_TESTS=1
|
||||
|
||||
# Disable specific features
|
||||
export DISABLE_LEAK_DETECTION=1
|
||||
|
||||
# Customize test parameters
|
||||
export TEST_MAX_CONCURRENCY=10
|
||||
export TEST_MAX_ITERATIONS=50
|
||||
export TEST_MEMORY_THRESHOLD_MB=25.5
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Test Categories
|
||||
|
||||
### Quick Tests (Default)
|
||||
|
||||
- Basic functionality verification
|
||||
- Limited iterations (1-3)
|
||||
- Small data sets
|
||||
- Essential memory leak checks
|
||||
|
||||
**Configuration:**
|
||||
- Max Iterations: 3
|
||||
- Max Concurrency: 5
|
||||
- Memory Threshold: 2.0 MB
|
||||
- Timeout: 10 seconds
|
||||
|
||||
### Extended Tests
|
||||
|
||||
- Comprehensive testing before commits
|
||||
- More iterations (5-10)
|
||||
- Enhanced memory leak detection
|
||||
|
||||
**Configuration:**
|
||||
- Max Iterations: 10
|
||||
- Max Concurrency: 20
|
||||
- Memory Threshold: 10.0 MB
|
||||
- Timeout: 30 seconds
|
||||
|
||||
### Long Tests
|
||||
|
||||
- Performance validation
|
||||
- High iteration counts (50-100)
|
||||
- Large data sets
|
||||
|
||||
**Configuration:**
|
||||
- Max Iterations: 100
|
||||
- Max Concurrency: 50
|
||||
- Memory Threshold: 50.0 MB
|
||||
- Timeout: 60 seconds
|
||||
|
||||
### Stress Tests
|
||||
|
||||
- Maximum load testing
|
||||
- Edge case validation
|
||||
- Extreme parameters
|
||||
|
||||
**Configuration:**
|
||||
- Max Iterations: 500
|
||||
- Max Concurrency: 100
|
||||
- Memory Threshold: 100.0 MB
|
||||
- Timeout: 120 seconds
|
||||
|
||||
### Running Specific Test Suites
|
||||
|
||||
```bash
|
||||
# Memory leak tests
|
||||
go test -v -run='.*Leak.*' ./...
|
||||
|
||||
# Integration tests
|
||||
go test -v -run='.*Integration.*' ./...
|
||||
|
||||
# Regression tests
|
||||
go test -v -run='.*Regression.*' ./...
|
||||
|
||||
# Provider-specific tests
|
||||
go test -v -run='.*Azure.*' ./...
|
||||
go test -v -run='.*Google.*' ./...
|
||||
```
|
||||
|
||||
### Benchmarks
|
||||
|
||||
```bash
|
||||
# Quick benchmarks
|
||||
go test -bench=. -short
|
||||
|
||||
# Extended benchmarks
|
||||
RUN_EXTENDED_TESTS=1 go test -bench=.
|
||||
|
||||
# Memory profiling
|
||||
go test -bench=. -memprofile=mem.prof
|
||||
go tool pprof mem.prof
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## CI/CD Pipeline
|
||||
|
||||
The repository uses GitHub Actions for comprehensive validation with 20+ parallel checks.
|
||||
|
||||
### Triggered On
|
||||
|
||||
- Pull requests to `main` branch
|
||||
- Pushes to `main` branch
|
||||
|
||||
### Parallel Jobs
|
||||
|
||||
#### Code Quality (3 checks)
|
||||
- **Format & Basic Checks** - gofmt, go vet, go mod
|
||||
- **golangci-lint** - 30+ linters
|
||||
- **Staticcheck** - Advanced static analysis
|
||||
|
||||
#### Security (3 checks)
|
||||
- **Gosec** - Security vulnerability scanning
|
||||
- **Govulncheck** - Go vulnerability database
|
||||
- **CodeQL** - GitHub's semantic code analysis
|
||||
|
||||
#### Testing (9 suites)
|
||||
- Race Detector
|
||||
- Coverage (75% threshold)
|
||||
- Memory Leaks
|
||||
- Integration Tests
|
||||
- Regression Tests
|
||||
- Security Edge Cases
|
||||
- Session Tests
|
||||
- Token Tests
|
||||
- CSRF Tests
|
||||
|
||||
#### Provider Testing (9 providers)
|
||||
Tests run in parallel for:
|
||||
- Google
|
||||
- Azure AD
|
||||
- Auth0
|
||||
- Okta
|
||||
- Keycloak
|
||||
- AWS Cognito
|
||||
- GitLab
|
||||
- GitHub
|
||||
- Generic OIDC
|
||||
|
||||
#### Performance & Build (3 checks)
|
||||
- Benchmarks
|
||||
- Multi-platform Build (linux/darwin x amd64/arm64)
|
||||
- Go Version Compatibility (Go 1.23 & 1.24)
|
||||
|
||||
### Quality Gates
|
||||
|
||||
All PRs must pass:
|
||||
- All parallel checks
|
||||
- 75% test coverage minimum
|
||||
- Zero security vulnerabilities
|
||||
- No race conditions
|
||||
- No memory leaks
|
||||
- All providers tested
|
||||
- Builds on all platforms
|
||||
|
||||
---
|
||||
|
||||
## Code Quality
|
||||
|
||||
### Pre-Commit Checklist
|
||||
|
||||
```bash
|
||||
# Run before every commit
|
||||
gofmt -s -w . && \
|
||||
go mod tidy && \
|
||||
golangci-lint run && \
|
||||
go test -race -short ./... && \
|
||||
echo "Ready to commit!"
|
||||
```
|
||||
|
||||
### Local Validation
|
||||
|
||||
```bash
|
||||
# Format code
|
||||
gofmt -s -w .
|
||||
|
||||
# Run linter
|
||||
golangci-lint run
|
||||
|
||||
# Static analysis
|
||||
staticcheck ./...
|
||||
|
||||
# Security scan
|
||||
gosec ./...
|
||||
|
||||
# Vulnerability check
|
||||
govulncheck ./...
|
||||
|
||||
# Tests with race detector
|
||||
go test -race -timeout=15m -count=1 ./...
|
||||
|
||||
# Coverage report
|
||||
go test -coverprofile=coverage.out ./...
|
||||
go tool cover -func=coverage.out
|
||||
|
||||
# View coverage in browser
|
||||
go tool cover -html=coverage.out
|
||||
```
|
||||
|
||||
### Troubleshooting
|
||||
|
||||
**Coverage Below Threshold:**
|
||||
```bash
|
||||
go test -coverprofile=coverage.out ./...
|
||||
go tool cover -html=coverage.out # See uncovered lines
|
||||
```
|
||||
|
||||
**Race Condition Found:**
|
||||
```bash
|
||||
go test -race -v -run=TestName ./...
|
||||
```
|
||||
|
||||
**Linter Errors:**
|
||||
```bash
|
||||
golangci-lint run -v
|
||||
golangci-lint run --fix # Auto-fix some issues
|
||||
```
|
||||
|
||||
**Provider Test Fails:**
|
||||
```bash
|
||||
go test -v -run='.*Azure.*' ./...
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Contributing
|
||||
|
||||
### Development Guidelines
|
||||
|
||||
1. **Memory Management**: Ensure all goroutines can be cancelled and resources are bounded
|
||||
2. **Testing**: Add tests for new features, including memory leak tests where appropriate
|
||||
3. **Race Conditions**: Run tests with `-race` flag to detect race conditions
|
||||
4. **Documentation**: Update README and configuration files for new options
|
||||
|
||||
### Pull Request Template
|
||||
|
||||
PRs should include:
|
||||
- Description of changes
|
||||
- Type of change (bug fix, feature, breaking change, etc.)
|
||||
- Related issues
|
||||
- Provider impact (which providers are affected)
|
||||
- Testing performed
|
||||
- Security considerations
|
||||
- Performance impact
|
||||
- Breaking changes (if any)
|
||||
|
||||
### Checklist
|
||||
|
||||
Before submitting:
|
||||
- [ ] Code follows project style
|
||||
- [ ] Self-review completed
|
||||
- [ ] Tests added for new functionality
|
||||
- [ ] All tests pass locally
|
||||
- [ ] Documentation updated
|
||||
- [ ] No new warnings generated
|
||||
|
||||
### Code Owners
|
||||
|
||||
The repository uses CODEOWNERS for automatic PR reviewer assignment based on file paths.
|
||||
|
||||
### Dependabot
|
||||
|
||||
Automated dependency updates run weekly (Mondays 9 AM) with security updates prioritized.
|
||||
|
||||
---
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- [golangci-lint Rules](.golangci.yml)
|
||||
- [PR Template](.github/PULL_REQUEST_TEMPLATE.md)
|
||||
- [Workflow Documentation](.github/workflows/README.md)
|
||||
- [GitHub Actions Documentation](https://docs.github.com/en/actions)
|
||||
@@ -0,0 +1,582 @@
|
||||
# OIDC Provider Configuration Guide
|
||||
|
||||
Configuration reference for each supported OIDC provider.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Provider Support Matrix](#provider-support-matrix)
|
||||
- [Google](#google)
|
||||
- [Microsoft Azure AD](#microsoft-azure-ad)
|
||||
- [Auth0](#auth0)
|
||||
- [Okta](#okta)
|
||||
- [Keycloak](#keycloak)
|
||||
- [AWS Cognito](#aws-cognito)
|
||||
- [GitLab](#gitlab)
|
||||
- [GitHub](#github)
|
||||
- [Generic OIDC](#generic-oidc)
|
||||
- [Automatic Scope Filtering](#automatic-scope-filtering)
|
||||
|
||||
---
|
||||
|
||||
## Provider Support Matrix
|
||||
|
||||
| Provider | OIDC Support | Refresh Tokens | Auto-Detection | ID Tokens |
|
||||
|----------|-------------|----------------|----------------|-----------|
|
||||
| Google | Full | Yes | `accounts.google.com` | Yes |
|
||||
| Azure AD | Full | Yes | `login.microsoftonline.com` | Yes |
|
||||
| Auth0 | Full | Yes | `*.auth0.com` | Yes |
|
||||
| Okta | Full | Yes | `*.okta.com` | Yes |
|
||||
| Keycloak | Full | Yes | `/auth/realms/` path | Yes |
|
||||
| AWS Cognito | Full | Yes | `cognito-idp.*.amazonaws.com` | Yes |
|
||||
| GitLab | Full | Yes | `gitlab.com` | Yes |
|
||||
| GitHub | OAuth 2.0 Only | No | `github.com` | No |
|
||||
| Generic | Full | Yes | Any OIDC endpoint | Yes |
|
||||
|
||||
---
|
||||
|
||||
## Google
|
||||
|
||||
### Provider URL
|
||||
|
||||
```yaml
|
||||
providerURL: "https://accounts.google.com"
|
||||
```
|
||||
|
||||
### Configuration
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-google
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: "https://accounts.google.com"
|
||||
clientID: "your-id.apps.googleusercontent.com"
|
||||
clientSecret: "your-client-secret"
|
||||
callbackURL: "/oauth2/callback"
|
||||
sessionEncryptionKey: "your-32-char-encryption-key-here"
|
||||
scopes:
|
||||
- openid
|
||||
- email
|
||||
- profile
|
||||
allowedUserDomains:
|
||||
- "your-gsuite-domain.com" # Optional: Workspace restriction
|
||||
forceHttps: true
|
||||
enablePkce: true
|
||||
```
|
||||
|
||||
### Google-Specific Features
|
||||
|
||||
- **Automatic offline access**: Middleware adds `access_type=offline` and `prompt=consent`
|
||||
- **Scope filtering**: Automatically removes unsupported `offline_access` scope
|
||||
- **Workspace domains**: Restrict to specific Google Workspace domains via `hd` claim
|
||||
|
||||
### Google Cloud Console Setup
|
||||
|
||||
1. Go to [Google Cloud Console](https://console.cloud.google.com/)
|
||||
2. Create or select a project
|
||||
3. Navigate to APIs & Services > Credentials
|
||||
4. Create OAuth 2.0 Client ID (Web application)
|
||||
5. Add authorized redirect URI: `https://your-domain.com/oauth2/callback`
|
||||
6. Configure OAuth consent screen (must be "Published" for production)
|
||||
|
||||
---
|
||||
|
||||
## Microsoft Azure AD
|
||||
|
||||
### Provider URL
|
||||
|
||||
```yaml
|
||||
# Single tenant
|
||||
providerURL: "https://login.microsoftonline.com/{tenant-id}/v2.0"
|
||||
|
||||
# Multi-tenant
|
||||
providerURL: "https://login.microsoftonline.com/common/v2.0"
|
||||
```
|
||||
|
||||
### Basic Configuration
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-azure
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: "https://login.microsoftonline.com/common/v2.0"
|
||||
clientID: "your-azure-client-id"
|
||||
clientSecret: "your-azure-client-secret"
|
||||
callbackURL: "/oauth2/callback"
|
||||
sessionEncryptionKey: "your-32-char-encryption-key-here"
|
||||
scopes:
|
||||
- openid
|
||||
- profile
|
||||
- email
|
||||
- offline_access
|
||||
allowedRolesAndGroups:
|
||||
- "App.Users"
|
||||
- "Admin.Group"
|
||||
forceHttps: true
|
||||
```
|
||||
|
||||
### With Application ID URI (API Access)
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-azure-api
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: "https://login.microsoftonline.com/common/v2.0"
|
||||
clientID: "your-azure-client-id"
|
||||
clientSecret: "your-azure-client-secret"
|
||||
audience: "api://your-azure-client-id" # Application ID URI
|
||||
callbackURL: "/oauth2/callback"
|
||||
sessionEncryptionKey: "your-32-char-encryption-key-here"
|
||||
forceHttps: true
|
||||
```
|
||||
|
||||
### Users Without Email
|
||||
|
||||
```yaml
|
||||
userIdentifierClaim: sub # Options: sub, oid, upn, preferred_username
|
||||
allowedUsers:
|
||||
- "user-object-id-1"
|
||||
- "user-object-id-2"
|
||||
```
|
||||
|
||||
### Azure AD Setup
|
||||
|
||||
1. Go to [Azure Portal](https://portal.azure.com/)
|
||||
2. Navigate to Azure Active Directory > App registrations
|
||||
3. Create new registration
|
||||
4. Add redirect URI: `https://your-domain.com/oauth2/callback`
|
||||
5. Create client secret in Certificates & secrets
|
||||
6. Configure Token Configuration for group claims
|
||||
|
||||
---
|
||||
|
||||
## Auth0
|
||||
|
||||
### Provider URL
|
||||
|
||||
```yaml
|
||||
providerURL: "https://your-domain.auth0.com"
|
||||
```
|
||||
|
||||
### Basic Configuration
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-auth0
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: "https://your-domain.auth0.com"
|
||||
clientID: "your-auth0-client-id"
|
||||
clientSecret: "your-auth0-client-secret"
|
||||
callbackURL: "/oauth2/callback"
|
||||
sessionEncryptionKey: "your-32-char-encryption-key-here"
|
||||
scopes:
|
||||
- openid
|
||||
- profile
|
||||
- email
|
||||
- offline_access
|
||||
postLogoutRedirectUri: "https://your-app.com"
|
||||
forceHttps: true
|
||||
enablePkce: true
|
||||
```
|
||||
|
||||
### With Custom API Audience
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-auth0-api
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: "https://your-domain.auth0.com"
|
||||
clientID: "your-auth0-client-id"
|
||||
clientSecret: "your-auth0-client-secret"
|
||||
audience: "https://api.your-domain.com" # API identifier
|
||||
strictAudienceValidation: true
|
||||
callbackURL: "/oauth2/callback"
|
||||
sessionEncryptionKey: "your-32-char-encryption-key-here"
|
||||
roleClaimName: "https://your-app.com/roles" # Namespaced claim
|
||||
groupClaimName: "https://your-app.com/groups"
|
||||
allowedRolesAndGroups:
|
||||
- admin
|
||||
- editor
|
||||
```
|
||||
|
||||
### Auth0 Action for Custom Claims
|
||||
|
||||
```javascript
|
||||
exports.onExecutePostLogin = async (event, api) => {
|
||||
const namespace = 'https://your-app.com/';
|
||||
if (event.authorization) {
|
||||
api.idToken.setCustomClaim(namespace + 'roles', event.authorization.roles);
|
||||
api.idToken.setCustomClaim('email', event.user.email);
|
||||
}
|
||||
};
|
||||
```
|
||||
|
||||
### Auth0 Setup
|
||||
|
||||
1. Go to [Auth0 Dashboard](https://manage.auth0.com/)
|
||||
2. Create Regular Web Application
|
||||
3. Configure Allowed Callback URLs: `https://your-domain.com/oauth2/callback`
|
||||
4. Configure Allowed Logout URLs: `https://your-domain.com/oauth2/logout`
|
||||
5. Enable OIDC Conformant in Advanced Settings
|
||||
6. Create API in APIs section for custom audiences
|
||||
|
||||
See [AUTH0_AUDIENCE_GUIDE.md](AUTH0_AUDIENCE_GUIDE.md) for detailed audience configuration.
|
||||
|
||||
---
|
||||
|
||||
## Okta
|
||||
|
||||
### Provider URL
|
||||
|
||||
```yaml
|
||||
providerURL: "https://your-domain.okta.com"
|
||||
# Or with custom authorization server:
|
||||
providerURL: "https://your-domain.okta.com/oauth2/default"
|
||||
```
|
||||
|
||||
### Configuration
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-okta
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: "https://your-domain.okta.com"
|
||||
clientID: "your-okta-client-id"
|
||||
clientSecret: "your-okta-client-secret"
|
||||
callbackURL: "/oauth2/callback"
|
||||
sessionEncryptionKey: "your-32-char-encryption-key-here"
|
||||
scopes:
|
||||
- openid
|
||||
- profile
|
||||
- email
|
||||
- groups
|
||||
- offline_access
|
||||
allowedRolesAndGroups:
|
||||
- admin
|
||||
- "Everyone"
|
||||
forceHttps: true
|
||||
enablePkce: true
|
||||
```
|
||||
|
||||
### Okta Setup
|
||||
|
||||
1. Access Okta Admin Console
|
||||
2. Go to Applications > Create App Integration
|
||||
3. Select OIDC - OpenID Connect > Web Application
|
||||
4. Configure Sign-in redirect URIs: `https://your-domain.com/oauth2/callback`
|
||||
5. Configure Sign-out redirect URIs: `https://your-domain.com/oauth2/logout`
|
||||
6. Enable Authorization Code and Refresh Token grant types
|
||||
7. Configure Groups claim in authorization server
|
||||
|
||||
---
|
||||
|
||||
## Keycloak
|
||||
|
||||
### Provider URL
|
||||
|
||||
```yaml
|
||||
providerURL: "https://keycloak.your-domain.com/realms/{realm-name}"
|
||||
```
|
||||
|
||||
### Configuration
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-keycloak
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: "https://keycloak.company.com/realms/your-realm"
|
||||
clientID: "your-keycloak-client-id"
|
||||
clientSecret: "your-keycloak-client-secret"
|
||||
callbackURL: "/oauth2/callback"
|
||||
sessionEncryptionKey: "your-32-char-encryption-key-here"
|
||||
scopes:
|
||||
- openid
|
||||
- profile
|
||||
- email
|
||||
- roles
|
||||
- groups
|
||||
- offline_access
|
||||
allowedRolesAndGroups:
|
||||
- admin
|
||||
- editor
|
||||
forceHttps: true
|
||||
enablePkce: true
|
||||
```
|
||||
|
||||
### Internal Network Deployment
|
||||
|
||||
For private IP addresses (Docker networks, Kubernetes):
|
||||
|
||||
```yaml
|
||||
providerURL: "https://192.168.1.100:8443/realms/your-realm"
|
||||
allowPrivateIPAddresses: true # Required for private IPs
|
||||
```
|
||||
|
||||
### Keycloak Client Setup
|
||||
|
||||
1. Access Keycloak Admin Console
|
||||
2. Select your realm
|
||||
3. Go to Clients > Create client
|
||||
4. Set Client Protocol: openid-connect
|
||||
5. Set Access Type: confidential
|
||||
6. Add Valid Redirect URIs: `https://your-domain.com/oauth2/callback`
|
||||
7. Generate client secret in Credentials tab
|
||||
8. Configure mappers to add claims to ID Token:
|
||||
- Email: User Property mapper with "Add to ID token" enabled
|
||||
- Roles: User Client Role mapper with "Add to ID token" enabled
|
||||
- Groups: Group Membership mapper with "Add to ID token" enabled
|
||||
|
||||
See [KEYCLOAK_SETUP_GUIDE.md](KEYCLOAK_SETUP_GUIDE.md) for detailed step-by-step setup instructions, mapper configuration, troubleshooting, and performance optimization.
|
||||
|
||||
---
|
||||
|
||||
## AWS Cognito
|
||||
|
||||
### Provider URL
|
||||
|
||||
```yaml
|
||||
providerURL: "https://cognito-idp.{region}.amazonaws.com/{user-pool-id}"
|
||||
```
|
||||
|
||||
### Configuration
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-cognito
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: "https://cognito-idp.us-east-1.amazonaws.com/us-east-1_ABCDEF123"
|
||||
clientID: "your-cognito-client-id"
|
||||
clientSecret: "your-cognito-client-secret"
|
||||
callbackURL: "/oauth2/callback"
|
||||
sessionEncryptionKey: "your-32-char-encryption-key-here"
|
||||
scopes:
|
||||
- openid
|
||||
- profile
|
||||
- email
|
||||
- aws.cognito.signin.user.admin
|
||||
allowedRolesAndGroups:
|
||||
- admin
|
||||
- users
|
||||
forceHttps: true
|
||||
```
|
||||
|
||||
### AWS Cognito Setup
|
||||
|
||||
1. Create Cognito User Pool
|
||||
2. Create App Client with OIDC scopes
|
||||
3. Configure App Client settings:
|
||||
- Callback URLs: `https://your-domain.com/oauth2/callback`
|
||||
- Sign out URLs: `https://your-domain.com/oauth2/logout`
|
||||
- OAuth flows: Authorization code grant
|
||||
4. Configure hosted UI domain (optional)
|
||||
5. Set up groups for role-based access
|
||||
|
||||
---
|
||||
|
||||
## GitLab
|
||||
|
||||
### Provider URL
|
||||
|
||||
```yaml
|
||||
# GitLab.com
|
||||
providerURL: "https://gitlab.com"
|
||||
|
||||
# Self-hosted
|
||||
providerURL: "https://gitlab.your-company.com"
|
||||
```
|
||||
|
||||
### Configuration
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-gitlab
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: "https://gitlab.com"
|
||||
clientID: "your-gitlab-application-id"
|
||||
clientSecret: "your-gitlab-application-secret"
|
||||
callbackURL: "/oauth2/callback"
|
||||
sessionEncryptionKey: "your-32-char-encryption-key-here"
|
||||
scopes:
|
||||
- openid
|
||||
- profile
|
||||
- email
|
||||
# Note: GitLab doesn't require offline_access scope
|
||||
# Refresh tokens are issued automatically with openid
|
||||
allowedRolesAndGroups:
|
||||
- developers
|
||||
- maintainers
|
||||
forceHttps: true
|
||||
enablePkce: true
|
||||
```
|
||||
|
||||
### GitLab Setup
|
||||
|
||||
1. Go to GitLab Settings > Applications
|
||||
2. Create new application
|
||||
3. Add scopes: `openid`, `profile`, `email`
|
||||
4. Set redirect URI: `https://your-domain.com/oauth2/callback`
|
||||
5. Save and note Application ID and Secret
|
||||
|
||||
---
|
||||
|
||||
## GitHub
|
||||
|
||||
### Provider URL
|
||||
|
||||
```yaml
|
||||
providerURL: "https://github.com"
|
||||
```
|
||||
|
||||
### Configuration
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oauth-github
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: "https://github.com/login/oauth"
|
||||
clientID: "your-github-client-id"
|
||||
clientSecret: "your-github-client-secret"
|
||||
callbackURL: "/oauth2/callback"
|
||||
sessionEncryptionKey: "your-32-char-encryption-key-here"
|
||||
scopes:
|
||||
- user:email
|
||||
- read:user
|
||||
allowedUsers:
|
||||
- "github-username"
|
||||
forceHttps: true
|
||||
```
|
||||
|
||||
### Limitations
|
||||
|
||||
- **OAuth 2.0 only** - Not OpenID Connect
|
||||
- **No ID tokens** - Only access tokens for API calls
|
||||
- **No refresh tokens** - Users must re-authenticate on expiry
|
||||
- **No standard claims** - User info requires API calls
|
||||
|
||||
Use GitHub only for API access, not for user authentication with claims.
|
||||
|
||||
### GitHub Setup
|
||||
|
||||
1. Go to GitHub Settings > Developer settings > OAuth Apps
|
||||
2. Create new OAuth App
|
||||
3. Set Authorization callback URL: `https://your-domain.com/oauth2/callback`
|
||||
4. Note Client ID and generate Client Secret
|
||||
|
||||
---
|
||||
|
||||
## Generic OIDC
|
||||
|
||||
For any OIDC-compliant provider not listed above.
|
||||
|
||||
### Configuration
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-generic
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: "https://oidc.your-provider.com"
|
||||
clientID: "your-client-id"
|
||||
clientSecret: "your-client-secret"
|
||||
callbackURL: "/oauth2/callback"
|
||||
sessionEncryptionKey: "your-32-char-encryption-key-here"
|
||||
scopes:
|
||||
- openid
|
||||
- profile
|
||||
- email
|
||||
forceHttps: true
|
||||
enablePkce: true
|
||||
```
|
||||
|
||||
### Requirements
|
||||
|
||||
- Provider must expose `.well-known/openid-configuration` endpoint
|
||||
- Must support authorization code flow
|
||||
- ID tokens must contain required claims (email, sub, etc.)
|
||||
|
||||
---
|
||||
|
||||
## Automatic Scope Filtering
|
||||
|
||||
The middleware automatically filters OAuth scopes based on the provider's declared capabilities.
|
||||
|
||||
### How It Works
|
||||
|
||||
1. Fetches provider's `.well-known/openid-configuration`
|
||||
2. Extracts `scopes_supported` field
|
||||
3. Filters requested scopes to only include supported ones
|
||||
4. Falls back to all requested scopes if provider doesn't declare supported scopes
|
||||
|
||||
### Example: Self-Hosted GitLab
|
||||
|
||||
Self-hosted GitLab may reject `offline_access` scope:
|
||||
|
||||
```yaml
|
||||
scopes:
|
||||
- openid
|
||||
- profile
|
||||
- email
|
||||
- offline_access # Will be automatically filtered out if unsupported
|
||||
```
|
||||
|
||||
The middleware will:
|
||||
1. Read GitLab's discovery document
|
||||
2. Detect `offline_access` is NOT in `scopes_supported`
|
||||
3. Filter it out automatically
|
||||
4. Authentication succeeds
|
||||
|
||||
### Logging
|
||||
|
||||
```
|
||||
INFO: ScopeFilter: Filtered unsupported scopes: [offline_access]
|
||||
DEBUG: ScopeFilter: Final filtered scopes: [openid profile email]
|
||||
```
|
||||
|
||||
### Troubleshooting
|
||||
|
||||
If a provider rejects scopes even after filtering:
|
||||
1. Check the provider's discovery document: `curl https://provider/.well-known/openid-configuration`
|
||||
2. Use `overrideScopes: true` with only supported scopes
|
||||
3. Review middleware debug logs for filtering decisions
|
||||
@@ -1,770 +0,0 @@
|
||||
# Provider-Specific Configuration Guide
|
||||
|
||||
This guide covers the configuration requirements and best practices for each supported OIDC provider.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Google](#google)
|
||||
- [Microsoft Azure AD](#microsoft-azure-ad)
|
||||
- [Auth0](#auth0)
|
||||
- [GitHub](#github)
|
||||
- [GitLab](#gitlab)
|
||||
- [AWS Cognito](#aws-cognito)
|
||||
- [Keycloak](#keycloak)
|
||||
- [Okta](#okta)
|
||||
- [Generic OIDC](#generic-oidc)
|
||||
|
||||
---
|
||||
|
||||
## Google
|
||||
|
||||
### Provider URL
|
||||
```yaml
|
||||
providerUrl: "https://accounts.google.com"
|
||||
```
|
||||
|
||||
### Required Configuration
|
||||
```yaml
|
||||
clientId: "your-google-client-id.apps.googleusercontent.com"
|
||||
clientSecret: "your-google-client-secret"
|
||||
callbackUrl: "https://your-domain.com/auth/callback"
|
||||
scopes: ["openid", "profile", "email"]
|
||||
```
|
||||
|
||||
### Google-Specific Features
|
||||
- **Automatic offline access**: Google provider automatically adds `access_type=offline` and `prompt=consent`
|
||||
- **Scope filtering**: Automatically removes `offline_access` scope (not used by Google)
|
||||
- **Refresh token support**: Fully supported
|
||||
- **Domain restrictions**: Can restrict by Google Workspace domains
|
||||
|
||||
### Example Configuration
|
||||
```yaml
|
||||
# Traefik dynamic configuration
|
||||
http:
|
||||
middlewares:
|
||||
google-oidc:
|
||||
plugin:
|
||||
traefik-oidc:
|
||||
providerUrl: "https://accounts.google.com"
|
||||
clientId: "123456789-abcdef.apps.googleusercontent.com"
|
||||
clientSecret: "GOCSPX-your-client-secret"
|
||||
callbackUrl: "https://app.example.com/auth/callback"
|
||||
logoutUrl: "https://app.example.com/auth/logout"
|
||||
scopes: ["openid", "profile", "email"]
|
||||
allowedUserDomains: ["example.com", "company.org"]
|
||||
forceHttps: true
|
||||
enablePkce: true
|
||||
```
|
||||
|
||||
### Google OAuth Console Setup
|
||||
1. Go to [Google Cloud Console](https://console.cloud.google.com/)
|
||||
2. Create or select a project
|
||||
3. Enable Google+ API
|
||||
4. Create OAuth 2.0 credentials
|
||||
5. Add authorized redirect URIs: `https://your-domain.com/auth/callback`
|
||||
|
||||
---
|
||||
|
||||
## Microsoft Azure AD
|
||||
|
||||
### Provider URL
|
||||
```yaml
|
||||
# For Azure AD (single tenant)
|
||||
providerUrl: "https://login.microsoftonline.com/{tenant-id}/v2.0"
|
||||
|
||||
# For Azure AD (multi-tenant)
|
||||
providerUrl: "https://login.microsoftonline.com/common/v2.0"
|
||||
```
|
||||
|
||||
### Required Configuration
|
||||
```yaml
|
||||
clientId: "your-azure-application-id"
|
||||
clientSecret: "your-azure-client-secret"
|
||||
callbackUrl: "https://your-domain.com/auth/callback"
|
||||
scopes: ["openid", "profile", "email", "offline_access"]
|
||||
```
|
||||
|
||||
### Azure-Specific Features
|
||||
- **Response mode**: Automatically adds `response_mode=query`
|
||||
- **Offline access**: Requires `offline_access` scope for refresh tokens
|
||||
- **Access token validation**: Supports both JWT and opaque access tokens
|
||||
- **Tenant isolation**: Can restrict to specific Azure AD tenants
|
||||
|
||||
### Example Configuration
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
azure-oidc:
|
||||
plugin:
|
||||
traefik-oidc:
|
||||
providerUrl: "https://login.microsoftonline.com/common/v2.0"
|
||||
clientId: "12345678-1234-1234-1234-123456789abc"
|
||||
clientSecret: "your-azure-client-secret"
|
||||
callbackUrl: "https://app.example.com/auth/callback"
|
||||
logoutUrl: "https://app.example.com/auth/logout"
|
||||
postLogoutRedirectUri: "https://app.example.com"
|
||||
scopes: ["openid", "profile", "email", "offline_access"]
|
||||
allowedRolesAndGroups: ["App.Users", "Admin.Group"]
|
||||
forceHttps: true
|
||||
```
|
||||
|
||||
### Azure App Registration Setup
|
||||
1. Go to [Azure Portal](https://portal.azure.com/)
|
||||
2. Navigate to "Azure Active Directory" > "App registrations"
|
||||
3. Create new registration
|
||||
4. Add redirect URI: `https://your-domain.com/auth/callback`
|
||||
5. Create client secret in "Certificates & secrets"
|
||||
6. Configure API permissions for required scopes
|
||||
|
||||
---
|
||||
|
||||
## Auth0
|
||||
|
||||
### Provider URL
|
||||
```yaml
|
||||
providerUrl: "https://your-domain.auth0.com"
|
||||
```
|
||||
|
||||
### Required Configuration
|
||||
```yaml
|
||||
clientId: "your-auth0-client-id"
|
||||
clientSecret: "your-auth0-client-secret"
|
||||
callbackUrl: "https://your-domain.com/auth/callback"
|
||||
scopes: ["openid", "profile", "email", "offline_access"]
|
||||
```
|
||||
|
||||
### Auth0-Specific Features
|
||||
- **Custom domains**: Supports Auth0 custom domains
|
||||
- **Rules and hooks**: Leverages Auth0's extensibility
|
||||
- **Social connections**: Works with Auth0's social identity providers
|
||||
- **Offline access**: Requires `offline_access` scope
|
||||
|
||||
### Example Configuration
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
auth0-oidc:
|
||||
plugin:
|
||||
traefik-oidc:
|
||||
providerUrl: "https://company.auth0.com"
|
||||
clientId: "abcdef123456789"
|
||||
clientSecret: "your-auth0-client-secret"
|
||||
callbackUrl: "https://app.example.com/auth/callback"
|
||||
logoutUrl: "https://app.example.com/auth/logout"
|
||||
postLogoutRedirectUri: "https://app.example.com"
|
||||
scopes: ["openid", "profile", "email", "offline_access"]
|
||||
allowedUsers: ["user@example.com", "admin@company.com"]
|
||||
forceHttps: true
|
||||
enablePkce: true
|
||||
```
|
||||
|
||||
### Auth0 Application Setup
|
||||
1. Go to [Auth0 Dashboard](https://manage.auth0.com/)
|
||||
2. Create new application (Regular Web Application)
|
||||
3. Configure allowed callback URLs: `https://your-domain.com/auth/callback`
|
||||
4. Configure allowed logout URLs: `https://your-domain.com/auth/logout`
|
||||
5. Enable OIDC Conformant in Advanced Settings
|
||||
|
||||
---
|
||||
|
||||
## GitHub
|
||||
|
||||
### Provider URL
|
||||
```yaml
|
||||
providerUrl: "https://github.com"
|
||||
```
|
||||
|
||||
### Required Configuration
|
||||
```yaml
|
||||
clientId: "your-github-client-id"
|
||||
clientSecret: "your-github-client-secret"
|
||||
callbackUrl: "https://your-domain.com/auth/callback"
|
||||
scopes: ["read:user", "user:email"]
|
||||
```
|
||||
|
||||
### GitHub-Specific Features
|
||||
- **Organization membership**: Can restrict by GitHub organization
|
||||
- **Team membership**: Can restrict by specific teams
|
||||
- **Limited OIDC**: GitHub has limited OIDC support
|
||||
- **Email verification**: Requires verified email addresses
|
||||
|
||||
### Example Configuration
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
github-oidc:
|
||||
plugin:
|
||||
traefik-oidc:
|
||||
providerUrl: "https://github.com"
|
||||
clientId: "Iv1.abcdef123456"
|
||||
clientSecret: "your-github-client-secret"
|
||||
callbackUrl: "https://app.example.com/auth/callback"
|
||||
logoutUrl: "https://app.example.com/auth/logout"
|
||||
scopes: ["read:user", "user:email"]
|
||||
allowedUsers: ["octocat", "github-user"]
|
||||
forceHttps: true
|
||||
```
|
||||
|
||||
### GitHub OAuth App Setup
|
||||
1. Go to GitHub Settings > Developer settings > OAuth Apps
|
||||
2. Create new OAuth App
|
||||
3. Set Authorization callback URL: `https://your-domain.com/auth/callback`
|
||||
4. Note the Client ID and generate Client Secret
|
||||
|
||||
---
|
||||
|
||||
## GitLab
|
||||
|
||||
### Provider URL
|
||||
```yaml
|
||||
# GitLab.com
|
||||
providerUrl: "https://gitlab.com"
|
||||
|
||||
# Self-hosted GitLab
|
||||
providerUrl: "https://gitlab.your-company.com"
|
||||
```
|
||||
|
||||
### Required Configuration
|
||||
```yaml
|
||||
clientId: "your-gitlab-application-id"
|
||||
clientSecret: "your-gitlab-application-secret"
|
||||
callbackUrl: "https://your-domain.com/auth/callback"
|
||||
scopes: ["openid", "profile", "email"]
|
||||
```
|
||||
|
||||
### GitLab-Specific Features
|
||||
- **Self-hosted support**: Works with self-hosted GitLab instances
|
||||
- **Group membership**: Can restrict by GitLab groups
|
||||
- **Project access**: Can validate project permissions
|
||||
- **Offline access**: Supports refresh tokens with `offline_access`
|
||||
|
||||
### Example Configuration
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
gitlab-oidc:
|
||||
plugin:
|
||||
traefik-oidc:
|
||||
providerUrl: "https://gitlab.com"
|
||||
clientId: "abcdef123456789"
|
||||
clientSecret: "your-gitlab-application-secret"
|
||||
callbackUrl: "https://app.example.com/auth/callback"
|
||||
logoutUrl: "https://app.example.com/auth/logout"
|
||||
scopes: ["openid", "profile", "email", "offline_access"]
|
||||
allowedRolesAndGroups: ["developers", "maintainers"]
|
||||
forceHttps: true
|
||||
enablePkce: true
|
||||
```
|
||||
|
||||
### GitLab Application Setup
|
||||
1. Go to GitLab Settings > Applications
|
||||
2. Create new application
|
||||
3. Add scopes: `openid`, `profile`, `email`
|
||||
4. Set redirect URI: `https://your-domain.com/auth/callback`
|
||||
5. Save and note the Application ID and Secret
|
||||
|
||||
---
|
||||
|
||||
## AWS Cognito
|
||||
|
||||
### Provider URL
|
||||
```yaml
|
||||
providerUrl: "https://cognito-idp.{region}.amazonaws.com/{user-pool-id}"
|
||||
```
|
||||
|
||||
### Required Configuration
|
||||
```yaml
|
||||
clientId: "your-cognito-app-client-id"
|
||||
clientSecret: "your-cognito-app-client-secret" # If app client has secret
|
||||
callbackUrl: "https://your-domain.com/auth/callback"
|
||||
scopes: ["openid", "profile", "email"]
|
||||
```
|
||||
|
||||
### Cognito-Specific Features
|
||||
- **User pools**: Integrates with Cognito User Pools
|
||||
- **Custom attributes**: Supports custom user attributes
|
||||
- **Groups**: Can validate Cognito user group membership
|
||||
- **Regional endpoints**: Requires region-specific URLs
|
||||
|
||||
### Example Configuration
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
cognito-oidc:
|
||||
plugin:
|
||||
traefik-oidc:
|
||||
providerUrl: "https://cognito-idp.us-east-1.amazonaws.com/us-east-1_ABCDEF123"
|
||||
clientId: "1234567890abcdefghij"
|
||||
clientSecret: "your-cognito-client-secret"
|
||||
callbackUrl: "https://app.example.com/auth/callback"
|
||||
logoutUrl: "https://app.example.com/auth/logout"
|
||||
scopes: ["openid", "profile", "email"]
|
||||
allowedRolesAndGroups: ["admin", "users"]
|
||||
forceHttps: true
|
||||
```
|
||||
|
||||
### AWS Cognito Setup
|
||||
1. Create Cognito User Pool
|
||||
2. Create App Client with OIDC scopes
|
||||
3. Configure App Client settings:
|
||||
- Callback URLs: `https://your-domain.com/auth/callback`
|
||||
- Sign out URLs: `https://your-domain.com/auth/logout`
|
||||
- OAuth flows: Authorization code grant
|
||||
4. Configure hosted UI domain (optional)
|
||||
|
||||
---
|
||||
|
||||
## Keycloak
|
||||
|
||||
### Provider URL
|
||||
```yaml
|
||||
providerUrl: "https://keycloak.your-company.com/realms/{realm-name}"
|
||||
```
|
||||
|
||||
### Required Configuration
|
||||
```yaml
|
||||
clientId: "your-keycloak-client-id"
|
||||
clientSecret: "your-keycloak-client-secret"
|
||||
callbackUrl: "https://your-domain.com/auth/callback"
|
||||
scopes: ["openid", "profile", "email"]
|
||||
```
|
||||
|
||||
### Keycloak-Specific Features
|
||||
- **Realm support**: Multi-realm deployments
|
||||
- **Custom mappers**: Rich claim mapping capabilities
|
||||
- **Role-based access**: Fine-grained role management
|
||||
- **Offline access**: Full refresh token support
|
||||
|
||||
### Example Configuration
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
keycloak-oidc:
|
||||
plugin:
|
||||
traefik-oidc:
|
||||
providerUrl: "https://keycloak.company.com/realms/employees"
|
||||
clientId: "traefik-app"
|
||||
clientSecret: "your-keycloak-client-secret"
|
||||
callbackUrl: "https://app.example.com/auth/callback"
|
||||
logoutUrl: "https://app.example.com/auth/logout"
|
||||
postLogoutRedirectUri: "https://app.example.com"
|
||||
scopes: ["openid", "profile", "email", "offline_access"]
|
||||
allowedRolesAndGroups: ["app-users", "administrators"]
|
||||
forceHttps: true
|
||||
enablePkce: true
|
||||
```
|
||||
|
||||
### Keycloak Client Setup
|
||||
1. Access Keycloak Admin Console
|
||||
2. Select appropriate realm
|
||||
3. Create new client:
|
||||
- Client Protocol: openid-connect
|
||||
- Access Type: confidential
|
||||
- Valid Redirect URIs: `https://your-domain.com/auth/callback`
|
||||
4. Configure client scopes and mappers
|
||||
5. Generate client secret in Credentials tab
|
||||
|
||||
---
|
||||
|
||||
## Okta
|
||||
|
||||
### Provider URL
|
||||
```yaml
|
||||
providerUrl: "https://your-domain.okta.com"
|
||||
```
|
||||
|
||||
### Required Configuration
|
||||
```yaml
|
||||
clientId: "your-okta-client-id"
|
||||
clientSecret: "your-okta-client-secret"
|
||||
callbackUrl: "https://your-domain.com/auth/callback"
|
||||
scopes: ["openid", "profile", "email", "offline_access"]
|
||||
```
|
||||
|
||||
### Okta-Specific Features
|
||||
- **Custom authorization servers**: Supports custom auth servers
|
||||
- **Group claims**: Rich group membership information
|
||||
- **Universal Directory**: Integrates with Okta's user store
|
||||
- **Offline access**: Requires `offline_access` scope
|
||||
|
||||
### Example Configuration
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
okta-oidc:
|
||||
plugin:
|
||||
traefik-oidc:
|
||||
providerUrl: "https://company.okta.com"
|
||||
clientId: "0oa123456789abcdef"
|
||||
clientSecret: "your-okta-client-secret"
|
||||
callbackUrl: "https://app.example.com/auth/callback"
|
||||
logoutUrl: "https://app.example.com/auth/logout"
|
||||
postLogoutRedirectUri: "https://app.example.com"
|
||||
scopes: ["openid", "profile", "email", "offline_access"]
|
||||
allowedRolesAndGroups: ["Everyone", "Administrators"]
|
||||
forceHttps: true
|
||||
enablePkce: true
|
||||
```
|
||||
|
||||
### Okta Application Setup
|
||||
1. Access Okta Admin Console
|
||||
2. Go to Applications > Create App Integration
|
||||
3. Select OIDC - OpenID Connect
|
||||
4. Choose Web Application
|
||||
5. Configure:
|
||||
- Sign-in redirect URIs: `https://your-domain.com/auth/callback`
|
||||
- Sign-out redirect URIs: `https://your-domain.com/auth/logout`
|
||||
- Grant types: Authorization Code, Refresh Token
|
||||
6. Assign users or groups
|
||||
|
||||
---
|
||||
|
||||
## Generic OIDC
|
||||
|
||||
### Provider URL
|
||||
```yaml
|
||||
providerUrl: "https://your-oidc-provider.com"
|
||||
```
|
||||
|
||||
### Required Configuration
|
||||
```yaml
|
||||
clientId: "your-client-id"
|
||||
clientSecret: "your-client-secret"
|
||||
callbackUrl: "https://your-domain.com/auth/callback"
|
||||
scopes: ["openid", "profile", "email"]
|
||||
```
|
||||
|
||||
### Generic Features
|
||||
- **Standards compliance**: Works with any OIDC-compliant provider
|
||||
- **Auto-discovery**: Uses `.well-known/openid-configuration` endpoint
|
||||
- **Flexible scopes**: Supports custom scope requirements
|
||||
- **Custom claims**: Works with provider-specific claims
|
||||
|
||||
### Example Configuration
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
generic-oidc:
|
||||
plugin:
|
||||
traefik-oidc:
|
||||
providerUrl: "https://oidc.your-provider.com"
|
||||
clientId: "your-client-id"
|
||||
clientSecret: "your-client-secret"
|
||||
callbackUrl: "https://app.example.com/auth/callback"
|
||||
logoutUrl: "https://app.example.com/auth/logout"
|
||||
scopes: ["openid", "profile", "email"]
|
||||
forceHttps: true
|
||||
enablePkce: true
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Common Configuration Options
|
||||
|
||||
### Security Settings
|
||||
```yaml
|
||||
# Force HTTPS (recommended for production)
|
||||
forceHttps: true
|
||||
|
||||
# Enable PKCE (recommended for security)
|
||||
enablePkce: true
|
||||
|
||||
# Session encryption key (32+ characters)
|
||||
sessionEncryptionKey: "your-very-long-encryption-key-here"
|
||||
```
|
||||
|
||||
### Access Control
|
||||
```yaml
|
||||
# Restrict by email addresses
|
||||
allowedUsers: ["user1@example.com", "user2@example.com"]
|
||||
|
||||
# Restrict by email domains
|
||||
allowedUserDomains: ["company.com", "partner.org"]
|
||||
|
||||
# Restrict by roles/groups (provider-specific)
|
||||
allowedRolesAndGroups: ["admin", "users", "developers"]
|
||||
```
|
||||
|
||||
### URLs and Endpoints
|
||||
```yaml
|
||||
# OAuth callback URL (must match provider config)
|
||||
callbackUrl: "https://your-domain.com/auth/callback"
|
||||
|
||||
# Logout endpoint
|
||||
logoutUrl: "https://your-domain.com/auth/logout"
|
||||
|
||||
# Post-logout redirect (optional)
|
||||
postLogoutRedirectUri: "https://your-domain.com"
|
||||
|
||||
# URLs to exclude from authentication
|
||||
excludedUrls: ["/health", "/metrics", "/public"]
|
||||
```
|
||||
|
||||
### Advanced Settings
|
||||
```yaml
|
||||
# Override default scopes
|
||||
overrideScopes: true
|
||||
scopes: ["openid", "custom_scope"]
|
||||
|
||||
# Rate limiting (requests per second)
|
||||
rateLimit: 10
|
||||
|
||||
# Token refresh grace period (seconds)
|
||||
refreshGracePeriodSeconds: 60
|
||||
|
||||
# Cookie domain (for subdomain sharing)
|
||||
cookieDomain: ".example.com"
|
||||
|
||||
# Custom headers to inject
|
||||
headers:
|
||||
- name: "X-User-Email"
|
||||
value: "{{.Claims.email}}"
|
||||
- name: "X-User-Name"
|
||||
value: "{{.Claims.name}}"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
1. **Invalid redirect URI**
|
||||
- Ensure callback URL exactly matches provider configuration
|
||||
- Check for HTTP vs HTTPS mismatches
|
||||
|
||||
2. **Scope errors**
|
||||
- Verify required scopes are configured in provider
|
||||
- Some providers require specific scopes for refresh tokens
|
||||
|
||||
3. **Token validation failures**
|
||||
- Check provider URL format and accessibility
|
||||
- Verify `.well-known/openid-configuration` endpoint is reachable
|
||||
|
||||
4. **Session issues**
|
||||
- Ensure session encryption key is properly configured
|
||||
- Check cookie domain settings for subdomain scenarios
|
||||
|
||||
### Debug Mode
|
||||
Enable debug logging to troubleshoot configuration issues:
|
||||
```yaml
|
||||
logLevel: "debug"
|
||||
```
|
||||
|
||||
This will provide detailed logs of the authentication flow and help identify configuration problems.
|
||||
|
||||
---
|
||||
|
||||
## Security Headers Configuration
|
||||
|
||||
The plugin includes comprehensive security headers support to protect your applications against common web vulnerabilities.
|
||||
|
||||
### Default Security Headers
|
||||
|
||||
By default, the plugin applies these security headers:
|
||||
|
||||
- `X-Frame-Options: DENY` - Prevents clickjacking
|
||||
- `X-Content-Type-Options: nosniff` - Prevents MIME sniffing
|
||||
- `X-XSS-Protection: 1; mode=block` - Enables XSS protection
|
||||
- `Referrer-Policy: strict-origin-when-cross-origin` - Controls referrer information
|
||||
- `Strict-Transport-Security` - Forces HTTPS (when HTTPS is detected)
|
||||
|
||||
### Security Profiles
|
||||
|
||||
Choose from predefined security profiles or create custom configurations:
|
||||
|
||||
#### Default Profile (Recommended)
|
||||
```yaml
|
||||
securityHeaders:
|
||||
enabled: true
|
||||
profile: "default"
|
||||
```
|
||||
|
||||
#### Strict Profile (Maximum Security)
|
||||
```yaml
|
||||
securityHeaders:
|
||||
enabled: true
|
||||
profile: "strict"
|
||||
# Additional strict CSP and cross-origin policies
|
||||
```
|
||||
|
||||
#### Development Profile (Local Development)
|
||||
```yaml
|
||||
securityHeaders:
|
||||
enabled: true
|
||||
profile: "development"
|
||||
# Relaxed policies for local development
|
||||
```
|
||||
|
||||
#### API Profile (API Endpoints)
|
||||
```yaml
|
||||
securityHeaders:
|
||||
enabled: true
|
||||
profile: "api"
|
||||
corsEnabled: true
|
||||
corsAllowedOrigins: ["https://your-frontend.com"]
|
||||
```
|
||||
|
||||
### Custom Security Configuration
|
||||
|
||||
For complete control, use the custom profile:
|
||||
|
||||
```yaml
|
||||
securityHeaders:
|
||||
enabled: true
|
||||
profile: "custom"
|
||||
|
||||
# Content Security Policy
|
||||
contentSecurityPolicy: "default-src 'self'; script-src 'self' 'unsafe-inline'"
|
||||
|
||||
# HSTS Configuration
|
||||
strictTransportSecurity: true
|
||||
strictTransportSecurityMaxAge: 31536000 # 1 year
|
||||
strictTransportSecuritySubdomains: true
|
||||
strictTransportSecurityPreload: true
|
||||
|
||||
# Frame and content protection
|
||||
frameOptions: "DENY" # or "SAMEORIGIN", "ALLOW-FROM uri"
|
||||
contentTypeOptions: "nosniff"
|
||||
xssProtection: "1; mode=block"
|
||||
referrerPolicy: "strict-origin-when-cross-origin"
|
||||
|
||||
# Permissions policy (feature policy)
|
||||
permissionsPolicy: "geolocation=(), microphone=(), camera=()"
|
||||
|
||||
# Cross-origin policies
|
||||
crossOriginEmbedderPolicy: "require-corp"
|
||||
crossOriginOpenerPolicy: "same-origin"
|
||||
crossOriginResourcePolicy: "same-origin"
|
||||
|
||||
# CORS configuration
|
||||
corsEnabled: true
|
||||
corsAllowedOrigins:
|
||||
- "https://app.example.com"
|
||||
- "https://*.api.example.com"
|
||||
corsAllowedMethods: ["GET", "POST", "PUT", "DELETE", "OPTIONS"]
|
||||
corsAllowedHeaders: ["Authorization", "Content-Type", "X-Requested-With"]
|
||||
corsAllowCredentials: true
|
||||
corsMaxAge: 86400 # 24 hours
|
||||
|
||||
# Custom headers
|
||||
customHeaders:
|
||||
X-Custom-Header: "custom-value"
|
||||
X-API-Version: "v1"
|
||||
|
||||
# Server identification
|
||||
disableServerHeader: true
|
||||
disablePoweredByHeader: true
|
||||
```
|
||||
|
||||
### Complete Example with Security Headers
|
||||
|
||||
Here's a complete configuration example for Google OIDC with custom security headers:
|
||||
|
||||
```yaml
|
||||
# Traefik dynamic configuration
|
||||
http:
|
||||
middlewares:
|
||||
secure-google-oidc:
|
||||
plugin:
|
||||
traefik-oidc:
|
||||
# OIDC Configuration
|
||||
providerUrl: "https://accounts.google.com"
|
||||
clientId: "123456789-abcdef.apps.googleusercontent.com"
|
||||
clientSecret: "GOCSPX-your-client-secret"
|
||||
callbackUrl: "https://your-domain.com/auth/callback"
|
||||
sessionEncryptionKey: "your-32-character-encryption-key-here"
|
||||
|
||||
# Domain restrictions
|
||||
allowedUserDomains: ["your-company.com"]
|
||||
|
||||
# Security Headers
|
||||
securityHeaders:
|
||||
enabled: true
|
||||
profile: "strict"
|
||||
corsEnabled: true
|
||||
corsAllowedOrigins:
|
||||
- "https://your-frontend.com"
|
||||
- "https://*.your-domain.com"
|
||||
corsAllowCredentials: true
|
||||
customHeaders:
|
||||
X-Company: "YourCompany"
|
||||
X-Environment: "production"
|
||||
|
||||
routers:
|
||||
secure-app:
|
||||
rule: "Host(`your-domain.com`)"
|
||||
middlewares:
|
||||
- secure-google-oidc
|
||||
service: your-app-service
|
||||
tls:
|
||||
certResolver: letsencrypt
|
||||
```
|
||||
|
||||
### CORS Configuration Details
|
||||
|
||||
For applications with frontend-backend separation, configure CORS properly:
|
||||
|
||||
#### Simple CORS (Single Origin)
|
||||
```yaml
|
||||
securityHeaders:
|
||||
corsEnabled: true
|
||||
corsAllowedOrigins: ["https://app.example.com"]
|
||||
corsAllowCredentials: true
|
||||
```
|
||||
|
||||
#### Wildcard Subdomains
|
||||
```yaml
|
||||
securityHeaders:
|
||||
corsEnabled: true
|
||||
corsAllowedOrigins: ["https://*.example.com"]
|
||||
corsAllowCredentials: true
|
||||
```
|
||||
|
||||
#### Development with Multiple Ports
|
||||
```yaml
|
||||
securityHeaders:
|
||||
profile: "development"
|
||||
corsEnabled: true
|
||||
corsAllowedOrigins:
|
||||
- "http://localhost:*"
|
||||
- "http://127.0.0.1:*"
|
||||
```
|
||||
|
||||
### Security Best Practices
|
||||
|
||||
1. **Always use HTTPS in production**
|
||||
- Set `forceHttps: true`
|
||||
- Configure proper TLS certificates
|
||||
|
||||
2. **Implement proper CSP**
|
||||
- Start with strict policy
|
||||
- Add exceptions only when necessary
|
||||
- Test thoroughly
|
||||
|
||||
3. **Configure CORS restrictively**
|
||||
- Only allow necessary origins
|
||||
- Use specific domains instead of wildcards when possible
|
||||
|
||||
4. **Enable HSTS**
|
||||
- Use long max-age values (1 year minimum)
|
||||
- Include subdomains when appropriate
|
||||
|
||||
5. **Monitor security headers**
|
||||
- Use browser developer tools to verify headers
|
||||
- Test with security scanning tools
|
||||
- Regularly review and update policies
|
||||
|
||||
### Testing Security Headers
|
||||
|
||||
Use browser developer tools or online tools to verify your security headers:
|
||||
|
||||
1. **Browser DevTools**: Check Network tab → Response Headers
|
||||
2. **Online scanners**: Use securityheaders.com or observatory.mozilla.org
|
||||
3. **Command line**: Use `curl -I https://your-domain.com`
|
||||
|
||||
Example verification:
|
||||
```bash
|
||||
curl -I https://your-domain.com
|
||||
# Should show security headers in response
|
||||
```
|
||||
+546
@@ -0,0 +1,546 @@
|
||||
# Redis Cache for Distributed Deployments
|
||||
|
||||
Redis cache support for multi-replica Traefik deployments with shared state.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Overview](#overview)
|
||||
- [Why Use Redis Cache?](#why-use-redis-cache)
|
||||
- [Configuration](#configuration)
|
||||
- [Cache Modes](#cache-modes)
|
||||
- [Deployment Examples](#deployment-examples)
|
||||
- [Performance Tuning](#performance-tuning)
|
||||
- [Monitoring](#monitoring)
|
||||
- [Troubleshooting](#troubleshooting)
|
||||
- [Migration Guide](#migration-guide)
|
||||
|
||||
---
|
||||
|
||||
## Overview
|
||||
|
||||
The Redis cache feature provides distributed caching for the Traefik OIDC plugin, enabling seamless operation across multiple Traefik instances.
|
||||
|
||||
### Key Features
|
||||
|
||||
- **Distributed JTI Replay Detection**: Prevents token replay attacks across all instances
|
||||
- **Shared Session Management**: Consistent user sessions across replicas
|
||||
- **Circuit Breaker**: Automatic fallback to memory cache during Redis outages
|
||||
- **Health Checking**: Continuous monitoring of Redis connectivity
|
||||
- **Flexible Cache Modes**: Memory, Redis, or hybrid caching strategies
|
||||
- **Pure-Go Implementation**: Yaegi-compatible, works with dynamic plugin loading
|
||||
|
||||
### Architecture
|
||||
|
||||
```
|
||||
┌──────────────┐ ┌──────────────┐ ┌──────────────┐
|
||||
│ Traefik #1 │ │ Traefik #2 │ │ Traefik #3 │
|
||||
│ (Plugin) │ │ (Plugin) │ │ (Plugin) │
|
||||
└──────┬───────┘ └──────┬───────┘ └──────┬───────┘
|
||||
│ │ │
|
||||
└────────────────────┼────────────────────┘
|
||||
│
|
||||
┌──────▼──────┐
|
||||
│ Redis │
|
||||
│ (Shared │
|
||||
│ Cache) │
|
||||
└─────────────┘
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Why Use Redis Cache?
|
||||
|
||||
### The Problem
|
||||
|
||||
When running multiple Traefik instances without shared cache:
|
||||
|
||||
1. **False Positive Replay Detection**
|
||||
- User authenticates → Token stored in Instance A's JTI cache
|
||||
- Next request → Load balancer routes to Instance B
|
||||
- Instance B doesn't have the JTI → Falsely detects replay attack
|
||||
|
||||
2. **Session Inconsistency**
|
||||
- User session created on Instance A
|
||||
- Subsequent request routed to Instance B
|
||||
- Instance B has no knowledge of the session
|
||||
|
||||
3. **Token Metadata Fragmentation**
|
||||
- Token refresh happens on Instance A
|
||||
- Other instances continue using old tokens
|
||||
|
||||
### The Solution
|
||||
|
||||
Redis provides centralized cache that all instances share, ensuring:
|
||||
|
||||
- **Consistent Authentication**: All instances share authentication state
|
||||
- **True Replay Detection**: JTI cache shared across all instances
|
||||
- **Seamless Scaling**: Add/remove instances without affecting sessions
|
||||
- **High Availability**: Circuit breaker with automatic fallback
|
||||
|
||||
---
|
||||
|
||||
## Configuration
|
||||
|
||||
### Basic Configuration
|
||||
|
||||
```yaml
|
||||
redis:
|
||||
enabled: true
|
||||
address: "redis:6379"
|
||||
password: "your-password" # Optional
|
||||
db: 0
|
||||
keyPrefix: "traefikoidc:"
|
||||
cacheMode: "hybrid"
|
||||
```
|
||||
|
||||
### All Configuration Options
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
|-----------|------|---------|-------------|
|
||||
| `enabled` | bool | `false` | Enable Redis caching |
|
||||
| `address` | string | - | Redis server address (`host:port`) |
|
||||
| `password` | string | - | Redis password (optional) |
|
||||
| `db` | int | `0` | Redis database number (0-15) |
|
||||
| `keyPrefix` | string | `traefikoidc:` | Prefix for all Redis keys |
|
||||
| `cacheMode` | string | `redis` | Cache mode: `memory`, `redis`, `hybrid` |
|
||||
| `poolSize` | int | `10` | Connection pool size |
|
||||
| `connectTimeout` | int | `5` | Connection timeout (seconds) |
|
||||
| `readTimeout` | int | `3` | Read timeout (seconds) |
|
||||
| `writeTimeout` | int | `3` | Write timeout (seconds) |
|
||||
| `enableTLS` | bool | `false` | Enable TLS for connections |
|
||||
| `tlsSkipVerify` | bool | `false` | Skip TLS certificate verification |
|
||||
| `enableCircuitBreaker` | bool | `true` | Enable circuit breaker |
|
||||
| `circuitBreakerThreshold` | int | `5` | Failures before circuit opens |
|
||||
| `circuitBreakerTimeout` | int | `60` | Circuit reset timeout (seconds) |
|
||||
| `enableHealthCheck` | bool | `true` | Enable periodic health checks |
|
||||
| `healthCheckInterval` | int | `30` | Health check interval (seconds) |
|
||||
| `hybridL1Size` | int | `500` | Max items in L1 cache (hybrid mode) |
|
||||
| `hybridL1MemoryMB` | int64 | `10` | Max memory for L1 cache in MB |
|
||||
|
||||
### Environment Variables (Fallback)
|
||||
|
||||
If not configured through Traefik, these environment variables are used:
|
||||
|
||||
```bash
|
||||
REDIS_ENABLED=true
|
||||
REDIS_ADDRESS=redis:6379
|
||||
REDIS_PASSWORD=your-password
|
||||
REDIS_DB=0
|
||||
REDIS_KEY_PREFIX=traefikoidc:
|
||||
REDIS_CACHE_MODE=hybrid
|
||||
REDIS_POOL_SIZE=10
|
||||
REDIS_CONNECT_TIMEOUT=5
|
||||
REDIS_READ_TIMEOUT=3
|
||||
REDIS_WRITE_TIMEOUT=3
|
||||
REDIS_ENABLE_TLS=false
|
||||
REDIS_TLS_SKIP_VERIFY=false
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Cache Modes
|
||||
|
||||
### Memory Mode (Default without Redis)
|
||||
|
||||
```yaml
|
||||
redis:
|
||||
cacheMode: "memory"
|
||||
```
|
||||
|
||||
- Uses only in-memory cache
|
||||
- Suitable for single-instance deployments
|
||||
- No Redis dependency
|
||||
- Fastest performance
|
||||
|
||||
### Redis Mode
|
||||
|
||||
```yaml
|
||||
redis:
|
||||
enabled: true
|
||||
address: "redis:6379"
|
||||
cacheMode: "redis"
|
||||
```
|
||||
|
||||
- All operations go directly to Redis
|
||||
- Ensures consistency across replicas
|
||||
- Slightly higher latency
|
||||
|
||||
### Hybrid Mode (Recommended)
|
||||
|
||||
```yaml
|
||||
redis:
|
||||
enabled: true
|
||||
address: "redis:6379"
|
||||
cacheMode: "hybrid"
|
||||
```
|
||||
|
||||
Two-tier caching strategy:
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────┐
|
||||
│ Client Request │
|
||||
└────────────────┬────────────────────────┘
|
||||
▼
|
||||
┌────────────────┐
|
||||
│ Local Cache │ ← L1 Cache (Fast)
|
||||
│ (Memory) │
|
||||
└────────┬───────┘
|
||||
│ Miss
|
||||
▼
|
||||
┌────────────────┐
|
||||
│ Remote Cache │ ← L2 Cache (Shared)
|
||||
│ (Redis) │
|
||||
└────────────────┘
|
||||
```
|
||||
|
||||
**Read Path:**
|
||||
1. Check local memory cache (L1)
|
||||
2. On miss, check Redis (L2)
|
||||
3. On hit in Redis, populate L1
|
||||
4. Return value
|
||||
|
||||
**Write Path:**
|
||||
1. Write to Redis (L2) for durability
|
||||
2. Write to local cache (L1) for speed
|
||||
|
||||
### Performance Comparison
|
||||
|
||||
| Operation | Memory Mode | Redis Mode | Hybrid Mode |
|
||||
|-----------|------------|------------|-------------|
|
||||
| Read (p50) | 0.1ms | 2ms | 0.2ms |
|
||||
| Read (p99) | 0.5ms | 10ms | 5ms |
|
||||
| Write (p50) | 0.2ms | 3ms | 3ms |
|
||||
| Throughput | 100k/s | 20k/s | 80k/s |
|
||||
|
||||
---
|
||||
|
||||
## Deployment Examples
|
||||
|
||||
### Docker Compose
|
||||
|
||||
```yaml
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
redis:
|
||||
image: redis:7-alpine
|
||||
command: redis-server --requirepass ${REDIS_PASSWORD}
|
||||
volumes:
|
||||
- redis-data:/data
|
||||
healthcheck:
|
||||
test: ["CMD", "redis-cli", "--raw", "incr", "ping"]
|
||||
interval: 30s
|
||||
timeout: 3s
|
||||
retries: 3
|
||||
|
||||
traefik:
|
||||
image: traefik:v3.2
|
||||
deploy:
|
||||
replicas: 3
|
||||
labels:
|
||||
- "traefik.http.middlewares.oidc.plugin.traefikoidc.redis.enabled=true"
|
||||
- "traefik.http.middlewares.oidc.plugin.traefikoidc.redis.address=redis:6379"
|
||||
- "traefik.http.middlewares.oidc.plugin.traefikoidc.redis.password=${REDIS_PASSWORD}"
|
||||
- "traefik.http.middlewares.oidc.plugin.traefikoidc.redis.cacheMode=hybrid"
|
||||
depends_on:
|
||||
redis:
|
||||
condition: service_healthy
|
||||
|
||||
volumes:
|
||||
redis-data:
|
||||
```
|
||||
|
||||
### Kubernetes
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-with-redis
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://accounts.google.com
|
||||
clientID: your-client-id
|
||||
clientSecret: your-client-secret
|
||||
sessionEncryptionKey: your-encryption-key
|
||||
callbackURL: /oauth2/callback
|
||||
redis:
|
||||
enabled: true
|
||||
address: "redis-service.redis-namespace:6379"
|
||||
password: "urn:k8s:secret:redis-secret:password"
|
||||
db: 0
|
||||
keyPrefix: "traefikoidc:"
|
||||
cacheMode: "hybrid"
|
||||
poolSize: 20
|
||||
enableCircuitBreaker: true
|
||||
circuitBreakerThreshold: 5
|
||||
```
|
||||
|
||||
### AWS ElastiCache
|
||||
|
||||
```yaml
|
||||
redis:
|
||||
enabled: true
|
||||
address: "your-cache.abc123.cache.amazonaws.com:6379"
|
||||
cacheMode: "hybrid"
|
||||
enableTLS: true
|
||||
password: "your-elasticache-auth-token"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Performance Tuning
|
||||
|
||||
### Connection Pool Sizing
|
||||
|
||||
```yaml
|
||||
redis:
|
||||
poolSize: 20 # Formula: 2 * CPU cores * replicas
|
||||
# For 4 cores, 3 replicas: poolSize = 24
|
||||
```
|
||||
|
||||
### TTL Strategy
|
||||
|
||||
The plugin automatically sets TTLs based on token lifetimes:
|
||||
|
||||
- **JTI Cache**: Matches token lifetime (typically 1 hour)
|
||||
- **Session**: Matches `sessionMaxAge` configuration
|
||||
- **Token Metadata**: 5 minutes (short-lived)
|
||||
|
||||
### Redis Server Configuration
|
||||
|
||||
```bash
|
||||
# Recommended Redis settings for cache
|
||||
maxmemory 512mb
|
||||
maxmemory-policy allkeys-lru # Evict least recently used
|
||||
|
||||
# For cache data, disable persistence for better performance
|
||||
save ""
|
||||
appendonly no
|
||||
```
|
||||
|
||||
### Hybrid Mode Tuning
|
||||
|
||||
```yaml
|
||||
redis:
|
||||
cacheMode: "hybrid"
|
||||
hybridL1Size: 500 # Max items in local cache
|
||||
hybridL1MemoryMB: 10 # Max memory for local cache
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Monitoring
|
||||
|
||||
### Key Metrics
|
||||
|
||||
- **Cache hit rate** (target: >90% for hybrid mode)
|
||||
- **Redis latency** (target: <10ms p99)
|
||||
- **Circuit breaker state**
|
||||
- **Connection pool utilization
|
||||
|
||||
### Redis Commands for Monitoring
|
||||
|
||||
```bash
|
||||
# Monitor commands in real-time
|
||||
redis-cli MONITOR
|
||||
|
||||
# Check slow queries
|
||||
redis-cli SLOWLOG GET 10
|
||||
|
||||
# Memory usage
|
||||
redis-cli INFO memory
|
||||
|
||||
# Key statistics
|
||||
redis-cli DBSIZE
|
||||
|
||||
# List keys with prefix
|
||||
redis-cli --scan --pattern "traefikoidc:*"
|
||||
|
||||
# Check key TTL
|
||||
redis-cli TTL "traefikoidc:session:abc123"
|
||||
```
|
||||
|
||||
### Health Check Endpoint
|
||||
|
||||
The plugin provides health information including:
|
||||
|
||||
```json
|
||||
{
|
||||
"status": "healthy",
|
||||
"cache": {
|
||||
"mode": "hybrid",
|
||||
"redis": {
|
||||
"connected": true,
|
||||
"latency": "2ms"
|
||||
},
|
||||
"circuit_breaker": {
|
||||
"state": "closed",
|
||||
"failures": 0
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Connection Refused
|
||||
|
||||
**Symptoms:** `dial tcp: connection refused`
|
||||
|
||||
**Solutions:**
|
||||
1. Verify Redis is running: `redis-cli ping`
|
||||
2. Check network connectivity: `telnet redis-host 6379`
|
||||
3. Verify address configuration
|
||||
|
||||
### Authentication Failure
|
||||
|
||||
**Symptoms:** `NOAUTH Authentication required`
|
||||
|
||||
**Solutions:**
|
||||
1. Set Redis password in configuration
|
||||
2. Verify password is correct
|
||||
|
||||
### Circuit Breaker Open
|
||||
|
||||
**Symptoms:** `Circuit breaker is open`, falling back to memory
|
||||
|
||||
**Solutions:**
|
||||
1. Check Redis health: `redis-cli INFO server`
|
||||
2. Review network latency: `redis-cli --latency`
|
||||
3. Adjust circuit breaker thresholds if needed
|
||||
|
||||
### High Memory Usage
|
||||
|
||||
**Symptoms:** Redis memory constantly growing, OOM errors
|
||||
|
||||
**Solutions:**
|
||||
1. Configure eviction policy:
|
||||
```bash
|
||||
CONFIG SET maxmemory 512mb
|
||||
CONFIG SET maxmemory-policy allkeys-lru
|
||||
```
|
||||
2. Review key count: `redis-cli DBSIZE`
|
||||
3. Check for large keys: `redis-cli --bigkeys`
|
||||
|
||||
### Inconsistent Cache State
|
||||
|
||||
**Symptoms:** Different responses from different replicas
|
||||
|
||||
**Solutions:**
|
||||
1. Verify all instances use the same Redis address
|
||||
2. Check cache mode consistency across instances
|
||||
3. Verify time synchronization on all hosts
|
||||
|
||||
---
|
||||
|
||||
## Migration Guide
|
||||
|
||||
### From Memory-Only to Redis
|
||||
|
||||
#### Phase 1: Preparation
|
||||
|
||||
1. Deploy Redis infrastructure
|
||||
2. Test Redis connectivity
|
||||
3. Configure monitoring
|
||||
|
||||
#### Phase 2: Gradual Rollout
|
||||
|
||||
1. Enable Redis on one instance:
|
||||
```yaml
|
||||
redis:
|
||||
enabled: true
|
||||
address: "redis:6379"
|
||||
cacheMode: "hybrid"
|
||||
```
|
||||
2. Monitor for errors
|
||||
3. Gradually enable on more instances
|
||||
|
||||
#### Phase 3: Full Migration
|
||||
|
||||
1. Enable Redis on all instances
|
||||
2. Remove `disableReplayDetection: true` if set
|
||||
3. Monitor for issues
|
||||
|
||||
### Rollback Plan
|
||||
|
||||
If issues occur:
|
||||
1. Set `redis.enabled: false`
|
||||
2. Plugin falls back to memory cache automatically
|
||||
3. Investigate and resolve issues
|
||||
|
||||
### Migration Checklist
|
||||
|
||||
- [ ] Redis deployed and accessible
|
||||
- [ ] Redis password configured
|
||||
- [ ] Network connectivity verified
|
||||
- [ ] Monitoring configured
|
||||
- [ ] Backup plan prepared
|
||||
- [ ] Test environment validated
|
||||
- [ ] Gradual rollout planned
|
||||
|
||||
---
|
||||
|
||||
## Best Practices
|
||||
|
||||
### Security
|
||||
|
||||
- Always use Redis password authentication
|
||||
- Enable TLS for production deployments
|
||||
- Use network segmentation (private subnets)
|
||||
- Rotate Redis passwords regularly
|
||||
|
||||
### High Availability
|
||||
|
||||
- Use Redis Sentinel or Cluster for HA
|
||||
- Configure appropriate circuit breaker thresholds
|
||||
- Implement proper health checks
|
||||
- Use connection pooling
|
||||
|
||||
### Performance
|
||||
|
||||
- Use hybrid cache mode for best performance
|
||||
- Monitor cache hit rates
|
||||
- Size Redis memory appropriately
|
||||
- Disable persistence for cache-only usage
|
||||
|
||||
### Operations
|
||||
|
||||
- Implement comprehensive monitoring
|
||||
- Set up alerting for circuit breaker state
|
||||
- Document Redis configuration
|
||||
- Test failover scenarios
|
||||
|
||||
---
|
||||
|
||||
## FAQ
|
||||
|
||||
### Is Redis required?
|
||||
|
||||
No, Redis is optional. The plugin works with in-memory cache for single-instance deployments.
|
||||
|
||||
### What happens if Redis goes down?
|
||||
|
||||
The circuit breaker opens after threshold failures, and the plugin falls back to in-memory cache. It periodically attempts to reconnect.
|
||||
|
||||
### Which cache mode should I use?
|
||||
|
||||
For production multi-replica deployments, use `hybrid` mode for best performance and consistency.
|
||||
|
||||
### How much memory does Redis need?
|
||||
|
||||
Depends on active sessions and token sizes:
|
||||
- Small (1-1000 users): 128MB
|
||||
- Medium (1000-10000 users): 256-512MB
|
||||
- Large (10000+ users): 1GB+
|
||||
|
||||
### Can I use managed Redis services?
|
||||
|
||||
Yes, the plugin works with AWS ElastiCache, Azure Cache for Redis, Google Cloud Memorystore, and Redis Enterprise Cloud.
|
||||
|
||||
### Is data encrypted in Redis?
|
||||
|
||||
Session data is encrypted before storing using `sessionEncryptionKey`. Additionally, you can enable TLS for Redis connections.
|
||||
+390
@@ -0,0 +1,390 @@
|
||||
# Testing Guide
|
||||
|
||||
Comprehensive testing infrastructure for traefikoidc.
|
||||
|
||||
## Overview
|
||||
|
||||
| Metric | Value |
|
||||
|--------|-------|
|
||||
| Test files | 99 |
|
||||
| Lines of test code | ~65,500 |
|
||||
| Code coverage | 71.0% |
|
||||
| Race conditions | None (all pass with `-race`) |
|
||||
|
||||
## Running Tests
|
||||
|
||||
```bash
|
||||
# Run all tests
|
||||
go test ./...
|
||||
|
||||
# Run with race detection
|
||||
go test -race ./...
|
||||
|
||||
# Run with coverage
|
||||
go test -cover ./...
|
||||
|
||||
# Run specific test suite
|
||||
go test -v -run "TokenValidationSuite" .
|
||||
|
||||
# Run edge case tests
|
||||
go test -v -run "ClockSkewEdgeCasesSuite|UnicodeClaimsSuite" .
|
||||
```
|
||||
|
||||
## Test Infrastructure
|
||||
|
||||
### Directory Structure
|
||||
|
||||
```
|
||||
internal/testutil/
|
||||
├── compat.go # Re-exports for main package access
|
||||
├── mocks/
|
||||
│ ├── interfaces.go # JWKCache, TokenExchanger, TokenVerifier, etc.
|
||||
│ ├── session.go # SessionManager, SessionData
|
||||
│ ├── cache.go # Cache, TokenCache, Blacklist
|
||||
│ └── interfaces_test.go # Mock verification tests
|
||||
├── fixtures/
|
||||
│ └── tokens.go # JWT token generation fixtures
|
||||
└── servers/
|
||||
├── oidc.go # Mock OIDC server factory
|
||||
└── oidc_test.go # Server tests
|
||||
```
|
||||
|
||||
### Test Suites
|
||||
|
||||
| Suite | File | Description |
|
||||
|-------|------|-------------|
|
||||
| TokenValidationSuite | `token_validation_suite_test.go` | Token validation happy path and error cases |
|
||||
| JWKCacheTestSuite | `token_validation_suite_test.go` | JWK cache behavior tests |
|
||||
| TokenExchangerTestSuite | `token_validation_suite_test.go` | Token exchange scenarios |
|
||||
| ClockSkewEdgeCasesSuite | `edge_cases_suite_test.go` | Expiry boundary testing |
|
||||
| UnicodeClaimsSuite | `edge_cases_suite_test.go` | Unicode/emoji handling in claims |
|
||||
| LargeClaimsSuite | `edge_cases_suite_test.go` | Large data handling (100s of claims) |
|
||||
| URLPathEdgeCasesSuite | `edge_cases_suite_test.go` | URL parsing edge cases |
|
||||
| ConcurrencyEdgeCasesSuite | `edge_cases_suite_test.go` | Concurrent token validation |
|
||||
| ExampleTestSuite | `testutil_example_test.go` | Example demonstrating patterns |
|
||||
| AuthFlowBehaviourSuite | `auth_flow_behaviour_test.go` | Authentication flow behavior tests |
|
||||
| SessionBehaviourSuite | `session_behaviour_test.go` | Session management behavior tests |
|
||||
| EnhancedMocksSuite | `enhanced_mocks_suite_test.go` | Enhanced mock usage demonstration |
|
||||
|
||||
## Mock Types
|
||||
|
||||
The project provides two mocking patterns:
|
||||
|
||||
### State-Based Mocks (Basic)
|
||||
|
||||
Located in `main_test.go`, `mocks_test.go`. Simple mocks that store data in struct fields.
|
||||
|
||||
| Mock | Interface | Description |
|
||||
|------|-----------|-------------|
|
||||
| `MockJWKCache` | `JWKCacheInterface` | Simple state-based mock with JWKS/Err fields |
|
||||
| `MockTokenVerifier` | `TokenVerifier` | Function-based mock for token verification |
|
||||
| `MockTokenExchanger` | `TokenExchanger` | Function-based mock for token exchange |
|
||||
| `MockOAuthProvider` | `http.Handler` | Full HTTP handler mock for OAuth provider simulation |
|
||||
| `MockSessionManager` | `SessionManager` | State-based mock for session management |
|
||||
| `MockHTTPClient` | N/A | Mock HTTP client with customizable responses |
|
||||
|
||||
**Usage:**
|
||||
```go
|
||||
mock := &MockJWKCache{
|
||||
JWKS: &JWKSet{Keys: []JWK{jwk}},
|
||||
Err: nil,
|
||||
}
|
||||
tOidc := &TraefikOidc{
|
||||
jwkCache: mock,
|
||||
// ...
|
||||
}
|
||||
```
|
||||
|
||||
### Enhanced State-Based Mocks (with Call Tracking)
|
||||
|
||||
Located in `enhanced_mocks_test.go`. State-based mocks with built-in call tracking and assertion helpers.
|
||||
|
||||
| Mock | Interface | Description |
|
||||
|------|-----------|-------------|
|
||||
| `EnhancedMockJWKCache` | `JWKCacheInterface` | State-based with call tracking |
|
||||
| `EnhancedMockTokenVerifier` | `TokenVerifier` | State-based with call tracking |
|
||||
| `EnhancedMockTokenExchanger` | `TokenExchanger` | State-based with call tracking |
|
||||
| `EnhancedMockCacheInterface` | `CacheInterface` | Functional cache with call tracking |
|
||||
|
||||
**Usage:**
|
||||
```go
|
||||
mock := &EnhancedMockJWKCache{
|
||||
JWKS: &JWKSet{Keys: []JWK{jwk}},
|
||||
}
|
||||
|
||||
// Make calls
|
||||
result, err := mock.GetJWKS(ctx, "https://example.com/jwks", nil)
|
||||
|
||||
// Verify calls were made
|
||||
mock.AssertGetJWKSCalled(t)
|
||||
mock.AssertGetJWKSCalledWith(t, "https://example.com/jwks")
|
||||
mock.AssertGetJWKSCallCount(t, 1)
|
||||
|
||||
// Access call details
|
||||
s.Equal(1, mock.GetJWKSCallCount())
|
||||
```
|
||||
|
||||
**Features:**
|
||||
- Track all calls with parameters and timestamps
|
||||
- Built-in assertion helpers using testify
|
||||
- Thread-safe for concurrent tests
|
||||
- `Reset()` method to clear state between tests
|
||||
- `LastCall()` to inspect most recent call
|
||||
|
||||
### Testify-Based Mocks
|
||||
|
||||
Located in `testify_mocks_test.go`. Mocks using testify's `.On()/.Return()` pattern for behavior verification.
|
||||
|
||||
| Mock | Interface | Description |
|
||||
|------|-----------|-------------|
|
||||
| `TestifyJWKCache` | `JWKCacheInterface` | Testify mock with `.On()/.Return()` |
|
||||
| `TestifyTokenVerifier` | `TokenVerifier` | Testify mock for token verification |
|
||||
| `TestifyTokenExchanger` | `TokenExchanger` | Testify mock for token exchange |
|
||||
| `TestifyCacheInterface` | `CacheInterface` | Testify mock for cache operations |
|
||||
| `TestifyHTTPClient` | N/A | Testify mock for HTTP client |
|
||||
| `TestifyRoundTripper` | `http.RoundTripper` | Testify mock for HTTP transport |
|
||||
|
||||
**Usage:**
|
||||
```go
|
||||
mock := &TestifyJWKCache{}
|
||||
mock.On("GetJWKS", mock.Anything, "https://example.com/jwks", mock.Anything).
|
||||
Return(&JWKSet{Keys: []JWK{jwk}}, nil)
|
||||
|
||||
// After test
|
||||
mock.AssertExpectations(t)
|
||||
```
|
||||
|
||||
### Testutil Package Mocks
|
||||
|
||||
Located in `internal/testutil/mocks/`. Generic mocks for testing the test infrastructure itself.
|
||||
|
||||
```go
|
||||
import "github.com/lukaszraczylo/traefikoidc/internal/testutil"
|
||||
|
||||
mock := testutil.NewJWKCacheMock()
|
||||
mock.On("GetJWKS", mock.Anything, mock.Anything, mock.Anything).
|
||||
Return(&mocks.JWKSet{Keys: []mocks.JWK{{Kty: "RSA"}}}, nil)
|
||||
```
|
||||
|
||||
### Choosing the Right Mock
|
||||
|
||||
| Use Case | Recommended Mock |
|
||||
|----------|-----------------|
|
||||
| Simple return values only | Basic state-based (`MockJWKCache`) |
|
||||
| Return values + verify calls made | Enhanced state-based (`EnhancedMockJWKCache`) |
|
||||
| Complex call expectations | Testify-based (`TestifyJWKCache`) |
|
||||
| Verify call order/sequence | Testify-based |
|
||||
| HTTP endpoint simulation | `MockOAuthProvider` |
|
||||
| New testify suite tests | Enhanced or Testify-based |
|
||||
|
||||
**Decision Guide:**
|
||||
|
||||
1. **Basic State-Based**: Use when you only need to control return values and don't care about verifying interactions.
|
||||
|
||||
2. **Enhanced State-Based**: Use when you want to verify calls were made with specific parameters, but prefer simpler setup than testify's `.On()/.Return()` pattern.
|
||||
|
||||
3. **Testify-Based**: Use when you need complex behavior like different returns per call, strict call ordering, or detailed expectation matching.
|
||||
|
||||
## Token Fixtures
|
||||
|
||||
The `testutil.TokenFixture` generates JWT tokens for testing:
|
||||
|
||||
```go
|
||||
fixture, err := testutil.NewTokenFixture()
|
||||
|
||||
// Valid token with default claims
|
||||
token, _ := fixture.ValidToken(nil)
|
||||
|
||||
// Token with custom claims
|
||||
token, _ := fixture.ValidToken(map[string]interface{}{
|
||||
"email": "test@example.com",
|
||||
"roles": []string{"admin"},
|
||||
})
|
||||
|
||||
// Expired token
|
||||
token, _ := fixture.ExpiredToken()
|
||||
|
||||
// Token with specific roles/groups
|
||||
token, _ := fixture.TokenWithRoles([]string{"admin", "user"})
|
||||
token, _ := fixture.TokenWithGroups([]string{"developers"})
|
||||
|
||||
// Token with clock skew
|
||||
token, _ := fixture.TokenWithSkew(-2 * time.Minute) // expired 2 min ago
|
||||
token, _ := fixture.TokenWithSkew(5 * time.Minute) // expires in 5 min
|
||||
|
||||
// Token missing specific claims
|
||||
token, _ := fixture.TokenMissingClaim("email", "sub")
|
||||
|
||||
// Malformed token
|
||||
token := fixture.MalformedToken() // "not.a.valid.jwt"
|
||||
|
||||
// Get JWKS for verification
|
||||
jwks := fixture.GetJWKS()
|
||||
```
|
||||
|
||||
## Mock OIDC Server
|
||||
|
||||
The `testutil.OIDCServer` provides a fully functional mock OIDC provider:
|
||||
|
||||
```go
|
||||
// Default configuration
|
||||
server := testutil.NewOIDCServer(nil)
|
||||
defer server.Close()
|
||||
|
||||
// Custom configuration
|
||||
config := testutil.DefaultServerConfig()
|
||||
config.Issuer = "https://custom-issuer.com"
|
||||
config.TokenError = &testutil.OIDCError{
|
||||
Error: "invalid_grant",
|
||||
Description: "Authorization code expired",
|
||||
}
|
||||
server := testutil.NewOIDCServer(config)
|
||||
|
||||
// Provider-specific configurations
|
||||
googleConfig := testutil.GoogleServerConfig()
|
||||
azureConfig := testutil.AzureServerConfig()
|
||||
auth0Config := testutil.Auth0ServerConfig()
|
||||
keycloakConfig := testutil.KeycloakServerConfig()
|
||||
|
||||
// Behavior configurations
|
||||
slowConfig := testutil.SlowServerConfig(100 * time.Millisecond)
|
||||
rateLimitedConfig := testutil.RateLimitedServerConfig(5) // Limit after 5 requests
|
||||
```
|
||||
|
||||
### Server Endpoints
|
||||
|
||||
| Endpoint | Description |
|
||||
|----------|-------------|
|
||||
| `/.well-known/openid-configuration` | OIDC discovery document |
|
||||
| `/authorize` | Authorization endpoint |
|
||||
| `/token` | Token exchange endpoint |
|
||||
| `/jwks` | JSON Web Key Set |
|
||||
| `/userinfo` | User information endpoint |
|
||||
| `/introspect` | Token introspection |
|
||||
| `/revoke` | Token revocation |
|
||||
| `/logout` | End session endpoint |
|
||||
|
||||
### Request Tracking
|
||||
|
||||
```go
|
||||
server := testutil.NewOIDCServer(nil)
|
||||
|
||||
// Make requests...
|
||||
|
||||
count := server.GetRequestCount()
|
||||
requests := server.GetRequests()
|
||||
server.Reset() // Clear tracking
|
||||
```
|
||||
|
||||
## Writing Test Suites
|
||||
|
||||
### Basic Suite Structure
|
||||
|
||||
```go
|
||||
type MyTestSuite struct {
|
||||
suite.Suite
|
||||
|
||||
fixture *testutil.TokenFixture
|
||||
tOidc *TraefikOidc
|
||||
}
|
||||
|
||||
func (s *MyTestSuite) SetupSuite() {
|
||||
var err error
|
||||
s.fixture, err = testutil.NewTokenFixture()
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
func (s *MyTestSuite) SetupTest() {
|
||||
// Per-test setup
|
||||
s.tOidc = &TraefikOidc{
|
||||
issuerURL: s.fixture.Issuer,
|
||||
// ...
|
||||
}
|
||||
}
|
||||
|
||||
func (s *MyTestSuite) TearDownTest() {
|
||||
// Per-test cleanup
|
||||
}
|
||||
|
||||
func (s *MyTestSuite) TestSomething() {
|
||||
token, err := s.fixture.ValidToken(nil)
|
||||
s.Require().NoError(err)
|
||||
|
||||
err = s.tOidc.VerifyToken(token)
|
||||
s.NoError(err)
|
||||
}
|
||||
|
||||
func TestMyTestSuite(t *testing.T) {
|
||||
suite.Run(t, new(MyTestSuite))
|
||||
}
|
||||
```
|
||||
|
||||
### Table-Driven Tests
|
||||
|
||||
```go
|
||||
func (s *MyTestSuite) TestClockSkewEdgeCases() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
skew time.Duration
|
||||
shouldPass bool
|
||||
}{
|
||||
{"valid_token", 5 * time.Minute, true},
|
||||
{"expired_within_tolerance", -1 * time.Minute, true},
|
||||
{"expired_beyond_tolerance", -10 * time.Minute, false},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
token, err := s.fixture.TokenWithSkew(tc.skew)
|
||||
s.Require().NoError(err)
|
||||
|
||||
err = s.tOidc.VerifyToken(token)
|
||||
if tc.shouldPass {
|
||||
s.NoError(err)
|
||||
} else {
|
||||
s.Error(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Test Categories
|
||||
|
||||
### Happy Path Tests
|
||||
|
||||
Test the expected successful scenarios:
|
||||
|
||||
- Valid token verification
|
||||
- Successful token exchange
|
||||
- Session creation and retrieval
|
||||
- Cache operations
|
||||
|
||||
### Error Case Tests
|
||||
|
||||
Test failure scenarios:
|
||||
|
||||
- Expired tokens
|
||||
- Invalid signatures
|
||||
- Wrong issuer/audience
|
||||
- Network failures
|
||||
- Rate limiting
|
||||
|
||||
### Edge Case Tests
|
||||
|
||||
Test boundary conditions:
|
||||
|
||||
- Clock skew tolerance boundaries
|
||||
- Unicode/emoji in claims
|
||||
- Very large claim values
|
||||
- Concurrent access
|
||||
- Special characters in URLs
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Use fixtures for token generation** - Don't manually construct JWTs
|
||||
2. **Use mock servers for integration tests** - Test against realistic OIDC behavior
|
||||
3. **Always run with `-race`** - Catch concurrency issues early
|
||||
4. **Use testify assertions** - Better error messages and cleaner code
|
||||
5. **Clean up resources** - Use `t.Cleanup()` or `TearDownTest()`
|
||||
6. **Test edge cases systematically** - Use table-driven tests
|
||||
@@ -1,163 +0,0 @@
|
||||
# Google OAuth Integration Fix
|
||||
|
||||
## Problem Overview
|
||||
|
||||
The Traefik OIDC plugin encountered an authentication issue when using Google as an OAuth provider. Authentication would fail with the following error:
|
||||
|
||||
```
|
||||
Some requested scopes were invalid. {valid=[openid, https://www.googleapis.com/auth/userinfo.email, https://www.googleapis.com/auth/userinfo.profile], invalid=[offline_access]}
|
||||
```
|
||||
|
||||
This occurred because Google's OAuth implementation differs from the standard OIDC specification in how it handles refresh tokens and offline access.
|
||||
|
||||
## Technical Details of the Issue
|
||||
|
||||
### Standard OIDC Provider Behavior
|
||||
|
||||
Most OpenID Connect (OIDC) providers follow the standard specification, where:
|
||||
- To obtain a refresh token, clients include the `offline_access` scope in their authorization request
|
||||
- This allows authenticated sessions to persist beyond the initial access token expiration
|
||||
|
||||
### Google's Non-Standard Approach
|
||||
|
||||
Google's OAuth implementation deviates from the standard by:
|
||||
1. Not supporting the `offline_access` scope, instead rejecting it as an invalid scope
|
||||
2. Requiring the `access_type=offline` query parameter for requesting refresh tokens
|
||||
3. Needing the `prompt=consent` parameter to consistently issue refresh tokens (especially for repeat authentications)
|
||||
|
||||
This difference caused the plugin to fail when configured for Google OAuth, as it was using a standard approach that didn't work with Google's implementation.
|
||||
|
||||
## Solution Implementation
|
||||
|
||||
The fix involved modifying the authentication flow to specifically handle Google providers:
|
||||
|
||||
1. **Google Provider Detection**: Added code to detect if the OIDC provider is Google based on the issuer URL:
|
||||
|
||||
```go
|
||||
// Check if we're dealing with a Google OIDC provider
|
||||
isGoogleProvider := strings.Contains(t.issuerURL, "google") ||
|
||||
strings.Contains(t.issuerURL, "accounts.google.com")
|
||||
```
|
||||
|
||||
2. **Provider-Specific Auth URL Building**: Modified the `buildAuthURL` function to handle Google and non-Google providers differently:
|
||||
|
||||
```go
|
||||
// Handle offline access differently for Google vs other providers
|
||||
if isGoogleProvider {
|
||||
// For Google, use access_type=offline parameter instead of offline_access scope
|
||||
params.Set("access_type", "offline")
|
||||
t.logger.Debug("Google OIDC provider detected, added access_type=offline for refresh tokens")
|
||||
|
||||
// Add prompt=consent for Google to ensure refresh token is issued
|
||||
params.Set("prompt", "consent")
|
||||
t.logger.Debug("Google OIDC provider detected, added prompt=consent to ensure refresh tokens")
|
||||
} else {
|
||||
// For non-Google providers, use the offline_access scope
|
||||
hasOfflineAccess := false
|
||||
for _, scope := range scopes {
|
||||
if scope == "offline_access" {
|
||||
hasOfflineAccess = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !hasOfflineAccess {
|
||||
scopes = append(scopes, "offline_access")
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
3. **Token Refresh Enhancement**: Improved the token refresh logic to better handle Google's behavior, particularly when refresh tokens aren't returned in refresh responses (as Google often uses the same refresh token for multiple requests).
|
||||
|
||||
## Why This Approach Works
|
||||
|
||||
This solution aligns with Google's OAuth 2.0 documentation which specifies:
|
||||
|
||||
1. **Access Type Parameter**: Google's [OAuth 2.0 documentation](https://developers.google.com/identity/protocols/oauth2/web-server#offline) states that to request a refresh token, applications must include `access_type=offline` in the authorization request.
|
||||
|
||||
2. **Prompt Parameter**: The [`prompt=consent`](https://developers.google.com/identity/protocols/oauth2/web-server#forceapprovalprompt) parameter forces the consent screen to appear, ensuring a refresh token is issued even if the user has previously granted access.
|
||||
|
||||
3. **Scope Validation**: Google strictly validates scopes and rejects non-standard ones like `offline_access`, instead relying on the `access_type` parameter to indicate whether a refresh token should be issued.
|
||||
|
||||
By adapting to these Google-specific requirements, the OIDC plugin can now seamlessly work with both standard OIDC providers and Google's OAuth implementation.
|
||||
|
||||
## Testing and Verification
|
||||
|
||||
Comprehensive tests were implemented to verify the solution:
|
||||
|
||||
1. **Provider Detection Test**: Ensures the code correctly identifies Google providers and applies the appropriate parameters.
|
||||
|
||||
2. **Auth URL Parameter Tests**: Verifies that:
|
||||
- For Google providers: `access_type=offline` and `prompt=consent` are included; `offline_access` scope is NOT included
|
||||
- For non-Google providers: `offline_access` scope IS included; `access_type` parameter is NOT added
|
||||
|
||||
3. **Token Refresh Tests**: Validates that Google's token refresh process works correctly, including the preservation of refresh tokens when Google doesn't return a new one.
|
||||
|
||||
4. **Integration Test**: Tests the complete authentication flow with a mocked Google provider to ensure all components work together seamlessly.
|
||||
|
||||
Sample test case (simplified):
|
||||
|
||||
```go
|
||||
t.Run("Google provider detection adds required parameters", func(t *testing.T) {
|
||||
// Test buildAuthURL to ensure it adds access_type=offline and prompt=consent for Google
|
||||
authURL := tOidc.buildAuthURL("https://example.com/callback", "state123", "nonce123", "")
|
||||
|
||||
// Check that access_type=offline was added (not offline_access scope for Google)
|
||||
if !strings.Contains(authURL, "access_type=offline") {
|
||||
t.Errorf("access_type=offline not added to Google auth URL: %s", authURL)
|
||||
}
|
||||
|
||||
// Verify offline_access scope is NOT included for Google providers
|
||||
if strings.Contains(authURL, "offline_access") {
|
||||
t.Errorf("offline_access scope incorrectly added to Google auth URL: %s", authURL)
|
||||
}
|
||||
|
||||
// Check that prompt=consent was added
|
||||
if !strings.Contains(authURL, "prompt=consent") {
|
||||
t.Errorf("prompt=consent not added to Google auth URL: %s", authURL)
|
||||
}
|
||||
})
|
||||
```
|
||||
|
||||
## Usage Guidance for Developers
|
||||
|
||||
When configuring the Traefik OIDC middleware for Google:
|
||||
|
||||
1. **Provider URL**: Use `https://accounts.google.com` as the `providerURL` value
|
||||
|
||||
2. **Client Configuration**: Create OAuth 2.0 credentials in the Google Cloud Console:
|
||||
- Configure the authorized redirect URI to match your `callbackURL` setting
|
||||
- Ensure your OAuth consent screen is properly configured (especially if you want long-lived refresh tokens)
|
||||
|
||||
3. **Configuration Example**:
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-google
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://accounts.google.com
|
||||
clientID: your-google-client-id.apps.googleusercontent.com
|
||||
clientSecret: your-google-client-secret
|
||||
sessionEncryptionKey: your-secure-encryption-key-min-32-chars
|
||||
callbackURL: /oauth2/callback
|
||||
scopes:
|
||||
- openid
|
||||
- email
|
||||
- profile
|
||||
# Note: DO NOT manually add offline_access scope for Google
|
||||
# The middleware handles this automatically and correctly
|
||||
```
|
||||
|
||||
4. **Troubleshooting**: If sessions still expire prematurely with Google (typically after 1 hour):
|
||||
- Ensure your Google Cloud OAuth consent screen is set to "External" and "Production" mode (not "Testing" mode, which limits refresh token validity)
|
||||
- Review your application logs with `logLevel: debug` to check for refresh token errors
|
||||
- Verify you're using a version of the middleware that includes this fix
|
||||
|
||||
## Conclusion
|
||||
|
||||
This fix ensures that the Traefik OIDC plugin works seamlessly with Google's OAuth implementation without requiring users to make provider-specific configuration changes. The middleware now intelligently adapts to the provider's requirements, making it more robust and user-friendly while maintaining compatibility with the standard OIDC specification for other providers.
|
||||
+1373
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,540 @@
|
||||
// Package traefikoidc provides OIDC authentication middleware for Traefik
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ClientRegistrationResponse represents the response from a successful client registration (RFC 7591)
|
||||
type ClientRegistrationResponse struct {
|
||||
SubjectType string `json:"subject_type,omitempty"`
|
||||
LogoURI string `json:"logo_uri,omitempty"`
|
||||
RegistrationAccessToken string `json:"registration_access_token,omitempty"`
|
||||
RegistrationClientURI string `json:"registration_client_uri,omitempty"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
TokenEndpointAuthMethod string `json:"token_endpoint_auth_method,omitempty"`
|
||||
TOSURI string `json:"tos_uri,omitempty"`
|
||||
PolicyURI string `json:"policy_uri,omitempty"`
|
||||
ClientSecret string `json:"client_secret,omitempty"`
|
||||
ApplicationType string `json:"application_type,omitempty"`
|
||||
ClientID string `json:"client_id"`
|
||||
ClientName string `json:"client_name,omitempty"`
|
||||
JWKSURI string `json:"jwks_uri,omitempty"`
|
||||
ClientURI string `json:"client_uri,omitempty"`
|
||||
Contacts []string `json:"contacts,omitempty"`
|
||||
GrantTypes []string `json:"grant_types,omitempty"`
|
||||
ResponseTypes []string `json:"response_types,omitempty"`
|
||||
RedirectURIs []string `json:"redirect_uris,omitempty"`
|
||||
ClientSecretExpiresAt int64 `json:"client_secret_expires_at,omitempty"`
|
||||
ClientIDIssuedAt int64 `json:"client_id_issued_at,omitempty"`
|
||||
}
|
||||
|
||||
// ClientRegistrationError represents an error response from client registration (RFC 7591)
|
||||
type ClientRegistrationError struct {
|
||||
Error string `json:"error"`
|
||||
ErrorDescription string `json:"error_description,omitempty"`
|
||||
}
|
||||
|
||||
// DynamicClientRegistrar handles OIDC Dynamic Client Registration (RFC 7591)
|
||||
type DynamicClientRegistrar struct {
|
||||
httpClient *http.Client
|
||||
logger *Logger
|
||||
config *DynamicClientRegistrationConfig
|
||||
registrationResponse *ClientRegistrationResponse
|
||||
providerURL string
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewDynamicClientRegistrar creates a new dynamic client registrar
|
||||
func NewDynamicClientRegistrar(
|
||||
httpClient *http.Client,
|
||||
logger *Logger,
|
||||
dcrConfig *DynamicClientRegistrationConfig,
|
||||
providerURL string,
|
||||
) *DynamicClientRegistrar {
|
||||
if logger == nil {
|
||||
logger = GetSingletonNoOpLogger()
|
||||
}
|
||||
|
||||
return &DynamicClientRegistrar{
|
||||
httpClient: httpClient,
|
||||
logger: logger,
|
||||
config: dcrConfig,
|
||||
providerURL: providerURL,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterClient performs dynamic client registration with the OIDC provider
|
||||
// It first attempts to load existing credentials from a file if persistence is enabled,
|
||||
// then registers a new client if no valid credentials exist.
|
||||
func (r *DynamicClientRegistrar) RegisterClient(ctx context.Context, registrationEndpoint string) (*ClientRegistrationResponse, error) {
|
||||
if r.config == nil || !r.config.Enabled {
|
||||
return nil, fmt.Errorf("dynamic client registration is not enabled")
|
||||
}
|
||||
|
||||
// Try to load existing credentials if persistence is enabled
|
||||
if r.config.PersistCredentials {
|
||||
if resp, err := r.loadCredentials(); err == nil && resp != nil {
|
||||
// Check if credentials are still valid (not expired)
|
||||
if r.areCredentialsValid(resp) {
|
||||
r.logger.Info("Loaded existing client credentials from file")
|
||||
r.mu.Lock()
|
||||
r.registrationResponse = resp
|
||||
r.mu.Unlock()
|
||||
return resp, nil
|
||||
}
|
||||
r.logger.Info("Existing credentials expired or invalid, registering new client")
|
||||
}
|
||||
}
|
||||
|
||||
// Determine registration endpoint
|
||||
endpoint := registrationEndpoint
|
||||
if r.config.RegistrationEndpoint != "" {
|
||||
endpoint = r.config.RegistrationEndpoint
|
||||
}
|
||||
|
||||
if endpoint == "" {
|
||||
return nil, fmt.Errorf("no registration endpoint available: provider does not support dynamic client registration or endpoint not configured")
|
||||
}
|
||||
|
||||
// Validate the endpoint URL
|
||||
if !strings.HasPrefix(endpoint, "https://") {
|
||||
// Allow http only for localhost/development
|
||||
if !strings.HasPrefix(endpoint, "http://localhost") && !strings.HasPrefix(endpoint, "http://127.0.0.1") {
|
||||
return nil, fmt.Errorf("registration endpoint must use HTTPS for security")
|
||||
}
|
||||
r.logger.Infof("Warning: using insecure HTTP for registration endpoint (development only): %s", endpoint)
|
||||
}
|
||||
|
||||
// Build registration request
|
||||
reqBody, err := r.buildRegistrationRequest()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build registration request: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Debugf("Registering client at endpoint: %s", endpoint)
|
||||
|
||||
// Create HTTP request
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(reqBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create registration request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
// Add Initial Access Token if provided
|
||||
if r.config.InitialAccessToken != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+r.config.InitialAccessToken)
|
||||
}
|
||||
|
||||
// Execute request
|
||||
resp, err := r.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("registration request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Read response body
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) // 1MB limit
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read registration response: %w", err)
|
||||
}
|
||||
|
||||
// Handle error responses
|
||||
if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK {
|
||||
var regError ClientRegistrationError
|
||||
if jsonErr := json.Unmarshal(body, ®Error); jsonErr == nil && regError.Error != "" {
|
||||
return nil, fmt.Errorf("registration failed: %s - %s", regError.Error, regError.ErrorDescription)
|
||||
}
|
||||
return nil, fmt.Errorf("registration failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// Parse successful response
|
||||
var regResp ClientRegistrationResponse
|
||||
if err := json.Unmarshal(body, ®Resp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse registration response: %w", err)
|
||||
}
|
||||
|
||||
// Validate response
|
||||
if regResp.ClientID == "" {
|
||||
return nil, fmt.Errorf("registration response missing client_id")
|
||||
}
|
||||
|
||||
r.logger.Infof("Successfully registered client with ID: %s", regResp.ClientID)
|
||||
|
||||
// Cache the response
|
||||
r.mu.Lock()
|
||||
r.registrationResponse = ®Resp
|
||||
r.mu.Unlock()
|
||||
|
||||
// Persist credentials if enabled
|
||||
if r.config.PersistCredentials {
|
||||
if err := r.saveCredentials(®Resp); err != nil {
|
||||
r.logger.Errorf("Failed to persist client credentials: %v", err)
|
||||
// Don't fail registration if persistence fails
|
||||
}
|
||||
}
|
||||
|
||||
return ®Resp, nil
|
||||
}
|
||||
|
||||
// buildRegistrationRequest creates the JSON request body for client registration
|
||||
func (r *DynamicClientRegistrar) buildRegistrationRequest() ([]byte, error) {
|
||||
metadata := r.config.ClientMetadata
|
||||
if metadata == nil {
|
||||
metadata = &ClientRegistrationMetadata{}
|
||||
}
|
||||
|
||||
// Build request object
|
||||
reqData := make(map[string]interface{})
|
||||
|
||||
// Required: redirect_uris
|
||||
if len(metadata.RedirectURIs) > 0 {
|
||||
reqData["redirect_uris"] = metadata.RedirectURIs
|
||||
} else {
|
||||
return nil, fmt.Errorf("redirect_uris is required for client registration")
|
||||
}
|
||||
|
||||
// Optional fields - only include if set
|
||||
if len(metadata.ResponseTypes) > 0 {
|
||||
reqData["response_types"] = metadata.ResponseTypes
|
||||
} else {
|
||||
// Default to authorization code flow
|
||||
reqData["response_types"] = []string{"code"}
|
||||
}
|
||||
|
||||
if len(metadata.GrantTypes) > 0 {
|
||||
reqData["grant_types"] = metadata.GrantTypes
|
||||
} else {
|
||||
// Default grant types for authorization code flow
|
||||
reqData["grant_types"] = []string{"authorization_code", "refresh_token"}
|
||||
}
|
||||
|
||||
if metadata.ApplicationType != "" {
|
||||
reqData["application_type"] = metadata.ApplicationType
|
||||
}
|
||||
|
||||
if len(metadata.Contacts) > 0 {
|
||||
reqData["contacts"] = metadata.Contacts
|
||||
}
|
||||
|
||||
if metadata.ClientName != "" {
|
||||
reqData["client_name"] = metadata.ClientName
|
||||
}
|
||||
|
||||
if metadata.LogoURI != "" {
|
||||
reqData["logo_uri"] = metadata.LogoURI
|
||||
}
|
||||
|
||||
if metadata.ClientURI != "" {
|
||||
reqData["client_uri"] = metadata.ClientURI
|
||||
}
|
||||
|
||||
if metadata.PolicyURI != "" {
|
||||
reqData["policy_uri"] = metadata.PolicyURI
|
||||
}
|
||||
|
||||
if metadata.TOSURI != "" {
|
||||
reqData["tos_uri"] = metadata.TOSURI
|
||||
}
|
||||
|
||||
if metadata.JWKSURI != "" {
|
||||
reqData["jwks_uri"] = metadata.JWKSURI
|
||||
}
|
||||
|
||||
if metadata.SubjectType != "" {
|
||||
reqData["subject_type"] = metadata.SubjectType
|
||||
}
|
||||
|
||||
if metadata.TokenEndpointAuthMethod != "" {
|
||||
reqData["token_endpoint_auth_method"] = metadata.TokenEndpointAuthMethod
|
||||
} else {
|
||||
// Default to client_secret_basic for confidential clients
|
||||
reqData["token_endpoint_auth_method"] = "client_secret_basic"
|
||||
}
|
||||
|
||||
if metadata.DefaultMaxAge > 0 {
|
||||
reqData["default_max_age"] = metadata.DefaultMaxAge
|
||||
}
|
||||
|
||||
if metadata.RequireAuthTime {
|
||||
reqData["require_auth_time"] = metadata.RequireAuthTime
|
||||
}
|
||||
|
||||
if len(metadata.DefaultACRValues) > 0 {
|
||||
reqData["default_acr_values"] = metadata.DefaultACRValues
|
||||
}
|
||||
|
||||
if metadata.Scope != "" {
|
||||
reqData["scope"] = metadata.Scope
|
||||
}
|
||||
|
||||
return json.Marshal(reqData)
|
||||
}
|
||||
|
||||
// GetCachedResponse returns the cached registration response
|
||||
func (r *DynamicClientRegistrar) GetCachedResponse() *ClientRegistrationResponse {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
return r.registrationResponse
|
||||
}
|
||||
|
||||
// areCredentialsValid checks if the cached credentials are still valid
|
||||
func (r *DynamicClientRegistrar) areCredentialsValid(resp *ClientRegistrationResponse) bool {
|
||||
if resp == nil || resp.ClientID == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if secret has expired
|
||||
if resp.ClientSecretExpiresAt > 0 {
|
||||
expiresAt := time.Unix(resp.ClientSecretExpiresAt, 0)
|
||||
// Add 5 minute buffer before expiration
|
||||
if time.Now().Add(5 * time.Minute).After(expiresAt) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// credentialsFilePath returns the path for storing credentials
|
||||
func (r *DynamicClientRegistrar) credentialsFilePath() string {
|
||||
if r.config.CredentialsFile != "" {
|
||||
return r.config.CredentialsFile
|
||||
}
|
||||
return "/tmp/oidc-client-credentials.json"
|
||||
}
|
||||
|
||||
// saveCredentials persists client credentials to a file
|
||||
func (r *DynamicClientRegistrar) saveCredentials(resp *ClientRegistrationResponse) error {
|
||||
filePath := r.credentialsFilePath()
|
||||
|
||||
data, err := json.MarshalIndent(resp, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal credentials: %w", err)
|
||||
}
|
||||
|
||||
// Write with restrictive permissions (owner read/write only)
|
||||
if err := os.WriteFile(filePath, data, 0600); err != nil {
|
||||
return fmt.Errorf("failed to write credentials file: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Debugf("Saved client credentials to %s", filePath)
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadCredentials loads client credentials from a file
|
||||
func (r *DynamicClientRegistrar) loadCredentials() (*ClientRegistrationResponse, error) {
|
||||
filePath := r.credentialsFilePath()
|
||||
|
||||
// #nosec G304 -- path is constructed from trusted config values via credentialsFilePath()
|
||||
data, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, nil // No credentials file exists
|
||||
}
|
||||
return nil, fmt.Errorf("failed to read credentials file: %w", err)
|
||||
}
|
||||
|
||||
var resp ClientRegistrationResponse
|
||||
if err := json.Unmarshal(data, &resp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse credentials file: %w", err)
|
||||
}
|
||||
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// UpdateClientRegistration updates an existing client registration using RFC 7592
|
||||
// This requires the registration_client_uri and registration_access_token from the original registration
|
||||
func (r *DynamicClientRegistrar) UpdateClientRegistration(ctx context.Context) (*ClientRegistrationResponse, error) {
|
||||
r.mu.RLock()
|
||||
cachedResp := r.registrationResponse
|
||||
r.mu.RUnlock()
|
||||
|
||||
if cachedResp == nil {
|
||||
return nil, fmt.Errorf("no existing registration to update")
|
||||
}
|
||||
|
||||
if cachedResp.RegistrationClientURI == "" || cachedResp.RegistrationAccessToken == "" {
|
||||
return nil, fmt.Errorf("registration management not supported: missing registration_client_uri or registration_access_token")
|
||||
}
|
||||
|
||||
// Build update request
|
||||
reqBody, err := r.buildRegistrationRequest()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build update request: %w", err)
|
||||
}
|
||||
|
||||
// Create HTTP request
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPut, cachedResp.RegistrationClientURI, bytes.NewReader(reqBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create update request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+cachedResp.RegistrationAccessToken)
|
||||
|
||||
// Execute request
|
||||
resp, err := r.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("update request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Read response body
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read update response: %w", err)
|
||||
}
|
||||
|
||||
// Handle error responses
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
var regError ClientRegistrationError
|
||||
if jsonErr := json.Unmarshal(body, ®Error); jsonErr == nil && regError.Error != "" {
|
||||
return nil, fmt.Errorf("update failed: %s - %s", regError.Error, regError.ErrorDescription)
|
||||
}
|
||||
return nil, fmt.Errorf("update failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// Parse successful response
|
||||
var regResp ClientRegistrationResponse
|
||||
if err := json.Unmarshal(body, ®Resp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse update response: %w", err)
|
||||
}
|
||||
|
||||
// Update cache
|
||||
r.mu.Lock()
|
||||
r.registrationResponse = ®Resp
|
||||
r.mu.Unlock()
|
||||
|
||||
// Persist updated credentials if enabled
|
||||
if r.config.PersistCredentials {
|
||||
if err := r.saveCredentials(®Resp); err != nil {
|
||||
r.logger.Errorf("Failed to persist updated credentials: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
r.logger.Infof("Successfully updated client registration for client ID: %s", regResp.ClientID)
|
||||
return ®Resp, nil
|
||||
}
|
||||
|
||||
// ReadClientRegistration reads the current client registration using RFC 7592
|
||||
func (r *DynamicClientRegistrar) ReadClientRegistration(ctx context.Context) (*ClientRegistrationResponse, error) {
|
||||
r.mu.RLock()
|
||||
cachedResp := r.registrationResponse
|
||||
r.mu.RUnlock()
|
||||
|
||||
if cachedResp == nil {
|
||||
return nil, fmt.Errorf("no existing registration to read")
|
||||
}
|
||||
|
||||
if cachedResp.RegistrationClientURI == "" || cachedResp.RegistrationAccessToken == "" {
|
||||
return nil, fmt.Errorf("registration management not supported: missing registration_client_uri or registration_access_token")
|
||||
}
|
||||
|
||||
// Create HTTP request
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, cachedResp.RegistrationClientURI, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create read request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+cachedResp.RegistrationAccessToken)
|
||||
|
||||
// Execute request
|
||||
resp, err := r.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Read response body
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
// Handle error responses
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
var regError ClientRegistrationError
|
||||
if jsonErr := json.Unmarshal(body, ®Error); jsonErr == nil && regError.Error != "" {
|
||||
return nil, fmt.Errorf("read failed: %s - %s", regError.Error, regError.ErrorDescription)
|
||||
}
|
||||
return nil, fmt.Errorf("read failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// Parse successful response
|
||||
var regResp ClientRegistrationResponse
|
||||
if err := json.Unmarshal(body, ®Resp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse read response: %w", err)
|
||||
}
|
||||
|
||||
return ®Resp, nil
|
||||
}
|
||||
|
||||
// DeleteClientRegistration deletes the client registration using RFC 7592
|
||||
func (r *DynamicClientRegistrar) DeleteClientRegistration(ctx context.Context) error {
|
||||
r.mu.RLock()
|
||||
cachedResp := r.registrationResponse
|
||||
r.mu.RUnlock()
|
||||
|
||||
if cachedResp == nil {
|
||||
return fmt.Errorf("no existing registration to delete")
|
||||
}
|
||||
|
||||
if cachedResp.RegistrationClientURI == "" || cachedResp.RegistrationAccessToken == "" {
|
||||
return fmt.Errorf("registration management not supported: missing registration_client_uri or registration_access_token")
|
||||
}
|
||||
|
||||
// Create HTTP request
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, cachedResp.RegistrationClientURI, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create delete request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+cachedResp.RegistrationAccessToken)
|
||||
|
||||
// Execute request
|
||||
resp, err := r.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Handle error responses (204 No Content is success)
|
||||
if resp.StatusCode != http.StatusNoContent && resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
var regError ClientRegistrationError
|
||||
if jsonErr := json.Unmarshal(body, ®Error); jsonErr == nil && regError.Error != "" {
|
||||
return fmt.Errorf("delete failed: %s - %s", regError.Error, regError.ErrorDescription)
|
||||
}
|
||||
return fmt.Errorf("delete failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// Clear cache
|
||||
r.mu.Lock()
|
||||
r.registrationResponse = nil
|
||||
r.mu.Unlock()
|
||||
|
||||
// Remove credentials file if persistence is enabled
|
||||
if r.config.PersistCredentials {
|
||||
filePath := r.credentialsFilePath()
|
||||
if err := os.Remove(filePath); err != nil && !os.IsNotExist(err) {
|
||||
r.logger.Errorf("Failed to remove credentials file: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
r.logger.Info("Successfully deleted client registration")
|
||||
return nil
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,620 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/testutil"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// ClockSkewEdgeCasesSuite tests clock skew tolerance scenarios
|
||||
type ClockSkewEdgeCasesSuite struct {
|
||||
suite.Suite
|
||||
|
||||
fixture *testutil.TokenFixture
|
||||
tOidc *TraefikOidc
|
||||
}
|
||||
|
||||
func (s *ClockSkewEdgeCasesSuite) SetupSuite() {
|
||||
var err error
|
||||
s.fixture, err = testutil.NewTokenFixture()
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
func (s *ClockSkewEdgeCasesSuite) SetupTest() {
|
||||
// Create JWK for the test key
|
||||
jwk := JWK{
|
||||
Kty: "RSA",
|
||||
Kid: s.fixture.KeyID,
|
||||
Alg: "RS256",
|
||||
N: base64.RawURLEncoding.EncodeToString(s.fixture.RSAPublicKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString(bigIntToBytes(big.NewInt(int64(s.fixture.RSAPublicKey.E)))),
|
||||
}
|
||||
|
||||
jwkCache := &MockJWKCache{
|
||||
JWKS: &JWKSet{Keys: []JWK{jwk}},
|
||||
Err: nil,
|
||||
}
|
||||
|
||||
tokenBlacklist := NewCache()
|
||||
tokenCacheInternal := NewCache()
|
||||
tokenCache := &TokenCache{}
|
||||
if tokenCache.cache == nil {
|
||||
if wrapper, ok := tokenCacheInternal.(*CacheInterfaceWrapper); ok {
|
||||
tokenCache.cache = wrapper.cache
|
||||
}
|
||||
}
|
||||
|
||||
logger := NewLogger("error") // Reduce noise
|
||||
|
||||
s.tOidc = &TraefikOidc{
|
||||
issuerURL: s.fixture.Issuer,
|
||||
clientID: s.fixture.Audience,
|
||||
audience: s.fixture.Audience,
|
||||
clientSecret: "test-client-secret",
|
||||
roleClaimName: "roles",
|
||||
groupClaimName: "groups",
|
||||
userIdentifierClaim: "email",
|
||||
jwkCache: jwkCache,
|
||||
jwksURL: "https://test-jwks-url.com",
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
||||
tokenBlacklist: tokenBlacklist,
|
||||
tokenCache: tokenCache,
|
||||
logger: logger,
|
||||
httpClient: &http.Client{Timeout: 10 * time.Second},
|
||||
extractClaimsFunc: extractClaims,
|
||||
initComplete: make(chan struct{}),
|
||||
goroutineWG: &sync.WaitGroup{},
|
||||
ctx: context.Background(),
|
||||
}
|
||||
close(s.tOidc.initComplete)
|
||||
s.tOidc.tokenVerifier = s.tOidc
|
||||
s.tOidc.jwtVerifier = s.tOidc
|
||||
|
||||
s.T().Cleanup(func() {
|
||||
if s.tOidc.tokenBlacklist != nil {
|
||||
s.tOidc.tokenBlacklist.Close()
|
||||
}
|
||||
if s.tOidc.tokenCache != nil && s.tOidc.tokenCache.cache != nil {
|
||||
s.tOidc.tokenCache.cache.Close()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (s *ClockSkewEdgeCasesSuite) TestExactlyAtExpiry() {
|
||||
token, err := s.fixture.TokenWithSkew(0)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Token at exact expiry - behavior is implementation-defined
|
||||
err = s.tOidc.VerifyToken(token)
|
||||
s.T().Logf("Exact expiry result: %v", err)
|
||||
}
|
||||
|
||||
func (s *ClockSkewEdgeCasesSuite) TestOneSecondBeforeExpiry() {
|
||||
token, err := s.fixture.TokenWithSkew(1 * time.Second)
|
||||
s.Require().NoError(err)
|
||||
|
||||
err = s.tOidc.VerifyToken(token)
|
||||
s.NoError(err, "Token should be valid 1 second before expiry")
|
||||
}
|
||||
|
||||
func (s *ClockSkewEdgeCasesSuite) TestOneSecondAfterExpiry() {
|
||||
token, err := s.fixture.TokenWithSkew(-1 * time.Second)
|
||||
s.Require().NoError(err)
|
||||
|
||||
err = s.tOidc.VerifyToken(token)
|
||||
// With default 2-minute clock skew tolerance, 1 second past expiry should still be valid
|
||||
s.NoError(err, "Token 1 second past expiry should be valid within clock skew tolerance")
|
||||
}
|
||||
|
||||
func (s *ClockSkewEdgeCasesSuite) TestWithinSkewTolerance() {
|
||||
// Most implementations allow 5-minute clock skew
|
||||
token, err := s.fixture.TokenWithSkew(-4 * time.Minute)
|
||||
s.Require().NoError(err)
|
||||
|
||||
err = s.tOidc.VerifyToken(token)
|
||||
// May pass or fail depending on implementation
|
||||
s.T().Logf("4-minute expired token result: %v", err)
|
||||
}
|
||||
|
||||
func (s *ClockSkewEdgeCasesSuite) TestBeyondSkewTolerance() {
|
||||
token, err := s.fixture.TokenWithSkew(-10 * time.Minute)
|
||||
s.Require().NoError(err)
|
||||
|
||||
err = s.tOidc.VerifyToken(token)
|
||||
s.Error(err, "Token should be invalid 10 minutes after expiry")
|
||||
}
|
||||
|
||||
func TestClockSkewEdgeCasesSuite(t *testing.T) {
|
||||
suite.Run(t, new(ClockSkewEdgeCasesSuite))
|
||||
}
|
||||
|
||||
// UnicodeClaimsSuite tests Unicode handling in JWT claims
|
||||
type UnicodeClaimsSuite struct {
|
||||
suite.Suite
|
||||
|
||||
fixture *testutil.TokenFixture
|
||||
tOidc *TraefikOidc
|
||||
}
|
||||
|
||||
func (s *UnicodeClaimsSuite) SetupSuite() {
|
||||
var err error
|
||||
s.fixture, err = testutil.NewTokenFixture()
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
func (s *UnicodeClaimsSuite) SetupTest() {
|
||||
jwk := JWK{
|
||||
Kty: "RSA",
|
||||
Kid: s.fixture.KeyID,
|
||||
Alg: "RS256",
|
||||
N: base64.RawURLEncoding.EncodeToString(s.fixture.RSAPublicKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString(bigIntToBytes(big.NewInt(int64(s.fixture.RSAPublicKey.E)))),
|
||||
}
|
||||
|
||||
jwkCache := &MockJWKCache{
|
||||
JWKS: &JWKSet{Keys: []JWK{jwk}},
|
||||
Err: nil,
|
||||
}
|
||||
|
||||
tokenBlacklist := NewCache()
|
||||
tokenCacheInternal := NewCache()
|
||||
tokenCache := &TokenCache{}
|
||||
if tokenCache.cache == nil {
|
||||
if wrapper, ok := tokenCacheInternal.(*CacheInterfaceWrapper); ok {
|
||||
tokenCache.cache = wrapper.cache
|
||||
}
|
||||
}
|
||||
|
||||
logger := NewLogger("error")
|
||||
|
||||
s.tOidc = &TraefikOidc{
|
||||
issuerURL: s.fixture.Issuer,
|
||||
clientID: s.fixture.Audience,
|
||||
audience: s.fixture.Audience,
|
||||
clientSecret: "test-client-secret",
|
||||
roleClaimName: "roles",
|
||||
groupClaimName: "groups",
|
||||
userIdentifierClaim: "email",
|
||||
jwkCache: jwkCache,
|
||||
jwksURL: "https://test-jwks-url.com",
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
||||
tokenBlacklist: tokenBlacklist,
|
||||
tokenCache: tokenCache,
|
||||
logger: logger,
|
||||
httpClient: &http.Client{Timeout: 10 * time.Second},
|
||||
extractClaimsFunc: extractClaims,
|
||||
initComplete: make(chan struct{}),
|
||||
goroutineWG: &sync.WaitGroup{},
|
||||
ctx: context.Background(),
|
||||
}
|
||||
close(s.tOidc.initComplete)
|
||||
s.tOidc.tokenVerifier = s.tOidc
|
||||
s.tOidc.jwtVerifier = s.tOidc
|
||||
|
||||
s.T().Cleanup(func() {
|
||||
if s.tOidc.tokenBlacklist != nil {
|
||||
s.tOidc.tokenBlacklist.Close()
|
||||
}
|
||||
if s.tOidc.tokenCache != nil && s.tOidc.tokenCache.cache != nil {
|
||||
s.tOidc.tokenCache.cache.Close()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (s *UnicodeClaimsSuite) TestUnicodeEmail() {
|
||||
token, err := s.fixture.TokenWithEmail("用户@example.com")
|
||||
s.Require().NoError(err)
|
||||
|
||||
err = s.tOidc.VerifyToken(token)
|
||||
s.NoError(err, "Unicode email should be handled correctly")
|
||||
}
|
||||
|
||||
func (s *UnicodeClaimsSuite) TestUnicodeName() {
|
||||
token, err := s.fixture.TokenWithCustomClaims(map[string]interface{}{
|
||||
"name": "田中太郎",
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
|
||||
err = s.tOidc.VerifyToken(token)
|
||||
s.NoError(err, "Unicode name should be handled correctly")
|
||||
}
|
||||
|
||||
func (s *UnicodeClaimsSuite) TestEmojiInClaims() {
|
||||
token, err := s.fixture.TokenWithCustomClaims(map[string]interface{}{
|
||||
"name": "Test User 😀",
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
|
||||
err = s.tOidc.VerifyToken(token)
|
||||
s.NoError(err, "Emoji in claims should be handled correctly")
|
||||
}
|
||||
|
||||
func (s *UnicodeClaimsSuite) TestRTLText() {
|
||||
token, err := s.fixture.TokenWithCustomClaims(map[string]interface{}{
|
||||
"name": "مستخدم اختبار",
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
|
||||
err = s.tOidc.VerifyToken(token)
|
||||
s.NoError(err, "RTL text should be handled correctly")
|
||||
}
|
||||
|
||||
func (s *UnicodeClaimsSuite) TestMixedScripts() {
|
||||
token, err := s.fixture.TokenWithCustomClaims(map[string]interface{}{
|
||||
"name": "Test 测试 テスト",
|
||||
"roles": []string{"admin", "管理者", "管理员"},
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
|
||||
err = s.tOidc.VerifyToken(token)
|
||||
s.NoError(err, "Mixed scripts should be handled correctly")
|
||||
}
|
||||
|
||||
func TestUnicodeClaimsSuite(t *testing.T) {
|
||||
suite.Run(t, new(UnicodeClaimsSuite))
|
||||
}
|
||||
|
||||
// LargeClaimsSuite tests large claim values
|
||||
type LargeClaimsSuite struct {
|
||||
suite.Suite
|
||||
|
||||
fixture *testutil.TokenFixture
|
||||
tOidc *TraefikOidc
|
||||
}
|
||||
|
||||
func (s *LargeClaimsSuite) SetupSuite() {
|
||||
var err error
|
||||
s.fixture, err = testutil.NewTokenFixture()
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
func (s *LargeClaimsSuite) SetupTest() {
|
||||
jwk := JWK{
|
||||
Kty: "RSA",
|
||||
Kid: s.fixture.KeyID,
|
||||
Alg: "RS256",
|
||||
N: base64.RawURLEncoding.EncodeToString(s.fixture.RSAPublicKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString(bigIntToBytes(big.NewInt(int64(s.fixture.RSAPublicKey.E)))),
|
||||
}
|
||||
|
||||
jwkCache := &MockJWKCache{
|
||||
JWKS: &JWKSet{Keys: []JWK{jwk}},
|
||||
Err: nil,
|
||||
}
|
||||
|
||||
tokenBlacklist := NewCache()
|
||||
tokenCacheInternal := NewCache()
|
||||
tokenCache := &TokenCache{}
|
||||
if tokenCache.cache == nil {
|
||||
if wrapper, ok := tokenCacheInternal.(*CacheInterfaceWrapper); ok {
|
||||
tokenCache.cache = wrapper.cache
|
||||
}
|
||||
}
|
||||
|
||||
logger := NewLogger("error")
|
||||
|
||||
s.tOidc = &TraefikOidc{
|
||||
issuerURL: s.fixture.Issuer,
|
||||
clientID: s.fixture.Audience,
|
||||
audience: s.fixture.Audience,
|
||||
clientSecret: "test-client-secret",
|
||||
roleClaimName: "roles",
|
||||
groupClaimName: "groups",
|
||||
userIdentifierClaim: "email",
|
||||
jwkCache: jwkCache,
|
||||
jwksURL: "https://test-jwks-url.com",
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
||||
tokenBlacklist: tokenBlacklist,
|
||||
tokenCache: tokenCache,
|
||||
logger: logger,
|
||||
httpClient: &http.Client{Timeout: 10 * time.Second},
|
||||
extractClaimsFunc: extractClaims,
|
||||
initComplete: make(chan struct{}),
|
||||
goroutineWG: &sync.WaitGroup{},
|
||||
ctx: context.Background(),
|
||||
}
|
||||
close(s.tOidc.initComplete)
|
||||
s.tOidc.tokenVerifier = s.tOidc
|
||||
s.tOidc.jwtVerifier = s.tOidc
|
||||
|
||||
s.T().Cleanup(func() {
|
||||
if s.tOidc.tokenBlacklist != nil {
|
||||
s.tOidc.tokenBlacklist.Close()
|
||||
}
|
||||
if s.tOidc.tokenCache != nil && s.tOidc.tokenCache.cache != nil {
|
||||
s.tOidc.tokenCache.cache.Close()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (s *LargeClaimsSuite) TestManyRoles() {
|
||||
roles := make([]string, 100)
|
||||
for i := 0; i < 100; i++ {
|
||||
roles[i] = strings.Repeat("role", 10) + string(rune('A'+i%26))
|
||||
}
|
||||
|
||||
token, err := s.fixture.TokenWithRoles(roles)
|
||||
s.Require().NoError(err)
|
||||
|
||||
err = s.tOidc.VerifyToken(token)
|
||||
s.NoError(err, "Token with 100 roles should be handled")
|
||||
}
|
||||
|
||||
func (s *LargeClaimsSuite) TestManyGroups() {
|
||||
groups := make([]string, 50)
|
||||
for i := 0; i < 50; i++ {
|
||||
groups[i] = strings.Repeat("group", 5) + string(rune('A'+i%26))
|
||||
}
|
||||
|
||||
token, err := s.fixture.TokenWithGroups(groups)
|
||||
s.Require().NoError(err)
|
||||
|
||||
err = s.tOidc.VerifyToken(token)
|
||||
s.NoError(err, "Token with 50 groups should be handled")
|
||||
}
|
||||
|
||||
func (s *LargeClaimsSuite) TestLongEmail() {
|
||||
// RFC 5321 allows up to 254 characters
|
||||
localPart := strings.Repeat("a", 64)
|
||||
domain := strings.Repeat("b", 63) + ".com"
|
||||
email := localPart + "@" + domain
|
||||
|
||||
token, err := s.fixture.TokenWithEmail(email)
|
||||
s.Require().NoError(err)
|
||||
|
||||
err = s.tOidc.VerifyToken(token)
|
||||
s.NoError(err, "Token with long email should be handled")
|
||||
}
|
||||
|
||||
func (s *LargeClaimsSuite) TestLongSubject() {
|
||||
longSub := strings.Repeat("subject", 100)
|
||||
|
||||
token, err := s.fixture.TokenWithCustomClaims(map[string]interface{}{
|
||||
"sub": longSub,
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
|
||||
err = s.tOidc.VerifyToken(token)
|
||||
s.NoError(err, "Token with long subject should be handled")
|
||||
}
|
||||
|
||||
func TestLargeClaimsSuite(t *testing.T) {
|
||||
suite.Run(t, new(LargeClaimsSuite))
|
||||
}
|
||||
|
||||
// URLPathEdgeCasesSuite tests URL handling edge cases
|
||||
type URLPathEdgeCasesSuite struct {
|
||||
suite.Suite
|
||||
}
|
||||
|
||||
func (s *URLPathEdgeCasesSuite) TestVeryLongPath() {
|
||||
longPath := "/" + strings.Repeat("segment/", 100)
|
||||
req := httptest.NewRequest("GET", longPath, nil)
|
||||
|
||||
s.NotNil(req)
|
||||
s.Contains(req.URL.Path, "segment")
|
||||
}
|
||||
|
||||
func (s *URLPathEdgeCasesSuite) TestSpecialCharactersInPath() {
|
||||
paths := []string{
|
||||
"/path%20with%20spaces",
|
||||
"/path/with/日本語",
|
||||
"/path?query=value&another=test",
|
||||
"/path#fragment",
|
||||
"/path/../traversal",
|
||||
"/path/./current",
|
||||
}
|
||||
|
||||
for _, path := range paths {
|
||||
s.Run(path, func() {
|
||||
req := httptest.NewRequest("GET", path, nil)
|
||||
s.NotNil(req)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *URLPathEdgeCasesSuite) TestEmptyPath() {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
s.Equal("/", req.URL.Path)
|
||||
}
|
||||
|
||||
func (s *URLPathEdgeCasesSuite) TestDoubleSlashes() {
|
||||
req := httptest.NewRequest("GET", "//double//slashes//", nil)
|
||||
s.NotNil(req)
|
||||
}
|
||||
|
||||
func TestURLPathEdgeCasesSuite(t *testing.T) {
|
||||
suite.Run(t, new(URLPathEdgeCasesSuite))
|
||||
}
|
||||
|
||||
// ConcurrencyEdgeCasesSuite tests concurrency scenarios
|
||||
type ConcurrencyEdgeCasesSuite struct {
|
||||
suite.Suite
|
||||
|
||||
fixture *testutil.TokenFixture
|
||||
tOidc *TraefikOidc
|
||||
}
|
||||
|
||||
func (s *ConcurrencyEdgeCasesSuite) SetupSuite() {
|
||||
var err error
|
||||
s.fixture, err = testutil.NewTokenFixture()
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyEdgeCasesSuite) SetupTest() {
|
||||
jwk := JWK{
|
||||
Kty: "RSA",
|
||||
Kid: s.fixture.KeyID,
|
||||
Alg: "RS256",
|
||||
N: base64.RawURLEncoding.EncodeToString(s.fixture.RSAPublicKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString(bigIntToBytes(big.NewInt(int64(s.fixture.RSAPublicKey.E)))),
|
||||
}
|
||||
|
||||
jwkCache := &MockJWKCache{
|
||||
JWKS: &JWKSet{Keys: []JWK{jwk}},
|
||||
Err: nil,
|
||||
}
|
||||
|
||||
tokenBlacklist := NewCache()
|
||||
tokenCacheInternal := NewCache()
|
||||
tokenCache := &TokenCache{}
|
||||
if tokenCache.cache == nil {
|
||||
if wrapper, ok := tokenCacheInternal.(*CacheInterfaceWrapper); ok {
|
||||
tokenCache.cache = wrapper.cache
|
||||
}
|
||||
}
|
||||
|
||||
logger := NewLogger("error")
|
||||
|
||||
s.tOidc = &TraefikOidc{
|
||||
issuerURL: s.fixture.Issuer,
|
||||
clientID: s.fixture.Audience,
|
||||
audience: s.fixture.Audience,
|
||||
clientSecret: "test-client-secret",
|
||||
roleClaimName: "roles",
|
||||
groupClaimName: "groups",
|
||||
userIdentifierClaim: "email",
|
||||
jwkCache: jwkCache,
|
||||
jwksURL: "https://test-jwks-url.com",
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 100), // Higher limit for concurrency tests
|
||||
tokenBlacklist: tokenBlacklist,
|
||||
tokenCache: tokenCache,
|
||||
logger: logger,
|
||||
httpClient: &http.Client{Timeout: 10 * time.Second},
|
||||
extractClaimsFunc: extractClaims,
|
||||
initComplete: make(chan struct{}),
|
||||
goroutineWG: &sync.WaitGroup{},
|
||||
ctx: context.Background(),
|
||||
}
|
||||
close(s.tOidc.initComplete)
|
||||
s.tOidc.tokenVerifier = s.tOidc
|
||||
s.tOidc.jwtVerifier = s.tOidc
|
||||
|
||||
s.T().Cleanup(func() {
|
||||
if s.tOidc.tokenBlacklist != nil {
|
||||
s.tOidc.tokenBlacklist.Close()
|
||||
}
|
||||
if s.tOidc.tokenCache != nil && s.tOidc.tokenCache.cache != nil {
|
||||
s.tOidc.tokenCache.cache.Close()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (s *ConcurrencyEdgeCasesSuite) TestConcurrentTokenValidation() {
|
||||
token, err := s.fixture.ValidToken(nil)
|
||||
s.Require().NoError(err)
|
||||
|
||||
const goroutines = 50
|
||||
var wg sync.WaitGroup
|
||||
errors := make(chan error, goroutines)
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := s.tOidc.VerifyToken(token); err != nil {
|
||||
errors <- err
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errors)
|
||||
|
||||
var errCount int
|
||||
for err := range errors {
|
||||
s.T().Logf("Concurrent error: %v", err)
|
||||
errCount++
|
||||
}
|
||||
|
||||
s.Equal(0, errCount, "All concurrent validations should succeed")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyEdgeCasesSuite) TestConcurrentDifferentTokens() {
|
||||
const goroutines = 20
|
||||
var wg sync.WaitGroup
|
||||
errors := make(chan error, goroutines)
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
token, err := s.fixture.TokenWithCustomClaims(map[string]interface{}{
|
||||
"custom": idx,
|
||||
})
|
||||
if err != nil {
|
||||
errors <- err
|
||||
return
|
||||
}
|
||||
if err := s.tOidc.VerifyToken(token); err != nil {
|
||||
errors <- err
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errors)
|
||||
|
||||
var errCount int
|
||||
for err := range errors {
|
||||
s.T().Logf("Concurrent different token error: %v", err)
|
||||
errCount++
|
||||
}
|
||||
|
||||
s.Equal(0, errCount, "All concurrent different token validations should succeed")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyEdgeCasesSuite) TestConcurrentMixedValidInvalid() {
|
||||
validToken, err := s.fixture.ValidToken(nil)
|
||||
s.Require().NoError(err)
|
||||
expiredToken, err := s.fixture.ExpiredToken()
|
||||
s.Require().NoError(err)
|
||||
|
||||
const goroutines = 40
|
||||
var wg sync.WaitGroup
|
||||
validCount := int32(0)
|
||||
expiredCount := int32(0)
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
var token string
|
||||
if idx%2 == 0 {
|
||||
token = validToken
|
||||
} else {
|
||||
token = expiredToken
|
||||
}
|
||||
|
||||
err := s.tOidc.VerifyToken(token)
|
||||
if idx%2 == 0 {
|
||||
if err == nil {
|
||||
atomic.AddInt32(&validCount, 1)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
atomic.AddInt32(&expiredCount, 1)
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
s.T().Logf("Valid passed: %d, Expired rejected: %d", validCount, expiredCount)
|
||||
}
|
||||
|
||||
func TestConcurrencyEdgeCasesSuite(t *testing.T) {
|
||||
suite.Run(t, new(ConcurrencyEdgeCasesSuite))
|
||||
}
|
||||
@@ -0,0 +1,258 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
// EnhancedMocksSuite demonstrates improved state-based mocks with call tracking
|
||||
type EnhancedMocksSuite struct {
|
||||
suite.Suite
|
||||
}
|
||||
|
||||
func (s *EnhancedMocksSuite) TestEnhancedJWKCacheCallTracking() {
|
||||
mock := &EnhancedMockJWKCache{
|
||||
JWKS: &JWKSet{Keys: []JWK{{Kid: "test-key"}}},
|
||||
}
|
||||
|
||||
// Make some calls
|
||||
result, err := mock.GetJWKS(context.Background(), "https://example.com/jwks", nil)
|
||||
s.NoError(err)
|
||||
s.NotNil(result)
|
||||
|
||||
// Another call with different URL
|
||||
_, _ = mock.GetJWKS(context.Background(), "https://other.com/jwks", nil)
|
||||
|
||||
// Verify calls were tracked
|
||||
s.Equal(2, mock.GetJWKSCallCount())
|
||||
mock.AssertGetJWKSCalled(s.T())
|
||||
mock.AssertGetJWKSCalledWith(s.T(), "https://example.com/jwks")
|
||||
mock.AssertGetJWKSCallCount(s.T(), 2)
|
||||
}
|
||||
|
||||
func (s *EnhancedMocksSuite) TestEnhancedJWKCacheWithError() {
|
||||
expectedErr := errors.New("network error")
|
||||
mock := &EnhancedMockJWKCache{
|
||||
Err: expectedErr,
|
||||
}
|
||||
|
||||
result, err := mock.GetJWKS(context.Background(), "https://example.com/jwks", nil)
|
||||
|
||||
s.Nil(result)
|
||||
s.Equal(expectedErr, err)
|
||||
mock.AssertGetJWKSCalled(s.T())
|
||||
}
|
||||
|
||||
func (s *EnhancedMocksSuite) TestEnhancedJWKCacheReset() {
|
||||
mock := &EnhancedMockJWKCache{
|
||||
JWKS: &JWKSet{Keys: []JWK{{Kid: "test-key"}}},
|
||||
}
|
||||
|
||||
_, _ = mock.GetJWKS(context.Background(), "https://example.com/jwks", nil)
|
||||
s.Equal(1, mock.GetJWKSCallCount())
|
||||
|
||||
mock.Reset()
|
||||
|
||||
s.Equal(0, mock.GetJWKSCallCount())
|
||||
s.Nil(mock.JWKS)
|
||||
}
|
||||
|
||||
func (s *EnhancedMocksSuite) TestEnhancedTokenVerifierCallTracking() {
|
||||
mock := &EnhancedMockTokenVerifier{
|
||||
Err: nil, // Valid tokens
|
||||
}
|
||||
|
||||
// Verify a token
|
||||
err := mock.VerifyToken("test-token-1")
|
||||
s.NoError(err)
|
||||
|
||||
// Verify another token
|
||||
err = mock.VerifyToken("test-token-2")
|
||||
s.NoError(err)
|
||||
|
||||
// Check tracking
|
||||
s.Equal(2, mock.GetVerifyTokenCallCount())
|
||||
mock.AssertVerifyTokenCalled(s.T())
|
||||
mock.AssertVerifyTokenCalledWith(s.T(), "test-token-1")
|
||||
|
||||
// Check last call
|
||||
lastCall := mock.LastCall()
|
||||
s.NotNil(lastCall)
|
||||
s.Equal("test-token-2", lastCall.Token)
|
||||
}
|
||||
|
||||
func (s *EnhancedMocksSuite) TestEnhancedTokenVerifierWithDynamicFunc() {
|
||||
callCount := 0
|
||||
mock := &EnhancedMockTokenVerifier{
|
||||
VerifyFunc: func(token string) error {
|
||||
callCount++
|
||||
if token == "invalid" {
|
||||
return errors.New("invalid token")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
// Valid token
|
||||
err := mock.VerifyToken("valid-token")
|
||||
s.NoError(err)
|
||||
|
||||
// Invalid token
|
||||
err = mock.VerifyToken("invalid")
|
||||
s.Error(err)
|
||||
|
||||
s.Equal(2, callCount)
|
||||
s.Equal(2, mock.GetVerifyTokenCallCount())
|
||||
}
|
||||
|
||||
func (s *EnhancedMocksSuite) TestEnhancedTokenExchangerCallTracking() {
|
||||
mock := &EnhancedMockTokenExchanger{
|
||||
ExchangeResponse: &TokenResponse{
|
||||
AccessToken: "access-token",
|
||||
RefreshToken: "refresh-token",
|
||||
ExpiresIn: 3600,
|
||||
},
|
||||
RefreshResponse: &TokenResponse{
|
||||
AccessToken: "new-access-token",
|
||||
ExpiresIn: 3600,
|
||||
},
|
||||
}
|
||||
|
||||
// Exchange code
|
||||
resp, err := mock.ExchangeCodeForToken(context.Background(), "authorization_code", "auth-code", "https://redirect.com", "verifier")
|
||||
s.NoError(err)
|
||||
s.Equal("access-token", resp.AccessToken)
|
||||
|
||||
// Refresh token
|
||||
resp, err = mock.GetNewTokenWithRefreshToken("refresh-token")
|
||||
s.NoError(err)
|
||||
s.Equal("new-access-token", resp.AccessToken)
|
||||
|
||||
// Revoke token
|
||||
err = mock.RevokeTokenWithProvider("access-token", "access_token")
|
||||
s.NoError(err)
|
||||
|
||||
// Check tracking
|
||||
mock.AssertExchangeCalled(s.T())
|
||||
mock.AssertExchangeCalledWith(s.T(), "authorization_code")
|
||||
mock.AssertRefreshCalled(s.T())
|
||||
mock.AssertRevokeCalled(s.T())
|
||||
|
||||
s.Equal(1, mock.GetExchangeCallCount())
|
||||
s.Equal(1, mock.GetRefreshCallCount())
|
||||
s.Equal(1, mock.GetRevokeCallCount())
|
||||
|
||||
// Check last exchange call details
|
||||
lastExchange := mock.LastExchangeCall()
|
||||
s.NotNil(lastExchange)
|
||||
s.Equal("authorization_code", lastExchange.GrantType)
|
||||
s.Equal("auth-code", lastExchange.CodeOrToken)
|
||||
s.Equal("https://redirect.com", lastExchange.RedirectURL)
|
||||
}
|
||||
|
||||
func (s *EnhancedMocksSuite) TestEnhancedTokenExchangerWithErrors() {
|
||||
mock := &EnhancedMockTokenExchanger{
|
||||
ExchangeErr: errors.New("invalid_grant"),
|
||||
RefreshErr: errors.New("refresh_expired"),
|
||||
RevokeErr: errors.New("revoke_failed"),
|
||||
}
|
||||
|
||||
_, err := mock.ExchangeCodeForToken(context.Background(), "authorization_code", "code", "", "")
|
||||
s.Error(err)
|
||||
s.Contains(err.Error(), "invalid_grant")
|
||||
|
||||
_, err = mock.GetNewTokenWithRefreshToken("token")
|
||||
s.Error(err)
|
||||
s.Contains(err.Error(), "refresh_expired")
|
||||
|
||||
err = mock.RevokeTokenWithProvider("token", "access_token")
|
||||
s.Error(err)
|
||||
s.Contains(err.Error(), "revoke_failed")
|
||||
}
|
||||
|
||||
func (s *EnhancedMocksSuite) TestEnhancedCacheCallTracking() {
|
||||
mock := NewEnhancedMockCache()
|
||||
|
||||
// Set some values
|
||||
mock.Set("key1", "value1", 5*time.Minute)
|
||||
mock.Set("key2", "value2", 10*time.Minute)
|
||||
|
||||
// Get values
|
||||
val, found := mock.Get("key1")
|
||||
s.True(found)
|
||||
s.Equal("value1", val)
|
||||
|
||||
_, found = mock.Get("nonexistent")
|
||||
s.False(found)
|
||||
|
||||
// Delete
|
||||
mock.Delete("key1")
|
||||
|
||||
// Verify tracking
|
||||
mock.AssertSetCalled(s.T(), "key1")
|
||||
mock.AssertSetCalled(s.T(), "key2")
|
||||
mock.AssertGetCalled(s.T(), "key1")
|
||||
mock.AssertGetCalled(s.T(), "nonexistent")
|
||||
mock.AssertDeleteCalled(s.T(), "key1")
|
||||
|
||||
s.Equal(2, mock.SetCallCount())
|
||||
s.Equal(2, mock.GetCallCount())
|
||||
}
|
||||
|
||||
func (s *EnhancedMocksSuite) TestEnhancedCacheActualStorage() {
|
||||
mock := NewEnhancedMockCache()
|
||||
|
||||
// The enhanced mock actually stores data
|
||||
mock.Set("key", "value", time.Hour)
|
||||
s.Equal(1, mock.Size())
|
||||
|
||||
val, found := mock.Get("key")
|
||||
s.True(found)
|
||||
s.Equal("value", val)
|
||||
|
||||
mock.Delete("key")
|
||||
s.Equal(0, mock.Size())
|
||||
|
||||
_, found = mock.Get("key")
|
||||
s.False(found)
|
||||
}
|
||||
|
||||
func (s *EnhancedMocksSuite) TestEnhancedCacheClear() {
|
||||
mock := NewEnhancedMockCache()
|
||||
|
||||
mock.Set("key1", "value1", time.Hour)
|
||||
mock.Set("key2", "value2", time.Hour)
|
||||
s.Equal(2, mock.Size())
|
||||
|
||||
mock.Clear()
|
||||
s.Equal(0, mock.Size())
|
||||
}
|
||||
|
||||
func (s *EnhancedMocksSuite) TestConcurrentAccess() {
|
||||
mock := &EnhancedMockJWKCache{
|
||||
JWKS: &JWKSet{Keys: []JWK{{Kid: "test-key"}}},
|
||||
}
|
||||
|
||||
// Concurrent calls should be safe
|
||||
done := make(chan bool)
|
||||
for i := 0; i < 10; i++ {
|
||||
go func() {
|
||||
_, _ = mock.GetJWKS(context.Background(), "https://example.com/jwks", nil)
|
||||
done <- true
|
||||
}()
|
||||
}
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
<-done
|
||||
}
|
||||
|
||||
s.Equal(10, mock.GetJWKSCallCount())
|
||||
}
|
||||
|
||||
func TestEnhancedMocksSuite(t *testing.T) {
|
||||
suite.Run(t, new(EnhancedMocksSuite))
|
||||
}
|
||||
@@ -0,0 +1,577 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// EnhancedMockJWKCache is an improved state-based mock with call tracking
|
||||
type EnhancedMockJWKCache struct {
|
||||
Err error
|
||||
JWKS *JWKSet
|
||||
GetJWKSCalls []JWKSCall
|
||||
mu sync.RWMutex
|
||||
getJWKSCallsMu sync.Mutex
|
||||
CleanupCalls int32
|
||||
CloseCalls int32
|
||||
}
|
||||
|
||||
// JWKSCall records parameters from a GetJWKS call
|
||||
type JWKSCall struct {
|
||||
Timestamp time.Time
|
||||
URL string
|
||||
}
|
||||
|
||||
func (m *EnhancedMockJWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error) {
|
||||
m.getJWKSCallsMu.Lock()
|
||||
m.GetJWKSCalls = append(m.GetJWKSCalls, JWKSCall{
|
||||
URL: jwksURL,
|
||||
Timestamp: time.Now(),
|
||||
})
|
||||
m.getJWKSCallsMu.Unlock()
|
||||
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.JWKS, m.Err
|
||||
}
|
||||
|
||||
func (m *EnhancedMockJWKCache) Cleanup() {
|
||||
atomic.AddInt32(&m.CleanupCalls, 1)
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.JWKS = nil
|
||||
m.Err = nil
|
||||
}
|
||||
|
||||
func (m *EnhancedMockJWKCache) Close() {
|
||||
atomic.AddInt32(&m.CloseCalls, 1)
|
||||
}
|
||||
|
||||
// Assertion helpers
|
||||
|
||||
// AssertGetJWKSCalled verifies GetJWKS was called
|
||||
func (m *EnhancedMockJWKCache) AssertGetJWKSCalled(t assert.TestingT) bool {
|
||||
m.getJWKSCallsMu.Lock()
|
||||
defer m.getJWKSCallsMu.Unlock()
|
||||
return assert.NotEmpty(t, m.GetJWKSCalls, "GetJWKS should have been called")
|
||||
}
|
||||
|
||||
// AssertGetJWKSCalledWith verifies GetJWKS was called with specific URL
|
||||
func (m *EnhancedMockJWKCache) AssertGetJWKSCalledWith(t assert.TestingT, expectedURL string) bool {
|
||||
m.getJWKSCallsMu.Lock()
|
||||
defer m.getJWKSCallsMu.Unlock()
|
||||
for _, call := range m.GetJWKSCalls {
|
||||
if call.URL == expectedURL {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return assert.Fail(t, "GetJWKS was not called with URL: "+expectedURL)
|
||||
}
|
||||
|
||||
// AssertGetJWKSCallCount verifies the number of GetJWKS calls
|
||||
func (m *EnhancedMockJWKCache) AssertGetJWKSCallCount(t assert.TestingT, expected int) bool {
|
||||
m.getJWKSCallsMu.Lock()
|
||||
defer m.getJWKSCallsMu.Unlock()
|
||||
return assert.Equal(t, expected, len(m.GetJWKSCalls), "GetJWKS call count mismatch")
|
||||
}
|
||||
|
||||
// GetJWKSCallCount returns the number of GetJWKS calls
|
||||
func (m *EnhancedMockJWKCache) GetJWKSCallCount() int {
|
||||
m.getJWKSCallsMu.Lock()
|
||||
defer m.getJWKSCallsMu.Unlock()
|
||||
return len(m.GetJWKSCalls)
|
||||
}
|
||||
|
||||
// Reset clears all state and call tracking
|
||||
func (m *EnhancedMockJWKCache) Reset() {
|
||||
m.mu.Lock()
|
||||
m.JWKS = nil
|
||||
m.Err = nil
|
||||
m.mu.Unlock()
|
||||
|
||||
m.getJWKSCallsMu.Lock()
|
||||
m.GetJWKSCalls = nil
|
||||
m.getJWKSCallsMu.Unlock()
|
||||
|
||||
atomic.StoreInt32(&m.CleanupCalls, 0)
|
||||
atomic.StoreInt32(&m.CloseCalls, 0)
|
||||
}
|
||||
|
||||
// EnhancedMockTokenVerifier is an improved state-based mock with call tracking
|
||||
type EnhancedMockTokenVerifier struct {
|
||||
Err error
|
||||
VerifyFunc func(token string) error
|
||||
VerifyCalls []TokenVerifyCall
|
||||
mu sync.RWMutex
|
||||
verifyCallsMu sync.Mutex
|
||||
}
|
||||
|
||||
// TokenVerifyCall records parameters from a VerifyToken call
|
||||
type TokenVerifyCall struct {
|
||||
Timestamp time.Time
|
||||
Result error
|
||||
Token string
|
||||
}
|
||||
|
||||
func (m *EnhancedMockTokenVerifier) VerifyToken(token string) error {
|
||||
var result error
|
||||
|
||||
m.mu.RLock()
|
||||
if m.VerifyFunc != nil {
|
||||
result = m.VerifyFunc(token)
|
||||
} else {
|
||||
result = m.Err
|
||||
}
|
||||
m.mu.RUnlock()
|
||||
|
||||
m.verifyCallsMu.Lock()
|
||||
m.VerifyCalls = append(m.VerifyCalls, TokenVerifyCall{
|
||||
Token: token,
|
||||
Timestamp: time.Now(),
|
||||
Result: result,
|
||||
})
|
||||
m.verifyCallsMu.Unlock()
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// Assertion helpers
|
||||
|
||||
// AssertVerifyTokenCalled verifies VerifyToken was called
|
||||
func (m *EnhancedMockTokenVerifier) AssertVerifyTokenCalled(t assert.TestingT) bool {
|
||||
m.verifyCallsMu.Lock()
|
||||
defer m.verifyCallsMu.Unlock()
|
||||
return assert.NotEmpty(t, m.VerifyCalls, "VerifyToken should have been called")
|
||||
}
|
||||
|
||||
// AssertVerifyTokenCalledWith verifies VerifyToken was called with specific token
|
||||
func (m *EnhancedMockTokenVerifier) AssertVerifyTokenCalledWith(t assert.TestingT, expectedToken string) bool {
|
||||
m.verifyCallsMu.Lock()
|
||||
defer m.verifyCallsMu.Unlock()
|
||||
for _, call := range m.VerifyCalls {
|
||||
if call.Token == expectedToken {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return assert.Fail(t, "VerifyToken was not called with expected token")
|
||||
}
|
||||
|
||||
// AssertVerifyTokenCallCount verifies the number of VerifyToken calls
|
||||
func (m *EnhancedMockTokenVerifier) AssertVerifyTokenCallCount(t assert.TestingT, expected int) bool {
|
||||
m.verifyCallsMu.Lock()
|
||||
defer m.verifyCallsMu.Unlock()
|
||||
return assert.Equal(t, expected, len(m.VerifyCalls), "VerifyToken call count mismatch")
|
||||
}
|
||||
|
||||
// GetVerifyTokenCallCount returns the number of VerifyToken calls
|
||||
func (m *EnhancedMockTokenVerifier) GetVerifyTokenCallCount() int {
|
||||
m.verifyCallsMu.Lock()
|
||||
defer m.verifyCallsMu.Unlock()
|
||||
return len(m.VerifyCalls)
|
||||
}
|
||||
|
||||
// LastCall returns the most recent VerifyToken call
|
||||
func (m *EnhancedMockTokenVerifier) LastCall() *TokenVerifyCall {
|
||||
m.verifyCallsMu.Lock()
|
||||
defer m.verifyCallsMu.Unlock()
|
||||
if len(m.VerifyCalls) == 0 {
|
||||
return nil
|
||||
}
|
||||
return &m.VerifyCalls[len(m.VerifyCalls)-1]
|
||||
}
|
||||
|
||||
// Reset clears all state and call tracking
|
||||
func (m *EnhancedMockTokenVerifier) Reset() {
|
||||
m.mu.Lock()
|
||||
m.Err = nil
|
||||
m.VerifyFunc = nil
|
||||
m.mu.Unlock()
|
||||
|
||||
m.verifyCallsMu.Lock()
|
||||
m.VerifyCalls = nil
|
||||
m.verifyCallsMu.Unlock()
|
||||
}
|
||||
|
||||
// EnhancedMockTokenExchanger is an improved state-based mock with call tracking
|
||||
type EnhancedMockTokenExchanger struct {
|
||||
RefreshErr error
|
||||
RevokeErr error
|
||||
ExchangeErr error
|
||||
ExchangeCodeFunc func(ctx context.Context, grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error)
|
||||
RefreshResponse *TokenResponse
|
||||
ExchangeResponse *TokenResponse
|
||||
RefreshTokenFunc func(refreshToken string) (*TokenResponse, error)
|
||||
RevokeTokenFunc func(token, tokenType string) error
|
||||
ExchangeCalls []ExchangeCall
|
||||
RefreshCalls []RefreshCall
|
||||
RevokeCalls []RevokeCall
|
||||
mu sync.RWMutex
|
||||
exchangeCallsMu sync.Mutex
|
||||
refreshCallsMu sync.Mutex
|
||||
revokeCallsMu sync.Mutex
|
||||
}
|
||||
|
||||
// ExchangeCall records parameters from an ExchangeCodeForToken call
|
||||
type ExchangeCall struct {
|
||||
Timestamp time.Time
|
||||
GrantType string
|
||||
CodeOrToken string
|
||||
RedirectURL string
|
||||
CodeVerifier string
|
||||
}
|
||||
|
||||
// RefreshCall records parameters from a GetNewTokenWithRefreshToken call
|
||||
type RefreshCall struct {
|
||||
Timestamp time.Time
|
||||
RefreshToken string
|
||||
}
|
||||
|
||||
// RevokeCall records parameters from a RevokeTokenWithProvider call
|
||||
type RevokeCall struct {
|
||||
Timestamp time.Time
|
||||
Token string
|
||||
TokenType string
|
||||
}
|
||||
|
||||
func (m *EnhancedMockTokenExchanger) ExchangeCodeForToken(ctx context.Context, grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error) {
|
||||
m.exchangeCallsMu.Lock()
|
||||
m.ExchangeCalls = append(m.ExchangeCalls, ExchangeCall{
|
||||
GrantType: grantType,
|
||||
CodeOrToken: codeOrToken,
|
||||
RedirectURL: redirectURL,
|
||||
CodeVerifier: codeVerifier,
|
||||
Timestamp: time.Now(),
|
||||
})
|
||||
m.exchangeCallsMu.Unlock()
|
||||
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
if m.ExchangeCodeFunc != nil {
|
||||
return m.ExchangeCodeFunc(ctx, grantType, codeOrToken, redirectURL, codeVerifier)
|
||||
}
|
||||
return m.ExchangeResponse, m.ExchangeErr
|
||||
}
|
||||
|
||||
func (m *EnhancedMockTokenExchanger) GetNewTokenWithRefreshToken(refreshToken string) (*TokenResponse, error) {
|
||||
m.refreshCallsMu.Lock()
|
||||
m.RefreshCalls = append(m.RefreshCalls, RefreshCall{
|
||||
RefreshToken: refreshToken,
|
||||
Timestamp: time.Now(),
|
||||
})
|
||||
m.refreshCallsMu.Unlock()
|
||||
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
if m.RefreshTokenFunc != nil {
|
||||
return m.RefreshTokenFunc(refreshToken)
|
||||
}
|
||||
return m.RefreshResponse, m.RefreshErr
|
||||
}
|
||||
|
||||
func (m *EnhancedMockTokenExchanger) RevokeTokenWithProvider(token, tokenType string) error {
|
||||
m.revokeCallsMu.Lock()
|
||||
m.RevokeCalls = append(m.RevokeCalls, RevokeCall{
|
||||
Token: token,
|
||||
TokenType: tokenType,
|
||||
Timestamp: time.Now(),
|
||||
})
|
||||
m.revokeCallsMu.Unlock()
|
||||
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
if m.RevokeTokenFunc != nil {
|
||||
return m.RevokeTokenFunc(token, tokenType)
|
||||
}
|
||||
return m.RevokeErr
|
||||
}
|
||||
|
||||
// Assertion helpers
|
||||
|
||||
// AssertExchangeCalled verifies ExchangeCodeForToken was called
|
||||
func (m *EnhancedMockTokenExchanger) AssertExchangeCalled(t assert.TestingT) bool {
|
||||
m.exchangeCallsMu.Lock()
|
||||
defer m.exchangeCallsMu.Unlock()
|
||||
return assert.NotEmpty(t, m.ExchangeCalls, "ExchangeCodeForToken should have been called")
|
||||
}
|
||||
|
||||
// AssertExchangeCalledWith verifies ExchangeCodeForToken was called with specific grant type
|
||||
func (m *EnhancedMockTokenExchanger) AssertExchangeCalledWith(t assert.TestingT, grantType string) bool {
|
||||
m.exchangeCallsMu.Lock()
|
||||
defer m.exchangeCallsMu.Unlock()
|
||||
for _, call := range m.ExchangeCalls {
|
||||
if call.GrantType == grantType {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return assert.Fail(t, "ExchangeCodeForToken was not called with grant type: "+grantType)
|
||||
}
|
||||
|
||||
// AssertRefreshCalled verifies GetNewTokenWithRefreshToken was called
|
||||
func (m *EnhancedMockTokenExchanger) AssertRefreshCalled(t assert.TestingT) bool {
|
||||
m.refreshCallsMu.Lock()
|
||||
defer m.refreshCallsMu.Unlock()
|
||||
return assert.NotEmpty(t, m.RefreshCalls, "GetNewTokenWithRefreshToken should have been called")
|
||||
}
|
||||
|
||||
// AssertRevokeCalled verifies RevokeTokenWithProvider was called
|
||||
func (m *EnhancedMockTokenExchanger) AssertRevokeCalled(t assert.TestingT) bool {
|
||||
m.revokeCallsMu.Lock()
|
||||
defer m.revokeCallsMu.Unlock()
|
||||
return assert.NotEmpty(t, m.RevokeCalls, "RevokeTokenWithProvider should have been called")
|
||||
}
|
||||
|
||||
// GetExchangeCallCount returns the number of ExchangeCodeForToken calls
|
||||
func (m *EnhancedMockTokenExchanger) GetExchangeCallCount() int {
|
||||
m.exchangeCallsMu.Lock()
|
||||
defer m.exchangeCallsMu.Unlock()
|
||||
return len(m.ExchangeCalls)
|
||||
}
|
||||
|
||||
// GetRefreshCallCount returns the number of GetNewTokenWithRefreshToken calls
|
||||
func (m *EnhancedMockTokenExchanger) GetRefreshCallCount() int {
|
||||
m.refreshCallsMu.Lock()
|
||||
defer m.refreshCallsMu.Unlock()
|
||||
return len(m.RefreshCalls)
|
||||
}
|
||||
|
||||
// GetRevokeCallCount returns the number of RevokeTokenWithProvider calls
|
||||
func (m *EnhancedMockTokenExchanger) GetRevokeCallCount() int {
|
||||
m.revokeCallsMu.Lock()
|
||||
defer m.revokeCallsMu.Unlock()
|
||||
return len(m.RevokeCalls)
|
||||
}
|
||||
|
||||
// LastExchangeCall returns the most recent ExchangeCodeForToken call
|
||||
func (m *EnhancedMockTokenExchanger) LastExchangeCall() *ExchangeCall {
|
||||
m.exchangeCallsMu.Lock()
|
||||
defer m.exchangeCallsMu.Unlock()
|
||||
if len(m.ExchangeCalls) == 0 {
|
||||
return nil
|
||||
}
|
||||
return &m.ExchangeCalls[len(m.ExchangeCalls)-1]
|
||||
}
|
||||
|
||||
// Reset clears all state and call tracking
|
||||
func (m *EnhancedMockTokenExchanger) Reset() {
|
||||
m.mu.Lock()
|
||||
m.ExchangeResponse = nil
|
||||
m.ExchangeErr = nil
|
||||
m.RefreshResponse = nil
|
||||
m.RefreshErr = nil
|
||||
m.RevokeErr = nil
|
||||
m.ExchangeCodeFunc = nil
|
||||
m.RefreshTokenFunc = nil
|
||||
m.RevokeTokenFunc = nil
|
||||
m.mu.Unlock()
|
||||
|
||||
m.exchangeCallsMu.Lock()
|
||||
m.ExchangeCalls = nil
|
||||
m.exchangeCallsMu.Unlock()
|
||||
|
||||
m.refreshCallsMu.Lock()
|
||||
m.RefreshCalls = nil
|
||||
m.refreshCallsMu.Unlock()
|
||||
|
||||
m.revokeCallsMu.Lock()
|
||||
m.RevokeCalls = nil
|
||||
m.revokeCallsMu.Unlock()
|
||||
}
|
||||
|
||||
// EnhancedMockCacheInterface is an improved state-based mock for CacheInterface
|
||||
type EnhancedMockCacheInterface struct {
|
||||
data map[string]cacheEntry
|
||||
GetCalls []CacheGetCall
|
||||
SetCalls []CacheSetCall
|
||||
DeleteCalls []string
|
||||
maxSize int
|
||||
mu sync.RWMutex
|
||||
getCalls sync.Mutex
|
||||
setCalls sync.Mutex
|
||||
deleteCalls sync.Mutex
|
||||
}
|
||||
|
||||
type cacheEntry struct {
|
||||
value any
|
||||
ttl time.Duration
|
||||
}
|
||||
|
||||
// CacheGetCall records parameters from a Get call
|
||||
type CacheGetCall struct {
|
||||
Timestamp time.Time
|
||||
Key string
|
||||
Found bool
|
||||
}
|
||||
|
||||
// CacheSetCall records parameters from a Set call
|
||||
type CacheSetCall struct {
|
||||
Timestamp time.Time
|
||||
Value any
|
||||
Key string
|
||||
TTL time.Duration
|
||||
}
|
||||
|
||||
// NewEnhancedMockCache creates a new enhanced cache mock
|
||||
func NewEnhancedMockCache() *EnhancedMockCacheInterface {
|
||||
return &EnhancedMockCacheInterface{
|
||||
data: make(map[string]cacheEntry),
|
||||
maxSize: 1000,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *EnhancedMockCacheInterface) Set(key string, value any, ttl time.Duration) {
|
||||
m.setCalls.Lock()
|
||||
m.SetCalls = append(m.SetCalls, CacheSetCall{
|
||||
Key: key,
|
||||
Value: value,
|
||||
TTL: ttl,
|
||||
Timestamp: time.Now(),
|
||||
})
|
||||
m.setCalls.Unlock()
|
||||
|
||||
m.mu.Lock()
|
||||
m.data[key] = cacheEntry{value: value, ttl: ttl}
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
func (m *EnhancedMockCacheInterface) Get(key string) (any, bool) {
|
||||
m.mu.RLock()
|
||||
entry, found := m.data[key]
|
||||
m.mu.RUnlock()
|
||||
|
||||
m.getCalls.Lock()
|
||||
m.GetCalls = append(m.GetCalls, CacheGetCall{
|
||||
Key: key,
|
||||
Found: found,
|
||||
Timestamp: time.Now(),
|
||||
})
|
||||
m.getCalls.Unlock()
|
||||
|
||||
if found {
|
||||
return entry.value, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (m *EnhancedMockCacheInterface) Delete(key string) {
|
||||
m.deleteCalls.Lock()
|
||||
m.DeleteCalls = append(m.DeleteCalls, key)
|
||||
m.deleteCalls.Unlock()
|
||||
|
||||
m.mu.Lock()
|
||||
delete(m.data, key)
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
func (m *EnhancedMockCacheInterface) SetMaxSize(size int) {
|
||||
m.mu.Lock()
|
||||
m.maxSize = size
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
func (m *EnhancedMockCacheInterface) Size() int {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return len(m.data)
|
||||
}
|
||||
|
||||
func (m *EnhancedMockCacheInterface) Clear() {
|
||||
m.mu.Lock()
|
||||
m.data = make(map[string]cacheEntry)
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
func (m *EnhancedMockCacheInterface) Cleanup() {
|
||||
// No-op for mock
|
||||
}
|
||||
|
||||
func (m *EnhancedMockCacheInterface) Close() {
|
||||
// No-op for mock
|
||||
}
|
||||
|
||||
func (m *EnhancedMockCacheInterface) GetStats() map[string]any {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return map[string]any{
|
||||
"size": len(m.data),
|
||||
"max_size": m.maxSize,
|
||||
}
|
||||
}
|
||||
|
||||
// Assertion helpers
|
||||
|
||||
// AssertGetCalled verifies Get was called with specific key
|
||||
func (m *EnhancedMockCacheInterface) AssertGetCalled(t assert.TestingT, key string) bool {
|
||||
m.getCalls.Lock()
|
||||
defer m.getCalls.Unlock()
|
||||
for _, call := range m.GetCalls {
|
||||
if call.Key == key {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return assert.Fail(t, "Get was not called with key: "+key)
|
||||
}
|
||||
|
||||
// AssertSetCalled verifies Set was called with specific key
|
||||
func (m *EnhancedMockCacheInterface) AssertSetCalled(t assert.TestingT, key string) bool {
|
||||
m.setCalls.Lock()
|
||||
defer m.setCalls.Unlock()
|
||||
for _, call := range m.SetCalls {
|
||||
if call.Key == key {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return assert.Fail(t, "Set was not called with key: "+key)
|
||||
}
|
||||
|
||||
// AssertDeleteCalled verifies Delete was called with specific key
|
||||
func (m *EnhancedMockCacheInterface) AssertDeleteCalled(t assert.TestingT, key string) bool {
|
||||
m.deleteCalls.Lock()
|
||||
defer m.deleteCalls.Unlock()
|
||||
for _, k := range m.DeleteCalls {
|
||||
if k == key {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return assert.Fail(t, "Delete was not called with key: "+key)
|
||||
}
|
||||
|
||||
// GetCallCount returns the number of Get calls
|
||||
func (m *EnhancedMockCacheInterface) GetCallCount() int {
|
||||
m.getCalls.Lock()
|
||||
defer m.getCalls.Unlock()
|
||||
return len(m.GetCalls)
|
||||
}
|
||||
|
||||
// SetCallCount returns the number of Set calls
|
||||
func (m *EnhancedMockCacheInterface) SetCallCount() int {
|
||||
m.setCalls.Lock()
|
||||
defer m.setCalls.Unlock()
|
||||
return len(m.SetCalls)
|
||||
}
|
||||
|
||||
// Reset clears all state and call tracking
|
||||
func (m *EnhancedMockCacheInterface) Reset() {
|
||||
m.mu.Lock()
|
||||
m.data = make(map[string]cacheEntry)
|
||||
m.mu.Unlock()
|
||||
|
||||
m.getCalls.Lock()
|
||||
m.GetCalls = nil
|
||||
m.getCalls.Unlock()
|
||||
|
||||
m.setCalls.Lock()
|
||||
m.SetCalls = nil
|
||||
m.setCalls.Unlock()
|
||||
|
||||
m.deleteCalls.Lock()
|
||||
m.DeleteCalls = nil
|
||||
m.deleteCalls.Unlock()
|
||||
}
|
||||
+150
-38
@@ -2,10 +2,14 @@ package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"math/rand/v2"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -123,8 +127,10 @@ func (b *BaseRecoveryMechanism) GetBaseMetrics() map[string]interface{} {
|
||||
metrics["seconds_since_last_success"] = time.Since(b.lastSuccessTime).Seconds()
|
||||
}
|
||||
|
||||
if metrics["total_requests"].(int64) > 0 {
|
||||
successRate := float64(metrics["total_successes"].(int64)) / float64(metrics["total_requests"].(int64))
|
||||
totalReq, _ := metrics["total_requests"].(int64) // Safe to ignore: type assertion with fallback
|
||||
totalSucc, _ := metrics["total_successes"].(int64) // Safe to ignore: type assertion with fallback
|
||||
if totalReq > 0 {
|
||||
successRate := float64(totalSucc) / float64(totalReq)
|
||||
metrics["success_rate"] = successRate
|
||||
} else {
|
||||
metrics["success_rate"] = 1.0
|
||||
@@ -409,6 +415,31 @@ func DefaultRetryConfig() RetryConfig {
|
||||
}
|
||||
}
|
||||
|
||||
// MetadataFetchRetryConfig returns retry configuration optimized for OIDC metadata
|
||||
// fetching during startup. Uses more aggressive retry settings to handle the race
|
||||
// condition where Traefik initializes the plugin before routes are fully established,
|
||||
// or before TLS certificates are properly loaded.
|
||||
// See: https://github.com/lukaszraczylo/traefikoidc/issues/90
|
||||
func MetadataFetchRetryConfig() RetryConfig {
|
||||
return RetryConfig{
|
||||
MaxAttempts: 10, // More attempts for startup scenarios
|
||||
InitialDelay: 1 * time.Second, // 1 second between attempts as suggested
|
||||
MaxDelay: 10 * time.Second, // Cap at 10 seconds
|
||||
BackoffFactor: 1.5, // Gentler backoff for startup
|
||||
EnableJitter: true, // Prevent thundering herd
|
||||
RetryableErrors: []string{
|
||||
"connection refused",
|
||||
"timeout",
|
||||
"temporary failure",
|
||||
"network unreachable",
|
||||
"EOF",
|
||||
"certificate",
|
||||
"x509",
|
||||
"tls",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// RetryExecutor implements retry logic with exponential backoff and jitter.
|
||||
// It automatically retries failed operations based on configurable error patterns
|
||||
// and uses exponential backoff to avoid overwhelming failing services.
|
||||
@@ -485,11 +516,29 @@ func (re *RetryExecutor) Execute(ctx context.Context, fn func() error) error {
|
||||
// isRetryableError checks if an error should trigger a retry
|
||||
// isRetryableError determines if an error should trigger a retry attempt.
|
||||
// Checks error message against configured retryable error patterns.
|
||||
// Also handles startup-specific errors like Traefik default certificate errors
|
||||
// and EOF errors that occur during service initialization.
|
||||
func (re *RetryExecutor) isRetryableError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for Traefik default certificate error (startup race condition)
|
||||
// See: https://github.com/lukaszraczylo/traefikoidc/issues/90
|
||||
if isTraefikDefaultCertError(err) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for EOF errors (common during startup when services aren't ready)
|
||||
if isEOFError(err) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for certificate errors (transient during startup)
|
||||
if isCertificateError(err) {
|
||||
return true
|
||||
}
|
||||
|
||||
errStr := err.Error()
|
||||
|
||||
for _, retryableErr := range re.config.RetryableErrors {
|
||||
@@ -536,6 +585,7 @@ func (re *RetryExecutor) calculateDelay(attempt int) time.Duration {
|
||||
delay = float64(re.config.MaxDelay)
|
||||
}
|
||||
|
||||
// #nosec G404 -- math/rand is acceptable for jitter timing, not security-sensitive
|
||||
if re.config.EnableJitter {
|
||||
jitter := delay * 0.1 * (2.0*rand.Float64() - 1.0)
|
||||
delay += jitter
|
||||
@@ -592,14 +642,10 @@ func (e *HTTPError) Error() string {
|
||||
// OIDCError represents OIDC-specific errors with context information.
|
||||
// It provides structured error reporting for authentication and authorization failures.
|
||||
type OIDCError struct {
|
||||
// Code identifies the specific error type
|
||||
Code string
|
||||
// Message provides a human-readable description
|
||||
Message string
|
||||
// Context contains additional error context (e.g., provider, session details)
|
||||
Cause error
|
||||
Context map[string]interface{}
|
||||
// Cause is the underlying error that caused this error
|
||||
Cause error
|
||||
Code string
|
||||
Message string
|
||||
}
|
||||
|
||||
// Error returns the string representation of the OIDC error.
|
||||
@@ -619,14 +665,10 @@ func (e *OIDCError) Unwrap() error {
|
||||
// SessionError represents session-related errors with context.
|
||||
// Used for session management, validation, and storage errors.
|
||||
type SessionError struct {
|
||||
// Operation describes what session operation failed
|
||||
Cause error
|
||||
Operation string
|
||||
// Message provides a human-readable description
|
||||
Message string
|
||||
// SessionID identifies the session (if available)
|
||||
Message string
|
||||
SessionID string
|
||||
// Cause is the underlying error that caused this error
|
||||
Cause error
|
||||
}
|
||||
|
||||
// Error returns the string representation of the session error.
|
||||
@@ -646,14 +688,10 @@ func (e *SessionError) Unwrap() error {
|
||||
// TokenError represents token-related errors with validation context.
|
||||
// Used for JWT validation, token refresh, and token format errors.
|
||||
type TokenError struct {
|
||||
// TokenType identifies the type of token (id_token, access_token, refresh_token)
|
||||
Cause error
|
||||
TokenType string
|
||||
// Reason describes why the token is invalid
|
||||
Reason string
|
||||
// Message provides a human-readable description
|
||||
Message string
|
||||
// Cause is the underlying error that caused this error
|
||||
Cause error
|
||||
Reason string
|
||||
Message string
|
||||
}
|
||||
|
||||
// Error returns the string representation of the token error.
|
||||
@@ -715,24 +753,15 @@ func NewTokenError(tokenType, reason, message string, cause error) *TokenError {
|
||||
// It provides fallback mechanisms when primary services are unavailable and monitors
|
||||
// service health to automatically recover when services become available again.
|
||||
type GracefulDegradation struct {
|
||||
// BaseRecoveryMechanism provides common functionality
|
||||
*BaseRecoveryMechanism
|
||||
// fallbacks stores service-specific fallback implementations
|
||||
fallbacks map[string]func() (interface{}, error)
|
||||
// healthChecks stores service health check functions
|
||||
healthChecks map[string]func() bool
|
||||
// degradedServices tracks which services are currently degraded
|
||||
fallbacks map[string]func() (interface{}, error)
|
||||
healthChecks map[string]func() bool
|
||||
degradedServices map[string]time.Time
|
||||
// config contains graceful degradation configuration
|
||||
config GracefulDegradationConfig
|
||||
// mutex protects shared state
|
||||
mutex sync.RWMutex
|
||||
// healthCheckTask manages background health checking
|
||||
healthCheckTask *BackgroundTask
|
||||
// stopChan signals shutdown
|
||||
stopChan chan struct{}
|
||||
// shutdownOnce ensures shutdown happens only once
|
||||
shutdownOnce sync.Once
|
||||
healthCheckTask *BackgroundTask
|
||||
stopChan chan struct{}
|
||||
config GracefulDegradationConfig
|
||||
mutex sync.RWMutex
|
||||
shutdownOnce sync.Once
|
||||
}
|
||||
|
||||
// GracefulDegradationConfig holds configuration for graceful degradation behavior.
|
||||
@@ -1085,3 +1114,86 @@ func containsSubstring(s, substr string) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isTraefikDefaultCertError detects when Traefik is serving its default self-signed
|
||||
// certificate during cold-start, before the real certificates are loaded.
|
||||
// This manifests as an x509.HostnameError where one of the certificate's DNS names
|
||||
// ends with "traefik.default" (the default Traefik certificate pattern).
|
||||
// See: https://github.com/lukaszraczylo/traefikoidc/issues/90
|
||||
func isTraefikDefaultCertError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
var hostnameErr x509.HostnameError
|
||||
if errors.As(err, &hostnameErr) {
|
||||
if hostnameErr.Certificate != nil {
|
||||
for _, name := range hostnameErr.Certificate.DNSNames {
|
||||
if strings.HasSuffix(name, "traefik.default") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// isEOFError checks if an error is an EOF error, which can occur during
|
||||
// connection establishment when the remote end closes unexpectedly.
|
||||
// This is common during service startup when endpoints aren't fully ready.
|
||||
func isEOFError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for direct EOF
|
||||
if errors.Is(err, io.EOF) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for unexpected EOF
|
||||
if errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check error message for EOF patterns (wrapped errors)
|
||||
errStr := err.Error()
|
||||
return strings.Contains(errStr, "EOF") || strings.Contains(errStr, "unexpected EOF")
|
||||
}
|
||||
|
||||
// isCertificateError checks if an error is related to TLS certificate validation.
|
||||
// These errors are often transient during startup when services are still initializing.
|
||||
func isCertificateError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for x509 certificate errors
|
||||
var certInvalidErr x509.CertificateInvalidError
|
||||
var hostnameErr x509.HostnameError
|
||||
var unknownAuthErr x509.UnknownAuthorityError
|
||||
|
||||
if errors.As(err, &certInvalidErr) ||
|
||||
errors.As(err, &hostnameErr) ||
|
||||
errors.As(err, &unknownAuthErr) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check error message for certificate patterns
|
||||
errStr := strings.ToLower(err.Error())
|
||||
certPatterns := []string{
|
||||
"certificate",
|
||||
"x509",
|
||||
"tls",
|
||||
"ssl",
|
||||
}
|
||||
|
||||
for _, pattern := range certPatterns {
|
||||
if strings.Contains(errStr, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -1,242 +0,0 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestDefaultCircuitBreakerConfig tests the default configuration function
|
||||
func TestDefaultCircuitBreakerConfig(t *testing.T) {
|
||||
config := DefaultCircuitBreakerConfig()
|
||||
|
||||
// Test default values
|
||||
if config.MaxFailures != 2 {
|
||||
t.Errorf("Expected MaxFailures 2, got %d", config.MaxFailures)
|
||||
}
|
||||
|
||||
if config.Timeout != 60*time.Second {
|
||||
t.Errorf("Expected Timeout 60s, got %v", config.Timeout)
|
||||
}
|
||||
|
||||
if config.ResetTimeout != 30*time.Second {
|
||||
t.Errorf("Expected ResetTimeout 30s, got %v", config.ResetTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBaseRecoveryMechanism_GetBaseMetrics tests getting base metrics
|
||||
func TestBaseRecoveryMechanism_GetBaseMetrics(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
base := NewBaseRecoveryMechanism("test-mechanism", logger)
|
||||
|
||||
metrics := base.GetBaseMetrics()
|
||||
|
||||
if metrics == nil {
|
||||
t.Fatal("Expected non-nil metrics")
|
||||
}
|
||||
|
||||
// Check expected metric fields
|
||||
expectedFields := []string{
|
||||
"total_requests",
|
||||
"total_failures",
|
||||
"total_successes",
|
||||
"uptime_seconds",
|
||||
"name",
|
||||
}
|
||||
|
||||
for _, field := range expectedFields {
|
||||
if _, exists := metrics[field]; !exists {
|
||||
t.Errorf("Expected metric field %s to exist", field)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestBaseRecoveryMechanism_RecordRequest tests request recording
|
||||
func TestBaseRecoveryMechanism_RecordRequest(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
base := NewBaseRecoveryMechanism("test-mechanism", logger)
|
||||
|
||||
// Record some requests
|
||||
base.RecordRequest()
|
||||
base.RecordRequest()
|
||||
base.RecordRequest()
|
||||
|
||||
// Get metrics to verify
|
||||
metrics := base.GetBaseMetrics()
|
||||
totalRequests := metrics["total_requests"].(int64)
|
||||
|
||||
if totalRequests != 3 {
|
||||
t.Errorf("Expected 3 total requests, got %d", totalRequests)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBaseRecoveryMechanism_RecordSuccess tests success recording
|
||||
func TestBaseRecoveryMechanism_RecordSuccess(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
base := NewBaseRecoveryMechanism("test-mechanism", logger)
|
||||
|
||||
// Record some successes
|
||||
base.RecordSuccess()
|
||||
base.RecordSuccess()
|
||||
|
||||
// Get metrics to verify
|
||||
metrics := base.GetBaseMetrics()
|
||||
totalSuccesses := metrics["total_successes"].(int64)
|
||||
|
||||
if totalSuccesses != 2 {
|
||||
t.Errorf("Expected 2 successful requests, got %d", totalSuccesses)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBaseRecoveryMechanism_RecordFailure tests failure recording
|
||||
func TestBaseRecoveryMechanism_RecordFailure(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
base := NewBaseRecoveryMechanism("test-mechanism", logger)
|
||||
|
||||
// Record some failures
|
||||
base.RecordFailure()
|
||||
base.RecordFailure()
|
||||
base.RecordFailure()
|
||||
|
||||
// Get metrics to verify
|
||||
metrics := base.GetBaseMetrics()
|
||||
totalFailures := metrics["total_failures"].(int64)
|
||||
|
||||
if totalFailures != 3 {
|
||||
t.Errorf("Expected 3 failed requests, got %d", totalFailures)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBaseRecoveryMechanism_LogInfo tests info logging
|
||||
func TestBaseRecoveryMechanism_LogInfo(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
base := NewBaseRecoveryMechanism("test-mechanism", logger)
|
||||
|
||||
// Test logging doesn't panic
|
||||
base.LogInfo("test message")
|
||||
base.LogInfo("test message with args: %s %d", "arg1", 42)
|
||||
|
||||
// Test with nil logger
|
||||
baseNoLogger := NewBaseRecoveryMechanism("test", nil)
|
||||
baseNoLogger.LogInfo("test message") // Should not panic
|
||||
}
|
||||
|
||||
// TestBaseRecoveryMechanism_LogError tests error logging
|
||||
func TestBaseRecoveryMechanism_LogError(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
base := NewBaseRecoveryMechanism("test-mechanism", logger)
|
||||
|
||||
// Test logging doesn't panic
|
||||
base.LogError("error message")
|
||||
base.LogError("error message with args: %s %d", "error", 500)
|
||||
|
||||
// Test with nil logger
|
||||
baseNoLogger := NewBaseRecoveryMechanism("test", nil)
|
||||
baseNoLogger.LogError("error message") // Should not panic
|
||||
}
|
||||
|
||||
// TestBaseRecoveryMechanism_LogDebug tests debug logging
|
||||
func TestBaseRecoveryMechanism_LogDebug(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
base := NewBaseRecoveryMechanism("test-mechanism", logger)
|
||||
|
||||
// Test logging doesn't panic
|
||||
base.LogDebug("debug message")
|
||||
base.LogDebug("debug message with args: %s %d", "debug", 123)
|
||||
|
||||
// Test with nil logger
|
||||
baseNoLogger := NewBaseRecoveryMechanism("test", nil)
|
||||
baseNoLogger.LogDebug("debug message") // Should not panic
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_GetState tests getting circuit breaker state
|
||||
func TestCircuitBreaker_GetState(t *testing.T) {
|
||||
config := DefaultCircuitBreakerConfig()
|
||||
logger := GetSingletonNoOpLogger()
|
||||
cb := NewCircuitBreaker(config, logger)
|
||||
|
||||
// Initial state should be closed
|
||||
state := cb.GetState()
|
||||
if state != CircuitBreakerClosed {
|
||||
t.Errorf("Expected initial state to be closed, got %d", state)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_Reset tests resetting circuit breaker
|
||||
func TestCircuitBreaker_Reset(t *testing.T) {
|
||||
config := DefaultCircuitBreakerConfig()
|
||||
logger := GetSingletonNoOpLogger()
|
||||
cb := NewCircuitBreaker(config, logger)
|
||||
|
||||
// Reset should not panic
|
||||
cb.Reset()
|
||||
|
||||
// State should be closed after reset
|
||||
state := cb.GetState()
|
||||
if state != CircuitBreakerClosed {
|
||||
t.Errorf("Expected state to be closed after reset, got %d", state)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_IsAvailable tests availability check
|
||||
func TestCircuitBreaker_IsAvailable(t *testing.T) {
|
||||
config := DefaultCircuitBreakerConfig()
|
||||
logger := GetSingletonNoOpLogger()
|
||||
cb := NewCircuitBreaker(config, logger)
|
||||
|
||||
// Initially should be available
|
||||
available := cb.IsAvailable()
|
||||
if !available {
|
||||
t.Error("Expected circuit breaker to be available initially")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_GetMetrics tests getting circuit breaker metrics
|
||||
func TestCircuitBreaker_GetMetrics(t *testing.T) {
|
||||
config := DefaultCircuitBreakerConfig()
|
||||
logger := GetSingletonNoOpLogger()
|
||||
cb := NewCircuitBreaker(config, logger)
|
||||
|
||||
metrics := cb.GetMetrics()
|
||||
if metrics == nil {
|
||||
t.Fatal("Expected non-nil metrics")
|
||||
}
|
||||
|
||||
// Should include base metrics
|
||||
if _, exists := metrics["total_requests"]; !exists {
|
||||
t.Error("Expected total_requests in metrics")
|
||||
}
|
||||
|
||||
// Should include circuit breaker specific metrics
|
||||
if _, exists := metrics["state"]; !exists {
|
||||
t.Error("Expected state in metrics")
|
||||
}
|
||||
}
|
||||
|
||||
// Retry mechanism tests removed due to complex dependencies
|
||||
|
||||
// Benchmark tests
|
||||
func BenchmarkDefaultCircuitBreakerConfig(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
DefaultCircuitBreakerConfig()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkBaseRecoveryMechanism_GetBaseMetrics(b *testing.B) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
base := NewBaseRecoveryMechanism("test-mechanism", logger)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
base.GetBaseMetrics()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkBaseRecoveryMechanism_RecordRequest(b *testing.B) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
base := NewBaseRecoveryMechanism("test-mechanism", logger)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
base.RecordRequest()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
package traefikoidc
|
||||
|
||||
import "testing"
|
||||
|
||||
func BenchmarkDefaultCircuitBreakerConfig(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
DefaultCircuitBreakerConfig()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkBaseRecoveryMechanism_GetBaseMetrics(b *testing.B) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
base := NewBaseRecoveryMechanism("test-mechanism", logger)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
base.GetBaseMetrics()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkBaseRecoveryMechanism_RecordRequest(b *testing.B) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
base := NewBaseRecoveryMechanism("test-mechanism", logger)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
base.RecordRequest()
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,486 @@
|
||||
# ============================================================================
|
||||
# Complete Traefik Configuration Example with TraefikOIDC Plugin + Redis
|
||||
# ============================================================================
|
||||
#
|
||||
# This example shows a complete, production-ready configuration for using
|
||||
# the TraefikOIDC plugin with Redis caching in a multi-replica deployment.
|
||||
#
|
||||
|
||||
# ============================================================================
|
||||
# Part 1: Traefik Static Configuration (traefik.yml)
|
||||
# ============================================================================
|
||||
# This file configures Traefik itself and enables the plugin.
|
||||
# Place this in /etc/traefik/traefik.yml or mount it in your container.
|
||||
|
||||
---
|
||||
# Static Configuration
|
||||
api:
|
||||
dashboard: true
|
||||
insecure: false # Set to true only for local development
|
||||
|
||||
entryPoints:
|
||||
web:
|
||||
address: ":80"
|
||||
http:
|
||||
redirections:
|
||||
entryPoint:
|
||||
to: websecure
|
||||
scheme: https
|
||||
|
||||
websecure:
|
||||
address: ":443"
|
||||
http:
|
||||
tls:
|
||||
certResolver: letsencrypt
|
||||
|
||||
certificatesResolvers:
|
||||
letsencrypt:
|
||||
acme:
|
||||
email: admin@example.com
|
||||
storage: /letsencrypt/acme.json
|
||||
httpChallenge:
|
||||
entryPoint: web
|
||||
|
||||
providers:
|
||||
file:
|
||||
filename: /etc/traefik/dynamic.yml
|
||||
watch: true
|
||||
|
||||
# Enable the TraefikOIDC plugin
|
||||
experimental:
|
||||
plugins:
|
||||
traefikoidc:
|
||||
moduleName: github.com/lukaszraczylo/traefikoidc
|
||||
version: v0.8.0
|
||||
|
||||
log:
|
||||
level: INFO
|
||||
format: json
|
||||
|
||||
accessLog:
|
||||
format: json
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Part 2: Traefik Dynamic Configuration (dynamic.yml)
|
||||
# ============================================================================
|
||||
# This file defines your routes, services, and middleware.
|
||||
# Place this in /etc/traefik/dynamic.yml
|
||||
|
||||
---
|
||||
http:
|
||||
# -------------------------------------------------------------------------
|
||||
# Middleware Definitions
|
||||
# -------------------------------------------------------------------------
|
||||
middlewares:
|
||||
# Example 1: Minimal Redis Configuration
|
||||
# Perfect for getting started quickly
|
||||
oidc-minimal:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
# Required OIDC settings
|
||||
clientID: "your-application-client-id"
|
||||
clientSecret: "your-client-secret-from-provider"
|
||||
providerURL: "https://auth.example.com"
|
||||
callbackURL: "/oauth2/callback"
|
||||
sessionEncryptionKey: "your-secure-64-character-encryption-key-must-be-kept-secret"
|
||||
|
||||
# Minimal Redis configuration
|
||||
redis:
|
||||
enabled: true
|
||||
address: "redis:6379"
|
||||
|
||||
# Example 2: Production Redis Configuration
|
||||
# Recommended for production deployments with multiple Traefik replicas
|
||||
oidc-production:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
# OIDC Provider Configuration
|
||||
clientID: "prod-client-id"
|
||||
clientSecret: "prod-client-secret"
|
||||
providerURL: "https://auth.example.com"
|
||||
callbackURL: "/oauth2/callback"
|
||||
|
||||
# Session Configuration
|
||||
sessionEncryptionKey: "prod-encryption-key-64-chars-long-keep-it-secret-and-safe"
|
||||
sessionMaxAge: 28800 # 8 hours
|
||||
|
||||
# Security Settings
|
||||
forceHTTPS: true
|
||||
strictAudienceValidation: true
|
||||
|
||||
# Redis Configuration for Multi-Replica Deployment
|
||||
redis:
|
||||
enabled: true
|
||||
address: "redis-master.redis-namespace.svc.cluster.local:6379"
|
||||
password: "REPLACE_WITH_YOUR_REDIS_PASSWORD"
|
||||
db: 0
|
||||
keyPrefix: "traefikoidc:prod:"
|
||||
|
||||
# Cache Strategy
|
||||
cacheMode: "hybrid" # Fast local cache + shared Redis
|
||||
|
||||
# Connection Pooling
|
||||
poolSize: 20
|
||||
connectTimeout: 5
|
||||
readTimeout: 3
|
||||
writeTimeout: 3
|
||||
|
||||
# Resilience Features
|
||||
enableCircuitBreaker: true
|
||||
circuitBreakerThreshold: 5
|
||||
circuitBreakerTimeout: 60
|
||||
enableHealthCheck: true
|
||||
healthCheckInterval: 30
|
||||
|
||||
# Example 3: Redis with TLS (for production security)
|
||||
oidc-secure:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
clientID: "secure-client-id"
|
||||
clientSecret: "secure-client-secret"
|
||||
providerURL: "https://auth.example.com"
|
||||
callbackURL: "/oauth2/callback"
|
||||
sessionEncryptionKey: "secure-64-character-encryption-key-for-production-use-only"
|
||||
|
||||
redis:
|
||||
enabled: true
|
||||
address: "redis.example.com:6380"
|
||||
password: "REPLACE_WITH_YOUR_REDIS_PASSWORD"
|
||||
enableTLS: true
|
||||
tlsSkipVerify: false # Verify certificates in production
|
||||
cacheMode: "redis"
|
||||
|
||||
# Example 4: Hybrid Mode (Best Performance + Consistency)
|
||||
# Local cache for hot data, Redis for consistency across replicas
|
||||
oidc-hybrid:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
clientID: "app-client-id"
|
||||
clientSecret: "app-client-secret"
|
||||
providerURL: "https://auth.example.com"
|
||||
callbackURL: "/oauth2/callback"
|
||||
sessionEncryptionKey: "hybrid-mode-encryption-key-64-characters-long-and-secure"
|
||||
|
||||
redis:
|
||||
enabled: true
|
||||
address: "redis:6379"
|
||||
password: "REPLACE_WITH_YOUR_REDIS_PASSWORD"
|
||||
cacheMode: "hybrid"
|
||||
|
||||
# Hybrid mode L1 cache settings
|
||||
hybridL1Size: 1000 # Number of items in local cache
|
||||
hybridL1MemoryMB: 20 # MB of memory for local cache
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Router Definitions
|
||||
# -------------------------------------------------------------------------
|
||||
routers:
|
||||
# Protected application using OIDC authentication
|
||||
my-app:
|
||||
rule: "Host(`app.example.com`)"
|
||||
entryPoints:
|
||||
- websecure
|
||||
middlewares:
|
||||
- oidc-production # Use the OIDC middleware
|
||||
service: my-app-service
|
||||
tls:
|
||||
certResolver: letsencrypt
|
||||
|
||||
# Another app with minimal OIDC config
|
||||
simple-app:
|
||||
rule: "Host(`simple.example.com`)"
|
||||
entryPoints:
|
||||
- websecure
|
||||
middlewares:
|
||||
- oidc-minimal
|
||||
service: simple-app-service
|
||||
tls:
|
||||
certResolver: letsencrypt
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Service Definitions
|
||||
# -------------------------------------------------------------------------
|
||||
services:
|
||||
my-app-service:
|
||||
loadBalancer:
|
||||
servers:
|
||||
- url: "http://my-app:8080"
|
||||
healthCheck:
|
||||
path: /health
|
||||
interval: 30s
|
||||
timeout: 5s
|
||||
|
||||
simple-app-service:
|
||||
loadBalancer:
|
||||
servers:
|
||||
- url: "http://simple-app:3000"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Part 3: Docker Compose Example
|
||||
# ============================================================================
|
||||
|
||||
---
|
||||
# docker-compose.yml
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
# Redis service for shared caching
|
||||
redis:
|
||||
image: redis:7-alpine
|
||||
command: redis-server --requirepass yourredispassword --maxmemory 256mb --maxmemory-policy allkeys-lru
|
||||
ports:
|
||||
- "6379:6379"
|
||||
volumes:
|
||||
- redis-data:/data
|
||||
healthcheck:
|
||||
test: ["CMD", "redis-cli", "--raw", "incr", "ping"]
|
||||
interval: 10s
|
||||
timeout: 3s
|
||||
retries: 5
|
||||
networks:
|
||||
- traefik-network
|
||||
|
||||
# Traefik with TraefikOIDC plugin
|
||||
traefik:
|
||||
image: traefik:v3.2
|
||||
command:
|
||||
- "--api.dashboard=true"
|
||||
- "--providers.docker=true"
|
||||
- "--providers.docker.exposedbydefault=false"
|
||||
- "--providers.file.filename=/etc/traefik/dynamic.yml"
|
||||
- "--entrypoints.web.address=:80"
|
||||
- "--entrypoints.websecure.address=:443"
|
||||
- "--experimental.plugins.traefikoidc.modulename=github.com/lukaszraczylo/traefikoidc"
|
||||
- "--experimental.plugins.traefikoidc.version=v0.8.0"
|
||||
ports:
|
||||
- "80:80"
|
||||
- "443:443"
|
||||
- "8080:8080" # Dashboard
|
||||
volumes:
|
||||
- /var/run/docker.sock:/var/run/docker.sock:ro
|
||||
- ./traefik-dynamic.yml:/etc/traefik/dynamic.yml:ro
|
||||
- ./letsencrypt:/letsencrypt
|
||||
depends_on:
|
||||
- redis
|
||||
networks:
|
||||
- traefik-network
|
||||
|
||||
# Your application
|
||||
my-app:
|
||||
image: my-app:latest
|
||||
labels:
|
||||
- "traefik.enable=true"
|
||||
- "traefik.http.routers.my-app.rule=Host(`app.example.com`)"
|
||||
- "traefik.http.routers.my-app.entrypoints=websecure"
|
||||
- "traefik.http.routers.my-app.tls.certresolver=letsencrypt"
|
||||
|
||||
# OIDC Middleware Configuration with Redis (using labels)
|
||||
- "traefik.http.routers.my-app.middlewares=my-oidc@docker"
|
||||
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.clientID=your-client-id"
|
||||
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.clientSecret=your-client-secret"
|
||||
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.providerURL=https://auth.example.com"
|
||||
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.callbackURL=/oauth2/callback"
|
||||
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.sessionEncryptionKey=your-64-character-encryption-key-here"
|
||||
|
||||
# Redis configuration
|
||||
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.enabled=true"
|
||||
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.address=redis:6379"
|
||||
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.password=yourredispassword"
|
||||
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.db=0"
|
||||
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.keyPrefix=traefikoidc:"
|
||||
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.cacheMode=hybrid"
|
||||
networks:
|
||||
- traefik-network
|
||||
deploy:
|
||||
replicas: 3 # Multiple replicas sharing Redis cache
|
||||
|
||||
volumes:
|
||||
redis-data:
|
||||
|
||||
networks:
|
||||
traefik-network:
|
||||
driver: bridge
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Part 4: Kubernetes Example
|
||||
# ============================================================================
|
||||
|
||||
---
|
||||
# kubernetes-example.yaml
|
||||
|
||||
# Redis Deployment
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: redis
|
||||
namespace: traefik
|
||||
spec:
|
||||
replicas: 1
|
||||
selector:
|
||||
matchLabels:
|
||||
app: redis
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
app: redis
|
||||
spec:
|
||||
containers:
|
||||
- name: redis
|
||||
image: redis:7-alpine
|
||||
args:
|
||||
- redis-server
|
||||
- --requirepass
|
||||
- $(REDIS_PASSWORD)
|
||||
- --maxmemory
|
||||
- 512mb
|
||||
- --maxmemory-policy
|
||||
- allkeys-lru
|
||||
env:
|
||||
- name: REDIS_PASSWORD
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: redis-secret
|
||||
key: password
|
||||
ports:
|
||||
- containerPort: 6379
|
||||
resources:
|
||||
requests:
|
||||
memory: "256Mi"
|
||||
cpu: "100m"
|
||||
limits:
|
||||
memory: "512Mi"
|
||||
cpu: "500m"
|
||||
---
|
||||
# Redis Service
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: redis
|
||||
namespace: traefik
|
||||
spec:
|
||||
selector:
|
||||
app: redis
|
||||
ports:
|
||||
- port: 6379
|
||||
targetPort: 6379
|
||||
---
|
||||
# Redis Secret
|
||||
apiVersion: v1
|
||||
kind: Secret
|
||||
metadata:
|
||||
name: redis-secret
|
||||
namespace: traefik
|
||||
type: Opaque
|
||||
stringData:
|
||||
password: "REPLACE_WITH_YOUR_REDIS_PASSWORD"
|
||||
---
|
||||
# OIDC Middleware with Redis
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-auth
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
# OIDC Configuration
|
||||
clientID: "kubernetes-client-id"
|
||||
clientSecret: "kubernetes-client-secret"
|
||||
providerURL: "https://auth.example.com"
|
||||
callbackURL: "/oauth2/callback"
|
||||
sessionEncryptionKey: "kubernetes-64-character-session-encryption-key-keep-secret"
|
||||
|
||||
# Redis Configuration
|
||||
redis:
|
||||
enabled: true
|
||||
address: "redis.traefik.svc.cluster.local:6379"
|
||||
password: "REPLACE_WITH_YOUR_REDIS_PASSWORD"
|
||||
db: 0
|
||||
keyPrefix: "traefikoidc:k8s:"
|
||||
cacheMode: "hybrid"
|
||||
poolSize: 20
|
||||
enableCircuitBreaker: true
|
||||
enableHealthCheck: true
|
||||
---
|
||||
# IngressRoute using the middleware
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: IngressRoute
|
||||
metadata:
|
||||
name: my-app
|
||||
namespace: default
|
||||
spec:
|
||||
entryPoints:
|
||||
- websecure
|
||||
routes:
|
||||
- match: Host(`app.example.com`)
|
||||
kind: Rule
|
||||
middlewares:
|
||||
- name: oidc-auth
|
||||
namespace: traefik
|
||||
services:
|
||||
- name: my-app
|
||||
port: 80
|
||||
tls:
|
||||
certResolver: letsencrypt
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Part 5: Environment Variables (Optional Fallback)
|
||||
# ============================================================================
|
||||
|
||||
# If you prefer environment variables as fallback (not recommended for production),
|
||||
# you can set these. NOTE: Plugin configuration takes precedence!
|
||||
|
||||
# Docker Compose env file (.env)
|
||||
---
|
||||
# OIDC Configuration
|
||||
OIDC_CLIENT_ID=your-client-id
|
||||
OIDC_CLIENT_SECRET=your-client-secret
|
||||
OIDC_PROVIDER_URL=https://auth.example.com
|
||||
|
||||
# Redis Configuration (fallback)
|
||||
REDIS_ENABLED=true
|
||||
REDIS_ADDRESS=redis:6379
|
||||
REDIS_PASSWORD=yourredispassword
|
||||
REDIS_DB=0
|
||||
REDIS_KEY_PREFIX=traefikoidc:
|
||||
REDIS_CACHE_MODE=hybrid
|
||||
REDIS_POOL_SIZE=20
|
||||
REDIS_ENABLE_CIRCUIT_BREAKER=true
|
||||
REDIS_ENABLE_HEALTH_CHECK=true
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Configuration Cheat Sheet
|
||||
# ============================================================================
|
||||
|
||||
# Minimal Setup (Quick Start):
|
||||
# redis:
|
||||
# enabled: true
|
||||
# address: "redis:6379"
|
||||
|
||||
# Production Setup (Recommended):
|
||||
# redis:
|
||||
# enabled: true
|
||||
# address: "redis-master:6379"
|
||||
# password: "strong-password"
|
||||
# cacheMode: "hybrid"
|
||||
# enableCircuitBreaker: true
|
||||
# enableHealthCheck: true
|
||||
|
||||
# High Security Setup:
|
||||
# redis:
|
||||
# enabled: true
|
||||
# address: "redis.example.com:6380"
|
||||
# password: "strong-password"
|
||||
# enableTLS: true
|
||||
# tlsSkipVerify: false
|
||||
# cacheMode: "redis"
|
||||
|
||||
# Cache Modes:
|
||||
# - "memory": Local cache only (default, no Redis needed)
|
||||
# - "redis": Redis only (consistent, shared across replicas)
|
||||
# - "hybrid": Local L1 + Redis L2 (best performance + consistency)
|
||||
@@ -0,0 +1,149 @@
|
||||
# Example Traefik configuration for TraefikOIDC plugin with Redis caching
|
||||
# This example shows how to configure Redis through Traefik's dynamic configuration
|
||||
|
||||
# Static configuration (traefik.yml)
|
||||
experimental:
|
||||
plugins:
|
||||
traefikoidc:
|
||||
moduleName: github.com/lukaszraczylo/traefikoidc
|
||||
version: v0.8.0
|
||||
|
||||
# Dynamic configuration (dynamic.yml or labels)
|
||||
http:
|
||||
middlewares:
|
||||
# Example 1: Basic Redis configuration
|
||||
oidc-redis-basic:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
# Required OIDC settings
|
||||
clientID: "your-client-id"
|
||||
clientSecret: "your-client-secret"
|
||||
providerURL: "https://auth.example.com"
|
||||
callbackURL: "/oauth2/callback"
|
||||
sessionEncryptionKey: "your-64-character-encryption-key-here-keep-it-secret"
|
||||
|
||||
# Redis configuration
|
||||
redis:
|
||||
enabled: true
|
||||
address: "redis:6379"
|
||||
# password: "your-redis-password" # Optional
|
||||
db: 0
|
||||
keyPrefix: "traefikoidc:"
|
||||
|
||||
# Example 2: Redis with resilience features
|
||||
oidc-redis-resilient:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
# Required OIDC settings
|
||||
clientID: "your-client-id"
|
||||
clientSecret: "your-client-secret"
|
||||
providerURL: "https://auth.example.com"
|
||||
callbackURL: "/oauth2/callback"
|
||||
sessionEncryptionKey: "your-64-character-encryption-key-here-keep-it-secret"
|
||||
|
||||
# Redis with full resilience configuration
|
||||
redis:
|
||||
enabled: true
|
||||
address: "redis:6379"
|
||||
password: "REPLACE_WITH_YOUR_REDIS_PASSWORD" # Example placeholder - use your actual password
|
||||
db: 1
|
||||
keyPrefix: "myapp:"
|
||||
poolSize: 20
|
||||
connectTimeout: 10
|
||||
readTimeout: 5
|
||||
writeTimeout: 5
|
||||
cacheMode: "redis" # Options: "redis", "hybrid", "memory"
|
||||
# Circuit breaker settings
|
||||
enableCircuitBreaker: true
|
||||
circuitBreakerThreshold: 5
|
||||
circuitBreakerTimeout: 60
|
||||
# Health check settings
|
||||
enableHealthCheck: true
|
||||
healthCheckInterval: 30
|
||||
|
||||
# Example 3: Redis with TLS
|
||||
oidc-redis-tls:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
# Required OIDC settings
|
||||
clientID: "your-client-id"
|
||||
clientSecret: "your-client-secret"
|
||||
providerURL: "https://auth.example.com"
|
||||
callbackURL: "/oauth2/callback"
|
||||
sessionEncryptionKey: "your-64-character-encryption-key-here-keep-it-secret"
|
||||
|
||||
# Redis with TLS configuration
|
||||
redis:
|
||||
enabled: true
|
||||
address: "redis.example.com:6380"
|
||||
password: "REPLACE_WITH_YOUR_REDIS_PASSWORD" # Example placeholder
|
||||
enableTLS: true
|
||||
tlsSkipVerify: false # Set to true only for testing
|
||||
cacheMode: "redis"
|
||||
|
||||
routers:
|
||||
my-app:
|
||||
rule: "Host(`app.example.com`)"
|
||||
middlewares:
|
||||
- oidc-redis-basic
|
||||
service: my-app-service
|
||||
|
||||
services:
|
||||
my-app-service:
|
||||
loadBalancer:
|
||||
servers:
|
||||
- url: "http://localhost:8080"
|
||||
|
||||
# Docker Compose labels example
|
||||
# version: '3.8'
|
||||
# services:
|
||||
# traefik:
|
||||
# image: traefik:v3.0
|
||||
# # ... other config ...
|
||||
#
|
||||
# my-app:
|
||||
# image: my-app:latest
|
||||
# labels:
|
||||
# - "traefik.enable=true"
|
||||
# - "traefik.http.routers.my-app.rule=Host(`app.example.com`)"
|
||||
# - "traefik.http.routers.my-app.middlewares=my-oidc"
|
||||
# # OIDC middleware configuration with Redis
|
||||
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.clientID=your-client-id"
|
||||
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.clientSecret=your-secret"
|
||||
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.providerURL=https://auth.example.com"
|
||||
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.callbackURL=/oauth2/callback"
|
||||
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.sessionEncryptionKey=your-64-char-key"
|
||||
# # Redis configuration via labels
|
||||
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.enabled=true"
|
||||
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.address=redis:6379"
|
||||
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.password=redis-password"
|
||||
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.db=0"
|
||||
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.keyPrefix=traefikoidc:"
|
||||
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.cacheMode=redis"
|
||||
#
|
||||
# redis:
|
||||
# image: redis:7-alpine
|
||||
# command: redis-server --requirepass redis-password
|
||||
# # ... other config ...
|
||||
|
||||
# Environment variable fallback (optional)
|
||||
# If Redis configuration is not provided in Traefik config, these environment variables
|
||||
# can be used as a fallback (but Traefik config takes precedence):
|
||||
#
|
||||
# REDIS_ENABLED=true
|
||||
# REDIS_ADDRESS=redis:6379
|
||||
# REDIS_PASSWORD=secret
|
||||
# REDIS_DB=0
|
||||
# REDIS_KEY_PREFIX=traefikoidc:
|
||||
# REDIS_CACHE_MODE=redis
|
||||
# REDIS_POOL_SIZE=10
|
||||
# REDIS_CONNECT_TIMEOUT=5
|
||||
# REDIS_READ_TIMEOUT=3
|
||||
# REDIS_WRITE_TIMEOUT=3
|
||||
# REDIS_ENABLE_TLS=false
|
||||
# REDIS_TLS_SKIP_VERIFY=false
|
||||
# REDIS_ENABLE_CIRCUIT_BREAKER=true
|
||||
# REDIS_CIRCUIT_BREAKER_THRESHOLD=5
|
||||
# REDIS_CIRCUIT_BREAKER_TIMEOUT=60
|
||||
# REDIS_ENABLE_HEALTH_CHECK=true
|
||||
# REDIS_HEALTH_CHECK_INTERVAL=30
|
||||
@@ -1,797 +0,0 @@
|
||||
package features
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"text/template"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// Mock types for testing
|
||||
type TemplatedHeader struct {
|
||||
Name string `json:"name"`
|
||||
Value string `json:"value"`
|
||||
}
|
||||
|
||||
type MockConfig struct {
|
||||
ProviderURL string `json:"providerURL"`
|
||||
ClientID string `json:"clientID"`
|
||||
ClientSecret string `json:"clientSecret"`
|
||||
CallbackURL string `json:"callbackURL"`
|
||||
SessionEncryptionKey string `json:"sessionEncryptionKey"`
|
||||
Headers []TemplatedHeader `json:"headers"`
|
||||
}
|
||||
|
||||
// TestTemplateHeaderFeatures consolidates all template header-related tests
|
||||
func TestTemplateHeaderFeatures(t *testing.T) {
|
||||
t.Run("Issue55_TemplateExecutionWithWrongTypes", testIssue55TemplateExecutionWithWrongTypes)
|
||||
t.Run("Template_Parsing_Validation", testTemplateParsingValidation)
|
||||
t.Run("Middleware_Header_Templating", testMiddlewareHeaderTemplating)
|
||||
t.Run("JSON_Config_Parsing", testJSONConfigParsing)
|
||||
t.Run("Template_Double_Processing", testTemplateDoubleProcessing)
|
||||
t.Run("Template_Execution_Context", testTemplateExecutionContext)
|
||||
t.Run("Template_Integration_With_Plugin", testTemplateIntegrationWithPlugin)
|
||||
t.Run("Template_Syntax_Validation", testTemplateSyntaxValidation)
|
||||
t.Run("Missing_Field_Handling", testMissingFieldHandling)
|
||||
t.Run("Complex_Template_Expressions", testComplexTemplateExpressions)
|
||||
t.Run("Traefik_Configuration_Parsing", testTraefikConfigurationParsing)
|
||||
}
|
||||
|
||||
// testIssue55TemplateExecutionWithWrongTypes tests what happens when templates
|
||||
// receive wrong data types during execution - reproduces GitHub issue #55
|
||||
func testIssue55TemplateExecutionWithWrongTypes(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
templateText string
|
||||
templateData interface{}
|
||||
errorContains string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "correct map data",
|
||||
templateText: "Bearer {{.AccessToken}}",
|
||||
templateData: map[string]interface{}{
|
||||
"AccessToken": "valid-token",
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "boolean as root context - reproduces issue #55",
|
||||
templateText: "Bearer {{.AccessToken}}",
|
||||
templateData: true,
|
||||
expectError: true,
|
||||
errorContains: "can't evaluate field AccessToken in type bool",
|
||||
},
|
||||
{
|
||||
name: "string as root context",
|
||||
templateText: "Bearer {{.AccessToken}}",
|
||||
templateData: "just a string",
|
||||
expectError: true,
|
||||
errorContains: "can't evaluate field AccessToken in type string",
|
||||
},
|
||||
{
|
||||
name: "nested claims access with correct data",
|
||||
templateText: "User: {{.Claims.email}}",
|
||||
templateData: map[string]interface{}{
|
||||
"Claims": map[string]interface{}{
|
||||
"email": "user@example.com",
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "nested claims with wrong structure",
|
||||
templateText: "User: {{.Claims.email}}",
|
||||
templateData: map[string]interface{}{
|
||||
"Claims": "not a map",
|
||||
},
|
||||
expectError: true,
|
||||
errorContains: "can't evaluate field email in type",
|
||||
},
|
||||
{
|
||||
name: "complex nested structure",
|
||||
templateText: "{{.Claims.sub}} - {{.Claims.groups}} - {{.AccessToken}}",
|
||||
templateData: map[string]interface{}{
|
||||
"AccessToken": "token123",
|
||||
"Claims": map[string]interface{}{
|
||||
"sub": "user-id",
|
||||
"groups": "admin,users",
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tmpl, err := template.New("test").Parse(tc.templateText)
|
||||
require.NoError(t, err)
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = tmpl.Execute(&buf, tc.templateData)
|
||||
|
||||
if tc.expectError {
|
||||
require.Error(t, err)
|
||||
if tc.errorContains != "" {
|
||||
assert.Contains(t, err.Error(), tc.errorContains)
|
||||
}
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// testTemplateParsingValidation ensures templates are parsed correctly
|
||||
func testTemplateParsingValidation(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
headerTemplates []TemplatedHeader
|
||||
shouldError bool
|
||||
}{
|
||||
{
|
||||
name: "valid bearer token template",
|
||||
headerTemplates: []TemplatedHeader{
|
||||
{Name: "Authorization", Value: "Bearer {{.AccessToken}}"},
|
||||
},
|
||||
shouldError: false,
|
||||
},
|
||||
{
|
||||
name: "multiple valid templates",
|
||||
headerTemplates: []TemplatedHeader{
|
||||
{Name: "Authorization", Value: "Bearer {{.AccessToken}}"},
|
||||
{Name: "X-User-Email", Value: "{{.Claims.email}}"},
|
||||
{Name: "X-User-ID", Value: "{{.Claims.sub}}"},
|
||||
},
|
||||
shouldError: false,
|
||||
},
|
||||
{
|
||||
name: "template with conditional logic",
|
||||
headerTemplates: []TemplatedHeader{
|
||||
{Name: "X-Auth-Info", Value: "{{if .AccessToken}}Bearer {{.AccessToken}}{{else}}No Token{{end}}"},
|
||||
},
|
||||
shouldError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid template syntax",
|
||||
headerTemplates: []TemplatedHeader{
|
||||
{Name: "Bad-Template", Value: "{{.AccessToken"},
|
||||
},
|
||||
shouldError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
for _, header := range tc.headerTemplates {
|
||||
_, err := template.New(header.Name).Parse(header.Value)
|
||||
|
||||
if tc.shouldError {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// testMiddlewareHeaderTemplating simulates the actual middleware flow
|
||||
func testMiddlewareHeaderTemplating(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
headers []TemplatedHeader
|
||||
accessToken string
|
||||
idToken string
|
||||
claims map[string]interface{}
|
||||
expectedValues map[string]string
|
||||
}{
|
||||
{
|
||||
name: "authorization header with access token",
|
||||
headers: []TemplatedHeader{
|
||||
{Name: "Authorization", Value: "Bearer {{.AccessToken}}"},
|
||||
},
|
||||
accessToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9",
|
||||
expectedValues: map[string]string{
|
||||
"Authorization": "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple headers with claims",
|
||||
headers: []TemplatedHeader{
|
||||
{Name: "X-User-Email", Value: "{{.Claims.email}}"},
|
||||
{Name: "X-User-Groups", Value: "{{.Claims.groups}}"},
|
||||
{Name: "X-Auth-Token", Value: "{{.AccessToken}}"},
|
||||
},
|
||||
accessToken: "token123",
|
||||
claims: map[string]interface{}{
|
||||
"email": "user@example.com",
|
||||
"groups": "admin,developers",
|
||||
},
|
||||
expectedValues: map[string]string{
|
||||
"X-User-Email": "user@example.com",
|
||||
"X-User-Groups": "admin,developers",
|
||||
"X-Auth-Token": "token123",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "complex template expressions",
|
||||
headers: []TemplatedHeader{
|
||||
{Name: "X-User-Info", Value: "{{.Claims.sub}} ({{.Claims.email}})"},
|
||||
{Name: "X-Auth-Header", Value: "Bearer {{.AccessToken}} | ID: {{.IDToken}}"},
|
||||
},
|
||||
accessToken: "access-token",
|
||||
idToken: "id-token",
|
||||
claims: map[string]interface{}{
|
||||
"sub": "user-12345",
|
||||
"email": "john@example.com",
|
||||
},
|
||||
expectedValues: map[string]string{
|
||||
"X-User-Info": "user-12345 (john@example.com)",
|
||||
"X-Auth-Header": "Bearer access-token | ID: id-token",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Parse all templates
|
||||
headerTemplates := make(map[string]*template.Template)
|
||||
for _, header := range tc.headers {
|
||||
tmpl, err := template.New(header.Name).Parse(header.Value)
|
||||
require.NoError(t, err)
|
||||
headerTemplates[header.Name] = tmpl
|
||||
}
|
||||
|
||||
// Create template data
|
||||
templateData := map[string]interface{}{
|
||||
"AccessToken": tc.accessToken,
|
||||
"IDToken": tc.idToken,
|
||||
"Claims": tc.claims,
|
||||
}
|
||||
|
||||
// Create a test request
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
|
||||
// Execute templates and set headers
|
||||
for headerName, tmpl := range headerTemplates {
|
||||
var buf bytes.Buffer
|
||||
err := tmpl.Execute(&buf, templateData)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set(headerName, buf.String())
|
||||
}
|
||||
|
||||
// Verify all expected headers are set correctly
|
||||
for headerName, expectedValue := range tc.expectedValues {
|
||||
actualValue := req.Header.Get(headerName)
|
||||
assert.Equal(t, expectedValue, actualValue)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// testJSONConfigParsing tests that JSON configuration is properly parsed
|
||||
func testJSONConfigParsing(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
jsonConfig string
|
||||
expectedError bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "valid JSON configuration",
|
||||
jsonConfig: `{
|
||||
"headers": [
|
||||
{
|
||||
"name": "Authorization",
|
||||
"value": "Bearer {{.AccessToken}}"
|
||||
}
|
||||
]
|
||||
}`,
|
||||
expectedError: false,
|
||||
description: "Properly formatted JSON with string values",
|
||||
},
|
||||
{
|
||||
name: "JSON with boolean value",
|
||||
jsonConfig: `{
|
||||
"headers": [
|
||||
{
|
||||
"name": "Authorization",
|
||||
"value": true
|
||||
}
|
||||
]
|
||||
}`,
|
||||
expectedError: true,
|
||||
description: "Boolean value instead of string template",
|
||||
},
|
||||
{
|
||||
name: "JSON with number value",
|
||||
jsonConfig: `{
|
||||
"headers": [
|
||||
{
|
||||
"name": "Authorization",
|
||||
"value": 123
|
||||
}
|
||||
]
|
||||
}`,
|
||||
expectedError: true,
|
||||
description: "Number value instead of string template",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var config struct {
|
||||
Headers []TemplatedHeader `json:"headers"`
|
||||
}
|
||||
|
||||
err := json.Unmarshal([]byte(tc.jsonConfig), &config)
|
||||
|
||||
if tc.expectedError {
|
||||
require.Error(t, err, tc.description)
|
||||
} else {
|
||||
require.NoError(t, err, tc.description)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// testTemplateDoubleProcessing tests if template strings are being double-processed
|
||||
func testTemplateDoubleProcessing(t *testing.T) {
|
||||
// Simulate how Traefik passes config to the plugin
|
||||
config := &MockConfig{
|
||||
Headers: []TemplatedHeader{
|
||||
{Name: "X-User-Email", Value: "{{.Claims.email}}"},
|
||||
{Name: "X-User-Role", Value: "{{.Claims.internal_role}}"},
|
||||
},
|
||||
}
|
||||
|
||||
// Verify that template strings are still raw (not processed)
|
||||
assert.Equal(t, "{{.Claims.email}}", config.Headers[0].Value)
|
||||
assert.Equal(t, "{{.Claims.internal_role}}", config.Headers[1].Value)
|
||||
|
||||
// Simulate template parsing during initialization
|
||||
headerTemplates := make(map[string]*template.Template)
|
||||
|
||||
funcMap := template.FuncMap{
|
||||
"default": func(defaultVal interface{}, val interface{}) interface{} {
|
||||
if val == nil || val == "" || val == "<no value>" {
|
||||
return defaultVal
|
||||
}
|
||||
return val
|
||||
},
|
||||
"get": func(m interface{}, key string) interface{} {
|
||||
if mapVal, ok := m.(map[string]interface{}); ok {
|
||||
if val, exists := mapVal[key]; exists {
|
||||
return val
|
||||
}
|
||||
}
|
||||
return ""
|
||||
},
|
||||
}
|
||||
|
||||
for _, header := range config.Headers {
|
||||
tmpl := template.New(header.Name).Funcs(funcMap).Option("missingkey=zero")
|
||||
parsedTmpl, err := tmpl.Parse(header.Value)
|
||||
require.NoError(t, err)
|
||||
headerTemplates[header.Name] = parsedTmpl
|
||||
}
|
||||
|
||||
// Test execution with actual claims
|
||||
claims := map[string]interface{}{
|
||||
"email": "user@example.com",
|
||||
// Note: internal_role is missing
|
||||
}
|
||||
|
||||
templateData := map[string]interface{}{
|
||||
"Claims": claims,
|
||||
}
|
||||
|
||||
// Execute templates
|
||||
for headerName, tmpl := range headerTemplates {
|
||||
var buf bytes.Buffer
|
||||
err := tmpl.Execute(&buf, templateData)
|
||||
require.NoError(t, err)
|
||||
|
||||
result := buf.String()
|
||||
if headerName == "X-User-Email" {
|
||||
assert.Equal(t, "user@example.com", result)
|
||||
} else if headerName == "X-User-Role" {
|
||||
// With missingkey=zero, missing fields return "<no value>"
|
||||
assert.Equal(t, "<no value>", result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// testTemplateExecutionContext tests the specific template data context
|
||||
func testTemplateExecutionContext(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
templateText string
|
||||
data map[string]interface{}
|
||||
expectedValue string
|
||||
}{
|
||||
{
|
||||
name: "Access and ID token distinction",
|
||||
templateText: "Access: {{.AccessToken}} ID: {{.IDToken}}",
|
||||
data: map[string]interface{}{
|
||||
"AccessToken": "access-token-value",
|
||||
"IDToken": "id-token-value",
|
||||
"Claims": map[string]interface{}{},
|
||||
},
|
||||
expectedValue: "Access: access-token-value ID: id-token-value",
|
||||
},
|
||||
{
|
||||
name: "Combining tokens and claims",
|
||||
templateText: "User: {{.Claims.sub}} Token: {{.AccessToken}}",
|
||||
data: map[string]interface{}{
|
||||
"AccessToken": "access-token",
|
||||
"IDToken": "id-token",
|
||||
"Claims": map[string]interface{}{
|
||||
"sub": "user123",
|
||||
},
|
||||
},
|
||||
expectedValue: "User: user123 Token: access-token",
|
||||
},
|
||||
{
|
||||
name: "Custom non-standard claims",
|
||||
templateText: "X-User-Role: {{.Claims.role}}, X-User-Permissions: {{.Claims.permissions}}",
|
||||
data: map[string]interface{}{
|
||||
"AccessToken": "access-token-value",
|
||||
"Claims": map[string]interface{}{
|
||||
"role": "admin",
|
||||
"permissions": "read:all,write:own",
|
||||
},
|
||||
},
|
||||
expectedValue: "X-User-Role: admin, X-User-Permissions: read:all,write:own",
|
||||
},
|
||||
{
|
||||
name: "Deeply nested custom claims",
|
||||
templateText: "X-Organization: {{.Claims.app_metadata.organization.name}}, X-Team: {{.Claims.app_metadata.team}}",
|
||||
data: map[string]interface{}{
|
||||
"Claims": map[string]interface{}{
|
||||
"app_metadata": map[string]interface{}{
|
||||
"organization": map[string]interface{}{
|
||||
"name": "acme-corp",
|
||||
},
|
||||
"team": "platform",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedValue: "X-Organization: acme-corp, X-Team: platform",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tmpl, err := template.New("test").Parse(tc.templateText)
|
||||
require.NoError(t, err)
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = tmpl.Execute(&buf, tc.data)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tc.expectedValue, buf.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// testTemplateIntegrationWithPlugin tests template processing in the actual plugin
|
||||
func testTemplateIntegrationWithPlugin(t *testing.T) {
|
||||
// Test template integration using mock plugin components
|
||||
|
||||
// Set up test OIDC server
|
||||
var testServerURL string
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/.well-known/openid-configuration":
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"issuer": testServerURL,
|
||||
"authorization_endpoint": testServerURL + "/auth",
|
||||
"token_endpoint": testServerURL + "/token",
|
||||
"jwks_uri": testServerURL + "/jwks",
|
||||
"userinfo_endpoint": testServerURL + "/userinfo",
|
||||
})
|
||||
case "/jwks":
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"keys": []interface{}{},
|
||||
})
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer testServer.Close()
|
||||
testServerURL = testServer.URL
|
||||
|
||||
// Create config with templates that reference potentially missing fields
|
||||
config := &MockConfig{
|
||||
ProviderURL: testServer.URL,
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "test-secret",
|
||||
CallbackURL: "/callback",
|
||||
SessionEncryptionKey: "test-encryption-key-32-characters",
|
||||
Headers: []TemplatedHeader{
|
||||
{Name: "X-User-Email", Value: "{{.Claims.email}}"},
|
||||
{Name: "X-User-Role", Value: "{{.Claims.internal_role}}"},
|
||||
},
|
||||
}
|
||||
|
||||
// Initialize plugin would be done here
|
||||
ctx := context.Background()
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Test would create plugin handler here
|
||||
_ = ctx
|
||||
_ = next
|
||||
_ = config
|
||||
}
|
||||
|
||||
// testTemplateSyntaxValidation tests that template syntax is properly validated
|
||||
func testTemplateSyntaxValidation(t *testing.T) {
|
||||
validTemplates := []string{
|
||||
"{{.Claims.email}}",
|
||||
"{{.Claims.internal_role}}",
|
||||
"{{.AccessToken}}",
|
||||
"{{.IdToken}}",
|
||||
"{{.RefreshToken}}",
|
||||
}
|
||||
|
||||
for _, tmplStr := range validTemplates {
|
||||
err := validateTemplateSecure(tmplStr)
|
||||
assert.NoError(t, err, "Template should be valid: %s", tmplStr)
|
||||
}
|
||||
|
||||
// Test invalid templates
|
||||
invalidTemplates := []struct {
|
||||
template string
|
||||
reason string
|
||||
}{
|
||||
{"{{call .SomeFunc}}", "function calls not allowed"},
|
||||
{"{{range .Items}}{{.}}{{end}}", "range not allowed"},
|
||||
{"{{with .Data}}{{.Field}}{{end}}", "with statements blocked"},
|
||||
{"{{index .Array 0}}", "index access blocked"},
|
||||
{"{{printf \"%s\" .Data}}", "printf blocked"},
|
||||
}
|
||||
|
||||
for _, tc := range invalidTemplates {
|
||||
err := validateTemplateSecure(tc.template)
|
||||
assert.Error(t, err, "Template should be invalid: %s (%s)", tc.template, tc.reason)
|
||||
assert.Contains(t, strings.ToLower(err.Error()), "dangerous")
|
||||
}
|
||||
|
||||
// Test safe custom functions
|
||||
safeTemplates := []string{
|
||||
"{{get .Claims \"internal_role\"}}",
|
||||
"{{default \"guest\" .Claims.role}}",
|
||||
}
|
||||
|
||||
for _, tmplStr := range safeTemplates {
|
||||
err := validateTemplateSecure(tmplStr)
|
||||
assert.NoError(t, err, "Safe custom functions should be allowed: %s", tmplStr)
|
||||
}
|
||||
}
|
||||
|
||||
// Mock validation function for template security
|
||||
func validateTemplateSecure(templateStr string) error {
|
||||
// List of potentially dangerous template actions
|
||||
dangerousFunctions := []string{
|
||||
"call", "range", "with", "index", "printf", "println", "print",
|
||||
"js", "html", "urlquery", "base64", "exec",
|
||||
}
|
||||
|
||||
for _, dangerous := range dangerousFunctions {
|
||||
if strings.Contains(templateStr, dangerous) {
|
||||
return fmt.Errorf("dangerous template function detected: %s", dangerous)
|
||||
}
|
||||
}
|
||||
|
||||
// Define safe custom functions
|
||||
funcMap := template.FuncMap{
|
||||
"get": func(data map[string]interface{}, key string) interface{} {
|
||||
return data[key]
|
||||
},
|
||||
"default": func(defaultVal interface{}, val interface{}) interface{} {
|
||||
if val == nil || val == "" {
|
||||
return defaultVal
|
||||
}
|
||||
return val
|
||||
},
|
||||
}
|
||||
|
||||
// Try to parse the template with custom functions to check for syntax errors
|
||||
_, err := template.New("test").Funcs(funcMap).Parse(templateStr)
|
||||
return err
|
||||
}
|
||||
|
||||
// testMissingFieldHandling tests handling of missing fields in templates
|
||||
func testMissingFieldHandling(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
templateText string
|
||||
data map[string]interface{}
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "missing claim field",
|
||||
templateText: "{{.Claims.missing}}",
|
||||
data: map[string]interface{}{
|
||||
"Claims": map[string]interface{}{},
|
||||
},
|
||||
expected: "<no value>",
|
||||
},
|
||||
{
|
||||
name: "missing nested field",
|
||||
templateText: "{{.Claims.user.missing}}",
|
||||
data: map[string]interface{}{
|
||||
"Claims": map[string]interface{}{
|
||||
"user": map[string]interface{}{},
|
||||
},
|
||||
},
|
||||
expected: "<no value>",
|
||||
},
|
||||
{
|
||||
name: "missing entire path",
|
||||
templateText: "{{.Missing.Path.Field}}",
|
||||
data: map[string]interface{}{},
|
||||
expected: "<no value>",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tmpl, err := template.New("test").Parse(tc.templateText)
|
||||
require.NoError(t, err)
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = tmpl.Execute(&buf, tc.data)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tc.expected, buf.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// testComplexTemplateExpressions tests complex template expressions
|
||||
func testComplexTemplateExpressions(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
templateText string
|
||||
data map[string]interface{}
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "conditional template",
|
||||
templateText: "{{if .Claims.admin}}Admin User{{else}}Regular User{{end}}",
|
||||
data: map[string]interface{}{
|
||||
"Claims": map[string]interface{}{
|
||||
"admin": true,
|
||||
},
|
||||
},
|
||||
expected: "Admin User",
|
||||
},
|
||||
{
|
||||
name: "multiple claims concatenation",
|
||||
templateText: "{{.Claims.firstName}} {{.Claims.lastName}} <{{.Claims.email}}>",
|
||||
data: map[string]interface{}{
|
||||
"Claims": map[string]interface{}{
|
||||
"firstName": "John",
|
||||
"lastName": "Doe",
|
||||
"email": "john.doe@example.com",
|
||||
},
|
||||
},
|
||||
expected: "John Doe <john.doe@example.com>",
|
||||
},
|
||||
{
|
||||
name: "array access",
|
||||
templateText: "{{index .Claims.roles 0}}",
|
||||
data: map[string]interface{}{
|
||||
"Claims": map[string]interface{}{
|
||||
"roles": []string{"admin", "user"},
|
||||
},
|
||||
},
|
||||
expected: "admin",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tmpl, err := template.New("test").Parse(tc.templateText)
|
||||
require.NoError(t, err)
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = tmpl.Execute(&buf, tc.data)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tc.expected, buf.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// testTraefikConfigurationParsing tests various ways Traefik might pass configuration
|
||||
func testTraefikConfigurationParsing(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
config *MockConfig
|
||||
expectError bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "valid configuration with templated headers",
|
||||
config: &MockConfig{
|
||||
ProviderURL: "https://accounts.google.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "test-secret",
|
||||
SessionEncryptionKey: "test-encryption-key-32-bytes-long",
|
||||
CallbackURL: "/oauth2/callback",
|
||||
Headers: []TemplatedHeader{
|
||||
{Name: "Authorization", Value: "Bearer {{.AccessToken}}"},
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
description: "Standard configuration should work",
|
||||
},
|
||||
{
|
||||
name: "configuration with multiple headers",
|
||||
config: &MockConfig{
|
||||
ProviderURL: "https://accounts.google.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "test-secret",
|
||||
SessionEncryptionKey: "test-encryption-key-32-bytes-long",
|
||||
CallbackURL: "/oauth2/callback",
|
||||
Headers: []TemplatedHeader{
|
||||
{Name: "Authorization", Value: "Bearer {{.AccessToken}}"},
|
||||
{Name: "X-User-Email", Value: "{{.Claims.email}}"},
|
||||
{Name: "X-User-ID", Value: "{{.Claims.sub}}"},
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
description: "Multiple headers should work",
|
||||
},
|
||||
{
|
||||
name: "empty headers configuration",
|
||||
config: &MockConfig{
|
||||
ProviderURL: "https://accounts.google.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "test-secret",
|
||||
SessionEncryptionKey: "test-encryption-key-32-bytes-long",
|
||||
CallbackURL: "/oauth2/callback",
|
||||
Headers: []TemplatedHeader{},
|
||||
},
|
||||
expectError: false,
|
||||
description: "Empty headers should not cause issues",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Create a simple next handler
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Try to create the middleware would be done here
|
||||
ctx := context.Background()
|
||||
|
||||
// Test would create middleware handler here
|
||||
_ = ctx
|
||||
_ = next
|
||||
_ = tc.config
|
||||
|
||||
// For now, we just validate the configuration is well-formed
|
||||
if !tc.expectError {
|
||||
require.NotNil(t, tc.config, tc.description)
|
||||
require.NotEmpty(t, tc.config.ClientID, tc.description)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -3,15 +3,21 @@ module github.com/lukaszraczylo/traefikoidc
|
||||
go 1.24.0
|
||||
|
||||
require (
|
||||
github.com/alicebob/miniredis/v2 v2.35.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/gorilla/sessions v1.3.0
|
||||
github.com/redis/go-redis/v9 v9.17.2
|
||||
github.com/stretchr/testify v1.10.0
|
||||
golang.org/x/time v0.13.0
|
||||
golang.org/x/time v0.14.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/gorilla/securecookie v1.1.2 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
github.com/stretchr/objx v0.5.2 // indirect
|
||||
github.com/yuin/gopher-lua v1.1.1 // indirect
|
||||
)
|
||||
|
||||
@@ -1,5 +1,15 @@
|
||||
github.com/alicebob/miniredis/v2 v2.35.0 h1:QwLphYqCEAo1eu1TqPRN2jgVMPBweeQcR21jeqDCONI=
|
||||
github.com/alicebob/miniredis/v2 v2.35.0/go.mod h1:TcL7YfarKPGDAthEtl5NBeHZfeUQj6OXMm/+iu5cLMM=
|
||||
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
|
||||
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
|
||||
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
|
||||
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||
github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0=
|
||||
github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
@@ -10,10 +20,16 @@ github.com/gorilla/sessions v1.3.0 h1:XYlkq7KcpOB2ZhHBPv5WpjMIxrQosiZanfoy1HLZFz
|
||||
github.com/gorilla/sessions v1.3.0/go.mod h1:ePLdVu+jbEgHH+KWw8I1z2wqd0BAdAQh/8LRvBeoNcQ=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/redis/go-redis/v9 v9.17.2 h1:P2EGsA4qVIM3Pp+aPocCJ7DguDHhqrXNhVcEp4ViluI=
|
||||
github.com/redis/go-redis/v9 v9.17.2/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
|
||||
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
|
||||
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
golang.org/x/time v0.13.0 h1:eUlYslOIt32DgYD6utsuUeHs4d7AsEYLuIAdg7FlYgI=
|
||||
golang.org/x/time v0.13.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
|
||||
github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M=
|
||||
github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw=
|
||||
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
|
||||
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
|
||||
@@ -10,16 +10,16 @@ import (
|
||||
type GoroutineManager struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
mu sync.RWMutex
|
||||
goroutines map[string]*managedGoroutine
|
||||
logger *Logger
|
||||
wg sync.WaitGroup
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
type managedGoroutine struct {
|
||||
name string
|
||||
cancel context.CancelFunc
|
||||
startTime time.Time
|
||||
cancel context.CancelFunc
|
||||
name string
|
||||
running bool
|
||||
}
|
||||
|
||||
@@ -86,7 +86,7 @@ func (m *GoroutineManager) StartPeriodicTask(name string, interval time.Duration
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
m.logger.Debugf("Periodic task %s cancelled", name)
|
||||
m.logger.Debugf("Periodic task %s canceled", name)
|
||||
return
|
||||
case <-ticker.C:
|
||||
task()
|
||||
@@ -149,10 +149,10 @@ func (m *GoroutineManager) GetStatus() map[string]GoroutineStatus {
|
||||
|
||||
// GoroutineStatus represents the status of a managed goroutine
|
||||
type GoroutineStatus struct {
|
||||
Name string
|
||||
Running bool
|
||||
StartTime time.Time
|
||||
Name string
|
||||
Runtime time.Duration
|
||||
Running bool
|
||||
}
|
||||
|
||||
// ErrShutdownTimeout is returned when shutdown times out
|
||||
|
||||
@@ -0,0 +1,625 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Test GoroutineManager Creation
|
||||
|
||||
func TestNewGoroutineManager(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
|
||||
if gm == nil {
|
||||
t.Fatal("Expected non-nil goroutine manager")
|
||||
}
|
||||
|
||||
if gm.ctx == nil {
|
||||
t.Error("Expected context to be initialized")
|
||||
}
|
||||
|
||||
if gm.cancel == nil {
|
||||
t.Error("Expected cancel function to be initialized")
|
||||
}
|
||||
|
||||
if gm.goroutines == nil {
|
||||
t.Error("Expected goroutines map to be initialized")
|
||||
}
|
||||
|
||||
if gm.logger != logger {
|
||||
t.Error("Expected logger to be set")
|
||||
}
|
||||
}
|
||||
|
||||
// Test Starting Goroutines
|
||||
|
||||
func TestStartGoroutine(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
defer gm.Shutdown(time.Second)
|
||||
|
||||
executed := atomic.Bool{}
|
||||
|
||||
gm.StartGoroutine("test-goroutine", func(ctx context.Context) {
|
||||
executed.Store(true)
|
||||
})
|
||||
|
||||
// Give goroutine time to execute
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
if !executed.Load() {
|
||||
t.Error("Expected goroutine to execute")
|
||||
}
|
||||
|
||||
status := gm.GetStatus()
|
||||
if len(status) != 1 {
|
||||
t.Errorf("Expected 1 goroutine in status, got %d", len(status))
|
||||
}
|
||||
|
||||
if _, exists := status["test-goroutine"]; !exists {
|
||||
t.Error("Expected goroutine 'test-goroutine' in status")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartGoroutineDuplicate(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
defer gm.Shutdown(time.Second)
|
||||
|
||||
counter := atomic.Int32{}
|
||||
|
||||
// Start a long-running goroutine
|
||||
gm.StartGoroutine("duplicate-test", func(ctx context.Context) {
|
||||
counter.Add(1)
|
||||
<-ctx.Done()
|
||||
})
|
||||
|
||||
// Give first goroutine time to start
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Try to start another with same name (should be skipped)
|
||||
gm.StartGoroutine("duplicate-test", func(ctx context.Context) {
|
||||
counter.Add(1)
|
||||
})
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Should only have executed once
|
||||
if counter.Load() != 1 {
|
||||
t.Errorf("Expected counter to be 1 (duplicate should be skipped), got %d", counter.Load())
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartGoroutineContextCancellation(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
|
||||
started := atomic.Bool{}
|
||||
canceled := atomic.Bool{}
|
||||
|
||||
gm.StartGoroutine("cancel-test", func(ctx context.Context) {
|
||||
started.Store(true)
|
||||
<-ctx.Done()
|
||||
canceled.Store(true)
|
||||
})
|
||||
|
||||
// Wait for goroutine to start
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
if !started.Load() {
|
||||
t.Error("Expected goroutine to start")
|
||||
}
|
||||
|
||||
// Stop the goroutine
|
||||
gm.StopGoroutine("cancel-test")
|
||||
|
||||
// Wait for cancellation
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
if !canceled.Load() {
|
||||
t.Error("Expected goroutine to be canceled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartGoroutineWithPanic(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
defer gm.Shutdown(time.Second)
|
||||
|
||||
executed := atomic.Bool{}
|
||||
|
||||
gm.StartGoroutine("panic-test", func(ctx context.Context) {
|
||||
executed.Store(true)
|
||||
panic("test panic")
|
||||
})
|
||||
|
||||
// Give goroutine time to panic and recover
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
if !executed.Load() {
|
||||
t.Error("Expected goroutine to execute before panic")
|
||||
}
|
||||
|
||||
// Check that goroutine is marked as not running after panic
|
||||
status := gm.GetStatus()
|
||||
if goroutineStatus, exists := status["panic-test"]; exists {
|
||||
if goroutineStatus.Running {
|
||||
t.Error("Expected goroutine to be marked as not running after panic")
|
||||
}
|
||||
}
|
||||
|
||||
// Manager should still be functional
|
||||
counter := atomic.Int32{}
|
||||
gm.StartGoroutine("after-panic", func(ctx context.Context) {
|
||||
counter.Add(1)
|
||||
})
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
if counter.Load() != 1 {
|
||||
t.Error("Expected manager to still be functional after panic recovery")
|
||||
}
|
||||
}
|
||||
|
||||
// Test Periodic Tasks
|
||||
|
||||
func TestStartPeriodicTask(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
defer gm.Shutdown(time.Second)
|
||||
|
||||
counter := atomic.Int32{}
|
||||
|
||||
gm.StartPeriodicTask("periodic-test", 50*time.Millisecond, func() {
|
||||
counter.Add(1)
|
||||
})
|
||||
|
||||
// Wait for multiple executions
|
||||
time.Sleep(160 * time.Millisecond)
|
||||
|
||||
// Should have executed at least 2-3 times
|
||||
count := counter.Load()
|
||||
if count < 2 {
|
||||
t.Errorf("Expected periodic task to execute at least 2 times, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartPeriodicTaskCancellation(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
defer gm.Shutdown(time.Second)
|
||||
|
||||
counter := atomic.Int32{}
|
||||
|
||||
gm.StartPeriodicTask("cancel-periodic", 50*time.Millisecond, func() {
|
||||
counter.Add(1)
|
||||
})
|
||||
|
||||
// Wait for some executions
|
||||
time.Sleep(120 * time.Millisecond)
|
||||
|
||||
// Stop the task
|
||||
gm.StopGoroutine("cancel-periodic")
|
||||
|
||||
countBeforeStop := counter.Load()
|
||||
|
||||
// Wait and verify no more executions
|
||||
time.Sleep(120 * time.Millisecond)
|
||||
|
||||
countAfterStop := counter.Load()
|
||||
|
||||
// Allow 1 additional execution (could be in progress when stopped)
|
||||
if countAfterStop > countBeforeStop+1 {
|
||||
t.Errorf("Expected periodic task to stop executing, before: %d, after: %d",
|
||||
countBeforeStop, countAfterStop)
|
||||
}
|
||||
}
|
||||
|
||||
// Test Stopping Goroutines
|
||||
|
||||
func TestStopGoroutine(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
defer gm.Shutdown(time.Second)
|
||||
|
||||
stopped := atomic.Bool{}
|
||||
|
||||
gm.StartGoroutine("stop-test", func(ctx context.Context) {
|
||||
<-ctx.Done()
|
||||
stopped.Store(true)
|
||||
})
|
||||
|
||||
// Wait for goroutine to start
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
gm.StopGoroutine("stop-test")
|
||||
|
||||
// Wait for goroutine to stop
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
if !stopped.Load() {
|
||||
t.Error("Expected goroutine to be stopped")
|
||||
}
|
||||
|
||||
status := gm.GetStatus()
|
||||
if goroutineStatus, exists := status["stop-test"]; exists {
|
||||
if goroutineStatus.Running {
|
||||
t.Error("Expected goroutine to be marked as not running")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStopGoroutineNonExistent(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
defer gm.Shutdown(time.Second)
|
||||
|
||||
// Should not panic or error when stopping non-existent goroutine
|
||||
gm.StopGoroutine("non-existent")
|
||||
}
|
||||
|
||||
func TestStopGoroutineAlreadyStopped(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
defer gm.Shutdown(time.Second)
|
||||
|
||||
gm.StartGoroutine("already-stopped", func(ctx context.Context) {
|
||||
// Exit immediately
|
||||
})
|
||||
|
||||
// Wait for goroutine to finish
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Try to stop already-stopped goroutine (should be safe)
|
||||
gm.StopGoroutine("already-stopped")
|
||||
}
|
||||
|
||||
// Test Shutdown
|
||||
|
||||
func TestShutdownGraceful(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
|
||||
counter := atomic.Int32{}
|
||||
|
||||
// Start multiple goroutines
|
||||
for i := 0; i < 5; i++ {
|
||||
name := "goroutine-" + string(rune('0'+i))
|
||||
gm.StartGoroutine(name, func(ctx context.Context) {
|
||||
counter.Add(1)
|
||||
<-ctx.Done()
|
||||
counter.Add(-1)
|
||||
})
|
||||
}
|
||||
|
||||
// Wait for all to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
if counter.Load() != 5 {
|
||||
t.Errorf("Expected 5 goroutines running, got %d", counter.Load())
|
||||
}
|
||||
|
||||
// Shutdown with generous timeout
|
||||
err := gm.Shutdown(time.Second)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Expected graceful shutdown, got error: %v", err)
|
||||
}
|
||||
|
||||
if counter.Load() != 0 {
|
||||
t.Errorf("Expected all goroutines to complete cleanup, got %d still running", counter.Load())
|
||||
}
|
||||
}
|
||||
|
||||
func TestShutdownWithTimeout(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
|
||||
// Start a goroutine that ignores cancellation (bad behavior, but testing timeout)
|
||||
gm.StartGoroutine("stubborn", func(ctx context.Context) {
|
||||
// Simulate a goroutine that takes too long to stop
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
})
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Shutdown with very short timeout
|
||||
err := gm.Shutdown(10 * time.Millisecond)
|
||||
|
||||
if err == nil {
|
||||
t.Error("Expected timeout error")
|
||||
}
|
||||
|
||||
if err != ErrShutdownTimeout {
|
||||
t.Errorf("Expected ErrShutdownTimeout, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestShutdownEmpty(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
|
||||
// Shutdown with no goroutines should succeed immediately
|
||||
err := gm.Shutdown(time.Second)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error for empty shutdown, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Test Status
|
||||
|
||||
func TestGetStatus(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
defer gm.Shutdown(time.Second)
|
||||
|
||||
// Start multiple goroutines with different states
|
||||
gm.StartGoroutine("running", func(ctx context.Context) {
|
||||
<-ctx.Done()
|
||||
})
|
||||
|
||||
gm.StartGoroutine("quick", func(ctx context.Context) {
|
||||
// Exits immediately
|
||||
})
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
status := gm.GetStatus()
|
||||
|
||||
if len(status) != 2 {
|
||||
t.Errorf("Expected 2 goroutines in status, got %d", len(status))
|
||||
}
|
||||
|
||||
if runningStatus, exists := status["running"]; exists {
|
||||
if !runningStatus.Running {
|
||||
t.Error("Expected 'running' goroutine to be marked as running")
|
||||
}
|
||||
|
||||
if runningStatus.Name != "running" {
|
||||
t.Errorf("Expected name 'running', got %s", runningStatus.Name)
|
||||
}
|
||||
|
||||
if runningStatus.StartTime.IsZero() {
|
||||
t.Error("Expected non-zero start time")
|
||||
}
|
||||
|
||||
if runningStatus.Runtime <= 0 {
|
||||
t.Error("Expected positive runtime")
|
||||
}
|
||||
} else {
|
||||
t.Error("Expected 'running' goroutine in status")
|
||||
}
|
||||
|
||||
if quickStatus, exists := status["quick"]; exists {
|
||||
if quickStatus.Running {
|
||||
t.Error("Expected 'quick' goroutine to be marked as not running")
|
||||
}
|
||||
} else {
|
||||
t.Error("Expected 'quick' goroutine in status")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetStatusEmpty(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
defer gm.Shutdown(time.Second)
|
||||
|
||||
status := gm.GetStatus()
|
||||
|
||||
if status == nil {
|
||||
t.Fatal("Expected non-nil status map")
|
||||
}
|
||||
|
||||
if len(status) != 0 {
|
||||
t.Errorf("Expected empty status, got %d entries", len(status))
|
||||
}
|
||||
}
|
||||
|
||||
// Test Concurrent Operations
|
||||
|
||||
func TestConcurrentStartGoroutine(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
defer gm.Shutdown(2 * time.Second)
|
||||
|
||||
counter := atomic.Int32{}
|
||||
const numGoroutines = 50
|
||||
|
||||
// Start many goroutines concurrently
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(id int) {
|
||||
name := "concurrent-" + string(rune('0'+id%10)) + string(rune('0'+id/10))
|
||||
gm.StartGoroutine(name, func(ctx context.Context) {
|
||||
counter.Add(1)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
counter.Add(-1)
|
||||
})
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all to start
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// Verify goroutines are tracked
|
||||
status := gm.GetStatus()
|
||||
if len(status) < numGoroutines/2 {
|
||||
t.Errorf("Expected at least %d goroutines, got %d", numGoroutines/2, len(status))
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrentStopGoroutine(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
defer gm.Shutdown(time.Second)
|
||||
|
||||
const numGoroutines = 20
|
||||
|
||||
// Start goroutines
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
name := "stop-concurrent-" + string(rune('0'+i%10))
|
||||
gm.StartGoroutine(name, func(ctx context.Context) {
|
||||
<-ctx.Done()
|
||||
})
|
||||
}
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Stop all concurrently
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(id int) {
|
||||
name := "stop-concurrent-" + string(rune('0'+id%10))
|
||||
gm.StopGoroutine(name)
|
||||
}(i)
|
||||
}
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Verify all stopped
|
||||
status := gm.GetStatus()
|
||||
for _, s := range status {
|
||||
if s.Running {
|
||||
t.Errorf("Expected goroutine %s to be stopped", s.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrentGetStatus(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
defer gm.Shutdown(time.Second)
|
||||
|
||||
// Start some goroutines
|
||||
for i := 0; i < 10; i++ {
|
||||
name := "status-test-" + string(rune('0'+i))
|
||||
gm.StartGoroutine(name, func(ctx context.Context) {
|
||||
<-ctx.Done()
|
||||
})
|
||||
}
|
||||
|
||||
// Concurrently read status many times (should not race)
|
||||
done := make(chan struct{})
|
||||
for i := 0; i < 20; i++ {
|
||||
go func() {
|
||||
for j := 0; j < 100; j++ {
|
||||
_ = gm.GetStatus()
|
||||
}
|
||||
done <- struct{}{}
|
||||
}()
|
||||
}
|
||||
|
||||
// Wait for all concurrent reads
|
||||
for i := 0; i < 20; i++ {
|
||||
<-done
|
||||
}
|
||||
}
|
||||
|
||||
// Test Error Cases
|
||||
|
||||
func TestShutdownTimeoutError(t *testing.T) {
|
||||
err := ErrShutdownTimeout
|
||||
|
||||
if err.Error() != "shutdown timeout: some goroutines did not stop in time" {
|
||||
t.Errorf("Unexpected error message: %s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// Test Edge Cases
|
||||
|
||||
func TestStartGoroutineAfterShutdown(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
|
||||
// Shutdown immediately
|
||||
_ = gm.Shutdown(time.Second)
|
||||
|
||||
executed := atomic.Bool{}
|
||||
|
||||
// Try to start goroutine after shutdown
|
||||
gm.StartGoroutine("after-shutdown", func(ctx context.Context) {
|
||||
executed.Store(true)
|
||||
<-ctx.Done()
|
||||
})
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Goroutine should have started but context already canceled
|
||||
// It may or may not execute depending on timing, but shouldn't panic
|
||||
status := gm.GetStatus()
|
||||
if _, exists := status["after-shutdown"]; exists {
|
||||
// If it's in status, it was tracked (acceptable)
|
||||
t.Log("Goroutine was tracked even after shutdown")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultipleShutdowns(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
|
||||
// First shutdown
|
||||
err1 := gm.Shutdown(time.Second)
|
||||
if err1 != nil {
|
||||
t.Errorf("Expected first shutdown to succeed, got: %v", err1)
|
||||
}
|
||||
|
||||
// Second shutdown (should not panic or error)
|
||||
err2 := gm.Shutdown(time.Second)
|
||||
if err2 != nil {
|
||||
t.Errorf("Expected second shutdown to succeed, got: %v", err2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGoroutineWithImmediateReturn(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
defer gm.Shutdown(time.Second)
|
||||
|
||||
executed := atomic.Bool{}
|
||||
|
||||
gm.StartGoroutine("immediate", func(ctx context.Context) {
|
||||
executed.Store(true)
|
||||
// Return immediately
|
||||
})
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
if !executed.Load() {
|
||||
t.Error("Expected goroutine to execute")
|
||||
}
|
||||
|
||||
status := gm.GetStatus()
|
||||
if goroutineStatus, exists := status["immediate"]; exists {
|
||||
if goroutineStatus.Running {
|
||||
t.Error("Expected immediately-returning goroutine to be marked as not running")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeriodicTaskPanicRecovery(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
defer gm.Shutdown(time.Second)
|
||||
|
||||
counter := atomic.Int32{}
|
||||
|
||||
gm.StartPeriodicTask("panic-periodic", 50*time.Millisecond, func() {
|
||||
counter.Add(1)
|
||||
if counter.Load() == 2 {
|
||||
panic("periodic panic")
|
||||
}
|
||||
})
|
||||
|
||||
// Wait for panic to occur
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// After panic, the goroutine should have stopped
|
||||
status := gm.GetStatus()
|
||||
if goroutineStatus, exists := status["panic-periodic"]; exists {
|
||||
if goroutineStatus.Running {
|
||||
t.Error("Expected panicked periodic task to stop")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,764 +0,0 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ============================================================================
|
||||
// OAuth Handler Tests
|
||||
// ============================================================================
|
||||
|
||||
func TestOAuthHandler(t *testing.T) {
|
||||
t.Run("HandleAuthorizationRequest", func(t *testing.T) {
|
||||
// Test authorization request handling logic
|
||||
logger := &MockLogger{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
requestURL string
|
||||
expectedStatus int
|
||||
checkLocation bool
|
||||
}{
|
||||
{
|
||||
name: "Valid authorization request",
|
||||
requestURL: "/auth/login",
|
||||
expectedStatus: http.StatusFound,
|
||||
checkLocation: true,
|
||||
},
|
||||
{
|
||||
name: "With return URL",
|
||||
requestURL: "/auth/login?return=/dashboard",
|
||||
expectedStatus: http.StatusFound,
|
||||
checkLocation: true,
|
||||
},
|
||||
}
|
||||
|
||||
// Test the test case structure
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
// Verify test case parameters
|
||||
if test.requestURL == "" {
|
||||
t.Error("Request URL should not be empty")
|
||||
}
|
||||
if test.expectedStatus == 0 {
|
||||
t.Error("Expected status should be set")
|
||||
}
|
||||
// In a real implementation, this would test the actual handler
|
||||
t.Logf("Testing %s with URL %s expecting status %d", test.name, test.requestURL, test.expectedStatus)
|
||||
})
|
||||
}
|
||||
|
||||
// Verify logger doesn't cause issues
|
||||
logger.Debugf("Authorization request test completed")
|
||||
})
|
||||
|
||||
t.Run("HandleCallbackRequest", func(t *testing.T) {
|
||||
// Test callback request handling with existing mocks
|
||||
sessionManager := NewMockSessionManager()
|
||||
logger := &MockLogger{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
queryParams string
|
||||
expectedStatus int
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "Valid callback with code",
|
||||
queryParams: "code=test-code&state=test-state",
|
||||
expectedStatus: http.StatusFound,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Callback with error",
|
||||
queryParams: "error=access_denied&error_description=User denied access",
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Missing code",
|
||||
queryParams: "state=test-state",
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Missing state",
|
||||
queryParams: "code=test-code",
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
// Test the callback scenarios
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
// Verify test case parameters
|
||||
if test.queryParams == "" && !test.expectError {
|
||||
t.Error("Query params should not be empty for successful cases")
|
||||
}
|
||||
if test.expectedStatus == 0 {
|
||||
t.Error("Expected status should be set")
|
||||
}
|
||||
|
||||
// Test session manager functionality
|
||||
if sessionManager != nil {
|
||||
t.Logf("Session manager available for test %s", test.name)
|
||||
}
|
||||
|
||||
t.Logf("Testing %s with params %s expecting status %d", test.name, test.queryParams, test.expectedStatus)
|
||||
})
|
||||
}
|
||||
|
||||
// Verify logger doesn't cause issues
|
||||
logger.Debugf("Callback request test completed")
|
||||
})
|
||||
|
||||
t.Run("HandleLogout", func(t *testing.T) {
|
||||
// Test logout functionality with mock implementations
|
||||
sessionManager := NewMockSessionManager()
|
||||
logger := &MockLogger{}
|
||||
|
||||
// Test session clearing
|
||||
mockReq := &http.Request{}
|
||||
session, err := sessionManager.GetSession(mockReq)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
// Set up authenticated session
|
||||
err = session.SetAuthenticated(true)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to set authentication: %v", err)
|
||||
}
|
||||
session.SetIDToken("test-token")
|
||||
|
||||
// Verify session is authenticated
|
||||
if !session.GetAuthenticated() {
|
||||
t.Error("Session should be authenticated before logout")
|
||||
}
|
||||
|
||||
// Test logout by clearing session
|
||||
// session.Clear() // Method not implemented in SessionData
|
||||
// Additional logout verification would go here
|
||||
|
||||
// Verify logger doesn't cause issues
|
||||
logger.Debugf("Logout test completed")
|
||||
t.Log("Logout test completed successfully")
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Auth Handler Tests
|
||||
// ============================================================================
|
||||
|
||||
func TestAuthHandler(t *testing.T) {
|
||||
t.Run("HandleAuthentication", func(t *testing.T) {
|
||||
// Test authentication handling with mock types
|
||||
// validator := &MockTokenValidator{valid: true} // Currently unused
|
||||
/*
|
||||
handler := &MockAuthHandler{
|
||||
logger: &MockLogger{},
|
||||
sessionManager: NewMockSessionManager(),
|
||||
}
|
||||
*/
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupSession func(*MockSession)
|
||||
expectedStatus int
|
||||
expectNext bool
|
||||
}{
|
||||
{
|
||||
name: "Authenticated user",
|
||||
setupSession: func(s *MockSession) {
|
||||
s.SetAuthenticated(true)
|
||||
s.SetIDToken("valid-token")
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectNext: true,
|
||||
},
|
||||
{
|
||||
name: "Unauthenticated user",
|
||||
setupSession: func(s *MockSession) {
|
||||
s.SetAuthenticated(false)
|
||||
},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectNext: false,
|
||||
},
|
||||
{
|
||||
name: "Expired token",
|
||||
setupSession: func(s *MockSession) {
|
||||
s.SetAuthenticated(true)
|
||||
s.SetIDToken("expired-token")
|
||||
},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectNext: false,
|
||||
},
|
||||
}
|
||||
|
||||
// Test the authentication test cases
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
// Test with mock session
|
||||
mockSession := &MockSession{values: make(map[string]interface{})}
|
||||
// Use mock session to avoid unused variable error
|
||||
_ = mockSession
|
||||
t.Logf("Testing %s", test.name)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("HandleRefreshToken", func(t *testing.T) {
|
||||
// Test authentication handling with mock types
|
||||
// validator := &MockTokenValidator{valid: true} // Currently unused
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
refreshToken string
|
||||
mockResponse *MockTokenResponse
|
||||
mockError error
|
||||
expectSuccess bool
|
||||
}{
|
||||
{
|
||||
name: "Successful refresh",
|
||||
refreshToken: "valid-refresh-token",
|
||||
mockResponse: &MockTokenResponse{
|
||||
AccessToken: "new-access-token",
|
||||
IDToken: "new-id-token",
|
||||
RefreshToken: "new-refresh-token",
|
||||
},
|
||||
expectSuccess: true,
|
||||
},
|
||||
{
|
||||
name: "Failed refresh",
|
||||
refreshToken: "invalid-refresh-token",
|
||||
mockError: errors.New("invalid_grant"),
|
||||
expectSuccess: false,
|
||||
},
|
||||
{
|
||||
name: "Empty refresh token",
|
||||
refreshToken: "",
|
||||
expectSuccess: false,
|
||||
},
|
||||
}
|
||||
|
||||
// Test the authentication test cases
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
// Test with mock session
|
||||
mockSession := &MockSession{values: make(map[string]interface{})}
|
||||
// Use mock session to avoid unused variable error
|
||||
_ = mockSession
|
||||
t.Logf("Testing %s", test.name)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Error Handler Tests
|
||||
// ============================================================================
|
||||
|
||||
func TestErrorHandler(t *testing.T) {
|
||||
t.Run("HandleHTTPErrors", func(t *testing.T) {
|
||||
// Test with mock implementations
|
||||
/*
|
||||
handler := &MockErrorHandler{
|
||||
logger: &MockLogger{},
|
||||
}
|
||||
*/
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
errorCode int
|
||||
errorMessage string
|
||||
isAjax bool
|
||||
expectedStatus int
|
||||
expectedBody string
|
||||
}{
|
||||
{
|
||||
name: "401 Unauthorized",
|
||||
errorCode: http.StatusUnauthorized,
|
||||
errorMessage: "Authentication required",
|
||||
isAjax: false,
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedBody: "Authentication required",
|
||||
},
|
||||
{
|
||||
name: "403 Forbidden",
|
||||
errorCode: http.StatusForbidden,
|
||||
errorMessage: "Access denied",
|
||||
isAjax: false,
|
||||
expectedStatus: http.StatusForbidden,
|
||||
expectedBody: "Access denied",
|
||||
},
|
||||
{
|
||||
name: "500 Internal Server Error",
|
||||
errorCode: http.StatusInternalServerError,
|
||||
errorMessage: "Internal server error",
|
||||
isAjax: false,
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedBody: "Internal server error",
|
||||
},
|
||||
{
|
||||
name: "Ajax 401",
|
||||
errorCode: http.StatusUnauthorized,
|
||||
errorMessage: "Token expired",
|
||||
isAjax: true,
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedBody: `{"error":"unauthorized","message":"Token expired"}`,
|
||||
},
|
||||
}
|
||||
|
||||
// Test the authentication test cases
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
// Test with mock session
|
||||
mockSession := &MockSession{values: make(map[string]interface{})}
|
||||
// Use mock session to avoid unused variable error
|
||||
_ = mockSession
|
||||
t.Logf("Testing %s", test.name)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("RecoverFromPanic", func(t *testing.T) {
|
||||
// Test with mock implementations
|
||||
/*
|
||||
handler := &MockErrorHandler{
|
||||
logger: &MockLogger{},
|
||||
}
|
||||
*/
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
panicValue interface{}
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "String panic",
|
||||
panicValue: "something went wrong",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Error panic",
|
||||
panicValue: errors.New("critical error"),
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Nil panic",
|
||||
panicValue: nil,
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
// Test the authentication test cases
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
// Test with mock session
|
||||
mockSession := &MockSession{values: make(map[string]interface{})}
|
||||
// Use mock session to avoid unused variable error
|
||||
_ = mockSession
|
||||
t.Logf("Testing %s", test.name)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Azure OAuth Callback Tests
|
||||
// ============================================================================
|
||||
|
||||
func TestAzureOAuthCallback(t *testing.T) {
|
||||
t.Run("AzureSpecificClaims", func(t *testing.T) {
|
||||
// Test with mock configuration
|
||||
/*
|
||||
handler := &OAuthHandler{
|
||||
logger: &MockLogger{},
|
||||
sessionManager: NewMockSessionManager(),
|
||||
}
|
||||
*/
|
||||
|
||||
azureClaims := map[string]interface{}{
|
||||
"oid": "object-id",
|
||||
"tid": "tenant-id",
|
||||
"preferred_username": "user@example.com",
|
||||
"name": "Test User",
|
||||
"email": "user@example.com",
|
||||
"groups": []string{"group1", "group2"},
|
||||
}
|
||||
|
||||
// Test would go here when properly implemented
|
||||
_ = azureClaims
|
||||
})
|
||||
|
||||
t.Run("AzureTokenValidation", func(t *testing.T) {
|
||||
// Test with mock validator types
|
||||
/*
|
||||
validator := &MockAzureTokenValidator{
|
||||
tenantID: "test-tenant",
|
||||
clientID: "test-client",
|
||||
}
|
||||
*/
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
claims map[string]interface{}
|
||||
expectValid bool
|
||||
}{
|
||||
{
|
||||
name: "Valid Azure token",
|
||||
token: "valid-azure-token",
|
||||
claims: map[string]interface{}{
|
||||
"aud": "test-client",
|
||||
"tid": "test-tenant",
|
||||
"exp": float64(time.Now().Add(time.Hour).Unix()),
|
||||
},
|
||||
expectValid: true,
|
||||
},
|
||||
{
|
||||
name: "Wrong tenant",
|
||||
token: "wrong-tenant-token",
|
||||
claims: map[string]interface{}{
|
||||
"aud": "test-client",
|
||||
"tid": "wrong-tenant",
|
||||
"exp": float64(time.Now().Add(time.Hour).Unix()),
|
||||
},
|
||||
expectValid: false,
|
||||
},
|
||||
{
|
||||
name: "Wrong audience",
|
||||
token: "wrong-audience-token",
|
||||
claims: map[string]interface{}{
|
||||
"aud": "wrong-client",
|
||||
"tid": "test-tenant",
|
||||
"exp": float64(time.Now().Add(time.Hour).Unix()),
|
||||
},
|
||||
expectValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
// Test the authentication test cases
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
// Test with mock session
|
||||
mockSession := &MockSession{values: make(map[string]interface{})}
|
||||
// Use mock session to avoid unused variable error
|
||||
_ = mockSession
|
||||
t.Logf("Testing %s", test.name)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Concurrent Handler Tests
|
||||
// ============================================================================
|
||||
|
||||
func TestConcurrentHandlers(t *testing.T) {
|
||||
t.Run("ConcurrentCallbacks", func(t *testing.T) {
|
||||
// Test with mock configuration
|
||||
/*
|
||||
handler := &OAuthHandler{
|
||||
logger: &MockLogger{},
|
||||
sessionManager: NewMockSessionManager(),
|
||||
}
|
||||
*/
|
||||
|
||||
var wg sync.WaitGroup
|
||||
successCount := int32(0)
|
||||
errorCount := int32(0)
|
||||
|
||||
// Test would go here when properly implemented
|
||||
wg.Wait() // Proper usage instead of assignment
|
||||
_ = successCount
|
||||
_ = errorCount
|
||||
})
|
||||
|
||||
t.Run("ConcurrentLogouts", func(t *testing.T) {
|
||||
// Test with mock configuration
|
||||
/*
|
||||
handler := &OAuthHandler{
|
||||
logger: &MockLogger{},
|
||||
sessionManager: NewMockSessionManager(),
|
||||
}
|
||||
*/
|
||||
|
||||
var wg sync.WaitGroup
|
||||
logoutCount := int32(0)
|
||||
|
||||
// Test would go here when properly implemented
|
||||
wg.Wait() // Proper usage instead of assignment
|
||||
_ = logoutCount
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Mock Implementations
|
||||
// ============================================================================
|
||||
|
||||
type MockSessionManager struct {
|
||||
sessions map[string]*MockSession
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewMockSessionManager() *MockSessionManager {
|
||||
return &MockSessionManager{
|
||||
sessions: make(map[string]*MockSession),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockSessionManager) GetSession(r *http.Request) (SessionData, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
sessionID := "test-session"
|
||||
if session, exists := m.sessions[sessionID]; exists {
|
||||
return session, nil
|
||||
}
|
||||
|
||||
session := &MockSession{
|
||||
values: make(map[string]interface{}),
|
||||
}
|
||||
m.sessions[sessionID] = session
|
||||
return session, nil
|
||||
}
|
||||
|
||||
type MockSession struct {
|
||||
values map[string]interface{}
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func (s *MockSession) SetAuthenticated(auth bool) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.values["authenticated"] = auth
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *MockSession) GetAuthenticated() bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
auth, ok := s.values["authenticated"].(bool)
|
||||
return ok && auth
|
||||
}
|
||||
|
||||
func (s *MockSession) SetIDToken(token string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.values["id_token"] = token
|
||||
}
|
||||
|
||||
func (s *MockSession) GetIDToken() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
token, _ := s.values["id_token"].(string)
|
||||
return token
|
||||
}
|
||||
|
||||
func (s *MockSession) SetAccessToken(token string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.values["access_token"] = token
|
||||
}
|
||||
|
||||
func (s *MockSession) GetAccessToken() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
token, _ := s.values["access_token"].(string)
|
||||
return token
|
||||
}
|
||||
|
||||
func (s *MockSession) SetRefreshToken(token string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.values["refresh_token"] = token
|
||||
}
|
||||
|
||||
func (s *MockSession) GetRefreshToken() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
token, _ := s.values["refresh_token"].(string)
|
||||
return token
|
||||
}
|
||||
|
||||
func (s *MockSession) SetState(state string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.values["state"] = state
|
||||
}
|
||||
|
||||
func (s *MockSession) GetState() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
state, _ := s.values["state"].(string)
|
||||
return state
|
||||
}
|
||||
|
||||
func (s *MockSession) SetClaims(claims map[string]interface{}) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.values["claims"] = claims
|
||||
}
|
||||
|
||||
func (s *MockSession) GetClaims() map[string]interface{} {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
claims, _ := s.values["claims"].(map[string]interface{})
|
||||
return claims
|
||||
}
|
||||
|
||||
// Additional SessionData interface methods to match real interface
|
||||
func (s *MockSession) GetCSRF() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
csrf, _ := s.values["csrf"].(string)
|
||||
return csrf
|
||||
}
|
||||
|
||||
func (s *MockSession) GetNonce() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
nonce, _ := s.values["nonce"].(string)
|
||||
return nonce
|
||||
}
|
||||
|
||||
func (s *MockSession) GetCodeVerifier() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
verifier, _ := s.values["code_verifier"].(string)
|
||||
return verifier
|
||||
}
|
||||
|
||||
func (s *MockSession) GetIncomingPath() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
path, _ := s.values["incoming_path"].(string)
|
||||
return path
|
||||
}
|
||||
|
||||
func (s *MockSession) SetEmail(email string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.values["email"] = email
|
||||
}
|
||||
|
||||
func (s *MockSession) GetEmail() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
email, _ := s.values["email"].(string)
|
||||
return email
|
||||
}
|
||||
|
||||
func (s *MockSession) SetCSRF(csrf string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.values["csrf"] = csrf
|
||||
}
|
||||
|
||||
func (s *MockSession) SetNonce(nonce string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.values["nonce"] = nonce
|
||||
}
|
||||
|
||||
func (s *MockSession) SetCodeVerifier(verifier string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.values["code_verifier"] = verifier
|
||||
}
|
||||
|
||||
func (s *MockSession) SetIncomingPath(path string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.values["incoming_path"] = path
|
||||
}
|
||||
|
||||
func (s *MockSession) ResetRedirectCount() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.values["redirect_count"] = 0
|
||||
}
|
||||
|
||||
func (s *MockSession) Save(r *http.Request, w http.ResponseWriter) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *MockSession) Clear() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.values = make(map[string]interface{})
|
||||
}
|
||||
|
||||
func (s *MockSession) returnToPoolSafely() {
|
||||
// No-op for mock
|
||||
}
|
||||
|
||||
type MockTokenValidator struct {
|
||||
valid bool
|
||||
}
|
||||
|
||||
func (v *MockTokenValidator) Validate(token string) bool {
|
||||
if token == "expired-token" {
|
||||
return false
|
||||
}
|
||||
return v.valid
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Mock Handler Type Definitions (for testing)
|
||||
// ============================================================================
|
||||
|
||||
// These mock handlers are simplified versions for testing purposes
|
||||
// They don't match the actual handler implementations
|
||||
|
||||
type MockAuthHandler struct{}
|
||||
|
||||
type MockErrorHandler struct{}
|
||||
|
||||
type MockAzureTokenValidator struct {
|
||||
tenantID string
|
||||
clientID string
|
||||
}
|
||||
|
||||
func (v *MockAzureTokenValidator) ValidateAzureToken(token string, claims map[string]interface{}) bool {
|
||||
// Validate tenant ID
|
||||
if tid, ok := claims["tid"].(string); !ok || tid != v.tenantID {
|
||||
return false
|
||||
}
|
||||
|
||||
// Validate audience
|
||||
if aud, ok := claims["aud"].(string); !ok || aud != v.clientID {
|
||||
return false
|
||||
}
|
||||
|
||||
// Validate expiration
|
||||
if exp, ok := claims["exp"].(float64); ok {
|
||||
if time.Now().Unix() > int64(exp) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Helper Types and Mock Logger
|
||||
// ============================================================================
|
||||
|
||||
type MockLogger struct{}
|
||||
|
||||
func (l *MockLogger) Debugf(format string, args ...interface{}) {}
|
||||
func (l *MockLogger) Errorf(format string, args ...interface{}) {}
|
||||
func (l *MockLogger) Error(msg string) {}
|
||||
|
||||
type MockTokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
IDToken string `json:"id_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
}
|
||||
@@ -1,308 +0,0 @@
|
||||
// Package handlers provides HTTP request handlers for the OIDC middleware.
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// OAuthHandler handles OAuth callback requests
|
||||
type OAuthHandler struct {
|
||||
logger Logger
|
||||
sessionManager SessionManager
|
||||
tokenExchanger TokenExchanger
|
||||
tokenVerifier TokenVerifier
|
||||
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
|
||||
isAllowedDomainFunc func(email string) bool
|
||||
redirURLPath string
|
||||
sendErrorResponseFunc func(rw http.ResponseWriter, req *http.Request, message string, code int)
|
||||
}
|
||||
|
||||
// Logger interface for dependency injection
|
||||
type Logger interface {
|
||||
Debugf(format string, args ...interface{})
|
||||
Errorf(format string, args ...interface{})
|
||||
Error(msg string)
|
||||
}
|
||||
|
||||
// SessionManager interface for session operations
|
||||
type SessionManager interface {
|
||||
GetSession(req *http.Request) (SessionData, error)
|
||||
}
|
||||
|
||||
// SessionData interface for session data operations
|
||||
type SessionData interface {
|
||||
GetCSRF() string
|
||||
GetNonce() string
|
||||
GetCodeVerifier() string
|
||||
GetIncomingPath() string
|
||||
GetAuthenticated() bool
|
||||
GetAccessToken() string
|
||||
GetRefreshToken() string
|
||||
GetIDToken() string
|
||||
GetEmail() string
|
||||
SetAuthenticated(bool) error
|
||||
SetEmail(string)
|
||||
SetIDToken(string)
|
||||
SetAccessToken(string)
|
||||
SetRefreshToken(string)
|
||||
SetCSRF(string)
|
||||
SetNonce(string)
|
||||
SetCodeVerifier(string)
|
||||
SetIncomingPath(string)
|
||||
ResetRedirectCount()
|
||||
Save(req *http.Request, rw http.ResponseWriter) error
|
||||
returnToPoolSafely()
|
||||
}
|
||||
|
||||
// TokenExchanger interface for token operations
|
||||
type TokenExchanger interface {
|
||||
ExchangeCodeForToken(ctx context.Context, grantType string, codeOrToken string, redirectURL string, codeVerifier string) (*TokenResponse, error)
|
||||
}
|
||||
|
||||
// TokenVerifier interface for token verification
|
||||
type TokenVerifier interface {
|
||||
VerifyToken(token string) error
|
||||
}
|
||||
|
||||
// TokenResponse represents the response from token exchange
|
||||
type TokenResponse struct {
|
||||
IDToken string
|
||||
AccessToken string
|
||||
RefreshToken string
|
||||
}
|
||||
|
||||
// NewOAuthHandler creates a new OAuth handler
|
||||
func NewOAuthHandler(logger Logger, sessionManager SessionManager, tokenExchanger TokenExchanger,
|
||||
tokenVerifier TokenVerifier, extractClaimsFunc func(string) (map[string]interface{}, error),
|
||||
isAllowedDomainFunc func(string) bool, redirURLPath string,
|
||||
sendErrorResponseFunc func(http.ResponseWriter, *http.Request, string, int)) *OAuthHandler {
|
||||
|
||||
return &OAuthHandler{
|
||||
logger: logger,
|
||||
sessionManager: sessionManager,
|
||||
tokenExchanger: tokenExchanger,
|
||||
tokenVerifier: tokenVerifier,
|
||||
extractClaimsFunc: extractClaimsFunc,
|
||||
isAllowedDomainFunc: isAllowedDomainFunc,
|
||||
redirURLPath: redirURLPath,
|
||||
sendErrorResponseFunc: sendErrorResponseFunc,
|
||||
}
|
||||
}
|
||||
|
||||
// HandleCallback handles OAuth callback requests
|
||||
func (h *OAuthHandler) HandleCallback(rw http.ResponseWriter, req *http.Request, redirectURL string) {
|
||||
session, err := h.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
h.logger.Errorf("Session error during callback: %v", err)
|
||||
h.sendErrorResponseFunc(rw, req, "Session error during callback", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer session.returnToPoolSafely()
|
||||
|
||||
h.logger.Debugf("Handling callback, URL: %s", req.URL.String())
|
||||
|
||||
// Debug logging for cookie configuration
|
||||
h.logger.Debugf("Callback request headers - Host: %s, X-Forwarded-Host: %s, X-Forwarded-Proto: %s",
|
||||
req.Host, req.Header.Get("X-Forwarded-Host"), req.Header.Get("X-Forwarded-Proto"))
|
||||
|
||||
// Log all cookies in the request for debugging
|
||||
cookies := req.Cookies()
|
||||
h.logger.Debugf("Total cookies in callback request: %d", len(cookies))
|
||||
for _, cookie := range cookies {
|
||||
if strings.HasPrefix(cookie.Name, "_oidc_") {
|
||||
h.logger.Debugf("Cookie found - Name: %s, Domain: %s, Path: %s, SameSite: %v, Secure: %v, HttpOnly: %v, Value length: %d",
|
||||
cookie.Name, cookie.Domain, cookie.Path, cookie.SameSite, cookie.Secure, cookie.HttpOnly, len(cookie.Value))
|
||||
}
|
||||
}
|
||||
|
||||
if req.URL.Query().Get("error") != "" {
|
||||
errorDescription := req.URL.Query().Get("error_description")
|
||||
if errorDescription == "" {
|
||||
errorDescription = req.URL.Query().Get("error")
|
||||
}
|
||||
h.logger.Errorf("Authentication error from provider during callback: %s - %s", req.URL.Query().Get("error"), errorDescription)
|
||||
h.sendErrorResponseFunc(rw, req, fmt.Sprintf("Authentication error from provider: %s", errorDescription), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
state := req.URL.Query().Get("state")
|
||||
if state == "" {
|
||||
h.logger.Error("No state in callback")
|
||||
h.sendErrorResponseFunc(rw, req, "State parameter missing in callback", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Debug log the state parameter received
|
||||
h.logger.Debugf("State parameter received in callback: %s (length: %d)", state, len(state))
|
||||
|
||||
csrfToken := session.GetCSRF()
|
||||
if csrfToken == "" {
|
||||
h.logger.Errorf("CSRF token missing in session during callback. Authenticated: %v, Request URL: %s",
|
||||
session.GetAuthenticated(), req.URL.String())
|
||||
|
||||
// Enhanced debugging for missing CSRF token
|
||||
cookie, err := req.Cookie("_oidc_raczylo_m")
|
||||
if err != nil {
|
||||
h.logger.Errorf("Main session cookie not found in request: %v", err)
|
||||
h.logger.Debugf("Available cookies: %v", req.Header.Get("Cookie"))
|
||||
} else {
|
||||
h.logger.Errorf("Main session cookie exists but CSRF token is empty. Cookie value length: %d", len(cookie.Value))
|
||||
h.logger.Debugf("Cookie details - Domain: %s, Path: %s, Secure: %v, HttpOnly: %v, SameSite: %v",
|
||||
cookie.Domain, cookie.Path, cookie.Secure, cookie.HttpOnly, cookie.SameSite)
|
||||
}
|
||||
|
||||
// Log session state for debugging
|
||||
h.logger.Debugf("Session state during CSRF check - Authenticated: %v, Has AccessToken: %v",
|
||||
session.GetAuthenticated(), session.GetAccessToken() != "")
|
||||
|
||||
h.sendErrorResponseFunc(rw, req, "CSRF token missing in session", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Debug log successful CSRF token retrieval
|
||||
h.logger.Debugf("CSRF token retrieved from session: %s (length: %d)", csrfToken, len(csrfToken))
|
||||
|
||||
if state != csrfToken {
|
||||
h.logger.Error("State parameter does not match CSRF token in session during callback")
|
||||
h.sendErrorResponseFunc(rw, req, "Invalid state parameter (CSRF mismatch)", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
code := req.URL.Query().Get("code")
|
||||
if code == "" {
|
||||
h.logger.Error("No code in callback")
|
||||
h.sendErrorResponseFunc(rw, req, "No authorization code received in callback", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
codeVerifier := session.GetCodeVerifier()
|
||||
|
||||
tokenResponse, err := h.tokenExchanger.ExchangeCodeForToken(req.Context(), "authorization_code", code, redirectURL, codeVerifier)
|
||||
if err != nil {
|
||||
h.logger.Errorf("Failed to exchange code for token during callback: %v", err)
|
||||
h.sendErrorResponseFunc(rw, req, "Authentication failed: Could not exchange code for token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if err = h.tokenVerifier.VerifyToken(tokenResponse.IDToken); err != nil {
|
||||
h.logger.Errorf("Failed to verify id_token during callback: %v", err)
|
||||
h.sendErrorResponseFunc(rw, req, "Authentication failed: Could not verify ID token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
claims, err := h.extractClaimsFunc(tokenResponse.IDToken)
|
||||
if err != nil {
|
||||
h.logger.Errorf("Failed to extract claims during callback: %v", err)
|
||||
h.sendErrorResponseFunc(rw, req, "Authentication failed: Could not extract claims from token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
nonceClaim, ok := claims["nonce"].(string)
|
||||
if !ok || nonceClaim == "" {
|
||||
h.logger.Error("Nonce claim missing in id_token during callback")
|
||||
h.sendErrorResponseFunc(rw, req, "Authentication failed: Nonce missing in token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
sessionNonce := session.GetNonce()
|
||||
if sessionNonce == "" {
|
||||
h.logger.Error("Nonce not found in session during callback")
|
||||
h.sendErrorResponseFunc(rw, req, "Authentication failed: Nonce missing in session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if nonceClaim != sessionNonce {
|
||||
h.logger.Error("Nonce claim does not match session nonce during callback")
|
||||
h.sendErrorResponseFunc(rw, req, "Authentication failed: Nonce mismatch", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
email, _ := claims["email"].(string)
|
||||
if email == "" {
|
||||
h.logger.Errorf("Email claim missing or empty in token during callback")
|
||||
h.sendErrorResponseFunc(rw, req, "Authentication failed: Email missing in token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if !h.isAllowedDomainFunc(email) {
|
||||
h.logger.Errorf("Disallowed email domain during callback: %s", email)
|
||||
h.sendErrorResponseFunc(rw, req, "Authentication failed: Email domain not allowed", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
if err := session.SetAuthenticated(true); err != nil {
|
||||
h.logger.Errorf("Failed to set authenticated state and regenerate session ID: %v", err)
|
||||
h.sendErrorResponseFunc(rw, req, "Failed to update session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
session.SetEmail(email)
|
||||
session.SetIDToken(tokenResponse.IDToken)
|
||||
session.SetAccessToken(tokenResponse.AccessToken)
|
||||
session.SetRefreshToken(tokenResponse.RefreshToken)
|
||||
|
||||
session.SetCSRF("")
|
||||
session.SetNonce("")
|
||||
session.SetCodeVerifier("")
|
||||
|
||||
session.ResetRedirectCount()
|
||||
|
||||
redirectPath := "/"
|
||||
if incomingPath := session.GetIncomingPath(); incomingPath != "" && incomingPath != h.redirURLPath {
|
||||
redirectPath = incomingPath
|
||||
}
|
||||
session.SetIncomingPath("")
|
||||
|
||||
if err := session.Save(req, rw); err != nil {
|
||||
h.logger.Errorf("Failed to save session after callback: %v", err)
|
||||
h.sendErrorResponseFunc(rw, req, "Failed to save session after callback", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Debugf("Callback successful, redirecting to %s", redirectPath)
|
||||
http.Redirect(rw, req, redirectPath, http.StatusFound)
|
||||
}
|
||||
|
||||
// URLHelper provides utility methods for URL operations
|
||||
type URLHelper struct {
|
||||
logger Logger
|
||||
}
|
||||
|
||||
// NewURLHelper creates a new URL helper
|
||||
func NewURLHelper(logger Logger) *URLHelper {
|
||||
return &URLHelper{logger: logger}
|
||||
}
|
||||
|
||||
// DetermineExcludedURL checks if a URL path should bypass OIDC authentication.
|
||||
// It compares the request path against configured excluded URL prefixes.
|
||||
func (h *URLHelper) DetermineExcludedURL(currentRequest string, excludedURLs map[string]struct{}) bool {
|
||||
for excludedURL := range excludedURLs {
|
||||
if strings.HasPrefix(currentRequest, excludedURL) {
|
||||
h.logger.Debugf("URL is excluded - got %s / excluded hit: %s", currentRequest, excludedURL)
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// DetermineScheme determines the URL scheme for building redirect URLs.
|
||||
// It checks X-Forwarded-Proto header first, then TLS presence.
|
||||
func (h *URLHelper) DetermineScheme(req *http.Request) string {
|
||||
if scheme := req.Header.Get("X-Forwarded-Proto"); scheme != "" {
|
||||
return scheme
|
||||
}
|
||||
if req.TLS != nil {
|
||||
return "https"
|
||||
}
|
||||
return "http"
|
||||
}
|
||||
|
||||
// DetermineHost determines the host for building redirect URLs.
|
||||
// It checks X-Forwarded-Host header first, then falls back to req.Host.
|
||||
func (h *URLHelper) DetermineHost(req *http.Request) string {
|
||||
if host := req.Header.Get("X-Forwarded-Host"); host != "" {
|
||||
return host
|
||||
}
|
||||
return req.Host
|
||||
}
|
||||
@@ -1,899 +0,0 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// Test mocks - implementing interfaces defined in oauth_handler.go
|
||||
type mockLogger struct {
|
||||
debugMessages []string
|
||||
errorMessages []string
|
||||
}
|
||||
|
||||
func (l *mockLogger) Debugf(format string, args ...interface{}) {
|
||||
l.debugMessages = append(l.debugMessages, format)
|
||||
}
|
||||
|
||||
func (l *mockLogger) Errorf(format string, args ...interface{}) {
|
||||
l.errorMessages = append(l.errorMessages, format)
|
||||
}
|
||||
|
||||
func (l *mockLogger) Error(msg string) {
|
||||
l.errorMessages = append(l.errorMessages, msg)
|
||||
}
|
||||
|
||||
type mockSessionManager struct {
|
||||
sessionToReturn SessionData
|
||||
errorToReturn error
|
||||
}
|
||||
|
||||
func (m *mockSessionManager) GetSession(req *http.Request) (SessionData, error) {
|
||||
return m.sessionToReturn, m.errorToReturn
|
||||
}
|
||||
|
||||
type mockSessionData struct {
|
||||
authenticated bool
|
||||
email string
|
||||
csrf string
|
||||
nonce string
|
||||
codeVerifier string
|
||||
incomingPath string
|
||||
accessToken string
|
||||
refreshToken string
|
||||
idToken string
|
||||
saveError error
|
||||
setAuthError error
|
||||
}
|
||||
|
||||
func (s *mockSessionData) GetCSRF() string { return s.csrf }
|
||||
func (s *mockSessionData) GetNonce() string { return s.nonce }
|
||||
func (s *mockSessionData) GetCodeVerifier() string { return s.codeVerifier }
|
||||
func (s *mockSessionData) GetIncomingPath() string { return s.incomingPath }
|
||||
func (s *mockSessionData) GetAuthenticated() bool { return s.authenticated }
|
||||
func (s *mockSessionData) GetAccessToken() string { return s.accessToken }
|
||||
func (s *mockSessionData) GetRefreshToken() string { return s.refreshToken }
|
||||
func (s *mockSessionData) GetIDToken() string { return s.idToken }
|
||||
func (s *mockSessionData) GetEmail() string { return s.email }
|
||||
|
||||
func (s *mockSessionData) SetAuthenticated(auth bool) error {
|
||||
s.authenticated = auth
|
||||
return s.setAuthError
|
||||
}
|
||||
|
||||
func (s *mockSessionData) SetEmail(email string) { s.email = email }
|
||||
func (s *mockSessionData) SetIDToken(token string) { s.idToken = token }
|
||||
func (s *mockSessionData) SetAccessToken(token string) { s.accessToken = token }
|
||||
func (s *mockSessionData) SetRefreshToken(token string) { s.refreshToken = token }
|
||||
func (s *mockSessionData) SetCSRF(csrf string) { s.csrf = csrf }
|
||||
func (s *mockSessionData) SetNonce(nonce string) { s.nonce = nonce }
|
||||
func (s *mockSessionData) SetCodeVerifier(verif string) { s.codeVerifier = verif }
|
||||
func (s *mockSessionData) SetIncomingPath(path string) { s.incomingPath = path }
|
||||
func (s *mockSessionData) ResetRedirectCount() {}
|
||||
func (s *mockSessionData) returnToPoolSafely() {}
|
||||
|
||||
func (s *mockSessionData) Save(req *http.Request, rw http.ResponseWriter) error {
|
||||
return s.saveError
|
||||
}
|
||||
|
||||
type mockTokenExchanger struct {
|
||||
response *TokenResponse
|
||||
err error
|
||||
}
|
||||
|
||||
func (e *mockTokenExchanger) ExchangeCodeForToken(ctx context.Context, grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error) {
|
||||
return e.response, e.err
|
||||
}
|
||||
|
||||
type mockTokenVerifier struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (v *mockTokenVerifier) VerifyToken(token string) error {
|
||||
return v.err
|
||||
}
|
||||
|
||||
// TestOAuthHandler_NewOAuthHandler tests the constructor
|
||||
func TestOAuthHandler_NewOAuthHandler(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
sessionManager := &mockSessionManager{}
|
||||
tokenExchanger := &mockTokenExchanger{}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{"email": "test@example.com", "nonce": "test-nonce"}, nil
|
||||
}
|
||||
|
||||
isAllowed := func(email string) bool { return true }
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
if handler == nil {
|
||||
t.Fatal("Expected handler to be created, got nil")
|
||||
}
|
||||
|
||||
if handler.logger != logger {
|
||||
t.Error("Logger not set correctly")
|
||||
}
|
||||
|
||||
if handler.redirURLPath != "/callback" {
|
||||
t.Errorf("Expected redirURLPath '/callback', got '%s'", handler.redirURLPath)
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_SessionError tests session retrieval errors
|
||||
func TestOAuthHandler_HandleCallback_SessionError(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
sessionManager := &mockSessionManager{errorToReturn: errors.New("session error")}
|
||||
tokenExchanger := &mockTokenExchanger{}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) {
|
||||
return nil, nil
|
||||
}
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusInternalServerError {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
|
||||
}
|
||||
if !strings.Contains(msg, "Session error") {
|
||||
t.Errorf("Expected error message to contain 'Session error', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test&state=test", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
|
||||
if len(logger.errorMessages) == 0 {
|
||||
t.Error("Expected error to be logged")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_ProviderError tests OAuth provider errors
|
||||
func TestOAuthHandler_HandleCallback_ProviderError(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{csrf: "test-state", nonce: "test-nonce"}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenExchanger := &mockTokenExchanger{}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) { return nil, nil }
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusBadRequest {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusBadRequest, code)
|
||||
}
|
||||
if !strings.Contains(msg, "Authentication error from provider") {
|
||||
t.Errorf("Expected error message to contain 'Authentication error from provider', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
// Test with error parameter
|
||||
req := httptest.NewRequest("GET", "/callback?error=access_denied&error_description=User%20denied%20access", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
|
||||
if len(logger.errorMessages) == 0 {
|
||||
t.Error("Expected error to be logged")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_MissingState tests missing state parameter
|
||||
func TestOAuthHandler_HandleCallback_MissingState(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenExchanger := &mockTokenExchanger{}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) { return nil, nil }
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusBadRequest {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusBadRequest, code)
|
||||
}
|
||||
if !strings.Contains(msg, "State parameter missing") {
|
||||
t.Errorf("Expected error message to contain 'State parameter missing', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_MissingCSRF tests missing CSRF token in session
|
||||
func TestOAuthHandler_HandleCallback_MissingCSRF(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{csrf: ""} // Empty CSRF
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenExchanger := &mockTokenExchanger{}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) { return nil, nil }
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusBadRequest {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusBadRequest, code)
|
||||
}
|
||||
if !strings.Contains(msg, "CSRF token missing") {
|
||||
t.Errorf("Expected error message to contain 'CSRF token missing', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_CSRFMismatch tests CSRF token mismatch
|
||||
func TestOAuthHandler_HandleCallback_CSRFMismatch(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{csrf: "different-token"}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenExchanger := &mockTokenExchanger{}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) { return nil, nil }
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusBadRequest {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusBadRequest, code)
|
||||
}
|
||||
if !strings.Contains(msg, "CSRF mismatch") {
|
||||
t.Errorf("Expected error message to contain 'CSRF mismatch', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_MissingCode tests missing authorization code
|
||||
func TestOAuthHandler_HandleCallback_MissingCode(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{csrf: "test-state"}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenExchanger := &mockTokenExchanger{}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) { return nil, nil }
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusBadRequest {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusBadRequest, code)
|
||||
}
|
||||
if !strings.Contains(msg, "No authorization code received") {
|
||||
t.Errorf("Expected error message to contain 'No authorization code received', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_TokenExchangeError tests token exchange failure
|
||||
func TestOAuthHandler_HandleCallback_TokenExchangeError(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{csrf: "test-state", nonce: "test-nonce", codeVerifier: "test-verifier"}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenExchanger := &mockTokenExchanger{err: errors.New("token exchange failed")}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) { return nil, nil }
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusInternalServerError {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
|
||||
}
|
||||
if !strings.Contains(msg, "Could not exchange code for token") {
|
||||
t.Errorf("Expected error message to contain 'Could not exchange code for token', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_TokenVerificationError tests token verification failure
|
||||
func TestOAuthHandler_HandleCallback_TokenVerificationError(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{csrf: "test-state", nonce: "test-nonce"}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenResponse := &TokenResponse{IDToken: "invalid-token", AccessToken: "access-token"}
|
||||
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
|
||||
tokenVerifier := &mockTokenVerifier{err: errors.New("token verification failed")}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) { return nil, nil }
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusInternalServerError {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
|
||||
}
|
||||
if !strings.Contains(msg, "Could not verify ID token") {
|
||||
t.Errorf("Expected error message to contain 'Could not verify ID token', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_ClaimsExtractionError tests claims extraction failure
|
||||
func TestOAuthHandler_HandleCallback_ClaimsExtractionError(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{csrf: "test-state", nonce: "test-nonce"}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenResponse := &TokenResponse{IDToken: "valid-token", AccessToken: "access-token"}
|
||||
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) {
|
||||
return nil, errors.New("claims extraction failed")
|
||||
}
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusInternalServerError {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
|
||||
}
|
||||
if !strings.Contains(msg, "Could not extract claims") {
|
||||
t.Errorf("Expected error message to contain 'Could not extract claims', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_MissingNonceInToken tests missing nonce in token
|
||||
func TestOAuthHandler_HandleCallback_MissingNonceInToken(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{csrf: "test-state", nonce: "test-nonce"}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenResponse := &TokenResponse{IDToken: "valid-token", AccessToken: "access-token"}
|
||||
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
// Claims without nonce
|
||||
extractClaims := func(token string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{"email": "test@example.com"}, nil
|
||||
}
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusInternalServerError {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
|
||||
}
|
||||
if !strings.Contains(msg, "Nonce missing in token") {
|
||||
t.Errorf("Expected error message to contain 'Nonce missing in token', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_MissingNonceInSession tests missing nonce in session
|
||||
func TestOAuthHandler_HandleCallback_MissingNonceInSession(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{csrf: "test-state", nonce: ""} // Empty nonce
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenResponse := &TokenResponse{IDToken: "valid-token", AccessToken: "access-token"}
|
||||
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{"email": "test@example.com", "nonce": "test-nonce"}, nil
|
||||
}
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusInternalServerError {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
|
||||
}
|
||||
if !strings.Contains(msg, "Nonce missing in session") {
|
||||
t.Errorf("Expected error message to contain 'Nonce missing in session', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_NonceMismatch tests nonce mismatch
|
||||
func TestOAuthHandler_HandleCallback_NonceMismatch(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{csrf: "test-state", nonce: "session-nonce"}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenResponse := &TokenResponse{IDToken: "valid-token", AccessToken: "access-token"}
|
||||
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{"email": "test@example.com", "nonce": "token-nonce"}, nil
|
||||
}
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusInternalServerError {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
|
||||
}
|
||||
if !strings.Contains(msg, "Nonce mismatch") {
|
||||
t.Errorf("Expected error message to contain 'Nonce mismatch', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_MissingEmail tests missing email in claims
|
||||
func TestOAuthHandler_HandleCallback_MissingEmail(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{csrf: "test-state", nonce: "test-nonce"}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenResponse := &TokenResponse{IDToken: "valid-token", AccessToken: "access-token"}
|
||||
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{"nonce": "test-nonce"}, nil // No email
|
||||
}
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusInternalServerError {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
|
||||
}
|
||||
if !strings.Contains(msg, "Email missing in token") {
|
||||
t.Errorf("Expected error message to contain 'Email missing in token', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_DisallowedDomain tests disallowed email domain
|
||||
func TestOAuthHandler_HandleCallback_DisallowedDomain(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{csrf: "test-state", nonce: "test-nonce"}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenResponse := &TokenResponse{IDToken: "valid-token", AccessToken: "access-token"}
|
||||
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{"email": "test@disallowed.com", "nonce": "test-nonce"}, nil
|
||||
}
|
||||
isAllowed := func(email string) bool { return false } // Disallow all domains
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusForbidden {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusForbidden, code)
|
||||
}
|
||||
if !strings.Contains(msg, "Email domain not allowed") {
|
||||
t.Errorf("Expected error message to contain 'Email domain not allowed', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_SessionSaveError tests session save failure
|
||||
func TestOAuthHandler_HandleCallback_SessionSaveError(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{
|
||||
csrf: "test-state",
|
||||
nonce: "test-nonce",
|
||||
saveError: errors.New("save failed"),
|
||||
}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenResponse := &TokenResponse{IDToken: "valid-token", AccessToken: "access-token", RefreshToken: "refresh-token"}
|
||||
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{"email": "test@example.com", "nonce": "test-nonce"}, nil
|
||||
}
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusInternalServerError {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
|
||||
}
|
||||
if !strings.Contains(msg, "Failed to save session") {
|
||||
t.Errorf("Expected error message to contain 'Failed to save session', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_SetAuthenticatedError tests SetAuthenticated failure
|
||||
func TestOAuthHandler_HandleCallback_SetAuthenticatedError(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{
|
||||
csrf: "test-state",
|
||||
nonce: "test-nonce",
|
||||
setAuthError: errors.New("set auth failed"),
|
||||
}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenResponse := &TokenResponse{IDToken: "valid-token", AccessToken: "access-token"}
|
||||
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{"email": "test@example.com", "nonce": "test-nonce"}, nil
|
||||
}
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusInternalServerError {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
|
||||
}
|
||||
if !strings.Contains(msg, "Failed to update session") {
|
||||
t.Errorf("Expected error message to contain 'Failed to update session', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_Success tests successful callback handling
|
||||
func TestOAuthHandler_HandleCallback_Success(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{
|
||||
csrf: "test-state",
|
||||
nonce: "test-nonce",
|
||||
incomingPath: "/dashboard",
|
||||
}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenResponse := &TokenResponse{
|
||||
IDToken: "valid-id-token",
|
||||
AccessToken: "valid-access-token",
|
||||
RefreshToken: "valid-refresh-token",
|
||||
}
|
||||
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{"email": "test@example.com", "nonce": "test-nonce"}, nil
|
||||
}
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
t.Errorf("Unexpected error sent: %s (code: %d)", msg, code)
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if errorSent {
|
||||
t.Error("Unexpected error response sent")
|
||||
}
|
||||
|
||||
// Check redirect
|
||||
if rw.Code != http.StatusFound {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusFound, rw.Code)
|
||||
}
|
||||
|
||||
location := rw.Header().Get("Location")
|
||||
if location != "/dashboard" {
|
||||
t.Errorf("Expected redirect to '/dashboard', got '%s'", location)
|
||||
}
|
||||
|
||||
// Verify session data was set correctly
|
||||
if session.email != "test@example.com" {
|
||||
t.Errorf("Expected email 'test@example.com', got '%s'", session.email)
|
||||
}
|
||||
|
||||
if session.idToken != "valid-id-token" {
|
||||
t.Errorf("Expected ID token 'valid-id-token', got '%s'", session.idToken)
|
||||
}
|
||||
|
||||
if session.accessToken != "valid-access-token" {
|
||||
t.Errorf("Expected access token 'valid-access-token', got '%s'", session.accessToken)
|
||||
}
|
||||
|
||||
if session.refreshToken != "valid-refresh-token" {
|
||||
t.Errorf("Expected refresh token 'valid-refresh-token', got '%s'", session.refreshToken)
|
||||
}
|
||||
|
||||
if !session.authenticated {
|
||||
t.Error("Expected session to be authenticated")
|
||||
}
|
||||
|
||||
// Check that temporary fields are cleared
|
||||
if session.csrf != "" {
|
||||
t.Errorf("Expected CSRF to be cleared, got '%s'", session.csrf)
|
||||
}
|
||||
|
||||
if session.nonce != "" {
|
||||
t.Errorf("Expected nonce to be cleared, got '%s'", session.nonce)
|
||||
}
|
||||
|
||||
if session.codeVerifier != "" {
|
||||
t.Errorf("Expected code verifier to be cleared, got '%s'", session.codeVerifier)
|
||||
}
|
||||
|
||||
if session.incomingPath != "" {
|
||||
t.Errorf("Expected incoming path to be cleared, got '%s'", session.incomingPath)
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_SuccessDefaultRedirect tests successful callback with default redirect
|
||||
func TestOAuthHandler_HandleCallback_SuccessDefaultRedirect(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{
|
||||
csrf: "test-state",
|
||||
nonce: "test-nonce",
|
||||
incomingPath: "", // No incoming path, should default to "/"
|
||||
}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenResponse := &TokenResponse{IDToken: "valid-token", AccessToken: "access-token"}
|
||||
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{"email": "test@example.com", "nonce": "test-nonce"}, nil
|
||||
}
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
t.Errorf("Unexpected error sent: %s (code: %d)", msg, code)
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
// Check redirect to default path
|
||||
if rw.Code != http.StatusFound {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusFound, rw.Code)
|
||||
}
|
||||
|
||||
location := rw.Header().Get("Location")
|
||||
if location != "/" {
|
||||
t.Errorf("Expected redirect to '/', got '%s'", location)
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_RedirectURLPathExcluded tests incoming path same as redirect URL
|
||||
func TestOAuthHandler_HandleCallback_RedirectURLPathExcluded(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{
|
||||
csrf: "test-state",
|
||||
nonce: "test-nonce",
|
||||
incomingPath: "/callback", // Same as redirect URL path
|
||||
}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenResponse := &TokenResponse{IDToken: "valid-token", AccessToken: "access-token"}
|
||||
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{"email": "test@example.com", "nonce": "test-nonce"}, nil
|
||||
}
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
t.Errorf("Unexpected error sent: %s (code: %d)", msg, code)
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
// Should redirect to default path when incoming path is same as callback path
|
||||
location := rw.Header().Get("Location")
|
||||
if location != "/" {
|
||||
t.Errorf("Expected redirect to '/', got '%s'", location)
|
||||
}
|
||||
}
|
||||
@@ -1,454 +0,0 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestURLHelper_NewURLHelper tests the URLHelper constructor
|
||||
func TestURLHelper_NewURLHelper(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
helper := NewURLHelper(logger)
|
||||
|
||||
if helper == nil {
|
||||
t.Fatal("Expected URLHelper to be created, got nil")
|
||||
}
|
||||
|
||||
if helper.logger != logger {
|
||||
t.Error("Logger not set correctly")
|
||||
}
|
||||
}
|
||||
|
||||
// TestURLHelper_DetermineExcludedURL tests URL exclusion checking
|
||||
func TestURLHelper_DetermineExcludedURL(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
helper := NewURLHelper(logger)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
currentURL string
|
||||
excludedURLs map[string]struct{}
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "Exact match",
|
||||
currentURL: "/health",
|
||||
excludedURLs: map[string]struct{}{
|
||||
"/health": {},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Prefix match",
|
||||
currentURL: "/health/status",
|
||||
excludedURLs: map[string]struct{}{
|
||||
"/health": {},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "No match",
|
||||
currentURL: "/api/users",
|
||||
excludedURLs: map[string]struct{}{
|
||||
"/health": {},
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Multiple exclusions - first match",
|
||||
currentURL: "/api/health",
|
||||
excludedURLs: map[string]struct{}{
|
||||
"/api": {},
|
||||
"/health": {},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Multiple exclusions - second match",
|
||||
currentURL: "/health/check",
|
||||
excludedURLs: map[string]struct{}{
|
||||
"/api": {},
|
||||
"/health": {},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Empty excluded URLs",
|
||||
currentURL: "/api/users",
|
||||
excludedURLs: map[string]struct{}{},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Root path exclusion",
|
||||
currentURL: "/anything",
|
||||
excludedURLs: map[string]struct{}{
|
||||
"/": {},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Case sensitive matching",
|
||||
currentURL: "/API/users",
|
||||
excludedURLs: map[string]struct{}{
|
||||
"/api": {},
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Partial substring but not prefix",
|
||||
currentURL: "/user/api/test",
|
||||
excludedURLs: map[string]struct{}{
|
||||
"/api": {},
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Empty current URL",
|
||||
currentURL: "",
|
||||
excludedURLs: map[string]struct{}{
|
||||
"/health": {},
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "URL with query parameters",
|
||||
currentURL: "/health?status=ok",
|
||||
excludedURLs: map[string]struct{}{
|
||||
"/health": {},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := helper.DetermineExcludedURL(tt.currentURL, tt.excludedURLs)
|
||||
if result != tt.expected {
|
||||
t.Errorf("DetermineExcludedURL() = %v, expected %v", result, tt.expected)
|
||||
}
|
||||
|
||||
// Verify debug logging for excluded URLs
|
||||
if result && len(logger.debugMessages) > 0 {
|
||||
// Should have logged a debug message for excluded URL
|
||||
found := false
|
||||
for _, msg := range logger.debugMessages {
|
||||
if msg == "URL is excluded - got %s / excluded hit: %s" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("Expected debug message for excluded URL")
|
||||
}
|
||||
}
|
||||
|
||||
// Reset logger messages for next test
|
||||
logger.debugMessages = nil
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestURLHelper_DetermineScheme tests scheme determination
|
||||
func TestURLHelper_DetermineScheme(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
helper := NewURLHelper(logger)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupRequest func() *http.Request
|
||||
expectedScheme string
|
||||
}{
|
||||
{
|
||||
name: "X-Forwarded-Proto header present - https",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
return req
|
||||
},
|
||||
expectedScheme: "https",
|
||||
},
|
||||
{
|
||||
name: "X-Forwarded-Proto header present - http",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
req.Header.Set("X-Forwarded-Proto", "http")
|
||||
return req
|
||||
},
|
||||
expectedScheme: "http",
|
||||
},
|
||||
{
|
||||
name: "TLS connection without X-Forwarded-Proto",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "https://example.com", nil)
|
||||
req.TLS = &tls.ConnectionState{} // Simulate TLS connection
|
||||
return req
|
||||
},
|
||||
expectedScheme: "https",
|
||||
},
|
||||
{
|
||||
name: "No TLS and no X-Forwarded-Proto",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
return req
|
||||
},
|
||||
expectedScheme: "http",
|
||||
},
|
||||
{
|
||||
name: "X-Forwarded-Proto takes precedence over TLS",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "https://example.com", nil)
|
||||
req.TLS = &tls.ConnectionState{} // Simulate TLS connection
|
||||
req.Header.Set("X-Forwarded-Proto", "http")
|
||||
return req
|
||||
},
|
||||
expectedScheme: "http",
|
||||
},
|
||||
{
|
||||
name: "Empty X-Forwarded-Proto falls back to TLS",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "https://example.com", nil)
|
||||
req.TLS = &tls.ConnectionState{} // Simulate TLS connection
|
||||
req.Header.Set("X-Forwarded-Proto", "")
|
||||
return req
|
||||
},
|
||||
expectedScheme: "https",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := tt.setupRequest()
|
||||
result := helper.DetermineScheme(req)
|
||||
if result != tt.expectedScheme {
|
||||
t.Errorf("DetermineScheme() = %v, expected %v", result, tt.expectedScheme)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestURLHelper_DetermineHost tests host determination
|
||||
func TestURLHelper_DetermineHost(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
helper := NewURLHelper(logger)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupRequest func() *http.Request
|
||||
expectedHost string
|
||||
}{
|
||||
{
|
||||
name: "X-Forwarded-Host header present",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
req.Host = "internal.example.com"
|
||||
req.Header.Set("X-Forwarded-Host", "public.example.com")
|
||||
return req
|
||||
},
|
||||
expectedHost: "public.example.com",
|
||||
},
|
||||
{
|
||||
name: "No X-Forwarded-Host, use req.Host",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
req.Host = "direct.example.com"
|
||||
return req
|
||||
},
|
||||
expectedHost: "direct.example.com",
|
||||
},
|
||||
{
|
||||
name: "Empty X-Forwarded-Host falls back to req.Host",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
req.Host = "fallback.example.com"
|
||||
req.Header.Set("X-Forwarded-Host", "")
|
||||
return req
|
||||
},
|
||||
expectedHost: "fallback.example.com",
|
||||
},
|
||||
{
|
||||
name: "X-Forwarded-Host with port",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
req.Host = "internal.example.com:8080"
|
||||
req.Header.Set("X-Forwarded-Host", "public.example.com:443")
|
||||
return req
|
||||
},
|
||||
expectedHost: "public.example.com:443",
|
||||
},
|
||||
{
|
||||
name: "req.Host with port",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "http://example.com:8080", nil)
|
||||
req.Host = "example.com:8080"
|
||||
return req
|
||||
},
|
||||
expectedHost: "example.com:8080",
|
||||
},
|
||||
{
|
||||
name: "Multiple X-Forwarded-Host values (first one used)",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
req.Host = "internal.example.com"
|
||||
req.Header.Set("X-Forwarded-Host", "first.example.com, second.example.com")
|
||||
return req
|
||||
},
|
||||
expectedHost: "first.example.com, second.example.com", // Header value as-is
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := tt.setupRequest()
|
||||
result := helper.DetermineHost(req)
|
||||
if result != tt.expectedHost {
|
||||
t.Errorf("DetermineHost() = %v, expected %v", result, tt.expectedHost)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestURLHelper_DetermineSchemeAndHost_Integration tests scheme and host working together
|
||||
func TestURLHelper_DetermineSchemeAndHost_Integration(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
helper := NewURLHelper(logger)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupRequest func() *http.Request
|
||||
expectedScheme string
|
||||
expectedHost string
|
||||
}{
|
||||
{
|
||||
name: "Both headers present",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "http://internal.example.com", nil)
|
||||
req.Host = "internal.example.com"
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
req.Header.Set("X-Forwarded-Host", "public.example.com")
|
||||
return req
|
||||
},
|
||||
expectedScheme: "https",
|
||||
expectedHost: "public.example.com",
|
||||
},
|
||||
{
|
||||
name: "Neither header present, TLS connection",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "https://secure.example.com", nil)
|
||||
req.Host = "secure.example.com"
|
||||
req.TLS = &tls.ConnectionState{} // Simulate TLS connection
|
||||
return req
|
||||
},
|
||||
expectedScheme: "https",
|
||||
expectedHost: "secure.example.com",
|
||||
},
|
||||
{
|
||||
name: "Neither header present, no TLS",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "http://plain.example.com", nil)
|
||||
req.Host = "plain.example.com"
|
||||
return req
|
||||
},
|
||||
expectedScheme: "http",
|
||||
expectedHost: "plain.example.com",
|
||||
},
|
||||
{
|
||||
name: "Mixed - only scheme header",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "http://mixed.example.com", nil)
|
||||
req.Host = "mixed.example.com"
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
return req
|
||||
},
|
||||
expectedScheme: "https",
|
||||
expectedHost: "mixed.example.com",
|
||||
},
|
||||
{
|
||||
name: "Mixed - only host header",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "http://mixed.example.com", nil)
|
||||
req.Host = "internal.example.com"
|
||||
req.Header.Set("X-Forwarded-Host", "external.example.com")
|
||||
return req
|
||||
},
|
||||
expectedScheme: "http",
|
||||
expectedHost: "external.example.com",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := tt.setupRequest()
|
||||
|
||||
scheme := helper.DetermineScheme(req)
|
||||
host := helper.DetermineHost(req)
|
||||
|
||||
if scheme != tt.expectedScheme {
|
||||
t.Errorf("DetermineScheme() = %v, expected %v", scheme, tt.expectedScheme)
|
||||
}
|
||||
|
||||
if host != tt.expectedHost {
|
||||
t.Errorf("DetermineHost() = %v, expected %v", host, tt.expectedHost)
|
||||
}
|
||||
|
||||
// Test that we can build a complete URL
|
||||
fullURL := scheme + "://" + host + "/callback"
|
||||
expectedURL := tt.expectedScheme + "://" + tt.expectedHost + "/callback"
|
||||
if fullURL != expectedURL {
|
||||
t.Errorf("Combined URL = %v, expected %v", fullURL, expectedURL)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark tests to ensure the helper methods are performant
|
||||
func BenchmarkURLHelper_DetermineExcludedURL(b *testing.B) {
|
||||
logger := &mockLogger{}
|
||||
helper := NewURLHelper(logger)
|
||||
excludedURLs := map[string]struct{}{
|
||||
"/health": {},
|
||||
"/metrics": {},
|
||||
"/status": {},
|
||||
"/api/v1": {},
|
||||
"/api/v2": {},
|
||||
"/static": {},
|
||||
"/assets": {},
|
||||
"/favicon": {},
|
||||
"/robots": {},
|
||||
"/sitemap": {},
|
||||
}
|
||||
|
||||
testURL := "/api/users"
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
helper.DetermineExcludedURL(testURL, excludedURLs)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkURLHelper_DetermineScheme(b *testing.B) {
|
||||
logger := &mockLogger{}
|
||||
helper := NewURLHelper(logger)
|
||||
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
helper.DetermineScheme(req)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkURLHelper_DetermineHost(b *testing.B) {
|
||||
logger := &mockLogger{}
|
||||
helper := NewURLHelper(logger)
|
||||
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
req.Host = "internal.example.com"
|
||||
req.Header.Set("X-Forwarded-Host", "external.example.com")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
helper.DetermineHost(req)
|
||||
}
|
||||
}
|
||||
+22
-10
@@ -13,6 +13,8 @@ import (
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/utils"
|
||||
)
|
||||
|
||||
// generateNonce creates a cryptographically secure random nonce for OIDC flows.
|
||||
@@ -109,7 +111,7 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType string, code
|
||||
client := t.tokenHTTPClient
|
||||
if client == nil {
|
||||
// Use shared transport pool to prevent memory leaks
|
||||
jar, _ := cookiejar.New(nil)
|
||||
jar, _ := cookiejar.New(nil) // Safe to ignore: cookiejar creation with nil options rarely fails
|
||||
pooledClient := CreateTokenHTTPClient()
|
||||
client = &http.Client{
|
||||
Transport: pooledClient.Transport,
|
||||
@@ -124,7 +126,12 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType string, code
|
||||
}
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", t.tokenURL, strings.NewReader(data.Encode()))
|
||||
// Read tokenURL with RLock
|
||||
t.metadataMu.RLock()
|
||||
tokenURL := t.tokenURL
|
||||
t.metadataMu.RUnlock()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", tokenURL, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create token request: %w", err)
|
||||
}
|
||||
@@ -135,13 +142,13 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType string, code
|
||||
return nil, fmt.Errorf("failed to exchange tokens: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
io.Copy(io.Discard, resp.Body)
|
||||
resp.Body.Close()
|
||||
_, _ = io.Copy(io.Discard, resp.Body) // Safe to ignore: draining response body on defer
|
||||
_ = resp.Body.Close() // Safe to ignore: closing body on defer
|
||||
}()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
limitReader := io.LimitReader(resp.Body, 1024*10)
|
||||
bodyBytes, _ := io.ReadAll(limitReader)
|
||||
bodyBytes, _ := io.ReadAll(limitReader) // Safe to ignore: reading error body for diagnostics
|
||||
return nil, fmt.Errorf("token endpoint returned status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
@@ -232,7 +239,7 @@ func NewTokenCache() *TokenCache {
|
||||
// - expiration: The duration for which the cache entry should be valid
|
||||
func (tc *TokenCache) Set(token string, claims map[string]interface{}, expiration time.Duration) {
|
||||
token = "t-" + token
|
||||
tc.cache.Set(token, claims, expiration)
|
||||
_ = tc.cache.Set(token, claims, expiration) // Safe to ignore: cache failures are non-critical
|
||||
}
|
||||
|
||||
// Get retrieves cached claims for a token.
|
||||
@@ -344,8 +351,8 @@ func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
host := t.determineHost(req)
|
||||
scheme := t.determineScheme(req)
|
||||
host := utils.DetermineHost(req)
|
||||
scheme := utils.DetermineScheme(req, t.forceHTTPS)
|
||||
baseURL := fmt.Sprintf("%s://%s", scheme, host)
|
||||
|
||||
postLogoutRedirectURI := t.postLogoutRedirectURI
|
||||
@@ -355,8 +362,13 @@ func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
|
||||
postLogoutRedirectURI = fmt.Sprintf("%s%s", baseURL, postLogoutRedirectURI)
|
||||
}
|
||||
|
||||
if t.endSessionURL != "" && idToken != "" {
|
||||
logoutURL, err := BuildLogoutURL(t.endSessionURL, idToken, postLogoutRedirectURI)
|
||||
// Read endSessionURL with RLock
|
||||
t.metadataMu.RLock()
|
||||
endSessionURL := t.endSessionURL
|
||||
t.metadataMu.RUnlock()
|
||||
|
||||
if endSessionURL != "" && idToken != "" {
|
||||
logoutURL, err := BuildLogoutURL(endSessionURL, idToken, postLogoutRedirectURI)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to build logout URL: %v", err)
|
||||
http.Error(rw, "Logout error", http.StatusInternalServerError)
|
||||
|
||||
+13
-20
@@ -12,30 +12,23 @@ import (
|
||||
|
||||
// HTTPClientConfig provides configuration for creating HTTP clients
|
||||
type HTTPClientConfig struct {
|
||||
// Timeout for the entire request
|
||||
Timeout time.Duration
|
||||
// MaxRedirects allowed (0 means follow Go's default of 10)
|
||||
MaxRedirects int
|
||||
// UseCookieJar enables cookie jar for the client
|
||||
UseCookieJar bool
|
||||
// Connection settings
|
||||
IdleConnTimeout time.Duration
|
||||
MaxIdleConns int
|
||||
ReadBufferSize int
|
||||
DialTimeout time.Duration
|
||||
KeepAlive time.Duration
|
||||
TLSHandshakeTimeout time.Duration
|
||||
ResponseHeaderTimeout time.Duration
|
||||
ExpectContinueTimeout time.Duration
|
||||
IdleConnTimeout time.Duration
|
||||
// Connection pool settings
|
||||
MaxIdleConns int
|
||||
MaxIdleConnsPerHost int
|
||||
MaxConnsPerHost int
|
||||
// Buffer settings
|
||||
WriteBufferSize int
|
||||
ReadBufferSize int
|
||||
// Feature flags
|
||||
ForceHTTP2 bool
|
||||
DisableKeepAlives bool
|
||||
DisableCompression bool
|
||||
MaxRedirects int
|
||||
MaxIdleConnsPerHost int
|
||||
Timeout time.Duration
|
||||
MaxConnsPerHost int
|
||||
WriteBufferSize int
|
||||
UseCookieJar bool
|
||||
ForceHTTP2 bool
|
||||
DisableKeepAlives bool
|
||||
DisableCompression bool
|
||||
}
|
||||
|
||||
// DefaultHTTPClientConfig returns the default configuration for general use
|
||||
@@ -245,7 +238,7 @@ func (f *HTTPClientFactory) CreateHTTPClient(config HTTPClientConfig) *http.Clie
|
||||
|
||||
// Add cookie jar if requested
|
||||
if config.UseCookieJar {
|
||||
jar, _ := cookiejar.New(nil)
|
||||
jar, _ := cookiejar.New(nil) // Safe to ignore: cookiejar creation with nil options rarely fails
|
||||
client.Jar = jar
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,210 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestOIDCProviderHTTPClientConfigUnit tests OIDCProviderHTTPClientConfig function
|
||||
func TestOIDCProviderHTTPClientConfigUnit(t *testing.T) {
|
||||
config := OIDCProviderHTTPClientConfig()
|
||||
|
||||
// Verify OIDC-specific settings
|
||||
assert.Equal(t, 15*time.Second, config.Timeout, "OIDC provider should have 15s timeout")
|
||||
assert.Equal(t, 100, config.MaxIdleConns, "OIDC provider should have 100 max idle conns")
|
||||
assert.Equal(t, 25, config.MaxIdleConnsPerHost, "OIDC provider should have 25 max idle conns per host")
|
||||
assert.Equal(t, 50, config.MaxConnsPerHost, "OIDC provider should have 50 max conns per host")
|
||||
assert.Equal(t, 90*time.Second, config.IdleConnTimeout, "OIDC provider should have 90s idle conn timeout")
|
||||
assert.True(t, config.UseCookieJar, "OIDC provider should have cookie jar enabled")
|
||||
}
|
||||
|
||||
// TestCreateDefaultClientUnit tests CreateDefaultClient function
|
||||
func TestCreateDefaultClientUnit(t *testing.T) {
|
||||
factory := NewHTTPClientFactory()
|
||||
client := factory.CreateDefaultClient()
|
||||
|
||||
require.NotNil(t, client)
|
||||
assert.NotNil(t, client.Transport, "client should have transport")
|
||||
assert.Equal(t, 10*time.Second, client.Timeout, "default client should have 10s timeout")
|
||||
}
|
||||
|
||||
// TestCreateTokenClientUnit tests CreateTokenClient function
|
||||
func TestCreateTokenClientUnit(t *testing.T) {
|
||||
factory := NewHTTPClientFactory()
|
||||
client := factory.CreateTokenClient()
|
||||
|
||||
require.NotNil(t, client)
|
||||
assert.NotNil(t, client.Transport, "client should have transport")
|
||||
assert.NotNil(t, client.Jar, "token client should have cookie jar")
|
||||
assert.Equal(t, 10*time.Second, client.Timeout, "token client should have 10s timeout")
|
||||
}
|
||||
|
||||
// TestCreateHTTPClientWithConfigUnit tests CreateHTTPClientWithConfig function
|
||||
func TestCreateHTTPClientWithConfigUnit(t *testing.T) {
|
||||
config := HTTPClientConfig{
|
||||
Timeout: 5 * time.Second,
|
||||
MaxIdleConns: 20,
|
||||
MaxIdleConnsPerHost: 5,
|
||||
UseCookieJar: true,
|
||||
}
|
||||
|
||||
client := CreateHTTPClientWithConfig(config)
|
||||
|
||||
require.NotNil(t, client)
|
||||
assert.Equal(t, 5*time.Second, client.Timeout)
|
||||
assert.NotNil(t, client.Jar, "client should have cookie jar when configured")
|
||||
}
|
||||
|
||||
// TestHTTPClientFactoryCreateHTTPClientValidation tests validation in CreateHTTPClient
|
||||
func TestHTTPClientFactoryCreateHTTPClientValidation(t *testing.T) {
|
||||
factory := NewHTTPClientFactory()
|
||||
|
||||
t.Run("zero values get defaults", func(t *testing.T) {
|
||||
config := HTTPClientConfig{
|
||||
// All zero values
|
||||
}
|
||||
|
||||
client := factory.CreateHTTPClient(config)
|
||||
|
||||
require.NotNil(t, client)
|
||||
// Verify defaults were applied
|
||||
assert.Equal(t, 30*time.Second, client.Timeout)
|
||||
})
|
||||
|
||||
t.Run("custom values preserved", func(t *testing.T) {
|
||||
config := HTTPClientConfig{
|
||||
Timeout: 15 * time.Second,
|
||||
MaxIdleConns: 50,
|
||||
MaxRedirects: 3,
|
||||
UseCookieJar: true,
|
||||
ForceHTTP2: true,
|
||||
DisableKeepAlives: true,
|
||||
}
|
||||
|
||||
client := factory.CreateHTTPClient(config)
|
||||
|
||||
require.NotNil(t, client)
|
||||
assert.Equal(t, 15*time.Second, client.Timeout)
|
||||
assert.NotNil(t, client.Jar)
|
||||
})
|
||||
|
||||
t.Run("invalid timeout gets default", func(t *testing.T) {
|
||||
config := HTTPClientConfig{
|
||||
Timeout: -1 * time.Second, // Invalid
|
||||
}
|
||||
|
||||
client := factory.CreateHTTPClient(config)
|
||||
|
||||
require.NotNil(t, client)
|
||||
// Should get default due to validation failure
|
||||
assert.Equal(t, 30*time.Second, client.Timeout)
|
||||
})
|
||||
}
|
||||
|
||||
// TestHTTPClientFactoryValidateHTTPClientConfig tests ValidateHTTPClientConfig
|
||||
func TestHTTPClientFactoryValidateHTTPClientConfig(t *testing.T) {
|
||||
factory := NewHTTPClientFactory()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
errorMsg string
|
||||
config HTTPClientConfig
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "valid config",
|
||||
config: HTTPClientConfig{
|
||||
Timeout: 10 * time.Second,
|
||||
DialTimeout: 5 * time.Second,
|
||||
TLSHandshakeTimeout: 2 * time.Second,
|
||||
MaxIdleConns: 50,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
MaxConnsPerHost: 20,
|
||||
},
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "negative MaxIdleConns",
|
||||
config: HTTPClientConfig{
|
||||
Timeout: 10 * time.Second,
|
||||
DialTimeout: 5 * time.Second,
|
||||
TLSHandshakeTimeout: 2 * time.Second,
|
||||
MaxIdleConns: -1,
|
||||
},
|
||||
wantError: true,
|
||||
errorMsg: "MaxIdleConns cannot be negative",
|
||||
},
|
||||
{
|
||||
name: "MaxIdleConns too high",
|
||||
config: HTTPClientConfig{
|
||||
Timeout: 10 * time.Second,
|
||||
DialTimeout: 5 * time.Second,
|
||||
TLSHandshakeTimeout: 2 * time.Second,
|
||||
MaxIdleConns: 1500,
|
||||
},
|
||||
wantError: true,
|
||||
errorMsg: "MaxIdleConns too high",
|
||||
},
|
||||
{
|
||||
name: "negative MaxIdleConnsPerHost",
|
||||
config: HTTPClientConfig{
|
||||
Timeout: 10 * time.Second,
|
||||
DialTimeout: 5 * time.Second,
|
||||
TLSHandshakeTimeout: 2 * time.Second,
|
||||
MaxIdleConnsPerHost: -1,
|
||||
},
|
||||
wantError: true,
|
||||
errorMsg: "MaxIdleConnsPerHost cannot be negative",
|
||||
},
|
||||
{
|
||||
name: "timeout too high",
|
||||
config: HTTPClientConfig{
|
||||
Timeout: 10 * time.Minute,
|
||||
DialTimeout: 5 * time.Second,
|
||||
TLSHandshakeTimeout: 2 * time.Second,
|
||||
},
|
||||
wantError: true,
|
||||
errorMsg: "timeout too high",
|
||||
},
|
||||
{
|
||||
name: "negative timeout",
|
||||
config: HTTPClientConfig{
|
||||
Timeout: -1 * time.Second,
|
||||
DialTimeout: 5 * time.Second,
|
||||
TLSHandshakeTimeout: 2 * time.Second,
|
||||
},
|
||||
wantError: true,
|
||||
errorMsg: "timeout must be positive",
|
||||
},
|
||||
{
|
||||
name: "MaxIdleConnsPerHost exceeds MaxConnsPerHost",
|
||||
config: HTTPClientConfig{
|
||||
Timeout: 10 * time.Second,
|
||||
DialTimeout: 5 * time.Second,
|
||||
TLSHandshakeTimeout: 2 * time.Second,
|
||||
MaxIdleConnsPerHost: 50,
|
||||
MaxConnsPerHost: 10,
|
||||
},
|
||||
wantError: true,
|
||||
errorMsg: "MaxIdleConnsPerHost (50) cannot exceed MaxConnsPerHost (10)",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := factory.ValidateHTTPClientConfig(&tt.config)
|
||||
|
||||
if tt.wantError {
|
||||
assert.Error(t, err)
|
||||
if tt.errorMsg != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorMsg)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
+48
-16
@@ -12,19 +12,19 @@ import (
|
||||
|
||||
// SharedTransportPool manages a pool of shared HTTP transports to prevent connection exhaustion
|
||||
type SharedTransportPool struct {
|
||||
mu sync.RWMutex
|
||||
transports map[string]*sharedTransport
|
||||
maxConns int
|
||||
ctx context.Context
|
||||
transports map[string]*sharedTransport
|
||||
cancel context.CancelFunc
|
||||
clientCount int32 // SECURITY FIX: Track total HTTP clients
|
||||
maxClients int32 // SECURITY FIX: Limit total clients to 5
|
||||
maxConns int
|
||||
mu sync.RWMutex
|
||||
clientCount int32
|
||||
maxClients int32
|
||||
}
|
||||
|
||||
type sharedTransport struct {
|
||||
lastUsed time.Time
|
||||
transport *http.Transport
|
||||
refCount int
|
||||
lastUsed time.Time
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -146,6 +146,9 @@ func (p *SharedTransportPool) ReleaseTransport(transport *http.Transport) {
|
||||
}
|
||||
|
||||
// cleanupIdleTransports periodically cleans up unused transports
|
||||
// Uses two-phase cleanup to minimize lock contention:
|
||||
// 1. Find candidates while holding read lock
|
||||
// 2. Remove and close transports with minimal lock duration
|
||||
func (p *SharedTransportPool) cleanupIdleTransports(ctx context.Context) {
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
@@ -155,17 +158,46 @@ func (p *SharedTransportPool) cleanupIdleTransports(ctx context.Context) {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
p.mu.Lock()
|
||||
now := time.Now()
|
||||
for transportKey, shared := range p.transports {
|
||||
// Clean up transports not used for 2 minutes with no references
|
||||
if shared.refCount <= 0 && now.Sub(shared.lastUsed) > 2*time.Minute {
|
||||
shared.transport.CloseIdleConnections()
|
||||
delete(p.transports, transportKey)
|
||||
// SECURITY FIX: Decrement client count when removing transport
|
||||
atomic.AddInt32(&p.clientCount, -1)
|
||||
}
|
||||
p.performCleanup()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// performCleanup does the actual cleanup with optimized locking
|
||||
func (p *SharedTransportPool) performCleanup() {
|
||||
now := time.Now()
|
||||
|
||||
// Phase 1: Find candidates while holding read lock (fast)
|
||||
p.mu.RLock()
|
||||
candidates := make([]string, 0)
|
||||
for transportKey, shared := range p.transports {
|
||||
// Clean up transports not used for 2 minutes with no references
|
||||
if shared.refCount <= 0 && now.Sub(shared.lastUsed) > 2*time.Minute {
|
||||
candidates = append(candidates, transportKey)
|
||||
}
|
||||
}
|
||||
p.mu.RUnlock()
|
||||
|
||||
if len(candidates) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Phase 2: Remove and close each candidate individually
|
||||
// This minimizes lock contention and allows concurrent access
|
||||
for _, key := range candidates {
|
||||
p.mu.Lock()
|
||||
shared, exists := p.transports[key]
|
||||
if exists && shared.refCount <= 0 && now.Sub(shared.lastUsed) > 2*time.Minute {
|
||||
// Remove from map first (releases memory)
|
||||
delete(p.transports, key)
|
||||
atomic.AddInt32(&p.clientCount, -1)
|
||||
p.mu.Unlock()
|
||||
|
||||
// Close idle connections outside the lock (can be slow)
|
||||
if shared.transport != nil {
|
||||
shared.transport.CloseIdleConnections()
|
||||
}
|
||||
} else {
|
||||
p.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,691 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestSharedTransportPoolGetOrCreateTransport tests transport creation and reuse
|
||||
func TestSharedTransportPoolGetOrCreateTransport(t *testing.T) {
|
||||
t.Run("create new transport", func(t *testing.T) {
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
clientCount: 0,
|
||||
maxClients: 5,
|
||||
}
|
||||
|
||||
config := DefaultHTTPClientConfig()
|
||||
transport := pool.GetOrCreateTransport(config)
|
||||
|
||||
require.NotNil(t, transport)
|
||||
assert.Equal(t, int32(1), atomic.LoadInt32(&pool.clientCount))
|
||||
assert.Len(t, pool.transports, 1)
|
||||
})
|
||||
|
||||
t.Run("reuse existing transport", func(t *testing.T) {
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
clientCount: 0,
|
||||
maxClients: 5,
|
||||
}
|
||||
|
||||
config := DefaultHTTPClientConfig()
|
||||
transport1 := pool.GetOrCreateTransport(config)
|
||||
transport2 := pool.GetOrCreateTransport(config)
|
||||
|
||||
assert.Equal(t, transport1, transport2, "should reuse same transport")
|
||||
assert.Equal(t, int32(1), atomic.LoadInt32(&pool.clientCount), "client count should not increase")
|
||||
|
||||
// Check ref count
|
||||
pool.mu.RLock()
|
||||
key := pool.configKey(config)
|
||||
shared := pool.transports[key]
|
||||
pool.mu.RUnlock()
|
||||
|
||||
assert.Equal(t, 2, shared.refCount, "ref count should be 2")
|
||||
})
|
||||
|
||||
t.Run("client limit enforcement", func(t *testing.T) {
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
clientCount: 5, // Already at max
|
||||
maxClients: 5,
|
||||
}
|
||||
|
||||
config := DefaultHTTPClientConfig()
|
||||
transport := pool.GetOrCreateTransport(config)
|
||||
|
||||
assert.Nil(t, transport, "should return nil when at client limit")
|
||||
})
|
||||
|
||||
t.Run("client limit with existing transport", func(t *testing.T) {
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
clientCount: 0,
|
||||
maxClients: 5,
|
||||
}
|
||||
|
||||
// Create first transport
|
||||
config1 := DefaultHTTPClientConfig()
|
||||
transport1 := pool.GetOrCreateTransport(config1)
|
||||
require.NotNil(t, transport1)
|
||||
|
||||
// Set client count to max
|
||||
atomic.StoreInt32(&pool.clientCount, 5)
|
||||
|
||||
// Try to create with different config
|
||||
config2 := DefaultHTTPClientConfig()
|
||||
config2.MaxConnsPerHost = 15 // Different config
|
||||
transport2 := pool.GetOrCreateTransport(config2)
|
||||
|
||||
// Should return existing transport since at limit
|
||||
assert.NotNil(t, transport2)
|
||||
assert.Equal(t, transport1, transport2)
|
||||
})
|
||||
|
||||
t.Run("updates last used time", func(t *testing.T) {
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
clientCount: 0,
|
||||
maxClients: 5,
|
||||
}
|
||||
|
||||
config := DefaultHTTPClientConfig()
|
||||
transport := pool.GetOrCreateTransport(config)
|
||||
require.NotNil(t, transport)
|
||||
|
||||
pool.mu.RLock()
|
||||
key := pool.configKey(config)
|
||||
firstTime := pool.transports[key].lastUsed
|
||||
pool.mu.RUnlock()
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Get again
|
||||
transport2 := pool.GetOrCreateTransport(config)
|
||||
require.NotNil(t, transport2)
|
||||
|
||||
pool.mu.RLock()
|
||||
secondTime := pool.transports[key].lastUsed
|
||||
pool.mu.RUnlock()
|
||||
|
||||
assert.True(t, secondTime.After(firstTime), "lastUsed should be updated")
|
||||
})
|
||||
}
|
||||
|
||||
// TestSharedTransportPoolReleaseTransport tests transport release
|
||||
func TestSharedTransportPoolReleaseTransport(t *testing.T) {
|
||||
t.Run("decrement ref count", func(t *testing.T) {
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
clientCount: 0,
|
||||
maxClients: 5,
|
||||
}
|
||||
|
||||
config := DefaultHTTPClientConfig()
|
||||
transport := pool.GetOrCreateTransport(config)
|
||||
require.NotNil(t, transport)
|
||||
|
||||
// Get again to increase ref count
|
||||
pool.GetOrCreateTransport(config)
|
||||
|
||||
pool.mu.RLock()
|
||||
key := pool.configKey(config)
|
||||
refCount := pool.transports[key].refCount
|
||||
pool.mu.RUnlock()
|
||||
assert.Equal(t, 2, refCount)
|
||||
|
||||
// Release
|
||||
pool.ReleaseTransport(transport)
|
||||
|
||||
pool.mu.RLock()
|
||||
newRefCount := pool.transports[key].refCount
|
||||
pool.mu.RUnlock()
|
||||
assert.Equal(t, 1, newRefCount)
|
||||
})
|
||||
|
||||
t.Run("ref count reaches zero", func(t *testing.T) {
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
clientCount: 0,
|
||||
maxClients: 5,
|
||||
}
|
||||
|
||||
config := DefaultHTTPClientConfig()
|
||||
transport := pool.GetOrCreateTransport(config)
|
||||
require.NotNil(t, transport)
|
||||
|
||||
pool.mu.RLock()
|
||||
key := pool.configKey(config)
|
||||
pool.mu.RUnlock()
|
||||
|
||||
// Release to zero
|
||||
pool.ReleaseTransport(transport)
|
||||
|
||||
pool.mu.RLock()
|
||||
shared := pool.transports[key]
|
||||
pool.mu.RUnlock()
|
||||
|
||||
assert.Equal(t, 0, shared.refCount)
|
||||
assert.NotZero(t, shared.lastUsed, "lastUsed should be set")
|
||||
})
|
||||
|
||||
t.Run("release non-existent transport", func(t *testing.T) {
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
clientCount: 0,
|
||||
maxClients: 5,
|
||||
}
|
||||
|
||||
// Create a transport not in the pool
|
||||
fakeTransport := &http.Transport{}
|
||||
|
||||
// Should not panic
|
||||
assert.NotPanics(t, func() {
|
||||
pool.ReleaseTransport(fakeTransport)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("release updates last used", func(t *testing.T) {
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
clientCount: 0,
|
||||
maxClients: 5,
|
||||
}
|
||||
|
||||
config := DefaultHTTPClientConfig()
|
||||
transport := pool.GetOrCreateTransport(config)
|
||||
require.NotNil(t, transport)
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
beforeRelease := time.Now()
|
||||
pool.ReleaseTransport(transport)
|
||||
|
||||
pool.mu.RLock()
|
||||
key := pool.configKey(config)
|
||||
lastUsed := pool.transports[key].lastUsed
|
||||
pool.mu.RUnlock()
|
||||
|
||||
assert.True(t, lastUsed.After(beforeRelease) || lastUsed.Equal(beforeRelease))
|
||||
})
|
||||
}
|
||||
|
||||
// TestSharedTransportPoolCleanup tests cleanup functionality
|
||||
func TestSharedTransportPoolCleanup(t *testing.T) {
|
||||
t.Run("cleanup all transports", func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
clientCount: 0,
|
||||
maxClients: 5,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
// Create multiple transports
|
||||
config1 := DefaultHTTPClientConfig()
|
||||
pool.GetOrCreateTransport(config1)
|
||||
|
||||
config2 := DefaultHTTPClientConfig()
|
||||
config2.MaxConnsPerHost = 15
|
||||
pool.GetOrCreateTransport(config2)
|
||||
|
||||
assert.Greater(t, len(pool.transports), 0)
|
||||
|
||||
// Cleanup
|
||||
pool.Cleanup()
|
||||
|
||||
assert.Len(t, pool.transports, 0, "all transports should be removed")
|
||||
})
|
||||
|
||||
t.Run("cleanup cancels context", func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
clientCount: 0,
|
||||
maxClients: 5,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
pool.Cleanup()
|
||||
|
||||
select {
|
||||
case <-pool.ctx.Done():
|
||||
// Context was canceled
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Error("context should be canceled")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("cleanup with no transports", func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
clientCount: 0,
|
||||
maxClients: 5,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
pool.Cleanup()
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("cleanup closes idle connections", func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
clientCount: 0,
|
||||
maxClients: 5,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
config := DefaultHTTPClientConfig()
|
||||
transport := pool.GetOrCreateTransport(config)
|
||||
require.NotNil(t, transport)
|
||||
|
||||
// Cleanup should call CloseIdleConnections on each transport
|
||||
pool.Cleanup()
|
||||
|
||||
// Verify transports map is cleared
|
||||
assert.Empty(t, pool.transports)
|
||||
})
|
||||
}
|
||||
|
||||
// TestSharedTransportPoolCleanupIdleTransports tests periodic cleanup
|
||||
func TestSharedTransportPoolCleanupIdleTransports(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping cleanup goroutine test in short mode")
|
||||
}
|
||||
|
||||
t.Run("cleanup removes idle transports", func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
clientCount: 0,
|
||||
maxClients: 5,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
// Create transport and release it
|
||||
config := DefaultHTTPClientConfig()
|
||||
transport := pool.GetOrCreateTransport(config)
|
||||
require.NotNil(t, transport)
|
||||
|
||||
pool.ReleaseTransport(transport)
|
||||
|
||||
// Set lastUsed to old time
|
||||
pool.mu.Lock()
|
||||
key := pool.configKey(config)
|
||||
pool.transports[key].lastUsed = time.Now().Add(-3 * time.Minute)
|
||||
pool.mu.Unlock()
|
||||
|
||||
// Start cleanup in background (simulating what would happen)
|
||||
// Note: We're testing the cleanup logic manually here
|
||||
pool.mu.Lock()
|
||||
now := time.Now()
|
||||
for transportKey, shared := range pool.transports {
|
||||
if shared.refCount <= 0 && now.Sub(shared.lastUsed) > 2*time.Minute {
|
||||
shared.transport.CloseIdleConnections()
|
||||
delete(pool.transports, transportKey)
|
||||
atomic.AddInt32(&pool.clientCount, -1)
|
||||
}
|
||||
}
|
||||
pool.mu.Unlock()
|
||||
|
||||
// Transport should be removed
|
||||
pool.mu.RLock()
|
||||
_, exists := pool.transports[key]
|
||||
pool.mu.RUnlock()
|
||||
|
||||
assert.False(t, exists, "old idle transport should be removed")
|
||||
})
|
||||
|
||||
t.Run("cleanup preserves active transports", func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
clientCount: 0,
|
||||
maxClients: 5,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
// Create transport with refs
|
||||
config := DefaultHTTPClientConfig()
|
||||
transport := pool.GetOrCreateTransport(config)
|
||||
require.NotNil(t, transport)
|
||||
|
||||
// Keep ref count > 0, but set old lastUsed
|
||||
pool.mu.Lock()
|
||||
key := pool.configKey(config)
|
||||
pool.transports[key].lastUsed = time.Now().Add(-3 * time.Minute)
|
||||
pool.mu.Unlock()
|
||||
|
||||
// Run cleanup logic
|
||||
pool.mu.Lock()
|
||||
now := time.Now()
|
||||
for transportKey, shared := range pool.transports {
|
||||
if shared.refCount <= 0 && now.Sub(shared.lastUsed) > 2*time.Minute {
|
||||
shared.transport.CloseIdleConnections()
|
||||
delete(pool.transports, transportKey)
|
||||
}
|
||||
}
|
||||
pool.mu.Unlock()
|
||||
|
||||
// Transport should still exist (has ref count)
|
||||
pool.mu.RLock()
|
||||
_, exists := pool.transports[key]
|
||||
pool.mu.RUnlock()
|
||||
|
||||
assert.True(t, exists, "transport with references should be preserved")
|
||||
})
|
||||
|
||||
t.Run("cleanup respects context cancellation", func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
clientCount: 0,
|
||||
maxClients: 5,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
// Start cleanup goroutine
|
||||
done := make(chan bool)
|
||||
go func() {
|
||||
pool.cleanupIdleTransports(ctx)
|
||||
done <- true
|
||||
}()
|
||||
|
||||
// Cancel context
|
||||
cancel()
|
||||
|
||||
// Should exit quickly
|
||||
select {
|
||||
case <-done:
|
||||
// Success
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Error("cleanup goroutine should exit on context cancellation")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestCreatePooledHTTPClient tests pooled client creation
|
||||
func TestCreatePooledHTTPClient(t *testing.T) {
|
||||
t.Run("create client with default config", func(t *testing.T) {
|
||||
config := DefaultHTTPClientConfig()
|
||||
client := CreatePooledHTTPClient(config)
|
||||
|
||||
require.NotNil(t, client)
|
||||
assert.NotNil(t, client.Transport)
|
||||
assert.Equal(t, config.Timeout, client.Timeout)
|
||||
})
|
||||
|
||||
t.Run("create multiple clients reuse transport", func(t *testing.T) {
|
||||
// Reset global pool for clean test
|
||||
globalTransportPoolOnce = sync.Once{}
|
||||
globalTransportPool = nil
|
||||
|
||||
config := DefaultHTTPClientConfig()
|
||||
client1 := CreatePooledHTTPClient(config)
|
||||
client2 := CreatePooledHTTPClient(config)
|
||||
|
||||
require.NotNil(t, client1)
|
||||
require.NotNil(t, client2)
|
||||
|
||||
// Should use same transport
|
||||
assert.Equal(t, client1.Transport, client2.Transport)
|
||||
})
|
||||
|
||||
t.Run("redirect policy is set", func(t *testing.T) {
|
||||
config := DefaultHTTPClientConfig()
|
||||
config.MaxRedirects = 3
|
||||
|
||||
client := CreatePooledHTTPClient(config)
|
||||
|
||||
require.NotNil(t, client)
|
||||
assert.NotNil(t, client.CheckRedirect)
|
||||
|
||||
// Test redirect limit
|
||||
var redirects []*http.Request
|
||||
for i := 0; i < 3; i++ {
|
||||
redirects = append(redirects, &http.Request{})
|
||||
}
|
||||
|
||||
err := client.CheckRedirect(nil, redirects)
|
||||
assert.Error(t, err, "should error after max redirects")
|
||||
})
|
||||
|
||||
t.Run("default redirect limit", func(t *testing.T) {
|
||||
config := DefaultHTTPClientConfig()
|
||||
config.MaxRedirects = 0 // Should default to 10
|
||||
|
||||
client := CreatePooledHTTPClient(config)
|
||||
|
||||
require.NotNil(t, client)
|
||||
|
||||
// Test default redirect limit (10)
|
||||
var redirects []*http.Request
|
||||
for i := 0; i < 10; i++ {
|
||||
redirects = append(redirects, &http.Request{})
|
||||
}
|
||||
|
||||
err := client.CheckRedirect(nil, redirects)
|
||||
assert.Error(t, err, "should error after 10 redirects")
|
||||
})
|
||||
}
|
||||
|
||||
// TestGetGlobalTransportPool tests singleton pattern
|
||||
func TestGetGlobalTransportPool(t *testing.T) {
|
||||
t.Run("returns same instance", func(t *testing.T) {
|
||||
pool1 := GetGlobalTransportPool()
|
||||
pool2 := GetGlobalTransportPool()
|
||||
|
||||
assert.Equal(t, pool1, pool2, "should return same singleton instance")
|
||||
})
|
||||
|
||||
t.Run("pool is initialized", func(t *testing.T) {
|
||||
pool := GetGlobalTransportPool()
|
||||
|
||||
require.NotNil(t, pool)
|
||||
assert.NotNil(t, pool.transports)
|
||||
assert.Equal(t, 20, pool.maxConns)
|
||||
assert.Equal(t, int32(5), pool.maxClients)
|
||||
assert.NotNil(t, pool.ctx)
|
||||
assert.NotNil(t, pool.cancel)
|
||||
})
|
||||
}
|
||||
|
||||
// TestSharedTransportPoolConcurrency tests thread safety
|
||||
func TestSharedTransportPoolConcurrency(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping concurrency test in short mode")
|
||||
}
|
||||
|
||||
t.Run("concurrent GetOrCreateTransport", func(t *testing.T) {
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
clientCount: 0,
|
||||
maxClients: 10, // Allow more for concurrency test
|
||||
}
|
||||
|
||||
config := DefaultHTTPClientConfig()
|
||||
const numGoroutines = 20
|
||||
|
||||
var wg sync.WaitGroup
|
||||
transports := make([]*http.Transport, numGoroutines)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
transports[idx] = pool.GetOrCreateTransport(config)
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// All should get same transport
|
||||
firstTransport := transports[0]
|
||||
for i := 1; i < numGoroutines; i++ {
|
||||
if transports[i] != nil {
|
||||
assert.Equal(t, firstTransport, transports[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("concurrent ReleaseTransport", func(t *testing.T) {
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
clientCount: 0,
|
||||
maxClients: 10,
|
||||
}
|
||||
|
||||
config := DefaultHTTPClientConfig()
|
||||
transport := pool.GetOrCreateTransport(config)
|
||||
|
||||
// Increase ref count
|
||||
for i := 0; i < 20; i++ {
|
||||
pool.GetOrCreateTransport(config)
|
||||
}
|
||||
|
||||
const numReleases = 20
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i := 0; i < numReleases; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
pool.ReleaseTransport(transport)
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Should not panic and ref count should be decremented
|
||||
pool.mu.RLock()
|
||||
key := pool.configKey(config)
|
||||
refCount := pool.transports[key].refCount
|
||||
pool.mu.RUnlock()
|
||||
|
||||
assert.Equal(t, 1, refCount, "ref count should be 1 after 20 releases from initial 21")
|
||||
})
|
||||
}
|
||||
|
||||
// TestSharedTransportPoolEdgeCases tests edge cases
|
||||
func TestSharedTransportPoolEdgeCases(t *testing.T) {
|
||||
t.Run("config key generation", func(t *testing.T) {
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
}
|
||||
|
||||
config1 := DefaultHTTPClientConfig()
|
||||
config1.MaxConnsPerHost = 10
|
||||
config1.MaxIdleConnsPerHost = 5
|
||||
|
||||
config2 := DefaultHTTPClientConfig()
|
||||
config2.MaxConnsPerHost = 10
|
||||
config2.MaxIdleConnsPerHost = 5
|
||||
|
||||
key1 := pool.configKey(config1)
|
||||
key2 := pool.configKey(config2)
|
||||
|
||||
assert.Equal(t, key1, key2, "same config should produce same key")
|
||||
})
|
||||
|
||||
t.Run("different configs produce different keys", func(t *testing.T) {
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
}
|
||||
|
||||
config1 := DefaultHTTPClientConfig()
|
||||
config1.MaxConnsPerHost = 10
|
||||
|
||||
config2 := DefaultHTTPClientConfig()
|
||||
config2.MaxConnsPerHost = 20
|
||||
|
||||
key1 := pool.configKey(config1)
|
||||
key2 := pool.configKey(config2)
|
||||
|
||||
assert.NotEqual(t, key1, key2, "different configs should produce different keys")
|
||||
})
|
||||
|
||||
t.Run("client count decrements on cleanup", func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
clientCount: 0,
|
||||
maxClients: 5,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
config := DefaultHTTPClientConfig()
|
||||
transport := pool.GetOrCreateTransport(config)
|
||||
require.NotNil(t, transport)
|
||||
|
||||
initialCount := atomic.LoadInt32(&pool.clientCount)
|
||||
assert.Equal(t, int32(1), initialCount)
|
||||
|
||||
// Release and mark as old
|
||||
pool.ReleaseTransport(transport)
|
||||
pool.mu.Lock()
|
||||
key := pool.configKey(config)
|
||||
pool.transports[key].lastUsed = time.Now().Add(-3 * time.Minute)
|
||||
pool.mu.Unlock()
|
||||
|
||||
// Run cleanup
|
||||
pool.mu.Lock()
|
||||
now := time.Now()
|
||||
for transportKey, shared := range pool.transports {
|
||||
if shared.refCount <= 0 && now.Sub(shared.lastUsed) > 2*time.Minute {
|
||||
shared.transport.CloseIdleConnections()
|
||||
delete(pool.transports, transportKey)
|
||||
atomic.AddInt32(&pool.clientCount, -1)
|
||||
}
|
||||
}
|
||||
pool.mu.Unlock()
|
||||
|
||||
finalCount := atomic.LoadInt32(&pool.clientCount)
|
||||
assert.Equal(t, int32(0), finalCount, "client count should decrement on cleanup")
|
||||
})
|
||||
}
|
||||
+52
-47
@@ -15,20 +15,21 @@ import (
|
||||
// XSS, path traversal, and other injection attacks. It validates and sanitizes
|
||||
// various input types used in OIDC authentication flows.
|
||||
type InputValidator struct {
|
||||
usernameRegex *regexp.Regexp
|
||||
tokenRegex *regexp.Regexp
|
||||
logger *Logger
|
||||
urlRegex *regexp.Regexp
|
||||
emailRegex *regexp.Regexp
|
||||
sqlInjectionPatterns []string
|
||||
pathTraversalPatterns []string
|
||||
xssPatterns []string
|
||||
maxUsernameLength int
|
||||
maxURLLength int
|
||||
maxTokenLength int
|
||||
maxEmailLength int
|
||||
maxClaimLength int
|
||||
maxHeaderLength int
|
||||
usernameRegex *regexp.Regexp
|
||||
tokenRegex *regexp.Regexp
|
||||
logger *Logger
|
||||
urlRegex *regexp.Regexp
|
||||
emailRegex *regexp.Regexp
|
||||
sqlInjectionPatterns []string
|
||||
pathTraversalPatterns []string
|
||||
xssPatterns []string
|
||||
maxUsernameLength int
|
||||
maxURLLength int
|
||||
maxTokenLength int
|
||||
maxEmailLength int
|
||||
maxClaimLength int
|
||||
maxHeaderLength int
|
||||
allowPrivateIPAddresses bool // Allow private IP addresses in URL validation
|
||||
}
|
||||
|
||||
// ValidationResult encapsulates the outcome of input validation.
|
||||
@@ -46,13 +47,14 @@ type ValidationResult struct {
|
||||
// It specifies maximum lengths for various input types and controls whether
|
||||
// strict validation mode is enabled.
|
||||
type InputValidationConfig struct {
|
||||
MaxTokenLength int `json:"max_token_length"`
|
||||
MaxURLLength int `json:"max_url_length"`
|
||||
MaxHeaderLength int `json:"max_header_length"`
|
||||
MaxClaimLength int `json:"max_claim_length"`
|
||||
MaxEmailLength int `json:"max_email_length"`
|
||||
MaxUsernameLength int `json:"max_username_length"`
|
||||
StrictMode bool `json:"strict_mode"`
|
||||
MaxTokenLength int `json:"max_token_length"`
|
||||
MaxURLLength int `json:"max_url_length"`
|
||||
MaxHeaderLength int `json:"max_header_length"`
|
||||
MaxClaimLength int `json:"max_claim_length"`
|
||||
MaxEmailLength int `json:"max_email_length"`
|
||||
MaxUsernameLength int `json:"max_username_length"`
|
||||
StrictMode bool `json:"strict_mode"`
|
||||
AllowPrivateIPAddresses bool `json:"allow_private_ip_addresses"` // Allow private IP addresses in URL validation
|
||||
}
|
||||
|
||||
// DefaultInputValidationConfig returns a secure default configuration
|
||||
@@ -103,16 +105,17 @@ func NewInputValidator(config InputValidationConfig, logger *Logger) (*InputVali
|
||||
}
|
||||
|
||||
return &InputValidator{
|
||||
maxTokenLength: config.MaxTokenLength,
|
||||
maxURLLength: config.MaxURLLength,
|
||||
maxHeaderLength: config.MaxHeaderLength,
|
||||
maxClaimLength: config.MaxClaimLength,
|
||||
maxEmailLength: config.MaxEmailLength,
|
||||
maxUsernameLength: config.MaxUsernameLength,
|
||||
emailRegex: emailRegex,
|
||||
urlRegex: urlRegex,
|
||||
tokenRegex: tokenRegex,
|
||||
usernameRegex: usernameRegex,
|
||||
maxTokenLength: config.MaxTokenLength,
|
||||
maxURLLength: config.MaxURLLength,
|
||||
maxHeaderLength: config.MaxHeaderLength,
|
||||
maxClaimLength: config.MaxClaimLength,
|
||||
maxEmailLength: config.MaxEmailLength,
|
||||
maxUsernameLength: config.MaxUsernameLength,
|
||||
allowPrivateIPAddresses: config.AllowPrivateIPAddresses,
|
||||
emailRegex: emailRegex,
|
||||
urlRegex: urlRegex,
|
||||
tokenRegex: tokenRegex,
|
||||
usernameRegex: usernameRegex,
|
||||
sqlInjectionPatterns: []string{
|
||||
"'", "\"", ";", "--", "/*", "*/", "xp_", "sp_",
|
||||
"union", "select", "insert", "update", "delete", "drop",
|
||||
@@ -335,24 +338,26 @@ func (iv *InputValidator) ValidateURL(urlStr string) ValidationResult {
|
||||
}
|
||||
}
|
||||
|
||||
// Check for private IP ranges (RFC 1918)
|
||||
if strings.HasPrefix(hostname, "10.") ||
|
||||
strings.HasPrefix(hostname, "192.168.") ||
|
||||
strings.HasPrefix(hostname, "172.") {
|
||||
// For 172.x check if it's in the 172.16.0.0/12 range
|
||||
if strings.HasPrefix(hostname, "172.") {
|
||||
parts := strings.Split(hostname, ".")
|
||||
if len(parts) >= 2 {
|
||||
if second, err := strconv.Atoi(parts[1]); err == nil && second >= 16 && second <= 31 {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, "private IP URLs are not allowed for security")
|
||||
return result
|
||||
// Check for private IP ranges (RFC 1918) - skip if allowPrivateIPAddresses is enabled
|
||||
if !iv.allowPrivateIPAddresses {
|
||||
if strings.HasPrefix(hostname, "10.") ||
|
||||
strings.HasPrefix(hostname, "192.168.") ||
|
||||
strings.HasPrefix(hostname, "172.") {
|
||||
// For 172.x check if it's in the 172.16.0.0/12 range
|
||||
if strings.HasPrefix(hostname, "172.") {
|
||||
parts := strings.Split(hostname, ".")
|
||||
if len(parts) >= 2 {
|
||||
if second, err := strconv.Atoi(parts[1]); err == nil && second >= 16 && second <= 31 {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, "private IP URLs are not allowed for security")
|
||||
return result
|
||||
}
|
||||
}
|
||||
} else {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, "private IP URLs are not allowed for security")
|
||||
return result
|
||||
}
|
||||
} else {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, "private IP URLs are not allowed for security")
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ func TestInputValidator(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Run("Valid token validation", func(t *testing.T) {
|
||||
validToken := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.EkN-DOsnsuRjRO6BxXemmJDm3HbxrbRzXglbN2S4sOkopdU4IsDxTI8jO19W_A4K8ZPJijNLis4EZsHeY559a4DFOd50_OqgHs3UjpMC6M6FNqI2J-I2NxrragtnDxGxdJUvDERDQVHzeNlVQiuqWDEeO_O-0KptafbfyuGqfQxH_6dp2_MeFpAc"
|
||||
validToken := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.EkN-DOsnsuRjRO6BxXemmJDm3HbxrbRzXglbN2S4sOkopdU4IsDxTI8jO19W_A4K8ZPJijNLis4EZsHeY559a4DFOd50_OqgHs3UjpMC6M6FNqI2J-I2NxrragtnDxGxdJUvDERDQVHzeNlVQiuqWDEeO_O-0KptafbfyuGqfQxH_6dp2_MeFpAc" // trufflehog:ignore
|
||||
|
||||
result := validator.ValidateToken(validToken)
|
||||
if !result.IsValid {
|
||||
@@ -428,12 +428,12 @@ func TestInputValidatorValidateToken(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
expectValid bool
|
||||
description string
|
||||
expectValid bool
|
||||
}{
|
||||
{
|
||||
name: "ValidJWTToken",
|
||||
token: "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiZXhwIjoxNTE2MjM5MDIyLCJpYXQiOjE1MTYyMzkwMjJ9.signature",
|
||||
token: "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiZXhwIjoxNTE2MjM5MDIyLCJpYXQiOjE1MTYyMzkwMjJ9.signature", // trufflehog:ignore
|
||||
expectValid: true,
|
||||
description: "Valid JWT token should pass validation",
|
||||
},
|
||||
@@ -475,7 +475,7 @@ func TestInputValidatorValidateToken(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "MaliciousJWTWithExtraData",
|
||||
token: "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.sig.malicious_extra",
|
||||
token: "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.sig.malicious_extra", // trufflehog:ignore
|
||||
expectValid: false,
|
||||
description: "JWT with extra malicious data should fail validation",
|
||||
},
|
||||
@@ -500,8 +500,8 @@ func TestInputValidatorValidateEmail(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
email string
|
||||
expectValid bool
|
||||
description string
|
||||
expectValid bool
|
||||
}{
|
||||
{
|
||||
name: "ValidEmail",
|
||||
@@ -578,8 +578,8 @@ func TestInputValidatorValidateURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
url string
|
||||
expectValid bool
|
||||
description string
|
||||
expectValid bool
|
||||
}{
|
||||
{
|
||||
name: "ValidHTTPSURL",
|
||||
@@ -669,8 +669,8 @@ func TestInputValidatorValidateClaim(t *testing.T) {
|
||||
name string
|
||||
claimName string
|
||||
claimValue string
|
||||
expectValid bool
|
||||
description string
|
||||
expectValid bool
|
||||
}{
|
||||
{
|
||||
name: "ValidStringClaim",
|
||||
@@ -750,8 +750,8 @@ func TestInputValidatorValidateHeader(t *testing.T) {
|
||||
name string
|
||||
headerName string
|
||||
headerValue string
|
||||
expectValid bool
|
||||
description string
|
||||
expectValid bool
|
||||
}{
|
||||
{
|
||||
name: "ValidHeader",
|
||||
@@ -830,8 +830,8 @@ func TestInputValidatorValidateUsername(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
username string
|
||||
expectValid bool
|
||||
description string
|
||||
expectValid bool
|
||||
}{
|
||||
{
|
||||
name: "ValidUsername",
|
||||
|
||||
@@ -726,20 +726,20 @@ type MockConfig struct {
|
||||
}
|
||||
|
||||
type MockSession struct {
|
||||
id string
|
||||
userID string
|
||||
created time.Time
|
||||
lastUsed time.Time
|
||||
data map[string]interface{}
|
||||
id string
|
||||
userID string
|
||||
}
|
||||
|
||||
type TestResult struct {
|
||||
UserID int
|
||||
StartTime time.Time
|
||||
EndTime time.Time
|
||||
Error error
|
||||
UserID int
|
||||
Duration time.Duration
|
||||
Success bool
|
||||
Error error
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
|
||||
Vendored
+79
@@ -0,0 +1,79 @@
|
||||
package backends
|
||||
|
||||
import "time"
|
||||
|
||||
// BackendType represents the type of cache backend
|
||||
type BackendType string
|
||||
|
||||
const (
|
||||
BackendTypeMemory BackendType = "memory"
|
||||
BackendTypeRedis BackendType = "redis"
|
||||
BackendTypeHybrid BackendType = "hybrid"
|
||||
|
||||
// Aliases for backward compatibility
|
||||
TypeMemory BackendType = "memory"
|
||||
TypeRedis BackendType = "redis"
|
||||
TypeHybrid BackendType = "hybrid"
|
||||
)
|
||||
|
||||
// Config provides common configuration for cache backends
|
||||
type Config struct {
|
||||
L2Config *Config
|
||||
L1Config *Config
|
||||
RedisPrefix string
|
||||
Type BackendType
|
||||
RedisAddr string
|
||||
RedisPassword string
|
||||
PoolSize int
|
||||
RedisDB int
|
||||
CleanupInterval time.Duration
|
||||
MaxMemoryBytes int64
|
||||
MaxSize int
|
||||
HealthCheckInterval time.Duration
|
||||
AsyncWrites bool
|
||||
EnableCircuitBreaker bool
|
||||
EnableHealthCheck bool
|
||||
EnableMetrics bool
|
||||
}
|
||||
|
||||
// DefaultConfig returns a default configuration for in-memory caching
|
||||
func DefaultConfig() *Config {
|
||||
return &Config{
|
||||
Type: BackendTypeMemory,
|
||||
MaxSize: 1000,
|
||||
MaxMemoryBytes: 50 * 1024 * 1024, // 50MB
|
||||
CleanupInterval: 5 * time.Minute,
|
||||
EnableMetrics: true,
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultRedisConfig returns a default configuration for Redis caching
|
||||
func DefaultRedisConfig(addr string) *Config {
|
||||
return &Config{
|
||||
Type: BackendTypeRedis,
|
||||
RedisAddr: addr,
|
||||
RedisDB: 0,
|
||||
RedisPrefix: "traefikoidc:",
|
||||
PoolSize: 10,
|
||||
EnableCircuitBreaker: true,
|
||||
EnableHealthCheck: true,
|
||||
HealthCheckInterval: 30 * time.Second,
|
||||
EnableMetrics: true,
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultHybridConfig returns a default configuration for hybrid caching
|
||||
func DefaultHybridConfig(redisAddr string) *Config {
|
||||
return &Config{
|
||||
Type: BackendTypeHybrid,
|
||||
L1Config: &Config{
|
||||
Type: BackendTypeMemory,
|
||||
MaxSize: 500,
|
||||
MaxMemoryBytes: 10 * 1024 * 1024, // 10MB for L1
|
||||
CleanupInterval: 1 * time.Minute,
|
||||
},
|
||||
L2Config: DefaultRedisConfig(redisAddr),
|
||||
AsyncWrites: true,
|
||||
EnableMetrics: true,
|
||||
}
|
||||
}
|
||||
+59
@@ -0,0 +1,59 @@
|
||||
//go:build !yaegi
|
||||
|
||||
package backends
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestDefaultHybridConfig verifies the default hybrid configuration
|
||||
func TestDefaultHybridConfig(t *testing.T) {
|
||||
redisAddr := "localhost:6379"
|
||||
|
||||
config := DefaultHybridConfig(redisAddr)
|
||||
|
||||
require.NotNil(t, config)
|
||||
|
||||
// Verify top-level config
|
||||
assert.Equal(t, BackendTypeHybrid, config.Type)
|
||||
assert.True(t, config.AsyncWrites)
|
||||
assert.True(t, config.EnableMetrics)
|
||||
|
||||
// Verify L1 (memory) config
|
||||
require.NotNil(t, config.L1Config)
|
||||
assert.Equal(t, BackendTypeMemory, config.L1Config.Type)
|
||||
assert.Equal(t, 500, config.L1Config.MaxSize)
|
||||
assert.Equal(t, int64(10*1024*1024), config.L1Config.MaxMemoryBytes) // 10MB
|
||||
assert.Equal(t, 1*time.Minute, config.L1Config.CleanupInterval)
|
||||
|
||||
// Verify L2 (Redis) config exists
|
||||
require.NotNil(t, config.L2Config)
|
||||
assert.Equal(t, BackendTypeRedis, config.L2Config.Type)
|
||||
}
|
||||
|
||||
func TestDefaultHybridConfig_DifferentRedisAddr(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
redisAddr string
|
||||
}{
|
||||
{"localhost", "localhost:6379"},
|
||||
{"remote host", "redis.example.com:6379"},
|
||||
{"IP address", "192.168.1.100:6379"},
|
||||
{"custom port", "localhost:6380"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config := DefaultHybridConfig(tt.redisAddr)
|
||||
|
||||
require.NotNil(t, config)
|
||||
assert.Equal(t, BackendTypeHybrid, config.Type)
|
||||
assert.NotNil(t, config.L1Config)
|
||||
assert.NotNil(t, config.L2Config)
|
||||
})
|
||||
}
|
||||
}
|
||||
Vendored
+38
@@ -0,0 +1,38 @@
|
||||
package backends
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
// ErrBackendClosed is returned when operating on a closed backend
|
||||
ErrBackendClosed = errors.New("cache backend is closed")
|
||||
|
||||
// ErrKeyNotFound is returned when a key doesn't exist
|
||||
ErrKeyNotFound = errors.New("key not found")
|
||||
|
||||
// ErrCacheMiss indicates the requested key was not found in the cache
|
||||
ErrCacheMiss = errors.New("cache miss")
|
||||
|
||||
// ErrBackendUnavailable indicates the cache backend is not available
|
||||
ErrBackendUnavailable = errors.New("cache backend unavailable")
|
||||
|
||||
// ErrInvalidValue indicates the cached value is invalid or corrupted
|
||||
ErrInvalidValue = errors.New("invalid cached value")
|
||||
|
||||
// ErrInvalidTTL is returned when TTL is invalid
|
||||
ErrInvalidTTL = errors.New("invalid TTL")
|
||||
|
||||
// ErrConnectionFailed is returned when connection fails
|
||||
ErrConnectionFailed = errors.New("connection failed")
|
||||
|
||||
// ErrCircuitOpen is returned when circuit breaker is open
|
||||
ErrCircuitOpen = errors.New("circuit breaker is open")
|
||||
|
||||
// ErrTimeout is returned when operation times out
|
||||
ErrTimeout = errors.New("operation timeout")
|
||||
|
||||
// ErrSerializationFailed is returned when serialization fails
|
||||
ErrSerializationFailed = errors.New("serialization failed")
|
||||
|
||||
// ErrDeserializationFailed is returned when deserialization fails
|
||||
ErrDeserializationFailed = errors.New("deserialization failed")
|
||||
)
|
||||
Vendored
+685
@@ -0,0 +1,685 @@
|
||||
// Package backend provides cache backend implementations for the Traefik OIDC plugin.
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// HybridBackend implements a two-tier cache with L1 (memory) and L2 (Redis) backends
|
||||
// It provides automatic failover, async writes for non-critical data, and optimized read paths
|
||||
type HybridBackend struct {
|
||||
lastL2Error atomic.Value
|
||||
secondary CacheBackend
|
||||
primary CacheBackend
|
||||
logger Logger
|
||||
ctx context.Context
|
||||
syncWriteCacheTypes map[string]bool
|
||||
asyncWriteBuffer chan *asyncWriteItem
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
l1Hits atomic.Int64
|
||||
errors atomic.Int64
|
||||
l2Writes atomic.Int64
|
||||
l1Writes atomic.Int64
|
||||
misses atomic.Int64
|
||||
l2Hits atomic.Int64
|
||||
fallbackMode atomic.Bool
|
||||
}
|
||||
|
||||
// asyncWriteItem represents an async write operation
|
||||
type asyncWriteItem struct {
|
||||
ctx context.Context
|
||||
key string
|
||||
value []byte
|
||||
ttl time.Duration
|
||||
}
|
||||
|
||||
// Logger interface for structured logging
|
||||
type Logger interface {
|
||||
Debugf(format string, args ...interface{})
|
||||
Infof(format string, args ...interface{})
|
||||
Warnf(format string, args ...interface{})
|
||||
Errorf(format string, args ...interface{})
|
||||
}
|
||||
|
||||
// defaultLogger provides a basic logger implementation
|
||||
type defaultLogger struct {
|
||||
*log.Logger
|
||||
}
|
||||
|
||||
func (l *defaultLogger) Debugf(format string, args ...interface{}) {
|
||||
l.Printf("[DEBUG] "+format, args...)
|
||||
}
|
||||
|
||||
func (l *defaultLogger) Infof(format string, args ...interface{}) {
|
||||
l.Printf("[INFO] "+format, args...)
|
||||
}
|
||||
|
||||
func (l *defaultLogger) Warnf(format string, args ...interface{}) {
|
||||
l.Printf("[WARN] "+format, args...)
|
||||
}
|
||||
|
||||
func (l *defaultLogger) Errorf(format string, args ...interface{}) {
|
||||
l.Printf("[ERROR] "+format, args...)
|
||||
}
|
||||
|
||||
// HybridConfig provides configuration for the hybrid backend
|
||||
type HybridConfig struct {
|
||||
Primary CacheBackend
|
||||
Secondary CacheBackend
|
||||
Logger Logger
|
||||
SyncWriteCacheTypes map[string]bool
|
||||
AsyncBufferSize int
|
||||
}
|
||||
|
||||
// NewHybridBackend creates a new hybrid cache backend with L1 (memory) and L2 (Redis) tiers
|
||||
func NewHybridBackend(config *HybridConfig) (*HybridBackend, error) {
|
||||
if config == nil {
|
||||
return nil, fmt.Errorf("config is required")
|
||||
}
|
||||
|
||||
if config.Primary == nil {
|
||||
return nil, fmt.Errorf("primary (L1) backend is required")
|
||||
}
|
||||
|
||||
if config.Secondary == nil {
|
||||
return nil, fmt.Errorf("secondary (L2) backend is required")
|
||||
}
|
||||
|
||||
if config.Logger == nil {
|
||||
config.Logger = &defaultLogger{Logger: log.New(log.Writer(), "[HybridCache] ", log.LstdFlags)}
|
||||
}
|
||||
|
||||
if config.AsyncBufferSize <= 0 {
|
||||
config.AsyncBufferSize = 1000
|
||||
}
|
||||
|
||||
// Default critical cache types that require synchronous writes
|
||||
if config.SyncWriteCacheTypes == nil {
|
||||
config.SyncWriteCacheTypes = map[string]bool{
|
||||
"blacklist": true, // Token blacklist must be immediately consistent
|
||||
"token": true, // Token validation is critical
|
||||
}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
h := &HybridBackend{
|
||||
primary: config.Primary,
|
||||
secondary: config.Secondary,
|
||||
syncWriteCacheTypes: config.SyncWriteCacheTypes,
|
||||
asyncWriteBuffer: make(chan *asyncWriteItem, config.AsyncBufferSize),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
logger: config.Logger,
|
||||
}
|
||||
|
||||
// Start async write worker
|
||||
h.wg.Add(1)
|
||||
go h.asyncWriteWorker()
|
||||
|
||||
// Start health monitoring
|
||||
h.wg.Add(1)
|
||||
go h.healthMonitor()
|
||||
|
||||
h.logger.Infof("HybridBackend initialized with L1 (memory) and L2 (Redis) tiers")
|
||||
h.logger.Infof("Sync write cache types: %v", config.SyncWriteCacheTypes)
|
||||
h.logger.Infof("Async write buffer size: %d", config.AsyncBufferSize)
|
||||
|
||||
return h, nil
|
||||
}
|
||||
|
||||
// Set stores a value in both L1 and L2 caches
|
||||
func (h *HybridBackend) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
|
||||
// Always write to L1 first (synchronous)
|
||||
if err := h.primary.Set(ctx, key, value, ttl); err != nil {
|
||||
h.errors.Add(1)
|
||||
h.logger.Warnf("Failed to write to L1 cache: %v", err)
|
||||
// Continue to try L2 even if L1 fails
|
||||
} else {
|
||||
h.l1Writes.Add(1)
|
||||
}
|
||||
|
||||
// Check if we're in fallback mode
|
||||
if h.fallbackMode.Load() {
|
||||
h.logger.Debugf("Operating in fallback mode, skipping L2 write for key: %s", key)
|
||||
return nil // Don't fail the operation if L2 is down
|
||||
}
|
||||
|
||||
// Determine if this should be a sync or async write based on cache type
|
||||
cacheType := h.extractCacheType(key)
|
||||
requiresSync := h.syncWriteCacheTypes[cacheType]
|
||||
|
||||
if requiresSync {
|
||||
// Synchronous write for critical cache types
|
||||
if err := h.secondary.Set(ctx, key, value, ttl); err != nil {
|
||||
h.errors.Add(1)
|
||||
h.logger.Warnf("Failed to write to L2 cache (sync) for key %s: %v", key, err)
|
||||
h.recordL2Error()
|
||||
// Don't fail the operation - L1 write succeeded
|
||||
return nil
|
||||
}
|
||||
h.l2Writes.Add(1)
|
||||
h.logger.Debugf("Synchronous write to L2 completed for critical key: %s", key)
|
||||
} else {
|
||||
// Asynchronous write for non-critical cache types
|
||||
select {
|
||||
case h.asyncWriteBuffer <- &asyncWriteItem{
|
||||
key: key,
|
||||
value: value,
|
||||
ttl: ttl,
|
||||
ctx: ctx,
|
||||
}:
|
||||
h.logger.Debugf("Queued async write to L2 for key: %s", key)
|
||||
default:
|
||||
// Buffer is full, log and continue
|
||||
h.logger.Warnf("Async write buffer full, dropping L2 write for key: %s", key)
|
||||
h.errors.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get retrieves a value from cache, checking L1 first, then L2
|
||||
func (h *HybridBackend) Get(ctx context.Context, key string) ([]byte, time.Duration, bool, error) {
|
||||
// Try L1 first
|
||||
value, ttl, exists, err := h.primary.Get(ctx, key)
|
||||
if err != nil {
|
||||
h.errors.Add(1)
|
||||
h.logger.Debugf("L1 get error for key %s: %v", key, err)
|
||||
}
|
||||
|
||||
if exists {
|
||||
h.l1Hits.Add(1)
|
||||
return value, ttl, true, nil
|
||||
}
|
||||
|
||||
// Check if we're in fallback mode
|
||||
if h.fallbackMode.Load() {
|
||||
h.misses.Add(1)
|
||||
return nil, 0, false, nil
|
||||
}
|
||||
|
||||
// Try L2
|
||||
value, ttl, exists, err = h.secondary.Get(ctx, key)
|
||||
if err != nil {
|
||||
h.errors.Add(1)
|
||||
h.logger.Debugf("L2 get error for key %s: %v", key, err)
|
||||
h.recordL2Error()
|
||||
h.misses.Add(1)
|
||||
return nil, 0, false, nil // Don't propagate L2 errors
|
||||
}
|
||||
|
||||
if !exists {
|
||||
h.misses.Add(1)
|
||||
return nil, 0, false, nil
|
||||
}
|
||||
|
||||
h.l2Hits.Add(1)
|
||||
|
||||
// Populate L1 cache with value from L2 (write-through on read)
|
||||
// Use goroutine to avoid blocking the read path
|
||||
go func() {
|
||||
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
if err := h.primary.Set(writeCtx, key, value, ttl); err != nil {
|
||||
h.logger.Debugf("Failed to populate L1 cache from L2 for key %s: %v", key, err)
|
||||
} else {
|
||||
h.logger.Debugf("Populated L1 cache from L2 for key: %s", key)
|
||||
}
|
||||
}()
|
||||
|
||||
return value, ttl, true, nil
|
||||
}
|
||||
|
||||
// Delete removes a key from both L1 and L2 caches
|
||||
func (h *HybridBackend) Delete(ctx context.Context, key string) (bool, error) {
|
||||
var deleted bool
|
||||
|
||||
// Delete from L1
|
||||
if d, err := h.primary.Delete(ctx, key); err != nil {
|
||||
h.logger.Debugf("Failed to delete from L1 cache: %v", err)
|
||||
} else if d {
|
||||
deleted = true
|
||||
}
|
||||
|
||||
// Delete from L2 if not in fallback mode
|
||||
if !h.fallbackMode.Load() {
|
||||
if d, err := h.secondary.Delete(ctx, key); err != nil {
|
||||
h.logger.Debugf("Failed to delete from L2 cache: %v", err)
|
||||
h.recordL2Error()
|
||||
} else if d {
|
||||
deleted = true
|
||||
}
|
||||
}
|
||||
|
||||
return deleted, nil
|
||||
}
|
||||
|
||||
// Exists checks if a key exists in either cache
|
||||
func (h *HybridBackend) Exists(ctx context.Context, key string) (bool, error) {
|
||||
// Check L1 first
|
||||
if exists, err := h.primary.Exists(ctx, key); err == nil && exists {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Check L2 if not in fallback mode
|
||||
if !h.fallbackMode.Load() {
|
||||
if exists, err := h.secondary.Exists(ctx, key); err == nil && exists {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Clear removes all keys from both caches
|
||||
func (h *HybridBackend) Clear(ctx context.Context) error {
|
||||
var lastErr error
|
||||
|
||||
// Clear L1
|
||||
if err := h.primary.Clear(ctx); err != nil {
|
||||
h.logger.Errorf("Failed to clear L1 cache: %v", err)
|
||||
lastErr = err
|
||||
}
|
||||
|
||||
// Clear L2 if not in fallback mode
|
||||
if !h.fallbackMode.Load() {
|
||||
if err := h.secondary.Clear(ctx); err != nil {
|
||||
h.logger.Errorf("Failed to clear L2 cache: %v", err)
|
||||
h.recordL2Error()
|
||||
lastErr = err
|
||||
}
|
||||
}
|
||||
|
||||
return lastErr
|
||||
}
|
||||
|
||||
// GetStats returns statistics for the hybrid cache
|
||||
func (h *HybridBackend) GetStats() map[string]interface{} {
|
||||
l1Hits := h.l1Hits.Load()
|
||||
l2Hits := h.l2Hits.Load()
|
||||
misses := h.misses.Load()
|
||||
total := l1Hits + l2Hits + misses
|
||||
|
||||
stats := map[string]interface{}{
|
||||
"type": TypeHybrid,
|
||||
"l1_hits": l1Hits,
|
||||
"l2_hits": l2Hits,
|
||||
"misses": misses,
|
||||
"total": total,
|
||||
"l1_writes": h.l1Writes.Load(),
|
||||
"l2_writes": h.l2Writes.Load(),
|
||||
"errors": h.errors.Load(),
|
||||
"fallback_mode": h.fallbackMode.Load(),
|
||||
}
|
||||
|
||||
if total > 0 {
|
||||
stats["l1_hit_rate"] = float64(l1Hits) / float64(total)
|
||||
stats["l2_hit_rate"] = float64(l2Hits) / float64(total)
|
||||
stats["overall_hit_rate"] = float64(l1Hits+l2Hits) / float64(total)
|
||||
}
|
||||
|
||||
// Add sub-backend stats
|
||||
stats["l1_stats"] = h.primary.GetStats()
|
||||
stats["l2_stats"] = h.secondary.GetStats()
|
||||
|
||||
// Add last L2 error time if available
|
||||
if lastErr := h.lastL2Error.Load(); lastErr != nil {
|
||||
if t, ok := lastErr.(time.Time); ok {
|
||||
stats["last_l2_error"] = t.Format(time.RFC3339)
|
||||
stats["seconds_since_l2_error"] = time.Since(t).Seconds()
|
||||
}
|
||||
}
|
||||
|
||||
return stats
|
||||
}
|
||||
|
||||
// Ping checks if both backends are healthy
|
||||
func (h *HybridBackend) Ping(ctx context.Context) error {
|
||||
// Check L1
|
||||
if err := h.primary.Ping(ctx); err != nil {
|
||||
return fmt.Errorf("L1 ping failed: %w", err)
|
||||
}
|
||||
|
||||
// Check L2 (but don't fail if it's down)
|
||||
if err := h.secondary.Ping(ctx); err != nil {
|
||||
h.logger.Warnf("L2 ping failed: %v", err)
|
||||
h.recordL2Error()
|
||||
// Don't return error - we can operate with L1 only
|
||||
} else {
|
||||
// L2 is healthy, clear fallback mode if it was set
|
||||
if h.fallbackMode.CompareAndSwap(true, false) {
|
||||
h.logger.Infof("L2 backend recovered, exiting fallback mode")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close shuts down the hybrid backend
|
||||
func (h *HybridBackend) Close() error {
|
||||
// Cancel context to stop workers
|
||||
h.cancel()
|
||||
|
||||
// Close async write channel
|
||||
close(h.asyncWriteBuffer)
|
||||
|
||||
// Wait for workers to finish with timeout
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
h.wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// Workers finished
|
||||
case <-time.After(5 * time.Second):
|
||||
h.logger.Warnf("Timeout waiting for workers to finish")
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
|
||||
// Close backends
|
||||
if err := h.primary.Close(); err != nil {
|
||||
h.logger.Errorf("Failed to close L1 backend: %v", err)
|
||||
lastErr = err
|
||||
}
|
||||
|
||||
if err := h.secondary.Close(); err != nil {
|
||||
h.logger.Errorf("Failed to close L2 backend: %v", err)
|
||||
lastErr = err
|
||||
}
|
||||
|
||||
h.logger.Infof("HybridBackend closed")
|
||||
|
||||
return lastErr
|
||||
}
|
||||
|
||||
// GetMany retrieves multiple values efficiently
|
||||
func (h *HybridBackend) GetMany(ctx context.Context, keys []string) (map[string][]byte, error) {
|
||||
if len(keys) == 0 {
|
||||
return make(map[string][]byte), nil
|
||||
}
|
||||
|
||||
results := make(map[string][]byte, len(keys))
|
||||
missingKeys := make([]string, 0)
|
||||
|
||||
// Try L1 first for all keys
|
||||
for _, key := range keys {
|
||||
if value, _, exists, _ := h.primary.Get(ctx, key); exists {
|
||||
results[key] = value
|
||||
h.l1Hits.Add(1)
|
||||
} else {
|
||||
missingKeys = append(missingKeys, key)
|
||||
}
|
||||
}
|
||||
|
||||
// If all found in L1 or in fallback mode, return
|
||||
if len(missingKeys) == 0 || h.fallbackMode.Load() {
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// Try L2 for missing keys using batch operation if available
|
||||
if batcher, ok := h.secondary.(interface {
|
||||
GetMany(context.Context, []string) (map[string][]byte, error)
|
||||
}); ok {
|
||||
l2Results, err := batcher.GetMany(ctx, missingKeys)
|
||||
if err != nil {
|
||||
h.logger.Debugf("L2 batch get error: %v", err)
|
||||
h.recordL2Error()
|
||||
} else {
|
||||
for key, value := range l2Results {
|
||||
results[key] = value
|
||||
h.l2Hits.Add(1)
|
||||
|
||||
// Asynchronously populate L1
|
||||
go func(k string, v []byte) {
|
||||
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
_ = h.primary.Set(writeCtx, k, v, 0) // Use default TTL
|
||||
}(key, value)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Fallback to individual gets
|
||||
for _, key := range missingKeys {
|
||||
if value, ttl, exists, err := h.secondary.Get(ctx, key); err == nil && exists {
|
||||
results[key] = value
|
||||
h.l2Hits.Add(1)
|
||||
|
||||
// Asynchronously populate L1
|
||||
go func(k string, v []byte, t time.Duration) {
|
||||
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
_ = h.primary.Set(writeCtx, k, v, t)
|
||||
}(key, value, ttl)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Count misses for keys not found anywhere
|
||||
for _, key := range keys {
|
||||
if _, found := results[key]; !found {
|
||||
h.misses.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// SetMany stores multiple key-value pairs efficiently
|
||||
func (h *HybridBackend) SetMany(ctx context.Context, items map[string][]byte, ttl time.Duration) error {
|
||||
if len(items) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write to L1 first
|
||||
for key, value := range items {
|
||||
if err := h.primary.Set(ctx, key, value, ttl); err != nil {
|
||||
h.logger.Debugf("Failed to write to L1 in batch: %v", err)
|
||||
} else {
|
||||
h.l1Writes.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
// Skip L2 if in fallback mode
|
||||
if h.fallbackMode.Load() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if L2 supports batch operations
|
||||
if batcher, ok := h.secondary.(interface {
|
||||
SetMany(context.Context, map[string][]byte, time.Duration) error
|
||||
}); ok {
|
||||
if err := batcher.SetMany(ctx, items, ttl); err != nil {
|
||||
h.logger.Warnf("Failed to batch write to L2: %v", err)
|
||||
h.recordL2Error()
|
||||
} else {
|
||||
h.l2Writes.Add(int64(len(items)))
|
||||
}
|
||||
} else {
|
||||
// Fallback to individual sets
|
||||
for key, value := range items {
|
||||
cacheType := h.extractCacheType(key)
|
||||
if h.syncWriteCacheTypes[cacheType] {
|
||||
// Sync write for critical types
|
||||
if err := h.secondary.Set(ctx, key, value, ttl); err != nil {
|
||||
h.logger.Debugf("Failed to write to L2: %v", err)
|
||||
h.recordL2Error()
|
||||
} else {
|
||||
h.l2Writes.Add(1)
|
||||
}
|
||||
} else {
|
||||
// Async write for non-critical types
|
||||
select {
|
||||
case h.asyncWriteBuffer <- &asyncWriteItem{
|
||||
key: key,
|
||||
value: value,
|
||||
ttl: ttl,
|
||||
ctx: ctx,
|
||||
}:
|
||||
// Queued
|
||||
default:
|
||||
h.logger.Warnf("Async buffer full for batch write")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// asyncWriteWorker processes asynchronous writes to L2
|
||||
func (h *HybridBackend) asyncWriteWorker() {
|
||||
defer h.wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-h.ctx.Done():
|
||||
// Drain remaining items with best effort
|
||||
for len(h.asyncWriteBuffer) > 0 {
|
||||
select {
|
||||
case item := <-h.asyncWriteBuffer:
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
_ = h.secondary.Set(ctx, item.key, item.value, item.ttl)
|
||||
cancel()
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
|
||||
case item, ok := <-h.asyncWriteBuffer:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// Skip if in fallback mode
|
||||
if h.fallbackMode.Load() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Perform the write with a timeout
|
||||
writeCtx, cancel := context.WithTimeout(item.ctx, 500*time.Millisecond)
|
||||
if err := h.secondary.Set(writeCtx, item.key, item.value, item.ttl); err != nil {
|
||||
h.errors.Add(1)
|
||||
h.logger.Debugf("Async write to L2 failed for key %s: %v", item.key, err)
|
||||
h.recordL2Error()
|
||||
} else {
|
||||
h.l2Writes.Add(1)
|
||||
h.logger.Debugf("Async write to L2 completed for key: %s", item.key)
|
||||
}
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// healthMonitor periodically checks L2 health and manages fallback mode
|
||||
func (h *HybridBackend) healthMonitor() {
|
||||
defer h.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-h.ctx.Done():
|
||||
return
|
||||
|
||||
case <-ticker.C:
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
|
||||
if err := h.secondary.Ping(ctx); err != nil {
|
||||
if !h.fallbackMode.Load() {
|
||||
h.fallbackMode.Store(true)
|
||||
h.logger.Warnf("L2 backend unhealthy, entering fallback mode: %v", err)
|
||||
}
|
||||
} else {
|
||||
if h.fallbackMode.CompareAndSwap(true, false) {
|
||||
h.logger.Infof("L2 backend healthy, exiting fallback mode")
|
||||
}
|
||||
}
|
||||
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// recordL2Error records the timestamp of an L2 error
|
||||
func (h *HybridBackend) recordL2Error() {
|
||||
h.lastL2Error.Store(time.Now())
|
||||
|
||||
// Check if we should enter fallback mode based on recent errors
|
||||
if !h.fallbackMode.Load() {
|
||||
// Simple heuristic: if we've had an error in the last second, consider L2 unhealthy
|
||||
if lastErr := h.lastL2Error.Load(); lastErr != nil {
|
||||
if t, ok := lastErr.(time.Time); ok && time.Since(t) < time.Second {
|
||||
h.fallbackMode.Store(true)
|
||||
h.logger.Warnf("Multiple L2 errors detected, entering fallback mode")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// extractCacheType attempts to determine the cache type from the key
|
||||
func (h *HybridBackend) extractCacheType(key string) string {
|
||||
// Simple heuristic based on key prefixes
|
||||
// This should match the actual cache type strategy in the main application
|
||||
|
||||
if len(key) > 10 {
|
||||
prefix := key[:10]
|
||||
switch {
|
||||
case contains(prefix, "blacklist"):
|
||||
return "blacklist"
|
||||
case contains(prefix, "token"):
|
||||
return "token"
|
||||
case contains(prefix, "metadata"):
|
||||
return "metadata"
|
||||
case contains(prefix, "jwk"):
|
||||
return "jwk"
|
||||
case contains(prefix, "session"):
|
||||
return "session"
|
||||
case contains(prefix, "introspect"):
|
||||
return "introspection"
|
||||
}
|
||||
}
|
||||
|
||||
return "general"
|
||||
}
|
||||
|
||||
// contains checks if a string contains a substring (case-insensitive)
|
||||
func contains(s, substr string) bool {
|
||||
if len(substr) > len(s) {
|
||||
return false
|
||||
}
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
match := true
|
||||
for j := 0; j < len(substr); j++ {
|
||||
if toLower(s[i+j]) != toLower(substr[j]) {
|
||||
match = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if match {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// toLower converts a byte to lowercase
|
||||
func toLower(b byte) byte {
|
||||
if b >= 'A' && b <= 'Z' {
|
||||
return b + 32
|
||||
}
|
||||
return b
|
||||
}
|
||||
+1490
File diff suppressed because it is too large
Load Diff
Vendored
+102
@@ -0,0 +1,102 @@
|
||||
// Package backend provides cache backend implementations for the Traefik OIDC plugin.
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// CacheBackend defines the interface for all cache backend implementations
|
||||
// Implementations include: MemoryBackend, RedisBackend, and HybridBackend
|
||||
type CacheBackend interface {
|
||||
// Set stores a value in the cache with the specified TTL
|
||||
// Returns an error if the operation fails
|
||||
Set(ctx context.Context, key string, value []byte, ttl time.Duration) error
|
||||
|
||||
// Get retrieves a value from the cache
|
||||
// Returns: value, remaining TTL, exists flag, and error
|
||||
// If the key doesn't exist, exists will be false
|
||||
Get(ctx context.Context, key string) (value []byte, ttl time.Duration, exists bool, err error)
|
||||
|
||||
// Delete removes a key from the cache
|
||||
// Returns true if the key was deleted, false if it didn't exist
|
||||
Delete(ctx context.Context, key string) (bool, error)
|
||||
|
||||
// Exists checks if a key exists in the cache
|
||||
Exists(ctx context.Context, key string) (bool, error)
|
||||
|
||||
// Clear removes all keys from the cache
|
||||
Clear(ctx context.Context) error
|
||||
|
||||
// GetStats returns cache statistics
|
||||
// Stats include: hits, misses, size, memory usage, etc.
|
||||
GetStats() map[string]interface{}
|
||||
|
||||
// Close shuts down the cache backend and releases resources
|
||||
Close() error
|
||||
|
||||
// Ping checks if the backend is healthy and responsive
|
||||
Ping(ctx context.Context) error
|
||||
}
|
||||
|
||||
// BackendStats represents statistics for a cache backend
|
||||
type BackendStats struct {
|
||||
StartTime time.Time
|
||||
LastErrorTime time.Time
|
||||
Type BackendType
|
||||
LastError string
|
||||
Deletes int64
|
||||
Errors int64
|
||||
Evictions int64
|
||||
CurrentSize int64
|
||||
MaxSize int64
|
||||
MemoryUsage int64
|
||||
AverageGetLatency time.Duration
|
||||
AverageSetLatency time.Duration
|
||||
Sets int64
|
||||
Misses int64
|
||||
Uptime time.Duration
|
||||
Hits int64
|
||||
}
|
||||
|
||||
// BackendCapabilities describes the capabilities of a cache backend
|
||||
type BackendCapabilities struct {
|
||||
// Distributed indicates if the backend is distributed across multiple instances
|
||||
Distributed bool
|
||||
|
||||
// Persistent indicates if the backend persists data across restarts
|
||||
Persistent bool
|
||||
|
||||
// Eviction indicates if the backend supports automatic eviction
|
||||
Eviction bool
|
||||
|
||||
// TTL indicates if the backend supports TTL (time-to-live)
|
||||
TTL bool
|
||||
|
||||
// MaxKeySize is the maximum size of a key in bytes (0 = unlimited)
|
||||
MaxKeySize int64
|
||||
|
||||
// MaxValueSize is the maximum size of a value in bytes (0 = unlimited)
|
||||
MaxValueSize int64
|
||||
|
||||
// MaxKeys is the maximum number of keys (0 = unlimited)
|
||||
MaxKeys int64
|
||||
|
||||
// SupportsExpire indicates if the backend supports expiration
|
||||
SupportsExpire bool
|
||||
|
||||
// SupportsMultiGet indicates if the backend supports batch get operations
|
||||
SupportsMultiGet bool
|
||||
|
||||
// SupportsTransaction indicates if the backend supports transactions
|
||||
SupportsTransaction bool
|
||||
|
||||
// SupportsCompression indicates if the backend supports compression
|
||||
SupportsCompression bool
|
||||
|
||||
// RequiresSerialize indicates if values must be serialized
|
||||
RequiresSerialize bool
|
||||
|
||||
// AtomicOperations indicates if the backend supports atomic operations
|
||||
AtomicOperations bool
|
||||
}
|
||||
+421
@@ -0,0 +1,421 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestCacheBackendContract defines a set of tests that all CacheBackend implementations must pass
|
||||
// This ensures that Memory, Redis, and Hybrid backends all behave consistently
|
||||
func TestCacheBackendContract(t *testing.T) {
|
||||
// Test suite will be run against each backend type
|
||||
t.Run("MemoryBackend", func(t *testing.T) {
|
||||
backend := setupMemoryBackend(t)
|
||||
runContractTests(t, backend)
|
||||
})
|
||||
|
||||
t.Run("RedisBackend", func(t *testing.T) {
|
||||
backend := setupRedisBackend(t)
|
||||
runContractTests(t, backend)
|
||||
})
|
||||
|
||||
t.Run("HybridBackend", func(t *testing.T) {
|
||||
backend := setupHybridBackend(t)
|
||||
runContractTests(t, backend)
|
||||
})
|
||||
}
|
||||
|
||||
// runContractTests executes all contract tests against a backend
|
||||
func runContractTests(t *testing.T, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("BasicSetGet", func(t *testing.T) {
|
||||
testBasicSetGet(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("GetNonExistent", func(t *testing.T) {
|
||||
testGetNonExistent(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("UpdateExisting", func(t *testing.T) {
|
||||
testUpdateExisting(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("Delete", func(t *testing.T) {
|
||||
testDelete(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("DeleteNonExistent", func(t *testing.T) {
|
||||
testDeleteNonExistent(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("Exists", func(t *testing.T) {
|
||||
testExists(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("TTLExpiration", func(t *testing.T) {
|
||||
testTTLExpiration(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("Clear", func(t *testing.T) {
|
||||
testClear(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("Ping", func(t *testing.T) {
|
||||
testPing(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("Stats", func(t *testing.T) {
|
||||
testStats(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("ConcurrentAccess", func(t *testing.T) {
|
||||
testConcurrentAccess(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("LargeValues", func(t *testing.T) {
|
||||
testLargeValues(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("EmptyValues", func(t *testing.T) {
|
||||
testEmptyValues(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("SpecialCharactersInKeys", func(t *testing.T) {
|
||||
testSpecialCharactersInKeys(t, ctx, backend)
|
||||
})
|
||||
}
|
||||
|
||||
// testBasicSetGet verifies basic set and get operations
|
||||
func testBasicSetGet(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
key := "test-key-1"
|
||||
value := []byte("test-value-1")
|
||||
ttl := 1 * time.Minute
|
||||
|
||||
// Set value
|
||||
err := backend.Set(ctx, key, value, ttl)
|
||||
require.NoError(t, err, "Set should not return error")
|
||||
|
||||
// Get value
|
||||
retrieved, remainingTTL, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err, "Get should not return error")
|
||||
assert.True(t, exists, "Key should exist")
|
||||
assert.Equal(t, value, retrieved, "Retrieved value should match")
|
||||
assert.Greater(t, remainingTTL, 50*time.Second, "TTL should be close to original")
|
||||
assert.LessOrEqual(t, remainingTTL, ttl, "TTL should not exceed original")
|
||||
}
|
||||
|
||||
// testGetNonExistent verifies behavior when getting non-existent keys
|
||||
func testGetNonExistent(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
key := "non-existent-key"
|
||||
|
||||
retrieved, ttl, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err, "Get should not return error for non-existent key")
|
||||
assert.False(t, exists, "Key should not exist")
|
||||
assert.Nil(t, retrieved, "Value should be nil")
|
||||
assert.Equal(t, time.Duration(0), ttl, "TTL should be zero")
|
||||
}
|
||||
|
||||
// testUpdateExisting verifies updating an existing key
|
||||
func testUpdateExisting(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
key := "update-key"
|
||||
value1 := []byte("original-value")
|
||||
value2 := []byte("updated-value")
|
||||
ttl := 1 * time.Minute
|
||||
|
||||
// Set initial value
|
||||
err := backend.Set(ctx, key, value1, ttl)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update value
|
||||
err = backend.Set(ctx, key, value2, ttl)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify updated value
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, value2, retrieved, "Value should be updated")
|
||||
}
|
||||
|
||||
// testDelete verifies delete operation
|
||||
func testDelete(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
key := "delete-key"
|
||||
value := []byte("delete-value")
|
||||
|
||||
// Set value
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify exists
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
// Delete
|
||||
deleted, err := backend.Delete(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, deleted, "Delete should return true for existing key")
|
||||
|
||||
// Verify deleted
|
||||
exists, err = backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "Key should not exist after delete")
|
||||
}
|
||||
|
||||
// testDeleteNonExistent verifies deleting non-existent keys
|
||||
func testDeleteNonExistent(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
key := "non-existent-delete-key"
|
||||
|
||||
deleted, err := backend.Delete(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, deleted, "Delete should return false for non-existent key")
|
||||
}
|
||||
|
||||
// testExists verifies the Exists operation
|
||||
func testExists(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
key := "exists-key"
|
||||
value := []byte("exists-value")
|
||||
|
||||
// Check non-existent key
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "Key should not exist initially")
|
||||
|
||||
// Set value
|
||||
err = backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check existing key
|
||||
exists, err = backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Key should exist after Set")
|
||||
}
|
||||
|
||||
// testTTLExpiration verifies TTL expiration behavior
|
||||
func testTTLExpiration(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
key := "ttl-key"
|
||||
value := []byte("ttl-value")
|
||||
shortTTL := 100 * time.Millisecond
|
||||
|
||||
// Set with short TTL
|
||||
err := backend.Set(ctx, key, value, shortTTL)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify exists immediately
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Key should exist immediately after Set")
|
||||
|
||||
// Wait for expiration
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// Verify expired
|
||||
exists, err = backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "Key should not exist after TTL expiration")
|
||||
}
|
||||
|
||||
// testClear verifies Clear operation
|
||||
func testClear(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
// Set multiple keys
|
||||
for i := 0; i < 5; i++ {
|
||||
key := fmt.Sprintf("clear-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("clear-value-%d", i))
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Give async writes time to complete before clearing
|
||||
// This prevents race conditions with async write workers
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Clear all
|
||||
err := backend.Clear(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify all keys are gone
|
||||
for i := 0; i < 5; i++ {
|
||||
key := fmt.Sprintf("clear-key-%d", i)
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "Key should not exist after Clear")
|
||||
}
|
||||
}
|
||||
|
||||
// testPing verifies Ping operation
|
||||
func testPing(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
err := backend.Ping(ctx)
|
||||
assert.NoError(t, err, "Ping should succeed on healthy backend")
|
||||
}
|
||||
|
||||
// testStats verifies GetStats operation
|
||||
func testStats(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
stats := backend.GetStats()
|
||||
assert.NotNil(t, stats, "Stats should not be nil")
|
||||
|
||||
// Stats should contain basic metrics
|
||||
_, hasHits := stats["hits"]
|
||||
_, hasMisses := stats["misses"]
|
||||
assert.True(t, hasHits || hasMisses, "Stats should contain hits or misses")
|
||||
}
|
||||
|
||||
// testConcurrentAccess verifies thread safety
|
||||
func testConcurrentAccess(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
goroutines := 10
|
||||
iterations := 20
|
||||
|
||||
// Concurrent writes
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < iterations; j++ {
|
||||
key := fmt.Sprintf("concurrent-key-%d-%d", id, j)
|
||||
value := []byte(fmt.Sprintf("concurrent-value-%d-%d", id, j))
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Read back
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
assert.NoError(t, err)
|
||||
if exists {
|
||||
assert.Equal(t, value, retrieved)
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// testLargeValues verifies handling of large values
|
||||
func testLargeValues(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
key := "large-value-key"
|
||||
value := GenerateLargeValue(1024 * 1024) // 1MB
|
||||
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err, "Should handle large values")
|
||||
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, len(value), len(retrieved), "Large value should be retrieved intact")
|
||||
}
|
||||
|
||||
// testEmptyValues verifies handling of empty values
|
||||
func testEmptyValues(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
key := "empty-value-key"
|
||||
value := []byte{}
|
||||
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err, "Should handle empty values")
|
||||
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Empty value should exist")
|
||||
assert.Equal(t, 0, len(retrieved), "Retrieved value should be empty")
|
||||
}
|
||||
|
||||
// testSpecialCharactersInKeys verifies handling of special characters in keys
|
||||
func testSpecialCharactersInKeys(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
specialKeys := []string{
|
||||
"key:with:colons",
|
||||
"key/with/slashes",
|
||||
"key-with-dashes",
|
||||
"key_with_underscores",
|
||||
"key.with.dots",
|
||||
"key|with|pipes",
|
||||
}
|
||||
|
||||
for _, key := range specialKeys {
|
||||
value := []byte(fmt.Sprintf("value-for-%s", key))
|
||||
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err, "Should handle special character in key: %s", key)
|
||||
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Key with special characters should exist: %s", key)
|
||||
assert.Equal(t, value, retrieved)
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions to setup different backend types
|
||||
// These will be implemented in respective test files
|
||||
|
||||
func setupMemoryBackend(t *testing.T) CacheBackend {
|
||||
t.Helper()
|
||||
// This will be implemented in memory_test.go
|
||||
// For now, return nil to allow compilation
|
||||
t.Skip("MemoryBackend implementation pending")
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupRedisBackend(t *testing.T) CacheBackend {
|
||||
t.Helper()
|
||||
// This will be implemented in redis_test.go
|
||||
// For now, return nil to allow compilation
|
||||
t.Skip("RedisBackend implementation pending")
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupHybridBackend(t *testing.T) CacheBackend {
|
||||
t.Helper()
|
||||
|
||||
primary := newMockBackend()
|
||||
secondary := newMockBackend()
|
||||
|
||||
config := &HybridConfig{
|
||||
Primary: primary,
|
||||
Secondary: secondary,
|
||||
AsyncBufferSize: 100,
|
||||
Logger: NewTestLogger(t),
|
||||
}
|
||||
|
||||
hybrid, err := NewHybridBackend(config)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
hybrid.Close()
|
||||
})
|
||||
|
||||
return hybrid
|
||||
}
|
||||
Vendored
+535
@@ -0,0 +1,535 @@
|
||||
// Package backend provides cache backend implementations for the Traefik OIDC plugin.
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Default configuration values
|
||||
const (
|
||||
defaultShardCount = 256
|
||||
defaultMaxSize = int64(10000)
|
||||
defaultMaxMemory = int64(100 * 1024 * 1024) // 100MB
|
||||
defaultCleanupInterval = 5 * time.Minute
|
||||
)
|
||||
|
||||
// memoryCacheItem represents an item in the memory cache
|
||||
type memoryCacheItem struct {
|
||||
expiresAt time.Time
|
||||
createdAt time.Time
|
||||
accessedAt time.Time
|
||||
value interface{}
|
||||
element interface{} // *list.Element, using interface{} to avoid import cycle
|
||||
key string
|
||||
accessCount int64
|
||||
size int64
|
||||
}
|
||||
|
||||
// isExpired checks if the item is expired
|
||||
func (item *memoryCacheItem) isExpired() bool {
|
||||
if item.expiresAt.IsZero() {
|
||||
return false
|
||||
}
|
||||
return time.Now().After(item.expiresAt)
|
||||
}
|
||||
|
||||
// MemoryCacheBackend implements the CacheBackend interface using sharded in-memory storage
|
||||
// The sharded design reduces lock contention by partitioning keys across multiple shards,
|
||||
// each with its own lock.
|
||||
type MemoryCacheBackend struct {
|
||||
shards []*cacheShard
|
||||
startTime time.Time
|
||||
lastErrorTime time.Time
|
||||
cleanupDone chan struct{}
|
||||
cleanupTicker *time.Ticker
|
||||
lastError string
|
||||
shardCount uint32
|
||||
shardMask uint32
|
||||
maxSize int64
|
||||
maxMemory int64
|
||||
cleanupInterval time.Duration
|
||||
|
||||
// Global stats (aggregated from shards)
|
||||
hits atomic.Int64
|
||||
misses atomic.Int64
|
||||
sets atomic.Int64
|
||||
deletes atomic.Int64
|
||||
evictions atomic.Int64
|
||||
errors atomic.Int64
|
||||
|
||||
// Latency tracking
|
||||
totalGetTime atomic.Int64
|
||||
totalSetTime atomic.Int64
|
||||
getCount atomic.Int64
|
||||
setCount atomic.Int64
|
||||
|
||||
// State
|
||||
closed atomic.Bool
|
||||
mu sync.RWMutex // For global operations like stats and error tracking
|
||||
}
|
||||
|
||||
// NewMemoryCacheBackend creates a new sharded memory cache backend
|
||||
func NewMemoryCacheBackend(maxSize int64, maxMemory int64, cleanupInterval time.Duration) *MemoryCacheBackend {
|
||||
if maxSize <= 0 {
|
||||
maxSize = defaultMaxSize
|
||||
}
|
||||
if maxMemory <= 0 {
|
||||
maxMemory = defaultMaxMemory
|
||||
}
|
||||
if cleanupInterval <= 0 {
|
||||
cleanupInterval = defaultCleanupInterval
|
||||
}
|
||||
|
||||
shardCount := uint32(defaultShardCount)
|
||||
|
||||
// For very small caches, reduce shard count to maintain sensible per-shard limits
|
||||
// Ensure each shard can hold at least 2 items for proper LRU behavior
|
||||
for shardCount > 1 && maxSize/int64(shardCount) < 2 {
|
||||
shardCount /= 2
|
||||
}
|
||||
if shardCount < 1 {
|
||||
shardCount = 1
|
||||
}
|
||||
|
||||
// Per-shard limits are soft hints; global limits are enforced
|
||||
// Give shards 2x the average to allow for uneven distribution
|
||||
shardMaxSize := (maxSize * 2) / int64(shardCount)
|
||||
if shardMaxSize < 4 {
|
||||
shardMaxSize = 4
|
||||
}
|
||||
shardMaxMemory := (maxMemory * 2) / int64(shardCount)
|
||||
if shardMaxMemory < 4096 {
|
||||
shardMaxMemory = 4096 // Minimum 4KB per shard
|
||||
}
|
||||
|
||||
m := &MemoryCacheBackend{
|
||||
shards: make([]*cacheShard, shardCount),
|
||||
shardCount: shardCount,
|
||||
shardMask: shardCount - 1, // For fast modulo with power-of-2
|
||||
maxSize: maxSize,
|
||||
maxMemory: maxMemory,
|
||||
startTime: time.Now(),
|
||||
cleanupInterval: cleanupInterval,
|
||||
cleanupDone: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Initialize shards
|
||||
for i := uint32(0); i < shardCount; i++ {
|
||||
m.shards[i] = newCacheShard(shardMaxSize, shardMaxMemory)
|
||||
}
|
||||
|
||||
// Start cleanup goroutine
|
||||
m.cleanupTicker = time.NewTicker(cleanupInterval)
|
||||
go m.cleanupLoop()
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// getShard returns the shard for a given key
|
||||
func (m *MemoryCacheBackend) getShard(key string) *cacheShard {
|
||||
hash := fnv32(key)
|
||||
return m.shards[hash&m.shardMask]
|
||||
}
|
||||
|
||||
// cleanupLoop runs periodic cleanup of expired items
|
||||
func (m *MemoryCacheBackend) cleanupLoop() {
|
||||
for {
|
||||
select {
|
||||
case <-m.cleanupTicker.C:
|
||||
m.cleanupExpired()
|
||||
case <-m.cleanupDone:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupExpired removes all expired items from all shards
|
||||
func (m *MemoryCacheBackend) cleanupExpired() {
|
||||
if m.closed.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
totalRemoved := 0
|
||||
for _, shard := range m.shards {
|
||||
totalRemoved += shard.cleanup()
|
||||
}
|
||||
|
||||
if totalRemoved > 0 {
|
||||
m.evictions.Add(int64(totalRemoved))
|
||||
}
|
||||
}
|
||||
|
||||
// Get retrieves a value from the cache
|
||||
func (m *MemoryCacheBackend) Get(ctx context.Context, key string) (interface{}, error) {
|
||||
if m.closed.Load() {
|
||||
return nil, ErrBackendUnavailable
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
duration := time.Since(start).Nanoseconds()
|
||||
m.totalGetTime.Add(duration)
|
||||
m.getCount.Add(1)
|
||||
}()
|
||||
|
||||
shard := m.getShard(key)
|
||||
value, exists, expired := shard.get(key)
|
||||
|
||||
if expired {
|
||||
// Clean up expired item
|
||||
shard.delete(key)
|
||||
m.misses.Add(1)
|
||||
return nil, ErrCacheMiss
|
||||
}
|
||||
|
||||
if !exists {
|
||||
m.misses.Add(1)
|
||||
return nil, ErrCacheMiss
|
||||
}
|
||||
|
||||
m.hits.Add(1)
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// Set stores a value in the cache with optional TTL
|
||||
func (m *MemoryCacheBackend) Set(ctx context.Context, key string, value interface{}, ttl time.Duration) error {
|
||||
if m.closed.Load() {
|
||||
return ErrBackendUnavailable
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
duration := time.Since(start).Nanoseconds()
|
||||
m.totalSetTime.Add(duration)
|
||||
m.setCount.Add(1)
|
||||
}()
|
||||
|
||||
// Calculate item size
|
||||
itemSize := int64(len(key)) + estimateValueSize(value)
|
||||
|
||||
// Enforce global limits before adding new item
|
||||
m.enforceGlobalLimits(itemSize)
|
||||
|
||||
var expiresAt time.Time
|
||||
if ttl > 0 {
|
||||
expiresAt = time.Now().Add(ttl)
|
||||
}
|
||||
|
||||
shard := m.getShard(key)
|
||||
shard.set(key, value, expiresAt, itemSize)
|
||||
|
||||
m.sets.Add(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
// enforceGlobalLimits ensures global size and memory limits are respected
|
||||
// by evicting from shards when necessary
|
||||
func (m *MemoryCacheBackend) enforceGlobalLimits(newItemSize int64) {
|
||||
// Check and enforce size limit
|
||||
for {
|
||||
totalSize, totalMemory := m.getGlobalStats()
|
||||
|
||||
needsSizeEviction := m.maxSize > 0 && totalSize >= m.maxSize
|
||||
needsMemoryEviction := m.maxMemory > 0 && totalMemory+newItemSize > m.maxMemory
|
||||
|
||||
if !needsSizeEviction && !needsMemoryEviction {
|
||||
break
|
||||
}
|
||||
|
||||
// Find the shard with the most items and evict from it
|
||||
evicted := m.evictFromLargestShard()
|
||||
if !evicted {
|
||||
break // No more items to evict
|
||||
}
|
||||
m.evictions.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
// getGlobalStats returns the total size and memory usage across all shards
|
||||
func (m *MemoryCacheBackend) getGlobalStats() (totalSize, totalMemory int64) {
|
||||
for _, shard := range m.shards {
|
||||
size, memory := shard.stats()
|
||||
totalSize += size
|
||||
totalMemory += memory
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// evictFromLargestShard evicts the globally oldest item across all shards
|
||||
// This provides true LRU behavior even with sharding
|
||||
func (m *MemoryCacheBackend) evictFromLargestShard() bool {
|
||||
var oldestShard *cacheShard
|
||||
var oldestTime time.Time
|
||||
|
||||
for _, shard := range m.shards {
|
||||
accessTime := shard.getOldestAccessTime()
|
||||
// Skip empty shards
|
||||
if accessTime.IsZero() {
|
||||
continue
|
||||
}
|
||||
// Find the shard with the oldest (earliest) access time
|
||||
if oldestShard == nil || accessTime.Before(oldestTime) {
|
||||
oldestTime = accessTime
|
||||
oldestShard = shard
|
||||
}
|
||||
}
|
||||
|
||||
if oldestShard == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return oldestShard.evictOne()
|
||||
}
|
||||
|
||||
// Delete removes a key from the cache
|
||||
func (m *MemoryCacheBackend) Delete(ctx context.Context, key string) error {
|
||||
if m.closed.Load() {
|
||||
return ErrBackendUnavailable
|
||||
}
|
||||
|
||||
shard := m.getShard(key)
|
||||
if shard.delete(key) {
|
||||
m.deletes.Add(1)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Exists checks if a key exists in the cache
|
||||
func (m *MemoryCacheBackend) Exists(ctx context.Context, key string) (bool, error) {
|
||||
if m.closed.Load() {
|
||||
return false, ErrBackendUnavailable
|
||||
}
|
||||
|
||||
shard := m.getShard(key)
|
||||
return shard.exists(key), nil
|
||||
}
|
||||
|
||||
// Clear removes all items from the cache
|
||||
func (m *MemoryCacheBackend) Clear(ctx context.Context) error {
|
||||
if m.closed.Load() {
|
||||
return ErrBackendUnavailable
|
||||
}
|
||||
|
||||
for _, shard := range m.shards {
|
||||
shard.clear()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Keys returns all keys matching the pattern (use "*" for all keys)
|
||||
func (m *MemoryCacheBackend) Keys(ctx context.Context, pattern string) ([]string, error) {
|
||||
if m.closed.Load() {
|
||||
return nil, ErrBackendUnavailable
|
||||
}
|
||||
|
||||
var allKeys []string
|
||||
for _, shard := range m.shards {
|
||||
keys := shard.keys(pattern)
|
||||
allKeys = append(allKeys, keys...)
|
||||
}
|
||||
|
||||
return allKeys, nil
|
||||
}
|
||||
|
||||
// Size returns the total number of items in the cache
|
||||
func (m *MemoryCacheBackend) Size(ctx context.Context) (int64, error) {
|
||||
if m.closed.Load() {
|
||||
return 0, ErrBackendUnavailable
|
||||
}
|
||||
|
||||
var total int64
|
||||
for _, shard := range m.shards {
|
||||
size, _ := shard.stats()
|
||||
total += size
|
||||
}
|
||||
|
||||
return total, nil
|
||||
}
|
||||
|
||||
// TTL returns the remaining time-to-live for a key
|
||||
func (m *MemoryCacheBackend) TTL(ctx context.Context, key string) (time.Duration, error) {
|
||||
if m.closed.Load() {
|
||||
return 0, ErrBackendUnavailable
|
||||
}
|
||||
|
||||
shard := m.getShard(key)
|
||||
ttl, exists := shard.ttl(key)
|
||||
if !exists {
|
||||
return 0, ErrCacheMiss
|
||||
}
|
||||
|
||||
return ttl, nil
|
||||
}
|
||||
|
||||
// Expire updates the TTL for an existing key
|
||||
func (m *MemoryCacheBackend) Expire(ctx context.Context, key string, ttl time.Duration) error {
|
||||
if m.closed.Load() {
|
||||
return ErrBackendUnavailable
|
||||
}
|
||||
|
||||
shard := m.getShard(key)
|
||||
if !shard.expire(key, ttl) {
|
||||
return ErrCacheMiss
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetStats returns statistics about the cache backend
|
||||
func (m *MemoryCacheBackend) GetStats(ctx context.Context) (*BackendStats, error) {
|
||||
if m.closed.Load() {
|
||||
return nil, ErrBackendUnavailable
|
||||
}
|
||||
|
||||
// Aggregate stats from all shards
|
||||
var totalSize, totalMemory int64
|
||||
for _, shard := range m.shards {
|
||||
size, memory := shard.stats()
|
||||
totalSize += size
|
||||
totalMemory += memory
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
lastError := m.lastError
|
||||
lastErrorTime := m.lastErrorTime
|
||||
m.mu.RUnlock()
|
||||
|
||||
avgGetLatency := time.Duration(0)
|
||||
if getCount := m.getCount.Load(); getCount > 0 {
|
||||
avgGetLatency = time.Duration(m.totalGetTime.Load() / getCount)
|
||||
}
|
||||
|
||||
avgSetLatency := time.Duration(0)
|
||||
if setCount := m.setCount.Load(); setCount > 0 {
|
||||
avgSetLatency = time.Duration(m.totalSetTime.Load() / setCount)
|
||||
}
|
||||
|
||||
return &BackendStats{
|
||||
Type: TypeMemory,
|
||||
Hits: m.hits.Load(),
|
||||
Misses: m.misses.Load(),
|
||||
Sets: m.sets.Load(),
|
||||
Deletes: m.deletes.Load(),
|
||||
Errors: m.errors.Load(),
|
||||
Evictions: m.evictions.Load(),
|
||||
CurrentSize: totalSize,
|
||||
MaxSize: m.maxSize,
|
||||
MemoryUsage: totalMemory,
|
||||
AverageGetLatency: avgGetLatency,
|
||||
AverageSetLatency: avgSetLatency,
|
||||
LastError: lastError,
|
||||
LastErrorTime: lastErrorTime,
|
||||
Uptime: time.Since(m.startTime),
|
||||
StartTime: m.startTime,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Ping checks if the backend is healthy
|
||||
func (m *MemoryCacheBackend) Ping(ctx context.Context) error {
|
||||
if m.closed.Load() {
|
||||
return ErrBackendUnavailable
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the backend and releases resources
|
||||
func (m *MemoryCacheBackend) Close() error {
|
||||
if m.closed.Swap(true) {
|
||||
return nil // Already closed
|
||||
}
|
||||
|
||||
m.cleanupTicker.Stop()
|
||||
close(m.cleanupDone)
|
||||
|
||||
// Clear all shards
|
||||
for _, shard := range m.shards {
|
||||
shard.clear()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsHealthy returns true if the backend is healthy
|
||||
func (m *MemoryCacheBackend) IsHealthy() bool {
|
||||
return !m.closed.Load()
|
||||
}
|
||||
|
||||
// Type returns the backend type
|
||||
func (m *MemoryCacheBackend) Type() BackendType {
|
||||
return TypeMemory
|
||||
}
|
||||
|
||||
// Capabilities returns the backend capabilities
|
||||
func (m *MemoryCacheBackend) Capabilities() *BackendCapabilities {
|
||||
return &BackendCapabilities{
|
||||
Distributed: false,
|
||||
Persistent: false,
|
||||
Eviction: true,
|
||||
TTL: true,
|
||||
MaxKeySize: 1024, // 1KB
|
||||
MaxValueSize: 10485760, // 10MB
|
||||
MaxKeys: m.maxSize,
|
||||
SupportsExpire: true,
|
||||
SupportsMultiGet: true,
|
||||
SupportsTransaction: false,
|
||||
SupportsCompression: false,
|
||||
RequiresSerialize: false,
|
||||
}
|
||||
}
|
||||
|
||||
// GetShardCount returns the number of shards (for testing/monitoring)
|
||||
func (m *MemoryCacheBackend) GetShardCount() uint32 {
|
||||
return m.shardCount
|
||||
}
|
||||
|
||||
// GetShardStats returns per-shard statistics (for monitoring)
|
||||
func (m *MemoryCacheBackend) GetShardStats() []map[string]int64 {
|
||||
stats := make([]map[string]int64, m.shardCount)
|
||||
for i, shard := range m.shards {
|
||||
size, memory := shard.stats()
|
||||
stats[i] = map[string]int64{
|
||||
"size": size,
|
||||
"memory": memory,
|
||||
}
|
||||
}
|
||||
return stats
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
// estimateValueSize estimates the size of a value in bytes
|
||||
func estimateValueSize(value interface{}) int64 {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return int64(len(v))
|
||||
case []byte:
|
||||
return int64(len(v))
|
||||
case int, int32, int64, uint, uint32, uint64:
|
||||
return 8
|
||||
case float32, float64:
|
||||
return 8
|
||||
case bool:
|
||||
return 1
|
||||
default:
|
||||
// For complex types, use a default estimate
|
||||
return 256
|
||||
}
|
||||
}
|
||||
|
||||
// matchPattern checks if a key matches a pattern (simplified glob matching)
|
||||
func matchPattern(pattern, key string) bool {
|
||||
if pattern == "*" {
|
||||
return true
|
||||
}
|
||||
// Simplified pattern matching
|
||||
if len(pattern) > 0 && pattern[0] == '*' {
|
||||
suffix := pattern[1:]
|
||||
return len(key) >= len(suffix) && key[len(key)-len(suffix):] == suffix
|
||||
}
|
||||
return key == pattern
|
||||
}
|
||||
+182
@@ -0,0 +1,182 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
)
|
||||
|
||||
// setupBenchmarkRedis creates a miniredis instance for benchmarking
|
||||
func setupBenchmarkRedis(b *testing.B) string {
|
||||
b.Helper()
|
||||
mr, err := miniredis.Run()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
b.Cleanup(func() {
|
||||
mr.Close()
|
||||
})
|
||||
return mr.Addr()
|
||||
}
|
||||
|
||||
// BenchmarkRedisOperations_WithPooling benchmarks memory allocations with object pooling
|
||||
func BenchmarkRedisOperations_WithPooling(b *testing.B) {
|
||||
addr := setupBenchmarkRedis(b)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: addr,
|
||||
MaxConnections: 10,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer pool.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
conn, err := pool.Get(ctx)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
// Perform various operations
|
||||
_, _ = conn.Do("SET", "bench-key", "bench-value")
|
||||
_, _ = conn.Do("GET", "bench-key")
|
||||
_, _ = conn.Do("EXISTS", "bench-key")
|
||||
_, _ = conn.Do("DEL", "bench-key")
|
||||
|
||||
pool.Put(conn)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkRedisBackend_SetGet benchmarks the full backend with pooling
|
||||
func BenchmarkRedisBackend_SetGet(b *testing.B) {
|
||||
addr := setupBenchmarkRedis(b)
|
||||
|
||||
backend, err := NewRedisBackend(&Config{
|
||||
RedisAddr: addr,
|
||||
PoolSize: 10,
|
||||
})
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
testData := []byte("benchmark test data with some content")
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Set operation
|
||||
err := backend.Set(ctx, "bench-key", testData, 0)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
// Get operation
|
||||
_, _, _, err = backend.Get(ctx, "bench-key")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkRedisBackend_ConcurrentAccess benchmarks concurrent operations with pooling
|
||||
func BenchmarkRedisBackend_ConcurrentAccess(b *testing.B) {
|
||||
addr := setupBenchmarkRedis(b)
|
||||
|
||||
backend, err := NewRedisBackend(&Config{
|
||||
RedisAddr: addr,
|
||||
PoolSize: 10,
|
||||
})
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
testData := []byte("concurrent benchmark data")
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
_ = backend.Set(ctx, "concurrent-key", testData, 0)
|
||||
_, _, _, _ = backend.Get(ctx, "concurrent-key")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkRESPProtocol_WriteRead benchmarks RESP protocol encoding/decoding
|
||||
func BenchmarkRESPProtocol_WriteRead(b *testing.B) {
|
||||
addr := setupBenchmarkRedis(b)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: addr,
|
||||
MaxConnections: 10,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer pool.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
conn, err := pool.Get(ctx)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer pool.Put(conn)
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
// This tests the pooling of RESPReader/RESPWriter
|
||||
_, _ = conn.Do("PING")
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkConnectionPool_GetPut benchmarks connection pool operations
|
||||
func BenchmarkConnectionPool_GetPut(b *testing.B) {
|
||||
addr := setupBenchmarkRedis(b)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: addr,
|
||||
MaxConnections: 10,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer pool.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
conn, err := pool.Get(ctx)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
pool.Put(conn)
|
||||
}
|
||||
}
|
||||
+290
@@ -0,0 +1,290 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// cacheShard represents a single shard of the sharded cache
|
||||
// Each shard has its own lock for reduced contention
|
||||
type cacheShard struct {
|
||||
items map[string]*memoryCacheItem
|
||||
lruList *list.List
|
||||
mu sync.RWMutex
|
||||
maxSize int64
|
||||
maxMemory int64
|
||||
size int64
|
||||
memoryUsed int64
|
||||
}
|
||||
|
||||
// newCacheShard creates a new cache shard
|
||||
func newCacheShard(maxSize, maxMemory int64) *cacheShard {
|
||||
return &cacheShard{
|
||||
items: make(map[string]*memoryCacheItem),
|
||||
lruList: list.New(),
|
||||
maxSize: maxSize,
|
||||
maxMemory: maxMemory,
|
||||
}
|
||||
}
|
||||
|
||||
// get retrieves a value from this shard
|
||||
// Returns: value, exists, expired
|
||||
func (s *cacheShard) get(key string) (interface{}, bool, bool) {
|
||||
s.mu.RLock()
|
||||
item, exists := s.items[key]
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return nil, false, false
|
||||
}
|
||||
|
||||
if item.isExpired() {
|
||||
return nil, true, true // exists but expired
|
||||
}
|
||||
|
||||
// Update access time and LRU position under write lock
|
||||
s.mu.Lock()
|
||||
// Re-check item exists (could have been deleted)
|
||||
item, exists = s.items[key]
|
||||
if exists && !item.isExpired() {
|
||||
item.accessedAt = time.Now()
|
||||
item.accessCount++
|
||||
if elem, ok := item.element.(*list.Element); ok && elem != nil {
|
||||
s.lruList.MoveToFront(elem)
|
||||
}
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
if !exists || item.isExpired() {
|
||||
return nil, false, false
|
||||
}
|
||||
|
||||
return item.value, true, false
|
||||
}
|
||||
|
||||
// set stores a value in this shard
|
||||
func (s *cacheShard) set(key string, value interface{}, expiresAt time.Time, size int64) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Check if we need to evict items
|
||||
if s.maxSize > 0 && s.size >= s.maxSize {
|
||||
s.evictLRULocked()
|
||||
}
|
||||
if s.maxMemory > 0 && s.memoryUsed+size > s.maxMemory {
|
||||
s.evictLRULocked()
|
||||
}
|
||||
|
||||
// Remove old item if exists
|
||||
if oldItem, exists := s.items[key]; exists {
|
||||
s.memoryUsed -= oldItem.size
|
||||
if elem, ok := oldItem.element.(*list.Element); ok && elem != nil {
|
||||
s.lruList.Remove(elem)
|
||||
}
|
||||
s.size--
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
item := &memoryCacheItem{
|
||||
key: key,
|
||||
value: value,
|
||||
expiresAt: expiresAt,
|
||||
createdAt: now,
|
||||
accessedAt: now,
|
||||
accessCount: 0,
|
||||
size: size,
|
||||
}
|
||||
|
||||
item.element = s.lruList.PushFront(item)
|
||||
s.items[key] = item
|
||||
s.size++
|
||||
s.memoryUsed += size
|
||||
}
|
||||
|
||||
// delete removes a key from this shard
|
||||
// Returns true if the key was deleted
|
||||
func (s *cacheShard) delete(key string) bool {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
item, exists := s.items[key]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
s.deleteItemLocked(item)
|
||||
return true
|
||||
}
|
||||
|
||||
// exists checks if a key exists (and is not expired)
|
||||
func (s *cacheShard) exists(key string) bool {
|
||||
s.mu.RLock()
|
||||
item, exists := s.items[key]
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
return !item.isExpired()
|
||||
}
|
||||
|
||||
// ttl returns the remaining TTL for a key
|
||||
func (s *cacheShard) ttl(key string) (time.Duration, bool) {
|
||||
s.mu.RLock()
|
||||
item, exists := s.items[key]
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !exists || item.isExpired() {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
if item.expiresAt.IsZero() {
|
||||
return 0, true // No expiration
|
||||
}
|
||||
|
||||
remaining := time.Until(item.expiresAt)
|
||||
if remaining < 0 {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
return remaining, true
|
||||
}
|
||||
|
||||
// expire updates the TTL for an existing key
|
||||
func (s *cacheShard) expire(key string, ttl time.Duration) bool {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
item, exists := s.items[key]
|
||||
if !exists || item.isExpired() {
|
||||
return false
|
||||
}
|
||||
|
||||
if ttl > 0 {
|
||||
item.expiresAt = time.Now().Add(ttl)
|
||||
} else {
|
||||
item.expiresAt = time.Time{} // Remove expiration
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// keys returns all non-expired keys matching the pattern
|
||||
func (s *cacheShard) keys(pattern string) []string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
var keys []string
|
||||
for key, item := range s.items {
|
||||
if !item.isExpired() && matchPattern(pattern, key) {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
// clear removes all items from this shard
|
||||
func (s *cacheShard) clear() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.items = make(map[string]*memoryCacheItem)
|
||||
s.lruList.Init()
|
||||
s.size = 0
|
||||
s.memoryUsed = 0
|
||||
}
|
||||
|
||||
// cleanup removes expired items
|
||||
// Returns the number of items removed
|
||||
func (s *cacheShard) cleanup() int {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
var toRemove []*memoryCacheItem
|
||||
for _, item := range s.items {
|
||||
if item.isExpired() {
|
||||
toRemove = append(toRemove, item)
|
||||
}
|
||||
}
|
||||
|
||||
for _, item := range toRemove {
|
||||
s.deleteItemLocked(item)
|
||||
}
|
||||
|
||||
return len(toRemove)
|
||||
}
|
||||
|
||||
// stats returns statistics for this shard
|
||||
func (s *cacheShard) stats() (size, memory int64) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.size, s.memoryUsed
|
||||
}
|
||||
|
||||
// deleteItemLocked removes an item (must be called with lock held)
|
||||
func (s *cacheShard) deleteItemLocked(item *memoryCacheItem) {
|
||||
if elem, ok := item.element.(*list.Element); ok && elem != nil {
|
||||
s.lruList.Remove(elem)
|
||||
}
|
||||
delete(s.items, item.key)
|
||||
s.size--
|
||||
s.memoryUsed -= item.size
|
||||
}
|
||||
|
||||
// evictLRULocked evicts the least recently used item (must be called with lock held)
|
||||
func (s *cacheShard) evictLRULocked() bool {
|
||||
if s.lruList.Len() == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
element := s.lruList.Back()
|
||||
if element != nil {
|
||||
item := element.Value.(*memoryCacheItem)
|
||||
s.deleteItemLocked(item)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// evictOne evicts one item from this shard (for global limit enforcement)
|
||||
func (s *cacheShard) evictOne() bool {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.evictLRULocked()
|
||||
}
|
||||
|
||||
// getOldestAccessTime returns the access time of the LRU item (oldest) in this shard
|
||||
// Returns zero time if shard is empty
|
||||
func (s *cacheShard) getOldestAccessTime() time.Time {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
if s.lruList.Len() == 0 {
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
element := s.lruList.Back()
|
||||
if element != nil {
|
||||
item := element.Value.(*memoryCacheItem)
|
||||
return item.accessedAt
|
||||
}
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
// fnv32 computes FNV-1a hash of a string
|
||||
// This is a fast, well-distributed hash function
|
||||
func fnv32(key string) uint32 {
|
||||
const (
|
||||
offset32 = uint32(2166136261)
|
||||
prime32 = uint32(16777619)
|
||||
)
|
||||
|
||||
hash := offset32
|
||||
for i := 0; i < len(key); i++ {
|
||||
hash ^= uint32(key[i])
|
||||
hash *= prime32
|
||||
}
|
||||
return hash
|
||||
}
|
||||
+283
@@ -0,0 +1,283 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestShardedCache_ShardDistribution tests that keys are distributed across shards
|
||||
func TestShardedCache_ShardDistribution(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a cache with large enough size to have multiple shards
|
||||
config := DefaultConfig()
|
||||
config.MaxSize = 10000
|
||||
config.MaxMemoryBytes = 100 * 1024 * 1024 // 100MB
|
||||
|
||||
backend, err := NewMemoryBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Add many items to see distribution
|
||||
numItems := 1000
|
||||
for i := 0; i < numItems; i++ {
|
||||
key := fmt.Sprintf("dist-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("dist-value-%d", i))
|
||||
err := backend.Set(ctx, key, value, time.Minute)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Check that items are distributed across multiple shards
|
||||
shardStats := backend.MemoryCacheBackend.GetShardStats()
|
||||
nonEmptyShards := 0
|
||||
for _, stat := range shardStats {
|
||||
if stat["size"] > 0 {
|
||||
nonEmptyShards++
|
||||
}
|
||||
}
|
||||
|
||||
// With good hash distribution, we should have items in multiple shards
|
||||
assert.Greater(t, nonEmptyShards, 1, "Items should be distributed across multiple shards")
|
||||
}
|
||||
|
||||
// TestShardedCache_ShardCount tests that shard count adapts to cache size
|
||||
func TestShardedCache_ShardCount(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
maxSize int
|
||||
expectLowShards bool
|
||||
}{
|
||||
{5, true}, // Very small cache should have fewer shards
|
||||
{100, true}, // Small cache should have fewer shards
|
||||
{10000, false}, // Large cache should have default shards
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(fmt.Sprintf("MaxSize_%d", tt.maxSize), func(t *testing.T) {
|
||||
config := DefaultConfig()
|
||||
config.MaxSize = tt.maxSize
|
||||
|
||||
backend, err := NewMemoryBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
shardCount := backend.MemoryCacheBackend.GetShardCount()
|
||||
|
||||
if tt.expectLowShards {
|
||||
assert.Less(t, shardCount, uint32(256), "Small cache should have fewer shards")
|
||||
} else {
|
||||
assert.Equal(t, uint32(256), shardCount, "Large cache should have default shard count")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestShardedCache_ConcurrentSameKey tests concurrent access to the same key
|
||||
func TestShardedCache_ConcurrentSameKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
key := "concurrent-same-key"
|
||||
initialValue := []byte("initial-value")
|
||||
|
||||
err = backend.Set(ctx, key, initialValue, time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
goroutines := 50
|
||||
iterations := 100
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < iterations; j++ {
|
||||
// Mix of reads and writes
|
||||
if j%3 == 0 {
|
||||
newValue := []byte(fmt.Sprintf("value-%d-%d", id, j))
|
||||
err := backend.Set(ctx, key, newValue, time.Minute)
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
_, _, _, err := backend.Get(ctx, key)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Key should still exist
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
}
|
||||
|
||||
// TestShardedCache_GlobalLRUEviction tests that global LRU is maintained
|
||||
func TestShardedCache_GlobalLRUEviction(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a small cache to force eviction
|
||||
config := DefaultConfig()
|
||||
config.MaxSize = 10
|
||||
|
||||
backend, err := NewMemoryBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Add items
|
||||
for i := 0; i < 10; i++ {
|
||||
key := fmt.Sprintf("global-lru-%d", i)
|
||||
value := []byte(fmt.Sprintf("value-%d", i))
|
||||
err := backend.Set(ctx, key, value, time.Minute)
|
||||
require.NoError(t, err)
|
||||
// Small delay to ensure different access times
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
|
||||
// Access some items to make them recently used
|
||||
for i := 5; i < 10; i++ {
|
||||
key := fmt.Sprintf("global-lru-%d", i)
|
||||
_, _, _, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Add more items to trigger eviction
|
||||
for i := 10; i < 15; i++ {
|
||||
key := fmt.Sprintf("global-lru-%d", i)
|
||||
value := []byte(fmt.Sprintf("value-%d", i))
|
||||
err := backend.Set(ctx, key, value, time.Minute)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Recently accessed items (5-9) should still exist
|
||||
for i := 5; i < 10; i++ {
|
||||
key := fmt.Sprintf("global-lru-%d", i)
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Recently accessed item %d should exist", i)
|
||||
}
|
||||
|
||||
// Check eviction stats
|
||||
stats := backend.GetStats()
|
||||
evictions := stats["evictions"].(int64)
|
||||
assert.Greater(t, evictions, int64(0), "Should have evictions")
|
||||
}
|
||||
|
||||
// TestShardedCache_StatsAggregation tests that stats are aggregated correctly
|
||||
func TestShardedCache_StatsAggregation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := DefaultConfig()
|
||||
config.MaxSize = 10000
|
||||
|
||||
backend, err := NewMemoryBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Add items to multiple shards
|
||||
numItems := 100
|
||||
for i := 0; i < numItems; i++ {
|
||||
key := fmt.Sprintf("stats-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("stats-value-%d", i))
|
||||
err := backend.Set(ctx, key, value, time.Minute)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Read some items
|
||||
for i := 0; i < numItems/2; i++ {
|
||||
key := fmt.Sprintf("stats-key-%d", i)
|
||||
backend.Get(ctx, key)
|
||||
}
|
||||
|
||||
// Read non-existent items
|
||||
for i := 0; i < 10; i++ {
|
||||
backend.Get(ctx, fmt.Sprintf("nonexistent-%d", i))
|
||||
}
|
||||
|
||||
stats := backend.GetStats()
|
||||
|
||||
// Verify stats
|
||||
assert.Equal(t, int64(numItems), stats["sets"].(int64), "Sets should match")
|
||||
assert.Equal(t, int64(numItems/2), stats["hits"].(int64), "Hits should match")
|
||||
assert.Equal(t, int64(10), stats["misses"].(int64), "Misses should match")
|
||||
assert.Equal(t, int64(numItems), stats["size"].(int64), "Size should match")
|
||||
|
||||
// Verify hit rate
|
||||
hitRate := stats["hit_rate"].(float64)
|
||||
expectedHitRate := float64(numItems/2) / float64(numItems/2+10)
|
||||
assert.InDelta(t, expectedHitRate, hitRate, 0.01, "Hit rate should match")
|
||||
}
|
||||
|
||||
// BenchmarkShardedCache_Parallel benchmarks parallel access
|
||||
func BenchmarkShardedCache_Parallel(b *testing.B) {
|
||||
config := DefaultConfig()
|
||||
config.MaxSize = 100000
|
||||
config.MaxMemoryBytes = 100 * 1024 * 1024
|
||||
|
||||
backend, _ := NewMemoryBackend(config)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Pre-populate cache
|
||||
for i := 0; i < 10000; i++ {
|
||||
key := fmt.Sprintf("bench-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("bench-value-%d", i))
|
||||
backend.Set(ctx, key, value, time.Hour)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := 0
|
||||
for pb.Next() {
|
||||
key := fmt.Sprintf("bench-key-%d", i%10000)
|
||||
backend.Get(ctx, key)
|
||||
i++
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkShardedCache_MixedOps benchmarks mixed operations
|
||||
func BenchmarkShardedCache_MixedOps(b *testing.B) {
|
||||
config := DefaultConfig()
|
||||
config.MaxSize = 100000
|
||||
config.MaxMemoryBytes = 100 * 1024 * 1024
|
||||
|
||||
backend, _ := NewMemoryBackend(config)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := 0
|
||||
for pb.Next() {
|
||||
key := fmt.Sprintf("mixed-key-%d", i%1000)
|
||||
if i%3 == 0 {
|
||||
value := []byte(fmt.Sprintf("mixed-value-%d", i))
|
||||
backend.Set(ctx, key, value, time.Hour)
|
||||
} else {
|
||||
backend.Get(ctx, key)
|
||||
}
|
||||
i++
|
||||
}
|
||||
})
|
||||
}
|
||||
+783
@@ -0,0 +1,783 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestMemoryBackend_BasicOperations tests basic CRUD operations
|
||||
func TestMemoryBackend_BasicOperations(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("SetAndGet", func(t *testing.T) {
|
||||
key := "test-key"
|
||||
value := []byte("test-value")
|
||||
ttl := 1 * time.Minute
|
||||
|
||||
err := backend.Set(ctx, key, value, ttl)
|
||||
require.NoError(t, err)
|
||||
|
||||
retrieved, remainingTTL, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, value, retrieved)
|
||||
assert.Greater(t, remainingTTL, 50*time.Second)
|
||||
assert.LessOrEqual(t, remainingTTL, ttl)
|
||||
})
|
||||
|
||||
t.Run("GetNonExistent", func(t *testing.T) {
|
||||
_, _, exists, err := backend.Get(ctx, "non-existent")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("Delete", func(t *testing.T) {
|
||||
key := "delete-key"
|
||||
value := []byte("delete-value")
|
||||
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
deleted, err := backend.Delete(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, deleted)
|
||||
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("DeleteNonExistent", func(t *testing.T) {
|
||||
deleted, err := backend.Delete(ctx, "non-existent-delete")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, deleted)
|
||||
})
|
||||
|
||||
t.Run("Exists", func(t *testing.T) {
|
||||
key := "exists-key"
|
||||
value := []byte("exists-value")
|
||||
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
|
||||
err = backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
exists, err = backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
})
|
||||
|
||||
t.Run("Clear", func(t *testing.T) {
|
||||
// Add multiple items
|
||||
for i := 0; i < 10; i++ {
|
||||
key := fmt.Sprintf("clear-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("clear-value-%d", i))
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
err := backend.Clear(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
stats := backend.GetStats()
|
||||
size := stats["size"].(int64)
|
||||
assert.Equal(t, int64(0), size)
|
||||
})
|
||||
}
|
||||
|
||||
// TestMemoryBackend_TTLExpiration tests TTL and expiration
|
||||
func TestMemoryBackend_TTLExpiration(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := DefaultConfig()
|
||||
config.CleanupInterval = 50 * time.Millisecond
|
||||
backend, err := NewMemoryBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("ShortTTL", func(t *testing.T) {
|
||||
key := "short-ttl-key"
|
||||
value := []byte("short-ttl-value")
|
||||
shortTTL := 100 * time.Millisecond
|
||||
|
||||
err := backend.Set(ctx, key, value, shortTTL)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify exists immediately
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
// Wait for expiration
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// Should be expired
|
||||
_, _, exists, err = backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("TTLDecrement", func(t *testing.T) {
|
||||
key := "ttl-decrement-key"
|
||||
value := []byte("ttl-decrement-value")
|
||||
ttl := 2 * time.Second
|
||||
|
||||
err := backend.Set(ctx, key, value, ttl)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check TTL immediately
|
||||
_, ttl1, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
// Wait a bit
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// Check TTL again - should be less
|
||||
_, ttl2, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Less(t, ttl2, ttl1, "TTL should decrease over time")
|
||||
})
|
||||
|
||||
t.Run("CleanupExpiredItems", func(t *testing.T) {
|
||||
// Set multiple items with short TTL
|
||||
for i := 0; i < 5; i++ {
|
||||
key := fmt.Sprintf("cleanup-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("cleanup-value-%d", i))
|
||||
err := backend.Set(ctx, key, value, 50*time.Millisecond)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Wait for cleanup to run
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// All items should be cleaned up
|
||||
for i := 0; i < 5; i++ {
|
||||
key := fmt.Sprintf("cleanup-key-%d", i)
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "Expired items should be cleaned up")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestMemoryBackend_LRUEviction tests LRU eviction
|
||||
func TestMemoryBackend_LRUEviction(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := DefaultConfig()
|
||||
config.MaxSize = 5
|
||||
backend, err := NewMemoryBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Fill cache to max size
|
||||
for i := 0; i < 5; i++ {
|
||||
key := fmt.Sprintf("lru-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("lru-value-%d", i))
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Access first item to make it most recently used
|
||||
_, _, exists, err := backend.Get(ctx, "lru-key-0")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
// Add a new item - should evict lru-key-1 (least recently used)
|
||||
err = backend.Set(ctx, "lru-key-new", []byte("new-value"), 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// lru-key-0 should still exist (was accessed recently)
|
||||
exists, err = backend.Exists(ctx, "lru-key-0")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Recently accessed item should not be evicted")
|
||||
|
||||
// lru-key-1 should be evicted
|
||||
exists, err = backend.Exists(ctx, "lru-key-1")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "Least recently used item should be evicted")
|
||||
|
||||
// Check eviction count
|
||||
stats := backend.GetStats()
|
||||
evictions := stats["evictions"].(int64)
|
||||
assert.Greater(t, evictions, int64(0), "Should have evictions")
|
||||
}
|
||||
|
||||
// TestMemoryBackend_MemoryLimit tests memory-based eviction
|
||||
func TestMemoryBackend_MemoryLimit(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := DefaultConfig()
|
||||
config.MaxSize = 100
|
||||
config.MaxMemoryBytes = 1024 // 1KB limit
|
||||
backend, err := NewMemoryBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Add items until memory limit is reached
|
||||
largeValue := make([]byte, 512) // 512 bytes each
|
||||
for i := 0; i < 5; i++ {
|
||||
key := fmt.Sprintf("mem-key-%d", i)
|
||||
err := backend.Set(ctx, key, largeValue, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
stats := backend.GetStats()
|
||||
memory := stats["memory"].(int64)
|
||||
assert.LessOrEqual(t, memory, config.MaxMemoryBytes, "Memory should not exceed limit")
|
||||
|
||||
evictions := stats["evictions"].(int64)
|
||||
assert.Greater(t, evictions, int64(0), "Should have memory-based evictions")
|
||||
}
|
||||
|
||||
// TestMemoryBackend_ConcurrentAccess tests thread safety
|
||||
func TestMemoryBackend_ConcurrentAccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
var wg sync.WaitGroup
|
||||
goroutines := 20
|
||||
iterations := 50
|
||||
|
||||
// Concurrent writes
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < iterations; j++ {
|
||||
key := fmt.Sprintf("concurrent-key-%d-%d", id, j)
|
||||
value := []byte(fmt.Sprintf("concurrent-value-%d-%d", id, j))
|
||||
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Read back
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
assert.NoError(t, err)
|
||||
if exists {
|
||||
assert.Equal(t, value, retrieved)
|
||||
}
|
||||
|
||||
// Random deletes
|
||||
if j%5 == 0 {
|
||||
backend.Delete(ctx, key)
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Verify stats are consistent
|
||||
stats := backend.GetStats()
|
||||
hits := stats["hits"].(int64)
|
||||
misses := stats["misses"].(int64)
|
||||
assert.Greater(t, hits+misses, int64(0), "Should have cache operations")
|
||||
}
|
||||
|
||||
// TestMemoryBackend_UpdateExisting tests updating existing keys
|
||||
func TestMemoryBackend_UpdateExisting(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
key := "update-key"
|
||||
value1 := []byte("original-value")
|
||||
value2 := []byte("updated-value")
|
||||
|
||||
// Set original
|
||||
err = backend.Set(ctx, key, value1, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update
|
||||
err = backend.Set(ctx, key, value2, 2*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify updated
|
||||
retrieved, ttl, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, value2, retrieved)
|
||||
assert.Greater(t, ttl, 1*time.Minute, "TTL should be updated")
|
||||
|
||||
// Size should not increase (same key)
|
||||
stats := backend.GetStats()
|
||||
size := stats["size"].(int64)
|
||||
assert.Equal(t, int64(1), size, "Size should be 1 for one key")
|
||||
}
|
||||
|
||||
// TestMemoryBackend_Stats tests statistics tracking
|
||||
func TestMemoryBackend_Stats(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Initial stats
|
||||
stats := backend.GetStats()
|
||||
assert.Equal(t, int64(0), stats["hits"].(int64))
|
||||
assert.Equal(t, int64(0), stats["misses"].(int64))
|
||||
|
||||
// Add items and track hits/misses
|
||||
backend.Set(ctx, "key1", []byte("value1"), 1*time.Minute)
|
||||
backend.Set(ctx, "key2", []byte("value2"), 1*time.Minute)
|
||||
|
||||
// Hit
|
||||
backend.Get(ctx, "key1")
|
||||
// Miss
|
||||
backend.Get(ctx, "non-existent")
|
||||
|
||||
stats = backend.GetStats()
|
||||
assert.Equal(t, int64(1), stats["hits"].(int64))
|
||||
assert.Equal(t, int64(1), stats["misses"].(int64))
|
||||
|
||||
hitRate := stats["hit_rate"].(float64)
|
||||
assert.InDelta(t, 0.5, hitRate, 0.01)
|
||||
}
|
||||
|
||||
// TestMemoryBackend_EmptyValues tests handling of empty values
|
||||
func TestMemoryBackend_EmptyValues(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
key := "empty-key"
|
||||
emptyValue := []byte{}
|
||||
|
||||
err = backend.Set(ctx, key, emptyValue, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, 0, len(retrieved))
|
||||
}
|
||||
|
||||
// TestMemoryBackend_LargeValues tests handling of large values
|
||||
func TestMemoryBackend_LargeValues(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := DefaultConfig()
|
||||
config.MaxMemoryBytes = 10 * 1024 * 1024 // 10MB
|
||||
backend, err := NewMemoryBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
key := "large-key"
|
||||
largeValue := make([]byte, 1024*1024) // 1MB
|
||||
|
||||
err = backend.Set(ctx, key, largeValue, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, len(largeValue), len(retrieved))
|
||||
}
|
||||
|
||||
// TestMemoryBackend_Close tests proper cleanup on close
|
||||
func TestMemoryBackend_Close(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Add some items
|
||||
for i := 0; i < 10; i++ {
|
||||
key := fmt.Sprintf("close-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("close-value-%d", i))
|
||||
backend.Set(ctx, key, value, 1*time.Minute)
|
||||
}
|
||||
|
||||
// Close
|
||||
err = backend.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Operations after close should fail
|
||||
err = backend.Set(ctx, "after-close", []byte("value"), 1*time.Minute)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, ErrBackendClosed, err)
|
||||
|
||||
_, _, _, err = backend.Get(ctx, "close-key-0")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, ErrBackendClosed, err)
|
||||
|
||||
// Closing again should be safe
|
||||
err = backend.Close()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestMemoryBackend_Ping tests ping operation
|
||||
func TestMemoryBackend_Ping(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
err = backend.Ping(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Close and ping should fail
|
||||
backend.Close()
|
||||
err = backend.Ping(ctx)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// TestMemoryBackend_ValueIsolation tests that returned values are isolated
|
||||
func TestMemoryBackend_ValueIsolation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
key := "isolation-key"
|
||||
originalValue := []byte("original-value")
|
||||
|
||||
err = backend.Set(ctx, key, originalValue, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get value and modify it
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
// Modify retrieved value
|
||||
if len(retrieved) > 0 {
|
||||
retrieved[0] = 'X'
|
||||
}
|
||||
|
||||
// Get again - should be unchanged
|
||||
retrieved2, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, originalValue, retrieved2, "Original value should not be modified")
|
||||
}
|
||||
|
||||
// TestMemoryBackend_Keys tests the Keys method with pattern matching
|
||||
func TestMemoryBackend_Keys(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Add test data
|
||||
testKeys := []string{"user:1", "user:2", "session:abc", "session:def", "token:xyz"}
|
||||
for _, key := range testKeys {
|
||||
err := backend.Set(ctx, key, []byte("value"), 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
t.Run("AllKeys", func(t *testing.T) {
|
||||
keys, err := backend.Keys(ctx, "*")
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, keys, 5)
|
||||
})
|
||||
|
||||
t.Run("SpecificPattern", func(t *testing.T) {
|
||||
// Simple exact match
|
||||
keys, err := backend.Keys(ctx, "user:1")
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, keys, 1)
|
||||
assert.Contains(t, keys, "user:1")
|
||||
})
|
||||
|
||||
t.Run("ExcludesExpired", func(t *testing.T) {
|
||||
// Add an expired key
|
||||
expiredKey := "expired:key"
|
||||
err := backend.Set(ctx, expiredKey, []byte("value"), 1*time.Millisecond)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wait for expiration
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
keys, err := backend.Keys(ctx, "*")
|
||||
require.NoError(t, err)
|
||||
assert.NotContains(t, keys, expiredKey, "Expired keys should not be returned")
|
||||
})
|
||||
|
||||
t.Run("AfterClose", func(t *testing.T) {
|
||||
closedBackend, _ := NewMemoryBackend(DefaultConfig())
|
||||
closedBackend.Close()
|
||||
|
||||
_, err := closedBackend.Keys(ctx, "*")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, ErrBackendUnavailable, err)
|
||||
})
|
||||
}
|
||||
|
||||
// TestMemoryBackend_Size tests the Size method
|
||||
func TestMemoryBackend_Size(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Initially empty
|
||||
size, err := backend.Size(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(0), size)
|
||||
|
||||
// Add items
|
||||
for i := 0; i < 5; i++ {
|
||||
key := fmt.Sprintf("key-%d", i)
|
||||
err := backend.Set(ctx, key, []byte("value"), 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
size, err = backend.Size(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(5), size)
|
||||
|
||||
// Delete one
|
||||
backend.Delete(ctx, "key-0")
|
||||
|
||||
size, err = backend.Size(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(4), size)
|
||||
|
||||
// After close
|
||||
backend.Close()
|
||||
_, err = backend.Size(ctx)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, ErrBackendUnavailable, err)
|
||||
}
|
||||
|
||||
// TestMemoryBackend_TTL tests the TTL method
|
||||
func TestMemoryBackend_TTL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("ExistingKey", func(t *testing.T) {
|
||||
key := "ttl-key"
|
||||
ttl := 1 * time.Minute
|
||||
|
||||
err := backend.Set(ctx, key, []byte("value"), ttl)
|
||||
require.NoError(t, err)
|
||||
|
||||
remaining, err := backend.TTL(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.Greater(t, remaining, 50*time.Second)
|
||||
assert.LessOrEqual(t, remaining, ttl)
|
||||
})
|
||||
|
||||
t.Run("NonExistentKey", func(t *testing.T) {
|
||||
_, err := backend.TTL(ctx, "non-existent")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, ErrCacheMiss, err)
|
||||
})
|
||||
|
||||
t.Run("NoExpiration", func(t *testing.T) {
|
||||
key := "no-expiry"
|
||||
// TTL of 0 typically means no expiration
|
||||
err := backend.Set(ctx, key, []byte("value"), 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
remaining, err := backend.TTL(ctx, key)
|
||||
require.NoError(t, err)
|
||||
// No expiration returns 0
|
||||
assert.Equal(t, time.Duration(0), remaining)
|
||||
})
|
||||
|
||||
t.Run("AfterClose", func(t *testing.T) {
|
||||
closedBackend, _ := NewMemoryBackend(DefaultConfig())
|
||||
closedBackend.Close()
|
||||
|
||||
_, err := closedBackend.TTL(ctx, "key")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, ErrBackendUnavailable, err)
|
||||
})
|
||||
}
|
||||
|
||||
// TestMemoryBackend_Expire tests the Expire method
|
||||
func TestMemoryBackend_Expire(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("UpdateTTL", func(t *testing.T) {
|
||||
key := "expire-key"
|
||||
err := backend.Set(ctx, key, []byte("value"), 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update to shorter TTL
|
||||
err = backend.Expire(ctx, key, 5*time.Second)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check new TTL
|
||||
remaining, err := backend.TTL(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.LessOrEqual(t, remaining, 5*time.Second)
|
||||
})
|
||||
|
||||
t.Run("NonExistentKey", func(t *testing.T) {
|
||||
err := backend.Expire(ctx, "non-existent", 1*time.Minute)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, ErrCacheMiss, err)
|
||||
})
|
||||
|
||||
t.Run("RemoveExpiration", func(t *testing.T) {
|
||||
key := "no-expire-key"
|
||||
err := backend.Set(ctx, key, []byte("value"), 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set TTL to 0 to remove expiration
|
||||
err = backend.Expire(ctx, key, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
// TTL should now be 0
|
||||
remaining, err := backend.TTL(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, time.Duration(0), remaining)
|
||||
})
|
||||
|
||||
t.Run("AfterClose", func(t *testing.T) {
|
||||
closedBackend, _ := NewMemoryBackend(DefaultConfig())
|
||||
closedBackend.Close()
|
||||
|
||||
err := closedBackend.Expire(ctx, "key", 1*time.Minute)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, ErrBackendUnavailable, err)
|
||||
})
|
||||
}
|
||||
|
||||
// TestMemoryBackend_IsHealthy tests the IsHealthy method
|
||||
func TestMemoryBackend_IsHealthy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should be healthy when open
|
||||
assert.True(t, backend.IsHealthy())
|
||||
|
||||
// Should be unhealthy after close
|
||||
backend.Close()
|
||||
assert.False(t, backend.IsHealthy())
|
||||
}
|
||||
|
||||
// TestMemoryBackend_Type tests the Type method
|
||||
func TestMemoryBackend_Type(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
backendType := backend.Type()
|
||||
assert.Equal(t, TypeMemory, backendType)
|
||||
}
|
||||
|
||||
// TestMemoryBackend_Capabilities tests the Capabilities method
|
||||
func TestMemoryBackend_Capabilities(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
caps := backend.Capabilities()
|
||||
require.NotNil(t, caps)
|
||||
|
||||
// Memory backend should not be distributed or persistent
|
||||
assert.False(t, caps.Distributed)
|
||||
assert.False(t, caps.Persistent)
|
||||
|
||||
// Should support eviction and TTL
|
||||
assert.True(t, caps.Eviction)
|
||||
assert.True(t, caps.TTL)
|
||||
assert.True(t, caps.SupportsExpire)
|
||||
assert.True(t, caps.SupportsMultiGet)
|
||||
|
||||
// Check limits
|
||||
assert.Greater(t, caps.MaxKeySize, int64(0))
|
||||
assert.Greater(t, caps.MaxValueSize, int64(0))
|
||||
}
|
||||
|
||||
// TestMatchPattern tests the matchPattern helper function
|
||||
func TestMatchPattern(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
pattern string
|
||||
key string
|
||||
matches bool
|
||||
}{
|
||||
{"*", "any-key", true},
|
||||
{"*", "another", true},
|
||||
{"user:1", "user:1", true},
|
||||
{"user:1", "user:2", false},
|
||||
{"*:suffix", "prefix:suffix", true},
|
||||
{"*suffix", "prefix-suffix", true},
|
||||
{"*abc", "xyzabc", true},
|
||||
{"*abc", "xyz", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(fmt.Sprintf("%s-%s", tt.pattern, tt.key), func(t *testing.T) {
|
||||
result := matchPattern(tt.pattern, tt.key)
|
||||
assert.Equal(t, tt.matches, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
+143
@@ -0,0 +1,143 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// MemoryBackend wraps MemoryCacheBackend to implement the CacheBackend interface
|
||||
type MemoryBackend struct {
|
||||
*MemoryCacheBackend
|
||||
}
|
||||
|
||||
// NewMemoryBackend creates a new memory backend from a config
|
||||
func NewMemoryBackend(config *Config) (*MemoryBackend, error) {
|
||||
maxSize := int64(config.MaxSize)
|
||||
if maxSize <= 0 {
|
||||
maxSize = 1000
|
||||
}
|
||||
|
||||
cacheBackend := NewMemoryCacheBackend(maxSize, config.MaxMemoryBytes, config.CleanupInterval)
|
||||
return &MemoryBackend{
|
||||
MemoryCacheBackend: cacheBackend,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Set stores a value in the cache with the specified TTL
|
||||
func (m *MemoryBackend) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
|
||||
err := m.MemoryCacheBackend.Set(ctx, key, value, ttl)
|
||||
if err == ErrBackendUnavailable {
|
||||
return ErrBackendClosed
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Get retrieves a value from the cache
|
||||
func (m *MemoryBackend) Get(ctx context.Context, key string) ([]byte, time.Duration, bool, error) {
|
||||
val, err := m.MemoryCacheBackend.Get(ctx, key)
|
||||
if err != nil {
|
||||
if err == ErrCacheMiss {
|
||||
return nil, 0, false, nil
|
||||
}
|
||||
if err == ErrBackendUnavailable {
|
||||
return nil, 0, false, ErrBackendClosed
|
||||
}
|
||||
return nil, 0, false, err
|
||||
}
|
||||
|
||||
// Get TTL using the TTL method
|
||||
ttl, ttlErr := m.MemoryCacheBackend.TTL(ctx, key)
|
||||
if ttlErr != nil {
|
||||
// If we can't get TTL, still return the value with 0 TTL
|
||||
ttl = 0
|
||||
}
|
||||
|
||||
// Convert interface{} to []byte
|
||||
var valueBytes []byte
|
||||
if val != nil {
|
||||
if bytes, ok := val.([]byte); ok {
|
||||
valueBytes = bytes
|
||||
} else {
|
||||
// If it's not already []byte, return an error
|
||||
return nil, 0, false, ErrInvalidValue
|
||||
}
|
||||
}
|
||||
|
||||
return valueBytes, ttl, true, nil
|
||||
}
|
||||
|
||||
// Delete removes a key from the cache
|
||||
func (m *MemoryBackend) Delete(ctx context.Context, key string) (bool, error) {
|
||||
// Check if key exists first
|
||||
exists, err := m.MemoryCacheBackend.Exists(ctx, key)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if !exists {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
err = m.MemoryCacheBackend.Delete(ctx, key)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Exists checks if a key exists in the cache
|
||||
func (m *MemoryBackend) Exists(ctx context.Context, key string) (bool, error) {
|
||||
return m.MemoryCacheBackend.Exists(ctx, key)
|
||||
}
|
||||
|
||||
// Clear removes all keys from the cache
|
||||
func (m *MemoryBackend) Clear(ctx context.Context) error {
|
||||
return m.MemoryCacheBackend.Clear(ctx)
|
||||
}
|
||||
|
||||
// GetStats returns cache statistics
|
||||
func (m *MemoryBackend) GetStats() map[string]interface{} {
|
||||
stats, err := m.MemoryCacheBackend.GetStats(context.Background())
|
||||
if err != nil {
|
||||
return map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
}
|
||||
}
|
||||
|
||||
// Convert BackendStats to map
|
||||
hitRate := float64(0)
|
||||
total := stats.Hits + stats.Misses
|
||||
if total > 0 {
|
||||
hitRate = float64(stats.Hits) / float64(total)
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"type": stats.Type,
|
||||
"hits": stats.Hits,
|
||||
"misses": stats.Misses,
|
||||
"sets": stats.Sets,
|
||||
"deletes": stats.Deletes,
|
||||
"errors": stats.Errors,
|
||||
"evictions": stats.Evictions,
|
||||
"size": stats.CurrentSize,
|
||||
"max_size": stats.MaxSize,
|
||||
"memory": stats.MemoryUsage,
|
||||
"hit_rate": hitRate,
|
||||
"uptime": stats.Uptime,
|
||||
"start_time": stats.StartTime,
|
||||
"shard_count": m.MemoryCacheBackend.GetShardCount(),
|
||||
}
|
||||
}
|
||||
|
||||
// Close shuts down the cache backend and releases resources
|
||||
func (m *MemoryBackend) Close() error {
|
||||
return m.MemoryCacheBackend.Close()
|
||||
}
|
||||
|
||||
// Ping checks if the backend is healthy and responsive
|
||||
func (m *MemoryBackend) Ping(ctx context.Context) error {
|
||||
return m.MemoryCacheBackend.Ping(ctx)
|
||||
}
|
||||
|
||||
// Ensure MemoryBackend implements CacheBackend
|
||||
var _ CacheBackend = (*MemoryBackend)(nil)
|
||||
Vendored
+566
@@ -0,0 +1,566 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Pure-Go Redis client implementation
|
||||
// Compatible with Yaegi interpreter (no unsafe package)
|
||||
// Implements RESP protocol for basic Redis operations
|
||||
|
||||
var (
|
||||
ErrPoolExhausted = errors.New("connection pool exhausted")
|
||||
)
|
||||
|
||||
// RedisBackend implements a Redis-based cache backend using pure Go
|
||||
type RedisBackend struct {
|
||||
config *Config
|
||||
pool *ConnectionPool
|
||||
healthMonitor *HealthMonitor
|
||||
|
||||
// Metrics
|
||||
hits atomic.Int64
|
||||
misses atomic.Int64
|
||||
|
||||
// Lifecycle
|
||||
closed atomic.Bool
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewRedisBackend creates a new Redis cache backend with pure-Go implementation
|
||||
func NewRedisBackend(config *Config) (*RedisBackend, error) {
|
||||
if config == nil {
|
||||
return nil, fmt.Errorf("config is required")
|
||||
}
|
||||
|
||||
if config.RedisAddr == "" {
|
||||
return nil, fmt.Errorf("redis address is required")
|
||||
}
|
||||
|
||||
// Create connection pool with health checks enabled
|
||||
// Timeouts are kept short to prevent request pileup when Redis is slow/stalled.
|
||||
// The UniversalCache uses 200ms context timeout, so socket timeouts should be
|
||||
// shorter to allow proper context cancellation handling.
|
||||
poolConfig := &PoolConfig{
|
||||
Address: config.RedisAddr,
|
||||
Password: config.RedisPassword,
|
||||
DB: config.RedisDB,
|
||||
MaxConnections: config.PoolSize,
|
||||
ConnectTimeout: 2 * time.Second,
|
||||
ReadTimeout: 500 * time.Millisecond,
|
||||
WriteTimeout: 500 * time.Millisecond,
|
||||
EnableHealthCheck: true,
|
||||
MaxRetries: 3,
|
||||
RetryDelay: 100 * time.Millisecond,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(poolConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create connection pool: %w", err)
|
||||
}
|
||||
|
||||
// Create health monitor
|
||||
healthConfig := DefaultHealthMonitorConfig()
|
||||
healthMonitor := NewHealthMonitor(pool, healthConfig)
|
||||
|
||||
backend := &RedisBackend{
|
||||
config: config,
|
||||
pool: pool,
|
||||
healthMonitor: healthMonitor,
|
||||
}
|
||||
|
||||
// Test connectivity
|
||||
if err := backend.Ping(context.Background()); err != nil {
|
||||
_ = pool.Close()
|
||||
return nil, fmt.Errorf("failed to ping Redis: %w", err)
|
||||
}
|
||||
|
||||
// Start health monitoring
|
||||
healthMonitor.Start()
|
||||
|
||||
return backend, nil
|
||||
}
|
||||
|
||||
// Set stores a value in Redis with TTL
|
||||
func (r *RedisBackend) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
|
||||
if r.closed.Load() {
|
||||
return ErrBackendClosed
|
||||
}
|
||||
|
||||
prefixedKey := r.prefixKey(key)
|
||||
|
||||
// Execute with retry logic
|
||||
return r.executeWithRetry(ctx, func(conn *RedisConn) error {
|
||||
var err error
|
||||
|
||||
// Use PSETEX for millisecond precision, SETEX for second precision
|
||||
if ttl > 0 {
|
||||
ttlMillis := ttl.Milliseconds()
|
||||
if ttlMillis < 1000 {
|
||||
// Use PSETEX for sub-second TTLs (millisecond precision)
|
||||
_, err = conn.Do("PSETEX", prefixedKey, fmt.Sprintf("%d", ttlMillis), string(value))
|
||||
} else {
|
||||
// Use SETEX for larger TTLs (second precision)
|
||||
ttlSeconds := int(ttl.Seconds())
|
||||
_, err = conn.Do("SETEX", prefixedKey, fmt.Sprintf("%d", ttlSeconds), string(value))
|
||||
}
|
||||
} else {
|
||||
_, err = conn.Do("SET", prefixedKey, string(value))
|
||||
}
|
||||
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
// Get retrieves a value from Redis
|
||||
func (r *RedisBackend) Get(ctx context.Context, key string) ([]byte, time.Duration, bool, error) {
|
||||
if r.closed.Load() {
|
||||
return nil, 0, false, ErrBackendClosed
|
||||
}
|
||||
|
||||
prefixedKey := r.prefixKey(key)
|
||||
var resultValue []byte
|
||||
var resultTTL time.Duration
|
||||
var resultExists bool
|
||||
|
||||
// Execute with retry logic
|
||||
err := r.executeWithRetry(ctx, func(conn *RedisConn) error {
|
||||
// Get value
|
||||
resp, err := conn.Do("GET", prefixedKey)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrNilResponse) {
|
||||
r.misses.Add(1)
|
||||
resultExists = false
|
||||
return nil // Not an error, key just doesn't exist
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
value, err := RESPString(resp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get TTL
|
||||
ttlResp, err := conn.Do("TTL", prefixedKey)
|
||||
if err != nil {
|
||||
// If TTL fails, still return the value
|
||||
r.hits.Add(1)
|
||||
resultValue = []byte(value)
|
||||
resultTTL = 0
|
||||
resultExists = true
|
||||
return nil
|
||||
}
|
||||
|
||||
ttlSeconds, _ := RESPInt(ttlResp)
|
||||
var ttl time.Duration
|
||||
if ttlSeconds > 0 {
|
||||
ttl = time.Duration(ttlSeconds) * time.Second
|
||||
}
|
||||
|
||||
r.hits.Add(1)
|
||||
resultValue = []byte(value)
|
||||
resultTTL = ttl
|
||||
resultExists = true
|
||||
return nil
|
||||
})
|
||||
|
||||
return resultValue, resultTTL, resultExists, err
|
||||
}
|
||||
|
||||
// Delete removes a key from Redis
|
||||
func (r *RedisBackend) Delete(ctx context.Context, key string) (bool, error) {
|
||||
if r.closed.Load() {
|
||||
return false, ErrBackendClosed
|
||||
}
|
||||
|
||||
conn, err := r.pool.Get(ctx)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer r.pool.Put(conn)
|
||||
|
||||
prefixedKey := r.prefixKey(key)
|
||||
resp, err := conn.Do("DEL", prefixedKey)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
count, err := RESPInt(resp)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// Exists checks if a key exists in Redis
|
||||
func (r *RedisBackend) Exists(ctx context.Context, key string) (bool, error) {
|
||||
if r.closed.Load() {
|
||||
return false, ErrBackendClosed
|
||||
}
|
||||
|
||||
conn, err := r.pool.Get(ctx)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer r.pool.Put(conn)
|
||||
|
||||
prefixedKey := r.prefixKey(key)
|
||||
resp, err := conn.Do("EXISTS", prefixedKey)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
count, err := RESPInt(resp)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// Clear removes all keys with the configured prefix
|
||||
func (r *RedisBackend) Clear(ctx context.Context) error {
|
||||
if r.closed.Load() {
|
||||
return ErrBackendClosed
|
||||
}
|
||||
|
||||
conn, err := r.pool.Get(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer r.pool.Put(conn)
|
||||
|
||||
// Use FLUSHDB if no prefix (clear entire DB)
|
||||
if r.config.RedisPrefix == "" {
|
||||
_, err := conn.Do("FLUSHDB")
|
||||
return err
|
||||
}
|
||||
|
||||
// With prefix, we need to scan and delete keys
|
||||
// For simplicity in this implementation, we'll use KEYS pattern (not recommended for production at scale)
|
||||
pattern := r.config.RedisPrefix + "*"
|
||||
resp, err := conn.Do("KEYS", pattern)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Extract keys from array response
|
||||
keys, ok := resp.([]interface{})
|
||||
if !ok || len(keys) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete each key
|
||||
for _, keyInterface := range keys {
|
||||
key, err := RESPString(keyInterface)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
_, _ = conn.Do("DEL", key) // Best effort, ignore errors
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetStats returns backend statistics
|
||||
func (r *RedisBackend) GetStats() map[string]interface{} {
|
||||
hits := r.hits.Load()
|
||||
misses := r.misses.Load()
|
||||
total := hits + misses
|
||||
|
||||
hitRate := float64(0)
|
||||
if total > 0 {
|
||||
hitRate = float64(hits) / float64(total)
|
||||
}
|
||||
|
||||
stats := map[string]interface{}{
|
||||
"backend": "redis-pure-go",
|
||||
"address": r.config.RedisAddr,
|
||||
"hits": hits,
|
||||
"misses": misses,
|
||||
"hit_rate": hitRate,
|
||||
"pool": r.pool.Stats(),
|
||||
}
|
||||
|
||||
// Add health monitor stats if available
|
||||
if r.healthMonitor != nil {
|
||||
stats["health"] = r.healthMonitor.GetStats()
|
||||
}
|
||||
|
||||
return stats
|
||||
}
|
||||
|
||||
// Ping checks Redis connectivity
|
||||
func (r *RedisBackend) Ping(ctx context.Context) error {
|
||||
if r.closed.Load() {
|
||||
return ErrBackendClosed
|
||||
}
|
||||
|
||||
conn, err := r.pool.Get(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer r.pool.Put(conn)
|
||||
|
||||
_, err = conn.Do("PING")
|
||||
return err
|
||||
}
|
||||
|
||||
// Close closes the Redis backend and all connections
|
||||
func (r *RedisBackend) Close() error {
|
||||
if r.closed.Swap(true) {
|
||||
return nil // Already closed
|
||||
}
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
// Stop health monitor
|
||||
if r.healthMonitor != nil {
|
||||
r.healthMonitor.Stop()
|
||||
}
|
||||
|
||||
// Close connection pool
|
||||
if r.pool != nil {
|
||||
return r.pool.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// prefixKey adds the configured prefix to a key
|
||||
func (r *RedisBackend) prefixKey(key string) string {
|
||||
if r.config.RedisPrefix == "" {
|
||||
return key
|
||||
}
|
||||
return r.config.RedisPrefix + key
|
||||
}
|
||||
|
||||
// executeWithRetry executes a Redis operation with exponential backoff retry logic.
|
||||
// It checks context cancellation at multiple points to ensure fast abort when the
|
||||
// caller's context is cancelled (e.g., due to request timeout).
|
||||
func (r *RedisBackend) executeWithRetry(ctx context.Context, operation func(*RedisConn) error) error {
|
||||
maxRetries := 3
|
||||
baseDelay := 50 * time.Millisecond // Reduced from 100ms to fail faster
|
||||
|
||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||
// Check context before each attempt to fail fast
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
conn, err := r.pool.Get(ctx)
|
||||
if err != nil {
|
||||
// If we can't get a connection and this is the last attempt, fail
|
||||
if attempt == maxRetries-1 {
|
||||
return fmt.Errorf("failed to get connection after %d attempts: %w", maxRetries, err)
|
||||
}
|
||||
|
||||
// Wait with exponential backoff before retrying
|
||||
delay := baseDelay * time.Duration(1<<uint(attempt))
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(delay):
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Execute the operation
|
||||
err = operation(conn)
|
||||
r.pool.Put(conn)
|
||||
|
||||
// Check context after operation - if cancelled, don't bother retrying
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
// If successful, return
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// If error is not retryable or last attempt, fail
|
||||
if attempt == maxRetries-1 || !isRetryableError(err) {
|
||||
return err
|
||||
}
|
||||
|
||||
// Wait with exponential backoff before retrying
|
||||
delay := baseDelay * time.Duration(1<<uint(attempt))
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(delay):
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("operation failed after %d attempts", maxRetries)
|
||||
}
|
||||
|
||||
// isRetryableError determines if an error is worth retrying
|
||||
func isRetryableError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Retry on connection errors, timeouts, etc.
|
||||
// Don't retry on application-level errors like wrong type
|
||||
errMsg := err.Error()
|
||||
retryablePatterns := []string{
|
||||
"connection",
|
||||
"timeout",
|
||||
"EOF",
|
||||
"broken pipe",
|
||||
"reset by peer",
|
||||
}
|
||||
|
||||
for _, pattern := range retryablePatterns {
|
||||
if contains(errMsg, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// SetMany stores multiple values in Redis using pipelining for efficiency
|
||||
// This reduces N round-trips to a single round-trip
|
||||
func (r *RedisBackend) SetMany(ctx context.Context, items map[string][]byte, ttl time.Duration) error {
|
||||
if r.closed.Load() {
|
||||
return ErrBackendClosed
|
||||
}
|
||||
|
||||
if len(items) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// For single items, use regular Set
|
||||
if len(items) == 1 {
|
||||
for key, value := range items {
|
||||
return r.Set(ctx, key, value, ttl)
|
||||
}
|
||||
}
|
||||
|
||||
conn, err := r.pool.Get(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer r.pool.Put(conn)
|
||||
|
||||
pipeline := conn.NewPipeline()
|
||||
|
||||
// Queue all SET commands
|
||||
ttlSeconds := int(ttl.Seconds())
|
||||
ttlMillis := ttl.Milliseconds()
|
||||
|
||||
for key, value := range items {
|
||||
prefixedKey := r.prefixKey(key)
|
||||
|
||||
if ttl > 0 {
|
||||
if ttlMillis < 1000 {
|
||||
// Use PSETEX for sub-second TTLs
|
||||
pipeline.Queue("PSETEX", prefixedKey, fmt.Sprintf("%d", ttlMillis), string(value))
|
||||
} else {
|
||||
// Use SETEX for larger TTLs
|
||||
pipeline.Queue("SETEX", prefixedKey, fmt.Sprintf("%d", ttlSeconds), string(value))
|
||||
}
|
||||
} else {
|
||||
pipeline.Queue("SET", prefixedKey, string(value))
|
||||
}
|
||||
}
|
||||
|
||||
// Execute pipeline
|
||||
responses, err := pipeline.Execute()
|
||||
if err != nil {
|
||||
return fmt.Errorf("pipeline SetMany failed: %w", err)
|
||||
}
|
||||
|
||||
// Check responses for errors (each should be "OK")
|
||||
for i, resp := range responses {
|
||||
if resp == nil {
|
||||
continue
|
||||
}
|
||||
if str, ok := resp.(string); ok && str == "OK" {
|
||||
continue
|
||||
}
|
||||
return fmt.Errorf("SetMany: unexpected response at index %d: %v", i, resp)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetMany retrieves multiple values from Redis using pipelining for efficiency
|
||||
// This reduces N round-trips to a single round-trip
|
||||
func (r *RedisBackend) GetMany(ctx context.Context, keys []string) (map[string][]byte, error) {
|
||||
if r.closed.Load() {
|
||||
return nil, ErrBackendClosed
|
||||
}
|
||||
|
||||
if len(keys) == 0 {
|
||||
return make(map[string][]byte), nil
|
||||
}
|
||||
|
||||
// For single key, use regular Get
|
||||
if len(keys) == 1 {
|
||||
result := make(map[string][]byte)
|
||||
value, _, exists, err := r.Get(ctx, keys[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if exists {
|
||||
result[keys[0]] = value
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
conn, err := r.pool.Get(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer r.pool.Put(conn)
|
||||
|
||||
pipeline := conn.NewPipeline()
|
||||
|
||||
// Queue all GET commands
|
||||
prefixedKeys := make([]string, len(keys))
|
||||
for i, key := range keys {
|
||||
prefixedKeys[i] = r.prefixKey(key)
|
||||
pipeline.Queue("GET", prefixedKeys[i])
|
||||
}
|
||||
|
||||
// Execute pipeline
|
||||
responses, err := pipeline.Execute()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("pipeline GetMany failed: %w", err)
|
||||
}
|
||||
|
||||
// Process responses
|
||||
result := make(map[string][]byte)
|
||||
for i, resp := range responses {
|
||||
if resp == nil {
|
||||
// Key doesn't exist
|
||||
r.misses.Add(1)
|
||||
continue
|
||||
}
|
||||
|
||||
value, err := RESPString(resp)
|
||||
if err != nil {
|
||||
// Invalid response, skip this key
|
||||
r.misses.Add(1)
|
||||
continue
|
||||
}
|
||||
|
||||
r.hits.Add(1)
|
||||
result[keys[i]] = []byte(value)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
+170
@@ -0,0 +1,170 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// HealthMonitor continuously monitors Redis connection health and triggers reconnections
|
||||
type HealthMonitor struct {
|
||||
pool *ConnectionPool
|
||||
config *HealthMonitorConfig
|
||||
stopChan chan struct{}
|
||||
wg sync.WaitGroup
|
||||
lastCheckTime atomic.Int64
|
||||
consecutiveFailures atomic.Int64
|
||||
totalChecks atomic.Int64
|
||||
totalFailures atomic.Int64
|
||||
healthy atomic.Bool
|
||||
running atomic.Bool
|
||||
}
|
||||
|
||||
// HealthMonitorConfig configures the health monitor
|
||||
type HealthMonitorConfig struct {
|
||||
OnHealthChange func(healthy bool)
|
||||
CheckInterval time.Duration
|
||||
Timeout time.Duration
|
||||
UnhealthyThreshold int
|
||||
}
|
||||
|
||||
// DefaultHealthMonitorConfig returns default health monitor configuration
|
||||
func DefaultHealthMonitorConfig() *HealthMonitorConfig {
|
||||
return &HealthMonitorConfig{
|
||||
CheckInterval: 5 * time.Second,
|
||||
Timeout: 3 * time.Second,
|
||||
UnhealthyThreshold: 3,
|
||||
}
|
||||
}
|
||||
|
||||
// NewHealthMonitor creates a new health monitor
|
||||
func NewHealthMonitor(pool *ConnectionPool, config *HealthMonitorConfig) *HealthMonitor {
|
||||
if config == nil {
|
||||
config = DefaultHealthMonitorConfig()
|
||||
}
|
||||
|
||||
hm := &HealthMonitor{
|
||||
pool: pool,
|
||||
config: config,
|
||||
stopChan: make(chan struct{}),
|
||||
}
|
||||
|
||||
hm.healthy.Store(true) // Assume healthy initially
|
||||
return hm
|
||||
}
|
||||
|
||||
// Start begins health monitoring
|
||||
func (hm *HealthMonitor) Start() {
|
||||
if hm.running.Swap(true) {
|
||||
return // Already running
|
||||
}
|
||||
|
||||
hm.wg.Add(1)
|
||||
go hm.monitorLoop()
|
||||
}
|
||||
|
||||
// Stop stops health monitoring
|
||||
func (hm *HealthMonitor) Stop() {
|
||||
if !hm.running.Swap(false) {
|
||||
return // Not running
|
||||
}
|
||||
|
||||
close(hm.stopChan)
|
||||
hm.wg.Wait()
|
||||
}
|
||||
|
||||
// IsHealthy returns the current health status
|
||||
func (hm *HealthMonitor) IsHealthy() bool {
|
||||
return hm.healthy.Load()
|
||||
}
|
||||
|
||||
// GetStats returns health monitor statistics
|
||||
func (hm *HealthMonitor) GetStats() map[string]interface{} {
|
||||
lastCheck := time.Unix(hm.lastCheckTime.Load(), 0)
|
||||
|
||||
return map[string]interface{}{
|
||||
"healthy": hm.healthy.Load(),
|
||||
"consecutive_failures": hm.consecutiveFailures.Load(),
|
||||
"total_checks": hm.totalChecks.Load(),
|
||||
"total_failures": hm.totalFailures.Load(),
|
||||
"last_check": lastCheck,
|
||||
}
|
||||
}
|
||||
|
||||
// monitorLoop runs the health check loop
|
||||
func (hm *HealthMonitor) monitorLoop() {
|
||||
defer hm.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(hm.config.CheckInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// Perform initial check immediately
|
||||
hm.performHealthCheck()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-hm.stopChan:
|
||||
return
|
||||
case <-ticker.C:
|
||||
hm.performHealthCheck()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// performHealthCheck executes a health check
|
||||
func (hm *HealthMonitor) performHealthCheck() {
|
||||
hm.totalChecks.Add(1)
|
||||
hm.lastCheckTime.Store(time.Now().Unix())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), hm.config.Timeout)
|
||||
defer cancel()
|
||||
|
||||
// Try to get a connection and ping Redis
|
||||
conn, err := hm.pool.Get(ctx)
|
||||
if err != nil {
|
||||
hm.recordFailure()
|
||||
return
|
||||
}
|
||||
defer hm.pool.Put(conn)
|
||||
|
||||
// Ping Redis
|
||||
_, err = conn.Do("PING")
|
||||
if err != nil {
|
||||
hm.recordFailure()
|
||||
return
|
||||
}
|
||||
|
||||
// Success!
|
||||
hm.recordSuccess()
|
||||
}
|
||||
|
||||
// recordSuccess records a successful health check
|
||||
func (hm *HealthMonitor) recordSuccess() {
|
||||
wasHealthy := hm.healthy.Load()
|
||||
hm.consecutiveFailures.Store(0)
|
||||
hm.healthy.Store(true)
|
||||
|
||||
// Trigger callback if health changed
|
||||
if !wasHealthy && hm.config.OnHealthChange != nil {
|
||||
hm.config.OnHealthChange(true)
|
||||
}
|
||||
}
|
||||
|
||||
// recordFailure records a failed health check
|
||||
func (hm *HealthMonitor) recordFailure() {
|
||||
hm.totalFailures.Add(1)
|
||||
failures := hm.consecutiveFailures.Add(1)
|
||||
|
||||
wasHealthy := hm.healthy.Load()
|
||||
|
||||
// Mark unhealthy if threshold exceeded
|
||||
if failures >= int64(hm.config.UnhealthyThreshold) {
|
||||
hm.healthy.Store(false)
|
||||
|
||||
// Trigger callback if health changed
|
||||
if wasHealthy && hm.config.OnHealthChange != nil {
|
||||
hm.config.OnHealthChange(false)
|
||||
}
|
||||
}
|
||||
}
|
||||
+421
@@ -0,0 +1,421 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestHealthMonitor_BasicOperation tests basic health monitoring
|
||||
func TestHealthMonitor_BasicOperation(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 5,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
// Create health monitor with fast check interval for testing
|
||||
hmConfig := &HealthMonitorConfig{
|
||||
CheckInterval: 100 * time.Millisecond,
|
||||
Timeout: 1 * time.Second,
|
||||
UnhealthyThreshold: 2,
|
||||
}
|
||||
|
||||
hm := NewHealthMonitor(pool, hmConfig)
|
||||
require.NotNil(t, hm)
|
||||
|
||||
// Initially should be healthy
|
||||
assert.True(t, hm.IsHealthy())
|
||||
|
||||
// Start monitoring
|
||||
hm.Start()
|
||||
defer hm.Stop()
|
||||
|
||||
// Wait for a few checks
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// Should still be healthy
|
||||
assert.True(t, hm.IsHealthy())
|
||||
|
||||
// Check stats
|
||||
stats := hm.GetStats()
|
||||
require.NotNil(t, stats)
|
||||
assert.True(t, stats["healthy"].(bool))
|
||||
assert.Greater(t, stats["total_checks"].(int64), int64(0))
|
||||
assert.Equal(t, int64(0), stats["consecutive_failures"].(int64))
|
||||
}
|
||||
|
||||
// TestHealthMonitor_HealthyToUnhealthy tests transition to unhealthy state
|
||||
func TestHealthMonitor_HealthyToUnhealthy(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 5,
|
||||
ConnectTimeout: 100 * time.Millisecond,
|
||||
ReadTimeout: 100 * time.Millisecond,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
var healthChangedCalled atomic.Bool
|
||||
hmConfig := &HealthMonitorConfig{
|
||||
CheckInterval: 50 * time.Millisecond,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
UnhealthyThreshold: 2,
|
||||
OnHealthChange: func(healthy bool) {
|
||||
if !healthy {
|
||||
healthChangedCalled.Store(true)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
hm := NewHealthMonitor(pool, hmConfig)
|
||||
hm.Start()
|
||||
defer hm.Stop()
|
||||
|
||||
// Initially healthy
|
||||
assert.True(t, hm.IsHealthy())
|
||||
|
||||
// Simulate Redis errors
|
||||
mr.SetError("ERR server is down")
|
||||
|
||||
// Wait for health checks to detect failure (2 failures * 50ms + buffer)
|
||||
time.Sleep(350 * time.Millisecond)
|
||||
|
||||
// Should now be unhealthy
|
||||
assert.False(t, hm.IsHealthy(), "Health monitor should detect server failure")
|
||||
assert.True(t, healthChangedCalled.Load(), "OnHealthChange callback should be called")
|
||||
|
||||
// Check stats
|
||||
stats := hm.GetStats()
|
||||
assert.False(t, stats["healthy"].(bool))
|
||||
assert.GreaterOrEqual(t, stats["consecutive_failures"].(int64), int64(2))
|
||||
assert.Greater(t, stats["total_failures"].(int64), int64(0))
|
||||
}
|
||||
|
||||
// TestHealthMonitor_UnhealthyToHealthy tests recovery to healthy state
|
||||
func TestHealthMonitor_UnhealthyToHealthy(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 5,
|
||||
ConnectTimeout: 100 * time.Millisecond,
|
||||
ReadTimeout: 100 * time.Millisecond,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
var recoveryDetected atomic.Bool
|
||||
hmConfig := &HealthMonitorConfig{
|
||||
CheckInterval: 50 * time.Millisecond,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
UnhealthyThreshold: 2,
|
||||
OnHealthChange: func(healthy bool) {
|
||||
if healthy {
|
||||
recoveryDetected.Store(true)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
hm := NewHealthMonitor(pool, hmConfig)
|
||||
hm.Start()
|
||||
defer hm.Stop()
|
||||
|
||||
// Initially healthy
|
||||
assert.True(t, hm.IsHealthy())
|
||||
|
||||
// Simulate Redis errors
|
||||
mr.SetError("ERR server is down")
|
||||
|
||||
// Wait for health checks to detect failure
|
||||
time.Sleep(350 * time.Millisecond)
|
||||
|
||||
// Should now be unhealthy
|
||||
assert.False(t, hm.IsHealthy(), "Should detect server failure")
|
||||
|
||||
// Clear error to simulate recovery
|
||||
mr.ClearError()
|
||||
|
||||
// Wait for recovery
|
||||
time.Sleep(350 * time.Millisecond)
|
||||
|
||||
// Should be healthy again
|
||||
assert.True(t, hm.IsHealthy(), "Should recover after server restart")
|
||||
assert.True(t, recoveryDetected.Load(), "Recovery callback should be called")
|
||||
|
||||
// Consecutive failures should be reset
|
||||
stats := hm.GetStats()
|
||||
assert.True(t, stats["healthy"].(bool))
|
||||
assert.Equal(t, int64(0), stats["consecutive_failures"].(int64))
|
||||
}
|
||||
|
||||
// TestHealthMonitor_StartStop tests start/stop behavior
|
||||
func TestHealthMonitor_StartStop(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 5,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
hm := NewHealthMonitor(pool, DefaultHealthMonitorConfig())
|
||||
|
||||
// Start monitoring
|
||||
hm.Start()
|
||||
assert.True(t, hm.running.Load())
|
||||
|
||||
// Starting again should be no-op
|
||||
hm.Start()
|
||||
assert.True(t, hm.running.Load())
|
||||
|
||||
// Stop monitoring
|
||||
hm.Stop()
|
||||
assert.False(t, hm.running.Load())
|
||||
|
||||
// Stopping again should be no-op
|
||||
hm.Stop()
|
||||
assert.False(t, hm.running.Load())
|
||||
}
|
||||
|
||||
// TestHealthMonitor_MultipleMonitors tests multiple health monitors
|
||||
func TestHealthMonitor_MultipleMonitors(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 10,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
// Create multiple monitors
|
||||
hm1 := NewHealthMonitor(pool, &HealthMonitorConfig{
|
||||
CheckInterval: 100 * time.Millisecond,
|
||||
Timeout: 1 * time.Second,
|
||||
UnhealthyThreshold: 2,
|
||||
})
|
||||
|
||||
hm2 := NewHealthMonitor(pool, &HealthMonitorConfig{
|
||||
CheckInterval: 150 * time.Millisecond,
|
||||
Timeout: 1 * time.Second,
|
||||
UnhealthyThreshold: 3,
|
||||
})
|
||||
|
||||
// Start both
|
||||
hm1.Start()
|
||||
hm2.Start()
|
||||
|
||||
// Both should be healthy
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
assert.True(t, hm1.IsHealthy())
|
||||
assert.True(t, hm2.IsHealthy())
|
||||
|
||||
// Stop both
|
||||
hm1.Stop()
|
||||
hm2.Stop()
|
||||
|
||||
// Verify they stopped
|
||||
assert.False(t, hm1.running.Load())
|
||||
assert.False(t, hm2.running.Load())
|
||||
}
|
||||
|
||||
// TestHealthMonitor_StatsAccuracy tests stats tracking
|
||||
func TestHealthMonitor_StatsAccuracy(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 5,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
hm := NewHealthMonitor(pool, &HealthMonitorConfig{
|
||||
CheckInterval: 100 * time.Millisecond,
|
||||
Timeout: 1 * time.Second,
|
||||
UnhealthyThreshold: 2,
|
||||
})
|
||||
|
||||
hm.Start()
|
||||
defer hm.Stop()
|
||||
|
||||
// Wait for some checks
|
||||
time.Sleep(550 * time.Millisecond)
|
||||
|
||||
stats := hm.GetStats()
|
||||
|
||||
// Should have performed multiple checks
|
||||
totalChecks := stats["total_checks"].(int64)
|
||||
assert.GreaterOrEqual(t, totalChecks, int64(4))
|
||||
|
||||
// All checks should succeed
|
||||
assert.Equal(t, int64(0), stats["total_failures"].(int64))
|
||||
assert.Equal(t, int64(0), stats["consecutive_failures"].(int64))
|
||||
|
||||
// Last check time should be recent (within check interval + buffer)
|
||||
// Use 2s tolerance to account for CI runner load and timing variance
|
||||
lastCheck := stats["last_check"].(time.Time)
|
||||
assert.WithinDuration(t, time.Now(), lastCheck, 2*time.Second)
|
||||
}
|
||||
|
||||
// TestHealthMonitor_DefaultConfig tests default configuration
|
||||
func TestHealthMonitor_DefaultConfig(t *testing.T) {
|
||||
config := DefaultHealthMonitorConfig()
|
||||
|
||||
assert.Equal(t, 5*time.Second, config.CheckInterval)
|
||||
assert.Equal(t, 3*time.Second, config.Timeout)
|
||||
assert.Equal(t, 3, config.UnhealthyThreshold)
|
||||
assert.Nil(t, config.OnHealthChange)
|
||||
}
|
||||
|
||||
// TestHealthMonitor_PoolExhaustion tests behavior when pool is exhausted
|
||||
func TestHealthMonitor_PoolExhaustion(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 1, // Very small pool
|
||||
ConnectTimeout: 100 * time.Millisecond,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
hm := NewHealthMonitor(pool, &HealthMonitorConfig{
|
||||
CheckInterval: 100 * time.Millisecond,
|
||||
Timeout: 50 * time.Millisecond, // Short timeout
|
||||
UnhealthyThreshold: 2,
|
||||
})
|
||||
|
||||
hm.Start()
|
||||
defer hm.Stop()
|
||||
|
||||
// Get the only connection, blocking health checks
|
||||
ctx := context.Background()
|
||||
conn, err := pool.Get(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wait for health check attempts
|
||||
time.Sleep(350 * time.Millisecond)
|
||||
|
||||
// Health monitor might mark as unhealthy due to timeouts
|
||||
stats := hm.GetStats()
|
||||
t.Logf("Stats with blocked pool: %+v", stats)
|
||||
|
||||
// Return connection
|
||||
pool.Put(conn)
|
||||
|
||||
// Wait for recovery
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
|
||||
// Should recover
|
||||
assert.True(t, hm.IsHealthy())
|
||||
}
|
||||
|
||||
// TestConnectionPool_WithHealthChecks tests pool with health checks enabled
|
||||
func TestConnectionPool_WithHealthChecks(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 5,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
EnableHealthCheck: true,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Get a connection
|
||||
conn, err := pool.Get(ctx)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conn)
|
||||
|
||||
// Connection should be healthy
|
||||
assert.True(t, pool.isConnectionHealthy(conn))
|
||||
|
||||
// Use connection
|
||||
resp, err := conn.Do("PING")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "PONG", resp)
|
||||
|
||||
// Return to pool
|
||||
pool.Put(conn)
|
||||
|
||||
// Get again - should reuse and validate
|
||||
conn2, err := pool.Get(ctx)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conn2)
|
||||
|
||||
pool.Put(conn2)
|
||||
}
|
||||
|
||||
// TestConnectionPool_StaleConnectionRemoval tests stale connection handling
|
||||
func TestConnectionPool_StaleConnectionRemoval(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 3,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
EnableHealthCheck: true,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Get and return a connection
|
||||
conn, err := pool.Get(ctx)
|
||||
require.NoError(t, err)
|
||||
pool.Put(conn)
|
||||
|
||||
initialTotal := pool.totalConns.Load()
|
||||
|
||||
// Close the connection manually to make it stale
|
||||
conn.Close()
|
||||
|
||||
// Get another connection - should detect stale and create new
|
||||
conn2, err := pool.Get(ctx)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conn2)
|
||||
|
||||
// Connection should be healthy
|
||||
assert.True(t, pool.isConnectionHealthy(conn2))
|
||||
|
||||
pool.Put(conn2)
|
||||
|
||||
// Total connections might be same or less (stale removed)
|
||||
finalTotal := pool.totalConns.Load()
|
||||
assert.LessOrEqual(t, finalTotal, initialTotal+1)
|
||||
}
|
||||
+461
@@ -0,0 +1,461 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// setupTestRedis creates a miniredis instance for testing
|
||||
func setupTestRedis(t *testing.T) (*miniredis.Miniredis, *RedisBackend) {
|
||||
t.Helper()
|
||||
|
||||
mr, err := miniredis.Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
mr.Close()
|
||||
})
|
||||
|
||||
backend, err := NewRedisBackend(&Config{
|
||||
RedisAddr: mr.Addr(),
|
||||
RedisPrefix: "test:",
|
||||
PoolSize: 5,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
backend.Close()
|
||||
})
|
||||
|
||||
return mr, backend
|
||||
}
|
||||
|
||||
// TestPipeline_Basic tests basic pipeline functionality
|
||||
func TestPipeline_Basic(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr, err := miniredis.Run()
|
||||
require.NoError(t, err)
|
||||
defer mr.Close()
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.Addr(),
|
||||
MaxConnections: 5,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
ReadTimeout: 1 * time.Second,
|
||||
WriteTimeout: 1 * time.Second,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
conn, err := pool.Get(ctx)
|
||||
require.NoError(t, err)
|
||||
defer pool.Put(conn)
|
||||
|
||||
t.Run("SingleCommand", func(t *testing.T) {
|
||||
pipeline := conn.NewPipeline()
|
||||
pipeline.Queue("SET", "single-key", "single-value")
|
||||
|
||||
responses, err := pipeline.Execute()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, responses, 1)
|
||||
assert.Equal(t, "OK", responses[0])
|
||||
})
|
||||
|
||||
t.Run("MultipleCommands", func(t *testing.T) {
|
||||
pipeline := conn.NewPipeline()
|
||||
pipeline.Queue("SET", "key1", "value1")
|
||||
pipeline.Queue("SET", "key2", "value2")
|
||||
pipeline.Queue("SET", "key3", "value3")
|
||||
pipeline.Queue("GET", "key1")
|
||||
pipeline.Queue("GET", "key2")
|
||||
pipeline.Queue("GET", "key3")
|
||||
|
||||
responses, err := pipeline.Execute()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, responses, 6)
|
||||
|
||||
// First 3 are SET responses
|
||||
assert.Equal(t, "OK", responses[0])
|
||||
assert.Equal(t, "OK", responses[1])
|
||||
assert.Equal(t, "OK", responses[2])
|
||||
|
||||
// Last 3 are GET responses
|
||||
assert.Equal(t, "value1", responses[3])
|
||||
assert.Equal(t, "value2", responses[4])
|
||||
assert.Equal(t, "value3", responses[5])
|
||||
})
|
||||
|
||||
t.Run("EmptyPipeline", func(t *testing.T) {
|
||||
pipeline := conn.NewPipeline()
|
||||
|
||||
responses, err := pipeline.Execute()
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, responses)
|
||||
})
|
||||
|
||||
t.Run("NilResponses", func(t *testing.T) {
|
||||
pipeline := conn.NewPipeline()
|
||||
pipeline.Queue("GET", "nonexistent-key")
|
||||
|
||||
responses, err := pipeline.Execute()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, responses, 1)
|
||||
assert.Nil(t, responses[0])
|
||||
})
|
||||
}
|
||||
|
||||
// TestPipeline_SetMany tests pipelined SetMany
|
||||
func TestPipeline_SetMany(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, backend := setupTestRedis(t)
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("SetManyItems", func(t *testing.T) {
|
||||
items := make(map[string][]byte)
|
||||
for i := 0; i < 10; i++ {
|
||||
items[fmt.Sprintf("setmany-key-%d", i)] = []byte(fmt.Sprintf("value-%d", i))
|
||||
}
|
||||
|
||||
err := backend.SetMany(ctx, items, time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify all items were set
|
||||
for key, expectedValue := range items {
|
||||
value, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Key %s should exist", key)
|
||||
assert.Equal(t, expectedValue, value)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("SetManyEmpty", func(t *testing.T) {
|
||||
err := backend.SetMany(ctx, map[string][]byte{}, time.Minute)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("SetManySingleItem", func(t *testing.T) {
|
||||
items := map[string][]byte{
|
||||
"single-setmany": []byte("single-value"),
|
||||
}
|
||||
|
||||
err := backend.SetMany(ctx, items, time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
value, _, exists, err := backend.Get(ctx, "single-setmany")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, []byte("single-value"), value)
|
||||
})
|
||||
|
||||
t.Run("SetManyNoTTL", func(t *testing.T) {
|
||||
items := map[string][]byte{
|
||||
"nottl-key1": []byte("value1"),
|
||||
"nottl-key2": []byte("value2"),
|
||||
}
|
||||
|
||||
err := backend.SetMany(ctx, items, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Keys should exist
|
||||
for key := range items {
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestPipeline_GetMany tests pipelined GetMany
|
||||
func TestPipeline_GetMany(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, backend := setupTestRedis(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Pre-populate cache
|
||||
for i := 0; i < 10; i++ {
|
||||
key := fmt.Sprintf("getmany-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("value-%d", i))
|
||||
err := backend.Set(ctx, key, value, time.Minute)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
t.Run("GetManyExisting", func(t *testing.T) {
|
||||
keys := make([]string, 10)
|
||||
for i := 0; i < 10; i++ {
|
||||
keys[i] = fmt.Sprintf("getmany-key-%d", i)
|
||||
}
|
||||
|
||||
results, err := backend.GetMany(ctx, keys)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 10)
|
||||
|
||||
for i, key := range keys {
|
||||
assert.Equal(t, []byte(fmt.Sprintf("value-%d", i)), results[key])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetManyMixed", func(t *testing.T) {
|
||||
keys := []string{
|
||||
"getmany-key-0", // exists
|
||||
"nonexistent-key-1", // doesn't exist
|
||||
"getmany-key-2", // exists
|
||||
"nonexistent-key-2", // doesn't exist
|
||||
}
|
||||
|
||||
results, err := backend.GetMany(ctx, keys)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 2) // Only existing keys
|
||||
|
||||
assert.Equal(t, []byte("value-0"), results["getmany-key-0"])
|
||||
assert.Equal(t, []byte("value-2"), results["getmany-key-2"])
|
||||
assert.NotContains(t, results, "nonexistent-key-1")
|
||||
assert.NotContains(t, results, "nonexistent-key-2")
|
||||
})
|
||||
|
||||
t.Run("GetManyEmpty", func(t *testing.T) {
|
||||
results, err := backend.GetMany(ctx, []string{})
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, results)
|
||||
assert.Len(t, results, 0)
|
||||
})
|
||||
|
||||
t.Run("GetManySingleKey", func(t *testing.T) {
|
||||
results, err := backend.GetMany(ctx, []string{"getmany-key-5"})
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 1)
|
||||
assert.Equal(t, []byte("value-5"), results["getmany-key-5"])
|
||||
})
|
||||
|
||||
t.Run("GetManyAllNonexistent", func(t *testing.T) {
|
||||
keys := []string{
|
||||
"nonexistent-1",
|
||||
"nonexistent-2",
|
||||
"nonexistent-3",
|
||||
}
|
||||
|
||||
results, err := backend.GetMany(ctx, keys)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 0)
|
||||
})
|
||||
}
|
||||
|
||||
// TestPipeline_LargeBatch tests pipelining with large batches
|
||||
func TestPipeline_LargeBatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, backend := setupTestRedis(t)
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("SetMany100Items", func(t *testing.T) {
|
||||
items := make(map[string][]byte)
|
||||
for i := 0; i < 100; i++ {
|
||||
items[fmt.Sprintf("large-batch-%d", i)] = []byte(fmt.Sprintf("value-%d", i))
|
||||
}
|
||||
|
||||
err := backend.SetMany(ctx, items, time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify random samples
|
||||
for _, i := range []int{0, 25, 50, 75, 99} {
|
||||
key := fmt.Sprintf("large-batch-%d", i)
|
||||
value, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, []byte(fmt.Sprintf("value-%d", i)), value)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetMany100Items", func(t *testing.T) {
|
||||
keys := make([]string, 100)
|
||||
for i := 0; i < 100; i++ {
|
||||
keys[i] = fmt.Sprintf("large-batch-%d", i)
|
||||
}
|
||||
|
||||
results, err := backend.GetMany(ctx, keys)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 100)
|
||||
})
|
||||
}
|
||||
|
||||
// TestPipeline_Stats tests that stats are tracked correctly with pipelining
|
||||
func TestPipeline_Stats(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, backend := setupTestRedis(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Set some items
|
||||
items := map[string][]byte{
|
||||
"stats-key-1": []byte("value1"),
|
||||
"stats-key-2": []byte("value2"),
|
||||
}
|
||||
err := backend.SetMany(ctx, items, time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get items (some exist, some don't)
|
||||
keys := []string{
|
||||
"stats-key-1",
|
||||
"stats-key-2",
|
||||
"stats-key-nonexistent",
|
||||
}
|
||||
results, err := backend.GetMany(ctx, keys)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 2)
|
||||
|
||||
// Check stats
|
||||
stats := backend.GetStats()
|
||||
hits := stats["hits"].(int64)
|
||||
misses := stats["misses"].(int64)
|
||||
|
||||
assert.Equal(t, int64(2), hits, "Should have 2 hits")
|
||||
assert.Equal(t, int64(1), misses, "Should have 1 miss")
|
||||
}
|
||||
|
||||
// BenchmarkPipeline_SetMany benchmarks SetMany with pipelining
|
||||
func BenchmarkPipeline_SetMany(b *testing.B) {
|
||||
mr, err := miniredis.Run()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer mr.Close()
|
||||
|
||||
backend, err := NewRedisBackend(&Config{
|
||||
RedisAddr: mr.Addr(),
|
||||
RedisPrefix: "bench:",
|
||||
PoolSize: 10,
|
||||
})
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Prepare items
|
||||
items := make(map[string][]byte)
|
||||
for i := 0; i < 100; i++ {
|
||||
items[fmt.Sprintf("bench-key-%d", i)] = []byte(fmt.Sprintf("bench-value-%d", i))
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = backend.SetMany(ctx, items, time.Minute)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkPipeline_GetMany benchmarks GetMany with pipelining
|
||||
func BenchmarkPipeline_GetMany(b *testing.B) {
|
||||
mr, err := miniredis.Run()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer mr.Close()
|
||||
|
||||
backend, err := NewRedisBackend(&Config{
|
||||
RedisAddr: mr.Addr(),
|
||||
RedisPrefix: "bench:",
|
||||
PoolSize: 10,
|
||||
})
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Pre-populate cache
|
||||
for i := 0; i < 100; i++ {
|
||||
key := fmt.Sprintf("bench-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("bench-value-%d", i))
|
||||
backend.Set(ctx, key, value, time.Hour)
|
||||
}
|
||||
|
||||
// Prepare keys
|
||||
keys := make([]string, 100)
|
||||
for i := 0; i < 100; i++ {
|
||||
keys[i] = fmt.Sprintf("bench-key-%d", i)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = backend.GetMany(ctx, keys)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkPipeline_VsSequential benchmarks pipeline vs sequential operations
|
||||
func BenchmarkPipeline_VsSequential(b *testing.B) {
|
||||
mr, err := miniredis.Run()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer mr.Close()
|
||||
|
||||
backend, err := NewRedisBackend(&Config{
|
||||
RedisAddr: mr.Addr(),
|
||||
RedisPrefix: "bench:",
|
||||
PoolSize: 10,
|
||||
})
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Prepare items
|
||||
items := make(map[string][]byte)
|
||||
keys := make([]string, 50)
|
||||
for i := 0; i < 50; i++ {
|
||||
key := fmt.Sprintf("compare-key-%d", i)
|
||||
keys[i] = key
|
||||
items[key] = []byte(fmt.Sprintf("compare-value-%d", i))
|
||||
}
|
||||
|
||||
b.Run("Pipelined-Set", func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = backend.SetMany(ctx, items, time.Minute)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("Sequential-Set", func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
for key, value := range items {
|
||||
_ = backend.Set(ctx, key, value, time.Minute)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Pre-populate for get benchmarks
|
||||
_ = backend.SetMany(ctx, items, time.Hour)
|
||||
|
||||
b.Run("Pipelined-Get", func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = backend.GetMany(ctx, keys)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("Sequential-Get", func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
for _, key := range keys {
|
||||
_, _, _, _ = backend.Get(ctx, key)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
+455
@@ -0,0 +1,455 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ConnectionPool manages a pool of Redis connections
|
||||
// Pure-Go implementation compatible with Yaegi
|
||||
type ConnectionPool struct {
|
||||
config *PoolConfig
|
||||
|
||||
connections chan *RedisConn
|
||||
mu sync.Mutex
|
||||
closed atomic.Bool
|
||||
|
||||
// Metrics
|
||||
activeConns atomic.Int32
|
||||
totalConns atomic.Int32
|
||||
gets atomic.Int64
|
||||
puts atomic.Int64
|
||||
timeouts atomic.Int64
|
||||
}
|
||||
|
||||
// PoolConfig holds connection pool configuration
|
||||
type PoolConfig struct {
|
||||
Address string
|
||||
Password string
|
||||
DB int
|
||||
MaxConnections int
|
||||
ConnectTimeout time.Duration
|
||||
ReadTimeout time.Duration
|
||||
WriteTimeout time.Duration
|
||||
EnableHealthCheck bool // Enable connection health validation
|
||||
MaxRetries int // Max retries for failed operations
|
||||
RetryDelay time.Duration // Initial delay between retries
|
||||
}
|
||||
|
||||
// NewConnectionPool creates a new connection pool
|
||||
func NewConnectionPool(config *PoolConfig) (*ConnectionPool, error) {
|
||||
if config == nil {
|
||||
return nil, errors.New("config is required")
|
||||
}
|
||||
|
||||
if config.MaxConnections <= 0 {
|
||||
config.MaxConnections = 10
|
||||
}
|
||||
|
||||
if config.ConnectTimeout == 0 {
|
||||
config.ConnectTimeout = 5 * time.Second
|
||||
}
|
||||
|
||||
pool := &ConnectionPool{
|
||||
config: config,
|
||||
connections: make(chan *RedisConn, config.MaxConnections),
|
||||
}
|
||||
|
||||
return pool, nil
|
||||
}
|
||||
|
||||
// Get retrieves a connection from the pool or creates a new one
|
||||
func (p *ConnectionPool) Get(ctx context.Context) (*RedisConn, error) {
|
||||
if p.closed.Load() {
|
||||
return nil, ErrBackendClosed
|
||||
}
|
||||
|
||||
p.gets.Add(1)
|
||||
|
||||
// Try to get a connection with validation
|
||||
maxAttempts := 3
|
||||
for attempt := 0; attempt < maxAttempts; attempt++ {
|
||||
var conn *RedisConn
|
||||
var err error
|
||||
|
||||
select {
|
||||
case conn = <-p.connections:
|
||||
// Reuse existing connection - validate if health check enabled
|
||||
if p.config.EnableHealthCheck && !p.isConnectionHealthy(conn) {
|
||||
// Connection is stale, close it and try again
|
||||
_ = conn.Close()
|
||||
p.totalConns.Add(-1)
|
||||
continue
|
||||
}
|
||||
p.activeConns.Add(1)
|
||||
return conn, nil
|
||||
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
|
||||
default:
|
||||
// No available connection, create new one if under limit
|
||||
// #nosec G115 -- MaxConnections is a small config value that fits in int32
|
||||
if p.totalConns.Load() < int32(p.config.MaxConnections) {
|
||||
conn, err = p.createConnection()
|
||||
if err != nil {
|
||||
// If this is the last attempt, return error
|
||||
if attempt == maxAttempts-1 {
|
||||
return nil, err
|
||||
}
|
||||
// Wait before retry with exponential backoff
|
||||
time.Sleep(time.Duration(attempt+1) * 100 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
p.activeConns.Add(1)
|
||||
p.totalConns.Add(1)
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// Pool exhausted, wait for a connection with timeout
|
||||
select {
|
||||
case conn = <-p.connections:
|
||||
// Validate connection if health check enabled
|
||||
if p.config.EnableHealthCheck && !p.isConnectionHealthy(conn) {
|
||||
_ = conn.Close()
|
||||
p.totalConns.Add(-1)
|
||||
continue
|
||||
}
|
||||
p.activeConns.Add(1)
|
||||
return conn, nil
|
||||
case <-ctx.Done():
|
||||
p.timeouts.Add(1)
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(p.config.ConnectTimeout):
|
||||
p.timeouts.Add(1)
|
||||
return nil, ErrPoolExhausted
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errors.New("failed to get healthy connection after retries")
|
||||
}
|
||||
|
||||
// Put returns a connection to the pool
|
||||
func (p *ConnectionPool) Put(conn *RedisConn) {
|
||||
if conn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
p.puts.Add(1)
|
||||
p.activeConns.Add(-1)
|
||||
|
||||
if p.closed.Load() || conn.closed.Load() {
|
||||
_ = conn.Close()
|
||||
p.totalConns.Add(-1)
|
||||
return
|
||||
}
|
||||
|
||||
// Return to pool (non-blocking)
|
||||
select {
|
||||
case p.connections <- conn:
|
||||
// Successfully returned to pool
|
||||
default:
|
||||
// Pool full, close connection
|
||||
_ = conn.Close()
|
||||
p.totalConns.Add(-1)
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes all connections in the pool
|
||||
func (p *ConnectionPool) Close() error {
|
||||
if p.closed.Swap(true) {
|
||||
return nil
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
close(p.connections)
|
||||
|
||||
// Close all pooled connections
|
||||
for conn := range p.connections {
|
||||
_ = conn.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stats returns pool statistics
|
||||
func (p *ConnectionPool) Stats() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"active_connections": p.activeConns.Load(),
|
||||
"total_connections": p.totalConns.Load(),
|
||||
"max_connections": p.config.MaxConnections,
|
||||
"gets": p.gets.Load(),
|
||||
"puts": p.puts.Load(),
|
||||
"timeouts": p.timeouts.Load(),
|
||||
}
|
||||
}
|
||||
|
||||
// createConnection creates a new Redis connection
|
||||
func (p *ConnectionPool) createConnection() (*RedisConn, error) {
|
||||
// Connect with timeout
|
||||
dialer := &net.Dialer{
|
||||
Timeout: p.config.ConnectTimeout,
|
||||
}
|
||||
|
||||
conn, err := dialer.Dial("tcp", p.config.Address)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to Redis: %w", err)
|
||||
}
|
||||
|
||||
redisConn := &RedisConn{
|
||||
conn: conn,
|
||||
readTimeout: p.config.ReadTimeout,
|
||||
writeTimeout: p.config.WriteTimeout,
|
||||
}
|
||||
|
||||
// Authenticate if password is provided
|
||||
if p.config.Password != "" {
|
||||
if _, err := redisConn.Do("AUTH", p.config.Password); err != nil {
|
||||
_ = redisConn.Close()
|
||||
return nil, fmt.Errorf("authentication failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Select database
|
||||
if p.config.DB != 0 {
|
||||
if _, err := redisConn.Do("SELECT", fmt.Sprintf("%d", p.config.DB)); err != nil {
|
||||
_ = redisConn.Close()
|
||||
return nil, fmt.Errorf("failed to select database: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return redisConn, nil
|
||||
}
|
||||
|
||||
// RedisConn represents a single Redis connection
|
||||
type RedisConn struct {
|
||||
conn net.Conn
|
||||
readTimeout time.Duration
|
||||
writeTimeout time.Duration
|
||||
closed atomic.Bool
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// Do executes a Redis command and returns the response
|
||||
func (c *RedisConn) Do(command string, args ...string) (interface{}, error) {
|
||||
if c.closed.Load() {
|
||||
return nil, ErrBackendClosed
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
// Validate argument count to prevent integer overflow in slice operations
|
||||
// maxSafeArgs is set to (1<<20)-1 = 1,048,575 which is more than any reasonable Redis command
|
||||
const maxSafeArgs = (1 << 20) - 1
|
||||
if len(args) > maxSafeArgs {
|
||||
return nil, errors.New("too many arguments: exceeds maximum safe count")
|
||||
}
|
||||
|
||||
// Build command arguments
|
||||
// Validate total argument size to prevent memory exhaustion
|
||||
const maxTotalArgBytes = 64 << 20 // 64 MiB max total size
|
||||
totalBytes := len(command)
|
||||
for _, s := range args {
|
||||
// Protect against possible overflow
|
||||
if len(s) > maxTotalArgBytes-totalBytes {
|
||||
return nil, errors.New("arguments too large (would overflow maximum allowed total size)")
|
||||
}
|
||||
totalBytes += len(s)
|
||||
if totalBytes > maxTotalArgBytes {
|
||||
return nil, errors.New("total argument size exceeds maximum allowed")
|
||||
}
|
||||
}
|
||||
// Build command slice: prepend command to args
|
||||
// Using append avoids arithmetic on potentially large len(args)
|
||||
cmdArgs := append([]string{command}, args...)
|
||||
|
||||
// Set write timeout
|
||||
if c.writeTimeout > 0 {
|
||||
_ = c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
|
||||
}
|
||||
|
||||
// Write command (using pooled writer for memory efficiency)
|
||||
writer := NewRESPWriter(c.conn)
|
||||
err := writer.WriteCommand(cmdArgs...)
|
||||
writer.Release() // Return to pool immediately after use
|
||||
if err != nil {
|
||||
c.closed.Store(true)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Set read timeout
|
||||
if c.readTimeout > 0 {
|
||||
_ = c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
|
||||
}
|
||||
|
||||
// Read response (using pooled reader for memory efficiency)
|
||||
reader := NewRESPReader(c.conn)
|
||||
resp, err := reader.ReadResponse()
|
||||
reader.Release() // Return to pool immediately after use
|
||||
if err != nil {
|
||||
if !errors.Is(err, ErrNilResponse) {
|
||||
c.closed.Store(true)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// Close closes the connection
|
||||
func (c *RedisConn) Close() error {
|
||||
if c.closed.Swap(true) {
|
||||
return nil
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.conn != nil {
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isConnectionHealthy validates a connection is still working
|
||||
func (p *ConnectionPool) isConnectionHealthy(conn *RedisConn) bool {
|
||||
if conn == nil || conn.closed.Load() {
|
||||
return false
|
||||
}
|
||||
|
||||
// Set a read deadline for the ping
|
||||
if conn.conn != nil {
|
||||
_ = conn.conn.SetReadDeadline(time.Now().Add(1 * time.Second))
|
||||
defer func() { _ = conn.conn.SetReadDeadline(time.Time{}) }() // Clear deadline
|
||||
}
|
||||
|
||||
_, err := conn.Do("PING")
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// Pipeline represents a Redis pipeline for batch operations
|
||||
// It queues multiple commands and executes them in a single round-trip
|
||||
type Pipeline struct {
|
||||
conn *RedisConn
|
||||
commands []pipelineCommand
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// pipelineCommand represents a single command in the pipeline
|
||||
type pipelineCommand struct {
|
||||
command string
|
||||
args []string
|
||||
}
|
||||
|
||||
// NewPipeline creates a new pipeline for the connection
|
||||
func (c *RedisConn) NewPipeline() *Pipeline {
|
||||
return &Pipeline{
|
||||
conn: c,
|
||||
commands: make([]pipelineCommand, 0, 16), // Pre-allocate for typical batch size
|
||||
}
|
||||
}
|
||||
|
||||
// Queue adds a command to the pipeline
|
||||
func (p *Pipeline) Queue(command string, args ...string) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
p.commands = append(p.commands, pipelineCommand{
|
||||
command: command,
|
||||
args: args,
|
||||
})
|
||||
}
|
||||
|
||||
// Execute sends all queued commands and returns all responses
|
||||
// Returns a slice of responses in the same order as commands were queued
|
||||
func (p *Pipeline) Execute() ([]interface{}, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if len(p.commands) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if p.conn.closed.Load() {
|
||||
return nil, ErrBackendClosed
|
||||
}
|
||||
|
||||
p.conn.mu.Lock()
|
||||
defer p.conn.mu.Unlock()
|
||||
|
||||
// Set write timeout for all commands
|
||||
if p.conn.writeTimeout > 0 {
|
||||
// Use longer timeout for batch operations
|
||||
timeout := p.conn.writeTimeout * time.Duration(len(p.commands))
|
||||
if timeout > 30*time.Second {
|
||||
timeout = 30 * time.Second // Cap at 30 seconds
|
||||
}
|
||||
_ = p.conn.conn.SetWriteDeadline(time.Now().Add(timeout))
|
||||
}
|
||||
|
||||
// Write all commands (pipelining - send all before reading any responses)
|
||||
writer := NewRESPWriter(p.conn.conn)
|
||||
for _, cmd := range p.commands {
|
||||
cmdArgs := append([]string{cmd.command}, cmd.args...)
|
||||
if err := writer.WriteCommand(cmdArgs...); err != nil {
|
||||
writer.Release()
|
||||
p.conn.closed.Store(true)
|
||||
return nil, fmt.Errorf("pipeline write error: %w", err)
|
||||
}
|
||||
}
|
||||
writer.Release()
|
||||
|
||||
// Set read timeout for all responses
|
||||
if p.conn.readTimeout > 0 {
|
||||
timeout := p.conn.readTimeout * time.Duration(len(p.commands))
|
||||
if timeout > 30*time.Second {
|
||||
timeout = 30 * time.Second
|
||||
}
|
||||
_ = p.conn.conn.SetReadDeadline(time.Now().Add(timeout))
|
||||
}
|
||||
|
||||
// Read all responses
|
||||
responses := make([]interface{}, len(p.commands))
|
||||
reader := NewRESPReader(p.conn.conn)
|
||||
defer reader.Release()
|
||||
|
||||
for i := range p.commands {
|
||||
resp, err := reader.ReadResponse()
|
||||
if err != nil {
|
||||
// For nil responses, store nil instead of erroring
|
||||
if errors.Is(err, ErrNilResponse) {
|
||||
responses[i] = nil
|
||||
continue
|
||||
}
|
||||
p.conn.closed.Store(true)
|
||||
return responses[:i], fmt.Errorf("pipeline read error at command %d: %w", i, err)
|
||||
}
|
||||
responses[i] = resp
|
||||
}
|
||||
|
||||
return responses, nil
|
||||
}
|
||||
|
||||
// Clear resets the pipeline for reuse
|
||||
func (p *Pipeline) Clear() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.commands = p.commands[:0]
|
||||
}
|
||||
|
||||
// Len returns the number of queued commands
|
||||
func (p *Pipeline) Len() int {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
return len(p.commands)
|
||||
}
|
||||
+620
@@ -0,0 +1,620 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestConnectionPool_BasicOperations tests basic pool operations
|
||||
func TestConnectionPool_BasicOperations(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 5,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
ReadTimeout: 3 * time.Second,
|
||||
WriteTimeout: 3 * time.Second,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
t.Run("GetAndPutConnection", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Get a connection
|
||||
conn, err := pool.Get(ctx)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conn)
|
||||
|
||||
// Verify connection works
|
||||
resp, err := conn.Do("PING")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "PONG", resp)
|
||||
|
||||
// Return to pool
|
||||
pool.Put(conn)
|
||||
|
||||
// Get again - should reuse same connection
|
||||
conn2, err := pool.Get(ctx)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conn2)
|
||||
|
||||
pool.Put(conn2)
|
||||
})
|
||||
|
||||
t.Run("Stats", func(t *testing.T) {
|
||||
stats := pool.Stats()
|
||||
require.NotNil(t, stats)
|
||||
|
||||
assert.Contains(t, stats, "active_connections")
|
||||
assert.Contains(t, stats, "total_connections")
|
||||
assert.Contains(t, stats, "max_connections")
|
||||
assert.Equal(t, 5, stats["max_connections"])
|
||||
})
|
||||
}
|
||||
|
||||
// TestConnectionPool_MaxConnections tests pool size limits
|
||||
func TestConnectionPool_MaxConnections(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
maxConns := 3
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: maxConns,
|
||||
ConnectTimeout: 1 * time.Second,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Get max connections
|
||||
conns := make([]*RedisConn, maxConns)
|
||||
for i := 0; i < maxConns; i++ {
|
||||
conn, err := pool.Get(ctx)
|
||||
require.NoError(t, err)
|
||||
conns[i] = conn
|
||||
}
|
||||
|
||||
// Verify stats
|
||||
stats := pool.Stats()
|
||||
assert.Equal(t, int32(maxConns), stats["total_connections"])
|
||||
assert.Equal(t, int32(maxConns), stats["active_connections"])
|
||||
|
||||
// Try to get one more - should block/timeout
|
||||
ctx2, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
conn, err := pool.Get(ctx2)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, conn)
|
||||
|
||||
// Return one connection
|
||||
pool.Put(conns[0])
|
||||
|
||||
// Now we should be able to get a connection
|
||||
conn, err = pool.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conn)
|
||||
|
||||
// Cleanup
|
||||
pool.Put(conn)
|
||||
for i := 1; i < maxConns; i++ {
|
||||
pool.Put(conns[i])
|
||||
}
|
||||
}
|
||||
|
||||
// TestConnectionPool_ConcurrentAccess tests concurrent pool usage
|
||||
func TestConnectionPool_ConcurrentAccess(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 10,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
numGoroutines := 50
|
||||
numOperations := 20
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errors := make(chan error, numGoroutines*numOperations)
|
||||
|
||||
// Spawn goroutines
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
|
||||
for j := 0; j < numOperations; j++ {
|
||||
conn, err := pool.Get(ctx)
|
||||
if err != nil {
|
||||
errors <- err
|
||||
continue
|
||||
}
|
||||
|
||||
// Do some work
|
||||
_, err = conn.Do("PING")
|
||||
if err != nil {
|
||||
errors <- err
|
||||
}
|
||||
|
||||
// Return to pool
|
||||
pool.Put(conn)
|
||||
|
||||
// Small delay
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errors)
|
||||
|
||||
// Check for errors
|
||||
errorCount := 0
|
||||
for err := range errors {
|
||||
t.Logf("Error: %v", err)
|
||||
errorCount++
|
||||
}
|
||||
|
||||
assert.Equal(t, 0, errorCount, "Expected no errors in concurrent access")
|
||||
|
||||
// Verify stats
|
||||
stats := pool.Stats()
|
||||
t.Logf("Final stats: %+v", stats)
|
||||
assert.LessOrEqual(t, stats["total_connections"].(int32), int32(10))
|
||||
assert.Equal(t, int32(0), stats["active_connections"])
|
||||
}
|
||||
|
||||
// TestConnectionPool_ContextCancellation tests context cancellation
|
||||
func TestConnectionPool_ContextCancellation(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 1,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
// Get the only connection
|
||||
conn, err := pool.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to get another with cancelled context
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // Cancel immediately
|
||||
|
||||
conn2, err := pool.Get(ctx)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, conn2)
|
||||
assert.Contains(t, err.Error(), "context canceled")
|
||||
|
||||
// Cleanup
|
||||
pool.Put(conn)
|
||||
}
|
||||
|
||||
// TestConnectionPool_Authentication tests auth support
|
||||
func TestConnectionPool_Authentication(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
// Set password on miniredis
|
||||
mr.server.RequireAuth("secret-password")
|
||||
|
||||
t.Run("CorrectPassword", func(t *testing.T) {
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
Password: "secret-password",
|
||||
MaxConnections: 2,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
conn, err := pool.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := conn.Do("PING")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "PONG", resp)
|
||||
|
||||
pool.Put(conn)
|
||||
})
|
||||
|
||||
t.Run("WrongPassword", func(t *testing.T) {
|
||||
t.Skip("Miniredis doesn't fully simulate AUTH errors like real Redis")
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
Password: "wrong-password",
|
||||
MaxConnections: 2,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
_, err := NewConnectionPool(config)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "authentication failed")
|
||||
})
|
||||
}
|
||||
|
||||
// TestConnectionPool_DatabaseSelection tests DB selection
|
||||
func TestConnectionPool_DatabaseSelection(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
DB: 5,
|
||||
MaxConnections: 2,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
conn, err := pool.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Connection should be on DB 5
|
||||
resp, err := conn.Do("PING")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "PONG", resp)
|
||||
|
||||
pool.Put(conn)
|
||||
}
|
||||
|
||||
// TestConnectionPool_ClosedConnection tests handling closed connections
|
||||
func TestConnectionPool_ClosedConnection(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 2,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
// Get connection
|
||||
conn, err := pool.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Close it manually
|
||||
conn.Close()
|
||||
|
||||
// Try to use it
|
||||
_, err = conn.Do("PING")
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, ErrBackendClosed))
|
||||
|
||||
// Return to pool (should be discarded)
|
||||
pool.Put(conn)
|
||||
|
||||
// Get new connection - should create a new one
|
||||
conn2, err := pool.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conn2)
|
||||
|
||||
resp, err := conn2.Do("PING")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "PONG", resp)
|
||||
|
||||
pool.Put(conn2)
|
||||
}
|
||||
|
||||
// TestConnectionPool_Close tests pool closure
|
||||
func TestConnectionPool_Close(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 5,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get some connections
|
||||
conns := make([]*RedisConn, 3)
|
||||
for i := 0; i < 3; i++ {
|
||||
conn, err := pool.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
conns[i] = conn
|
||||
}
|
||||
|
||||
// Return them
|
||||
for _, conn := range conns {
|
||||
pool.Put(conn)
|
||||
}
|
||||
|
||||
// Close pool
|
||||
err = pool.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to get connection from closed pool
|
||||
_, err = pool.Get(context.Background())
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, ErrBackendClosed))
|
||||
|
||||
// Close again should be no-op
|
||||
err = pool.Close()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestConnectionPool_Timeouts tests various timeout scenarios
|
||||
func TestConnectionPool_Timeouts(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 2,
|
||||
ConnectTimeout: 100 * time.Millisecond,
|
||||
ReadTimeout: 100 * time.Millisecond,
|
||||
WriteTimeout: 100 * time.Millisecond,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
conn, err := pool.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Normal operation should work
|
||||
resp, err := conn.Do("PING")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "PONG", resp)
|
||||
|
||||
pool.Put(conn)
|
||||
}
|
||||
|
||||
// TestRedisConn_DoCommand tests the Do method
|
||||
func TestRedisConn_DoCommand(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 2,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
conn, err := pool.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
defer pool.Put(conn)
|
||||
|
||||
t.Run("SET and GET", func(t *testing.T) {
|
||||
// SET
|
||||
resp, err := conn.Do("SET", "testkey", "testvalue")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "OK", resp)
|
||||
|
||||
// GET
|
||||
resp, err = conn.Do("GET", "testkey")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "testvalue", resp)
|
||||
})
|
||||
|
||||
t.Run("DEL", func(t *testing.T) {
|
||||
// SET key first
|
||||
_, err := conn.Do("SET", "delkey", "delvalue")
|
||||
require.NoError(t, err)
|
||||
|
||||
// DEL
|
||||
resp, err := conn.Do("DEL", "delkey")
|
||||
require.NoError(t, err)
|
||||
|
||||
count, err := RESPInt(resp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(1), count)
|
||||
})
|
||||
|
||||
t.Run("EXISTS", func(t *testing.T) {
|
||||
// SET key first
|
||||
_, err := conn.Do("SET", "existskey", "value")
|
||||
require.NoError(t, err)
|
||||
|
||||
// EXISTS - key exists
|
||||
resp, err := conn.Do("EXISTS", "existskey")
|
||||
require.NoError(t, err)
|
||||
|
||||
count, err := RESPInt(resp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(1), count)
|
||||
|
||||
// EXISTS - key doesn't exist
|
||||
resp, err = conn.Do("EXISTS", "nonexistent")
|
||||
require.NoError(t, err)
|
||||
|
||||
count, err = RESPInt(resp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(0), count)
|
||||
})
|
||||
|
||||
t.Run("TTL commands", func(t *testing.T) {
|
||||
// SETEX
|
||||
resp, err := conn.Do("SETEX", "ttlkey", "60", "ttlvalue")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "OK", resp)
|
||||
|
||||
// TTL
|
||||
resp, err = conn.Do("TTL", "ttlkey")
|
||||
require.NoError(t, err)
|
||||
|
||||
ttl, err := RESPInt(resp)
|
||||
require.NoError(t, err)
|
||||
assert.Greater(t, ttl, int64(0))
|
||||
assert.LessOrEqual(t, ttl, int64(60))
|
||||
})
|
||||
}
|
||||
|
||||
// TestPoolConfig_Defaults tests default configuration values
|
||||
func TestPoolConfig_Defaults(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
// Leave other fields at zero values
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
// Should use defaults
|
||||
assert.Equal(t, 10, pool.config.MaxConnections)
|
||||
assert.Equal(t, 5*time.Second, pool.config.ConnectTimeout)
|
||||
|
||||
// Verify it works
|
||||
conn, err := pool.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
pool.Put(conn)
|
||||
}
|
||||
|
||||
// TestConnectionPool_NilConnection tests handling nil connections
|
||||
func TestConnectionPool_NilConnection(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 2,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
// Putting nil should be safe
|
||||
pool.Put(nil)
|
||||
|
||||
// Pool should still work
|
||||
conn, err := pool.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conn)
|
||||
pool.Put(conn)
|
||||
}
|
||||
|
||||
// TestConnectionPool_StatsTracking tests metrics tracking
|
||||
func TestConnectionPool_StatsTracking(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 5,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Initial stats
|
||||
stats := pool.Stats()
|
||||
initialGets := stats["gets"].(int64)
|
||||
initialPuts := stats["puts"].(int64)
|
||||
|
||||
// Perform operations
|
||||
numOps := 10
|
||||
for i := 0; i < numOps; i++ {
|
||||
conn, err := pool.Get(ctx)
|
||||
require.NoError(t, err)
|
||||
pool.Put(conn)
|
||||
}
|
||||
|
||||
// Check updated stats
|
||||
stats = pool.Stats()
|
||||
assert.Equal(t, initialGets+int64(numOps), stats["gets"].(int64))
|
||||
assert.Equal(t, initialPuts+int64(numOps), stats["puts"].(int64))
|
||||
assert.Equal(t, int32(0), stats["active_connections"].(int32))
|
||||
}
|
||||
|
||||
// TestRedisConn_TooManyArguments tests protection against allocation overflow
|
||||
func TestRedisConn_TooManyArguments(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 1,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
ReadTimeout: 3 * time.Second,
|
||||
WriteTimeout: 3 * time.Second,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
conn, err := pool.Get(ctx)
|
||||
require.NoError(t, err)
|
||||
defer pool.Put(conn)
|
||||
|
||||
t.Run("AcceptableArgumentCount", func(t *testing.T) {
|
||||
// Should work with reasonable number of args
|
||||
args := make([]string, 100)
|
||||
for i := range args {
|
||||
args[i] = "value"
|
||||
}
|
||||
_, err := conn.Do("MSET", args...)
|
||||
// May fail due to Redis constraints, but shouldn't panic or error on overflow
|
||||
// Just verify it doesn't trigger our overflow protection
|
||||
if err != nil {
|
||||
assert.NotContains(t, err.Error(), "too many arguments")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("RejectExcessiveArguments", func(t *testing.T) {
|
||||
// Create an absurdly large number of arguments that would cause overflow
|
||||
// Use 1M + 1 to exceed maxSafeArgs = (1<<20)-1 = 1048575
|
||||
args := make([]string, 1<<20) // 1,048,576 args
|
||||
for i := range args {
|
||||
args[i] = "x"
|
||||
}
|
||||
|
||||
_, err := conn.Do("MSET", args...)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "too many arguments")
|
||||
})
|
||||
|
||||
t.Run("BoundaryCase", func(t *testing.T) {
|
||||
// Test exactly at the boundary (maxSafeArgs)
|
||||
args := make([]string, (1<<20)-1) // Exactly 1,048,575 args (max allowed)
|
||||
for i := range args {
|
||||
args[i] = "x"
|
||||
}
|
||||
|
||||
_, err := conn.Do("ECHO", args...)
|
||||
// Should not error due to overflow protection
|
||||
if err != nil {
|
||||
assert.NotContains(t, err.Error(), "too many arguments")
|
||||
}
|
||||
})
|
||||
}
|
||||
+545
@@ -0,0 +1,545 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestRedisBackend_BasicOperations tests basic Redis operations
|
||||
func TestRedisBackend_BasicOperations(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("SetAndGet", func(t *testing.T) {
|
||||
key := "redis-test-key"
|
||||
value := []byte("redis-test-value")
|
||||
ttl := 1 * time.Minute
|
||||
|
||||
err := backend.Set(ctx, key, value, ttl)
|
||||
require.NoError(t, err)
|
||||
|
||||
retrieved, remainingTTL, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, value, retrieved)
|
||||
assert.Greater(t, remainingTTL, 50*time.Second)
|
||||
})
|
||||
|
||||
t.Run("GetNonExistent", func(t *testing.T) {
|
||||
_, _, exists, err := backend.Get(ctx, "non-existent-redis-key")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("Delete", func(t *testing.T) {
|
||||
key := "redis-delete-key"
|
||||
value := []byte("redis-delete-value")
|
||||
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
deleted, err := backend.Delete(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, deleted)
|
||||
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("Exists", func(t *testing.T) {
|
||||
key := "redis-exists-key"
|
||||
value := []byte("redis-exists-value")
|
||||
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
|
||||
err = backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
exists, err = backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
})
|
||||
}
|
||||
|
||||
// TestRedisBackend_KeyPrefixing tests key namespace prefixing
|
||||
func TestRedisBackend_KeyPrefixing(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
config.RedisPrefix = "test:prefix:"
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
key := "my-key"
|
||||
value := []byte("my-value")
|
||||
|
||||
err = backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check that key is stored with prefix
|
||||
keys := mr.CheckKeys()
|
||||
require.Len(t, keys, 1)
|
||||
assert.Equal(t, "test:prefix:my-key", keys[0])
|
||||
|
||||
// Get should work without prefix
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, value, retrieved)
|
||||
}
|
||||
|
||||
// TestRedisBackend_TTLExpiration tests TTL handling
|
||||
func TestRedisBackend_TTLExpiration(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("ShortTTL", func(t *testing.T) {
|
||||
key := "ttl-key"
|
||||
value := []byte("ttl-value")
|
||||
shortTTL := 100 * time.Millisecond
|
||||
|
||||
err := backend.Set(ctx, key, value, shortTTL)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Exists immediately
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
// Fast forward time in miniredis
|
||||
mr.FastForward(150 * time.Millisecond)
|
||||
|
||||
// Should be expired
|
||||
exists, err = backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("TTLRemaining", func(t *testing.T) {
|
||||
key := "ttl-remaining-key"
|
||||
value := []byte("ttl-remaining-value")
|
||||
ttl := 10 * time.Second
|
||||
|
||||
err := backend.Set(ctx, key, value, ttl)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get immediately
|
||||
_, ttl1, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
// Fast forward 2 seconds
|
||||
mr.FastForward(2 * time.Second)
|
||||
|
||||
// Check TTL is less
|
||||
_, ttl2, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Less(t, ttl2, ttl1)
|
||||
})
|
||||
}
|
||||
|
||||
// TestRedisBackend_Clear tests clearing all keys
|
||||
func TestRedisBackend_Clear(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
config.RedisPrefix = "clear-test:"
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Add multiple keys
|
||||
for i := 0; i < 10; i++ {
|
||||
key := fmt.Sprintf("clear-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("clear-value-%d", i))
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Verify keys exist
|
||||
keys := mr.CheckKeys()
|
||||
assert.Len(t, keys, 10)
|
||||
|
||||
// Clear all
|
||||
err = backend.Clear(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify all keys are gone
|
||||
keys = mr.CheckKeys()
|
||||
assert.Len(t, keys, 0)
|
||||
}
|
||||
|
||||
// TestRedisBackend_ConnectionFailure tests behavior on connection failure
|
||||
func TestRedisBackend_ConnectionFailure(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Try to connect to non-existent Redis
|
||||
config := DefaultRedisConfig("localhost:9999")
|
||||
_, err := NewRedisBackend(config)
|
||||
assert.Error(t, err, "Should fail to connect to non-existent Redis")
|
||||
}
|
||||
|
||||
// TestRedisBackend_RedisErrors tests handling of Redis errors
|
||||
func TestRedisBackend_RedisErrors(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Simulate Redis error
|
||||
mr.SetError("simulated error")
|
||||
|
||||
// Operations should fail
|
||||
err = backend.Set(ctx, "error-key", []byte("error-value"), 1*time.Minute)
|
||||
assert.Error(t, err)
|
||||
|
||||
// Clear error
|
||||
mr.ClearError()
|
||||
|
||||
// Operations should work again
|
||||
err = backend.Set(ctx, "success-key", []byte("success-value"), 1*time.Minute)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestRedisBackend_ConcurrentAccess tests thread safety
|
||||
func TestRedisBackend_ConcurrentAccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
var wg sync.WaitGroup
|
||||
goroutines := 20
|
||||
iterations := 50
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < iterations; j++ {
|
||||
key := fmt.Sprintf("concurrent-key-%d-%d", id, j)
|
||||
value := []byte(fmt.Sprintf("concurrent-value-%d-%d", id, j))
|
||||
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
assert.NoError(t, err)
|
||||
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
assert.NoError(t, err)
|
||||
if exists {
|
||||
assert.Equal(t, value, retrieved)
|
||||
}
|
||||
|
||||
if j%5 == 0 {
|
||||
backend.Delete(ctx, key)
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
stats := backend.GetStats()
|
||||
hits := stats["hits"].(int64)
|
||||
misses := stats["misses"].(int64)
|
||||
assert.Greater(t, hits+misses, int64(0))
|
||||
}
|
||||
|
||||
// TestRedisBackend_Stats tests statistics tracking
|
||||
func TestRedisBackend_Stats(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Initial stats
|
||||
stats := backend.GetStats()
|
||||
assert.Equal(t, int64(0), stats["hits"].(int64))
|
||||
assert.Equal(t, int64(0), stats["misses"].(int64))
|
||||
|
||||
// Add and access items
|
||||
backend.Set(ctx, "key1", []byte("value1"), 1*time.Minute)
|
||||
backend.Get(ctx, "key1") // Hit
|
||||
backend.Get(ctx, "non-existent") // Miss
|
||||
|
||||
stats = backend.GetStats()
|
||||
assert.Equal(t, int64(1), stats["hits"].(int64))
|
||||
assert.Equal(t, int64(1), stats["misses"].(int64))
|
||||
|
||||
hitRate := stats["hit_rate"].(float64)
|
||||
assert.InDelta(t, 0.5, hitRate, 0.01)
|
||||
}
|
||||
|
||||
// TestRedisBackend_Ping tests health check
|
||||
func TestRedisBackend_Ping(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
err = backend.Ping(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Close and ping should fail
|
||||
backend.Close()
|
||||
err = backend.Ping(ctx)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// TestRedisBackend_Close tests proper cleanup
|
||||
func TestRedisBackend_Close(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Add items
|
||||
for i := 0; i < 10; i++ {
|
||||
key := fmt.Sprintf("close-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("close-value-%d", i))
|
||||
backend.Set(ctx, key, value, 1*time.Minute)
|
||||
}
|
||||
|
||||
// Close
|
||||
err = backend.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Operations should fail
|
||||
err = backend.Set(ctx, "after-close", []byte("value"), 1*time.Minute)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, ErrBackendClosed, err)
|
||||
|
||||
// Double close should be safe
|
||||
err = backend.Close()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestRedisBackend_UpdateExisting tests updating existing keys
|
||||
func TestRedisBackend_UpdateExisting(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
key := "update-key"
|
||||
value1 := []byte("original-value")
|
||||
value2 := []byte("updated-value")
|
||||
|
||||
// Set original
|
||||
err = backend.Set(ctx, key, value1, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update
|
||||
err = backend.Set(ctx, key, value2, 2*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify updated
|
||||
retrieved, ttl, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, value2, retrieved)
|
||||
assert.Greater(t, ttl, 1*time.Minute)
|
||||
}
|
||||
|
||||
// TestRedisBackend_LargeValues tests handling of large values
|
||||
func TestRedisBackend_LargeValues(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
key := "large-key"
|
||||
largeValue := make([]byte, 1024*1024) // 1MB
|
||||
|
||||
err = backend.Set(ctx, key, largeValue, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, len(largeValue), len(retrieved))
|
||||
}
|
||||
|
||||
// TestRedisBackend_EmptyValues tests handling of empty values
|
||||
func TestRedisBackend_EmptyValues(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
key := "empty-key"
|
||||
emptyValue := []byte{}
|
||||
|
||||
err = backend.Set(ctx, key, emptyValue, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, 0, len(retrieved))
|
||||
}
|
||||
|
||||
// TestRedisBackend_PipelineOperations tests batch operations
|
||||
func TestRedisBackend_PipelineOperations(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("SetMany", func(t *testing.T) {
|
||||
items := make(map[string][]byte)
|
||||
for i := 0; i < 10; i++ {
|
||||
key := fmt.Sprintf("batch-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("batch-value-%d", i))
|
||||
items[key] = value
|
||||
}
|
||||
|
||||
err := backend.SetMany(ctx, items, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify all items were set
|
||||
for key, expectedValue := range items {
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, expectedValue, retrieved)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetMany", func(t *testing.T) {
|
||||
// Set test data
|
||||
testData := GenerateTestData(5)
|
||||
for key, value := range testData {
|
||||
backend.Set(ctx, key, value, 1*time.Minute)
|
||||
}
|
||||
|
||||
// Get all keys
|
||||
keys := make([]string, 0, len(testData))
|
||||
for key := range testData {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
|
||||
results, err := backend.GetMany(ctx, keys)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, len(testData))
|
||||
|
||||
for key, expectedValue := range testData {
|
||||
retrievedValue, exists := results[key]
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, expectedValue, retrievedValue)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetManyWithNonExistent", func(t *testing.T) {
|
||||
keys := []string{"exists-1", "non-existent", "exists-2"}
|
||||
|
||||
backend.Set(ctx, "exists-1", []byte("value-1"), 1*time.Minute)
|
||||
backend.Set(ctx, "exists-2", []byte("value-2"), 1*time.Minute)
|
||||
|
||||
results, err := backend.GetMany(ctx, keys)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 2) // Only existing keys
|
||||
assert.Equal(t, []byte("value-1"), results["exists-1"])
|
||||
assert.Equal(t, []byte("value-2"), results["exists-2"])
|
||||
_, exists := results["non-existent"]
|
||||
assert.False(t, exists)
|
||||
})
|
||||
}
|
||||
|
||||
// TestRedisBackend_NoPrefix tests operation without prefix
|
||||
func TestRedisBackend_NoPrefix(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
config.RedisPrefix = "" // No prefix
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
key := "no-prefix-key"
|
||||
value := []byte("no-prefix-value")
|
||||
|
||||
err = backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check key is stored without prefix
|
||||
keys := mr.CheckKeys()
|
||||
require.Len(t, keys, 1)
|
||||
assert.Equal(t, key, keys[0])
|
||||
}
|
||||
Vendored
+251
@@ -0,0 +1,251 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// RESP (REdis Serialization Protocol) implementation
|
||||
// Pure Go implementation compatible with Yaegi interpreter (no unsafe package)
|
||||
|
||||
var (
|
||||
ErrInvalidRESP = errors.New("invalid RESP response")
|
||||
ErrNilResponse = errors.New("nil response")
|
||||
)
|
||||
|
||||
// Object pools for memory optimization - reduces allocations by 50-70%
|
||||
var (
|
||||
readerPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return &RESPReader{
|
||||
r: bufio.NewReaderSize(nil, 4096),
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
writerPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return &RESPWriter{
|
||||
w: nil,
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
// RESPWriter writes RESP protocol messages
|
||||
type RESPWriter struct {
|
||||
w io.Writer
|
||||
}
|
||||
|
||||
// NewRESPWriter creates a new RESP writer from the pool (memory optimized)
|
||||
func NewRESPWriter(w io.Writer) *RESPWriter {
|
||||
writer := writerPool.Get().(*RESPWriter)
|
||||
writer.w = w
|
||||
return writer
|
||||
}
|
||||
|
||||
// Release returns the writer to the pool for reuse
|
||||
func (w *RESPWriter) Release() {
|
||||
w.w = nil
|
||||
writerPool.Put(w)
|
||||
}
|
||||
|
||||
// WriteCommand writes a Redis command in RESP array format
|
||||
// Example: SET key value EX 3600 -> *5\r\n$3\r\nSET\r\n$3\r\nkey\r\n$5\r\nvalue\r\n$2\r\nEX\r\n$4\r\n3600\r\n
|
||||
func (w *RESPWriter) WriteCommand(args ...string) error {
|
||||
// Write array header
|
||||
if _, err := fmt.Fprintf(w.w, "*%d\r\n", len(args)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write each argument as bulk string
|
||||
for _, arg := range args {
|
||||
if _, err := fmt.Fprintf(w.w, "$%d\r\n%s\r\n", len(arg), arg); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RESPReader reads RESP protocol messages
|
||||
type RESPReader struct {
|
||||
r *bufio.Reader
|
||||
}
|
||||
|
||||
// NewRESPReader creates a new RESP reader from the pool (memory optimized)
|
||||
func NewRESPReader(r io.Reader) *RESPReader {
|
||||
reader := readerPool.Get().(*RESPReader)
|
||||
reader.r.Reset(r)
|
||||
return reader
|
||||
}
|
||||
|
||||
// Release returns the reader to the pool for reuse
|
||||
func (r *RESPReader) Release() {
|
||||
r.r.Reset(nil)
|
||||
readerPool.Put(r)
|
||||
}
|
||||
|
||||
// ReadResponse reads a RESP response and returns the parsed value
|
||||
func (r *RESPReader) ReadResponse() (interface{}, error) {
|
||||
typeByte, err := r.r.ReadByte()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch typeByte {
|
||||
case '+': // Simple string
|
||||
return r.readSimpleString()
|
||||
case '-': // Error
|
||||
return nil, r.readError()
|
||||
case ':': // Integer
|
||||
return r.readInteger()
|
||||
case '$': // Bulk string
|
||||
return r.readBulkString()
|
||||
case '*': // Array
|
||||
return r.readArray()
|
||||
default:
|
||||
return nil, fmt.Errorf("%w: unknown type byte '%c'", ErrInvalidRESP, typeByte)
|
||||
}
|
||||
}
|
||||
|
||||
// readSimpleString reads a simple string (+OK\r\n)
|
||||
func (r *RESPReader) readSimpleString() (string, error) {
|
||||
line, err := r.readLine()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return line, nil
|
||||
}
|
||||
|
||||
// readError reads an error message (-Error message\r\n)
|
||||
func (r *RESPReader) readError() error {
|
||||
line, err := r.readLine()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return errors.New(line)
|
||||
}
|
||||
|
||||
// readInteger reads an integer (:1000\r\n)
|
||||
func (r *RESPReader) readInteger() (int64, error) {
|
||||
line, err := r.readLine()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return strconv.ParseInt(line, 10, 64)
|
||||
}
|
||||
|
||||
// readBulkString reads a bulk string ($6\r\nfoobar\r\n or $-1\r\n for nil)
|
||||
func (r *RESPReader) readBulkString() (interface{}, error) {
|
||||
line, err := r.readLine()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
length, err := strconv.Atoi(line)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: invalid bulk string length", ErrInvalidRESP)
|
||||
}
|
||||
|
||||
// -1 indicates nil bulk string
|
||||
if length == -1 {
|
||||
return nil, ErrNilResponse
|
||||
}
|
||||
|
||||
// Read exactly 'length' bytes plus \r\n
|
||||
buf := make([]byte, length+2)
|
||||
if _, err := io.ReadFull(r.r, buf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Verify \r\n terminator
|
||||
if buf[length] != '\r' || buf[length+1] != '\n' {
|
||||
return nil, fmt.Errorf("%w: missing CRLF after bulk string", ErrInvalidRESP)
|
||||
}
|
||||
|
||||
return string(buf[:length]), nil
|
||||
}
|
||||
|
||||
// readArray reads an array (*2\r\n...\r\n or *-1\r\n for nil)
|
||||
func (r *RESPReader) readArray() (interface{}, error) {
|
||||
line, err := r.readLine()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
length, err := strconv.Atoi(line)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: invalid array length", ErrInvalidRESP)
|
||||
}
|
||||
|
||||
// -1 indicates nil array
|
||||
if length == -1 {
|
||||
return nil, ErrNilResponse
|
||||
}
|
||||
|
||||
// Read each element
|
||||
result := make([]interface{}, length)
|
||||
for i := 0; i < length; i++ {
|
||||
elem, err := r.ReadResponse()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result[i] = elem
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// readLine reads a line terminated by \r\n
|
||||
func (r *RESPReader) readLine() (string, error) {
|
||||
line, err := r.r.ReadString('\n')
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Remove \r\n
|
||||
line = strings.TrimSuffix(line, "\r\n")
|
||||
if !strings.HasSuffix(line+"\r\n", "\r\n") {
|
||||
return "", fmt.Errorf("%w: missing CRLF", ErrInvalidRESP)
|
||||
}
|
||||
|
||||
return line, nil
|
||||
}
|
||||
|
||||
// RESPString extracts a string from RESP response
|
||||
func RESPString(resp interface{}) (string, error) {
|
||||
if resp == nil {
|
||||
return "", ErrNilResponse
|
||||
}
|
||||
|
||||
switch v := resp.(type) {
|
||||
case string:
|
||||
return v, nil
|
||||
case []byte:
|
||||
return string(v), nil
|
||||
default:
|
||||
return "", fmt.Errorf("expected string, got %T", resp)
|
||||
}
|
||||
}
|
||||
|
||||
// RESPInt extracts an integer from RESP response
|
||||
func RESPInt(resp interface{}) (int64, error) {
|
||||
if resp == nil {
|
||||
return 0, ErrNilResponse
|
||||
}
|
||||
|
||||
switch v := resp.(type) {
|
||||
case int64:
|
||||
return v, nil
|
||||
case int:
|
||||
return int64(v), nil
|
||||
default:
|
||||
return 0, fmt.Errorf("expected integer, got %T", resp)
|
||||
}
|
||||
}
|
||||
Vendored
+495
@@ -0,0 +1,495 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestRESPWriter_WriteCommand tests RESP command writing
|
||||
func TestRESPWriter_WriteCommand(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
expected string
|
||||
args []string
|
||||
}{
|
||||
{
|
||||
name: "Simple command",
|
||||
args: []string{"PING"},
|
||||
expected: "*1\r\n$4\r\nPING\r\n",
|
||||
},
|
||||
{
|
||||
name: "SET command",
|
||||
args: []string{"SET", "key", "value"},
|
||||
expected: "*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$5\r\nvalue\r\n",
|
||||
},
|
||||
{
|
||||
name: "SETEX command",
|
||||
args: []string{"SETEX", "mykey", "60", "myvalue"},
|
||||
expected: "*4\r\n$5\r\nSETEX\r\n$5\r\nmykey\r\n$2\r\n60\r\n$7\r\nmyvalue\r\n",
|
||||
},
|
||||
{
|
||||
name: "DEL with multiple keys",
|
||||
args: []string{"DEL", "key1", "key2", "key3"},
|
||||
expected: "*4\r\n$3\r\nDEL\r\n$4\r\nkey1\r\n$4\r\nkey2\r\n$4\r\nkey3\r\n",
|
||||
},
|
||||
{
|
||||
name: "Command with empty string",
|
||||
args: []string{"SET", "key", ""},
|
||||
expected: "*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$0\r\n\r\n",
|
||||
},
|
||||
{
|
||||
name: "Command with special characters",
|
||||
args: []string{"SET", "key", "val\r\nue"},
|
||||
expected: "*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$7\r\nval\r\nue\r\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
buf := &bytes.Buffer{}
|
||||
writer := NewRESPWriter(buf)
|
||||
|
||||
err := writer.WriteCommand(tt.args...)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, buf.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRESPReader_ReadSimpleString tests reading simple strings
|
||||
func TestRESPReader_ReadSimpleString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "OK response",
|
||||
input: "+OK\r\n",
|
||||
expected: "OK",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "PONG response",
|
||||
input: "+PONG\r\n",
|
||||
expected: "PONG",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Empty string",
|
||||
input: "+\r\n",
|
||||
expected: "",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "String with spaces",
|
||||
input: "+Hello World\r\n",
|
||||
expected: "Hello World",
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
reader := NewRESPReader(strings.NewReader(tt.input))
|
||||
result, err := reader.ReadResponse()
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRESPReader_ReadError tests reading error messages
|
||||
func TestRESPReader_ReadError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "ERR error",
|
||||
input: "-ERR unknown command\r\n",
|
||||
expectedError: "ERR unknown command",
|
||||
},
|
||||
{
|
||||
name: "WRONGTYPE error",
|
||||
input: "-WRONGTYPE Operation against a key holding the wrong kind of value\r\n",
|
||||
expectedError: "WRONGTYPE Operation against a key holding the wrong kind of value",
|
||||
},
|
||||
{
|
||||
name: "Simple error",
|
||||
input: "-Error\r\n",
|
||||
expectedError: "Error",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
reader := NewRESPReader(strings.NewReader(tt.input))
|
||||
_, err := reader.ReadResponse()
|
||||
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, tt.expectedError, err.Error())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRESPReader_ReadInteger tests reading integers
|
||||
func TestRESPReader_ReadInteger(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected int64
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Zero",
|
||||
input: ":0\r\n",
|
||||
expected: 0,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Positive integer",
|
||||
input: ":1000\r\n",
|
||||
expected: 1000,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Negative integer",
|
||||
input: ":-1\r\n",
|
||||
expected: -1,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Large integer",
|
||||
input: ":9223372036854775807\r\n",
|
||||
expected: 9223372036854775807,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid integer",
|
||||
input: ":abc\r\n",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
reader := NewRESPReader(strings.NewReader(tt.input))
|
||||
result, err := reader.ReadResponse()
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRESPReader_ReadBulkString tests reading bulk strings
|
||||
func TestRESPReader_ReadBulkString(t *testing.T) {
|
||||
tests := []struct {
|
||||
expected interface{}
|
||||
name string
|
||||
input string
|
||||
wantErr bool
|
||||
isNil bool
|
||||
}{
|
||||
{
|
||||
name: "Simple bulk string",
|
||||
input: "$6\r\nfoobar\r\n",
|
||||
expected: "foobar",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Empty bulk string",
|
||||
input: "$0\r\n\r\n",
|
||||
expected: "",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Nil bulk string",
|
||||
input: "$-1\r\n",
|
||||
expected: nil,
|
||||
wantErr: true,
|
||||
isNil: true,
|
||||
},
|
||||
{
|
||||
name: "Binary safe bulk string",
|
||||
input: "$5\r\n\x00\x01\x02\x03\x04\r\n",
|
||||
expected: "\x00\x01\x02\x03\x04",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid length",
|
||||
input: "$abc\r\ntest\r\n",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
reader := NewRESPReader(strings.NewReader(tt.input))
|
||||
result, err := reader.ReadResponse()
|
||||
|
||||
if tt.isNil {
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, ErrNilResponse))
|
||||
return
|
||||
}
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRESPReader_ReadArray tests reading arrays
|
||||
func TestRESPReader_ReadArray(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected []interface{}
|
||||
wantErr bool
|
||||
isNil bool
|
||||
}{
|
||||
{
|
||||
name: "Empty array",
|
||||
input: "*0\r\n",
|
||||
expected: []interface{}{},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Array of bulk strings",
|
||||
input: "*2\r\n$3\r\nfoo\r\n$3\r\nbar\r\n",
|
||||
expected: []interface{}{
|
||||
"foo",
|
||||
"bar",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Array of integers",
|
||||
input: "*3\r\n:1\r\n:2\r\n:3\r\n",
|
||||
expected: []interface{}{
|
||||
int64(1),
|
||||
int64(2),
|
||||
int64(3),
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Mixed array",
|
||||
input: "*5\r\n:1\r\n:2\r\n:3\r\n:4\r\n$6\r\nfoobar\r\n",
|
||||
expected: []interface{}{
|
||||
int64(1),
|
||||
int64(2),
|
||||
int64(3),
|
||||
int64(4),
|
||||
"foobar",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Nil array",
|
||||
input: "*-1\r\n",
|
||||
expected: nil,
|
||||
wantErr: true,
|
||||
isNil: true,
|
||||
},
|
||||
{
|
||||
name: "Nested arrays",
|
||||
input: "*2\r\n*2\r\n$3\r\nfoo\r\n$3\r\nbar\r\n*1\r\n$3\r\nbaz\r\n",
|
||||
expected: []interface{}{
|
||||
[]interface{}{"foo", "bar"},
|
||||
[]interface{}{"baz"},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
reader := NewRESPReader(strings.NewReader(tt.input))
|
||||
result, err := reader.ReadResponse()
|
||||
|
||||
if tt.isNil {
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, ErrNilResponse))
|
||||
return
|
||||
}
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRESPReader_InvalidInput tests error handling for invalid input
|
||||
func TestRESPReader_InvalidInput(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
}{
|
||||
{
|
||||
name: "Unknown type byte",
|
||||
input: "?invalid\r\n",
|
||||
},
|
||||
{
|
||||
name: "Incomplete response",
|
||||
input: "+OK",
|
||||
},
|
||||
{
|
||||
name: "Missing CRLF in bulk string",
|
||||
input: "$5\r\nhello",
|
||||
},
|
||||
{
|
||||
name: "Truncated array",
|
||||
input: "*3\r\n:1\r\n:2\r\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
reader := NewRESPReader(strings.NewReader(tt.input))
|
||||
_, err := reader.ReadResponse()
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRESPReader_EOF tests handling of EOF
|
||||
func TestRESPReader_EOF(t *testing.T) {
|
||||
reader := NewRESPReader(strings.NewReader(""))
|
||||
_, err := reader.ReadResponse()
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, io.EOF))
|
||||
}
|
||||
|
||||
// TestRESPHelpers tests helper functions
|
||||
func TestRESPHelpers(t *testing.T) {
|
||||
t.Run("RESPString", func(t *testing.T) {
|
||||
// Valid string
|
||||
result, err := RESPString("hello")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "hello", result)
|
||||
|
||||
// Byte slice
|
||||
result, err = RESPString([]byte("world"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "world", result)
|
||||
|
||||
// Nil
|
||||
_, err = RESPString(nil)
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, ErrNilResponse))
|
||||
|
||||
// Invalid type
|
||||
_, err = RESPString(123)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("RESPInt", func(t *testing.T) {
|
||||
// Valid int64
|
||||
result, err := RESPInt(int64(42))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(42), result)
|
||||
|
||||
// Valid int
|
||||
result, err = RESPInt(42)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(42), result)
|
||||
|
||||
// Nil
|
||||
_, err = RESPInt(nil)
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, ErrNilResponse))
|
||||
|
||||
// Invalid type
|
||||
_, err = RESPInt("string")
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// TestRESPRoundTrip tests full round-trip encoding/decoding
|
||||
func TestRESPRoundTrip(t *testing.T) {
|
||||
tests := []struct {
|
||||
expected interface{}
|
||||
name string
|
||||
response string
|
||||
command []string
|
||||
}{
|
||||
{
|
||||
name: "PING command",
|
||||
command: []string{"PING"},
|
||||
response: "+PONG\r\n",
|
||||
expected: "PONG",
|
||||
},
|
||||
{
|
||||
name: "GET command with result",
|
||||
command: []string{"GET", "mykey"},
|
||||
response: "$7\r\nmyvalue\r\n",
|
||||
expected: "myvalue",
|
||||
},
|
||||
{
|
||||
name: "GET command with nil",
|
||||
command: []string{"GET", "nonexistent"},
|
||||
response: "$-1\r\n",
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
name: "DEL command",
|
||||
command: []string{"DEL", "key1", "key2"},
|
||||
response: ":2\r\n",
|
||||
expected: int64(2),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Write command
|
||||
writeBuf := &bytes.Buffer{}
|
||||
writer := NewRESPWriter(writeBuf)
|
||||
err := writer.WriteCommand(tt.command...)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Read response
|
||||
reader := NewRESPReader(strings.NewReader(tt.response))
|
||||
result, err := reader.ReadResponse()
|
||||
|
||||
if tt.expected == nil {
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, ErrNilResponse))
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
+183
@@ -0,0 +1,183 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SingleflightCache wraps a CacheBackend with singleflight deduplication
|
||||
// to prevent thundering herd problems when multiple concurrent requests
|
||||
// try to fetch the same uncached key.
|
||||
type SingleflightCache struct {
|
||||
backend CacheBackend
|
||||
mu sync.Mutex
|
||||
calls map[string]*singleflightCall
|
||||
|
||||
// Metrics
|
||||
deduplicatedCalls atomic.Int64
|
||||
totalCalls atomic.Int64
|
||||
}
|
||||
|
||||
// singleflightCall represents an in-flight or completed fetch call
|
||||
type singleflightCall struct {
|
||||
wg sync.WaitGroup
|
||||
val []byte
|
||||
ttl time.Duration
|
||||
err error
|
||||
done bool
|
||||
}
|
||||
|
||||
// NewSingleflightCache creates a new singleflight-wrapped cache backend
|
||||
func NewSingleflightCache(backend CacheBackend) *SingleflightCache {
|
||||
return &SingleflightCache{
|
||||
backend: backend,
|
||||
calls: make(map[string]*singleflightCall),
|
||||
}
|
||||
}
|
||||
|
||||
// Fetcher is a function type that fetches data when cache misses
|
||||
type Fetcher func(ctx context.Context) (value []byte, ttl time.Duration, err error)
|
||||
|
||||
// GetOrFetch retrieves a value from cache or calls the fetcher exactly once
|
||||
// per key when there's a cache miss. Concurrent calls for the same key will
|
||||
// wait for the first call to complete and share its result.
|
||||
func (s *SingleflightCache) GetOrFetch(ctx context.Context, key string, fetcher Fetcher) ([]byte, error) {
|
||||
s.totalCalls.Add(1)
|
||||
|
||||
// Try cache first
|
||||
value, _, exists, err := s.backend.Get(ctx, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if exists {
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// Cache miss - use singleflight
|
||||
s.mu.Lock()
|
||||
|
||||
// Check if there's already an in-flight call for this key
|
||||
if call, ok := s.calls[key]; ok {
|
||||
s.mu.Unlock()
|
||||
s.deduplicatedCalls.Add(1)
|
||||
|
||||
// Wait for the in-flight call to complete
|
||||
call.wg.Wait()
|
||||
|
||||
// Check context cancellation
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
return call.val, call.err
|
||||
}
|
||||
|
||||
// Create new call
|
||||
call := &singleflightCall{}
|
||||
call.wg.Add(1)
|
||||
s.calls[key] = call
|
||||
s.mu.Unlock()
|
||||
|
||||
// Execute the fetcher
|
||||
call.val, call.ttl, call.err = fetcher(ctx)
|
||||
call.done = true
|
||||
|
||||
// If successful, store in cache
|
||||
if call.err == nil && call.val != nil {
|
||||
// Use a background context for cache storage to ensure it completes
|
||||
// even if the original context is cancelled
|
||||
storeCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
_ = s.backend.Set(storeCtx, key, call.val, call.ttl)
|
||||
cancel()
|
||||
}
|
||||
|
||||
// Signal waiting goroutines
|
||||
call.wg.Done()
|
||||
|
||||
// Clean up the call from the map after a short delay
|
||||
// This allows late arrivals to still benefit from the result
|
||||
go func() {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
s.mu.Lock()
|
||||
if c, ok := s.calls[key]; ok && c == call {
|
||||
delete(s.calls, key)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}()
|
||||
|
||||
return call.val, call.err
|
||||
}
|
||||
|
||||
// Get retrieves a value from the underlying cache backend
|
||||
func (s *SingleflightCache) Get(ctx context.Context, key string) ([]byte, time.Duration, bool, error) {
|
||||
return s.backend.Get(ctx, key)
|
||||
}
|
||||
|
||||
// Set stores a value in the underlying cache backend
|
||||
func (s *SingleflightCache) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
|
||||
return s.backend.Set(ctx, key, value, ttl)
|
||||
}
|
||||
|
||||
// Delete removes a key from the underlying cache backend
|
||||
func (s *SingleflightCache) Delete(ctx context.Context, key string) (bool, error) {
|
||||
return s.backend.Delete(ctx, key)
|
||||
}
|
||||
|
||||
// Exists checks if a key exists in the underlying cache backend
|
||||
func (s *SingleflightCache) Exists(ctx context.Context, key string) (bool, error) {
|
||||
return s.backend.Exists(ctx, key)
|
||||
}
|
||||
|
||||
// Clear removes all keys from the underlying cache backend
|
||||
func (s *SingleflightCache) Clear(ctx context.Context) error {
|
||||
return s.backend.Clear(ctx)
|
||||
}
|
||||
|
||||
// GetStats returns cache statistics including singleflight metrics
|
||||
func (s *SingleflightCache) GetStats() map[string]interface{} {
|
||||
stats := s.backend.GetStats()
|
||||
|
||||
// Add singleflight-specific stats
|
||||
totalCalls := s.totalCalls.Load()
|
||||
deduped := s.deduplicatedCalls.Load()
|
||||
|
||||
stats["singleflight_total_calls"] = totalCalls
|
||||
stats["singleflight_deduplicated"] = deduped
|
||||
if totalCalls > 0 {
|
||||
stats["singleflight_dedup_rate"] = float64(deduped) / float64(totalCalls)
|
||||
} else {
|
||||
stats["singleflight_dedup_rate"] = float64(0)
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
stats["singleflight_inflight"] = len(s.calls)
|
||||
s.mu.Unlock()
|
||||
|
||||
return stats
|
||||
}
|
||||
|
||||
// Close shuts down the cache backend
|
||||
func (s *SingleflightCache) Close() error {
|
||||
return s.backend.Close()
|
||||
}
|
||||
|
||||
// Ping checks if the backend is healthy
|
||||
func (s *SingleflightCache) Ping(ctx context.Context) error {
|
||||
return s.backend.Ping(ctx)
|
||||
}
|
||||
|
||||
// GetBackend returns the underlying cache backend
|
||||
func (s *SingleflightCache) GetBackend() CacheBackend {
|
||||
return s.backend
|
||||
}
|
||||
|
||||
// ResetStats resets the singleflight statistics
|
||||
func (s *SingleflightCache) ResetStats() {
|
||||
s.totalCalls.Store(0)
|
||||
s.deduplicatedCalls.Store(0)
|
||||
}
|
||||
|
||||
// Ensure SingleflightCache implements CacheBackend
|
||||
var _ CacheBackend = (*SingleflightCache)(nil)
|
||||
+510
@@ -0,0 +1,510 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestSingleflightCache_BasicGetOrFetch tests basic GetOrFetch functionality
|
||||
func TestSingleflightCache_BasicGetOrFetch(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewSingleflightCache(backend)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("CacheHit", func(t *testing.T) {
|
||||
key := "existing-key"
|
||||
value := []byte("existing-value")
|
||||
|
||||
// Pre-populate cache
|
||||
err := cache.Set(ctx, key, value, time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
var fetchCalled bool
|
||||
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
|
||||
fetchCalled = true
|
||||
return []byte("fetched-value"), time.Minute, nil
|
||||
}
|
||||
|
||||
result, err := cache.GetOrFetch(ctx, key, fetcher)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, value, result)
|
||||
assert.False(t, fetchCalled, "Fetcher should not be called on cache hit")
|
||||
})
|
||||
|
||||
t.Run("CacheMiss", func(t *testing.T) {
|
||||
key := "missing-key"
|
||||
expectedValue := []byte("fetched-value")
|
||||
|
||||
var fetchCalled bool
|
||||
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
|
||||
fetchCalled = true
|
||||
return expectedValue, time.Minute, nil
|
||||
}
|
||||
|
||||
result, err := cache.GetOrFetch(ctx, key, fetcher)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expectedValue, result)
|
||||
assert.True(t, fetchCalled, "Fetcher should be called on cache miss")
|
||||
|
||||
// Verify value was stored in cache
|
||||
cached, _, exists, err := cache.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, expectedValue, cached)
|
||||
})
|
||||
|
||||
t.Run("FetcherError", func(t *testing.T) {
|
||||
key := "error-key"
|
||||
expectedErr := errors.New("fetch failed")
|
||||
|
||||
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
|
||||
return nil, 0, expectedErr
|
||||
}
|
||||
|
||||
result, err := cache.GetOrFetch(ctx, key, fetcher)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, expectedErr, err)
|
||||
assert.Nil(t, result)
|
||||
|
||||
// Verify nothing was stored in cache
|
||||
_, _, exists, err := cache.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
}
|
||||
|
||||
// TestSingleflightCache_Deduplication tests that concurrent calls are deduplicated
|
||||
func TestSingleflightCache_Deduplication(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewSingleflightCache(backend)
|
||||
|
||||
ctx := context.Background()
|
||||
key := "dedup-key"
|
||||
expectedValue := []byte("dedup-value")
|
||||
|
||||
var fetchCount atomic.Int32
|
||||
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
|
||||
fetchCount.Add(1)
|
||||
// Simulate slow fetch
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
return expectedValue, time.Minute, nil
|
||||
}
|
||||
|
||||
// Launch multiple concurrent requests
|
||||
concurrency := 10
|
||||
var wg sync.WaitGroup
|
||||
results := make([][]byte, concurrency)
|
||||
errs := make([]error, concurrency)
|
||||
|
||||
for i := 0; i < concurrency; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
results[idx], errs[idx] = cache.GetOrFetch(ctx, key, fetcher)
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Verify all requests got the same result
|
||||
for i := 0; i < concurrency; i++ {
|
||||
assert.NoError(t, errs[i])
|
||||
assert.Equal(t, expectedValue, results[i])
|
||||
}
|
||||
|
||||
// Verify fetcher was only called once
|
||||
assert.Equal(t, int32(1), fetchCount.Load(), "Fetcher should only be called once")
|
||||
|
||||
// Verify deduplication stats
|
||||
stats := cache.GetStats()
|
||||
deduped := stats["singleflight_deduplicated"].(int64)
|
||||
assert.Equal(t, int64(concurrency-1), deduped, "Should have deduplicated N-1 calls")
|
||||
}
|
||||
|
||||
// TestSingleflightCache_DifferentKeys tests that different keys can fetch in parallel
|
||||
func TestSingleflightCache_DifferentKeys(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewSingleflightCache(backend)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
var fetchCount atomic.Int32
|
||||
fetchStarted := make(chan struct{}, 3)
|
||||
fetchComplete := make(chan struct{})
|
||||
|
||||
fetcher := func(key string) Fetcher {
|
||||
return func(ctx context.Context) ([]byte, time.Duration, error) {
|
||||
fetchCount.Add(1)
|
||||
fetchStarted <- struct{}{}
|
||||
<-fetchComplete // Wait for signal
|
||||
return []byte("value-" + key), time.Minute, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Launch concurrent requests for different keys
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 3; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
key := fmt.Sprintf("key-%d", idx)
|
||||
_, _ = cache.GetOrFetch(ctx, key, fetcher(key))
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all fetches to start
|
||||
for i := 0; i < 3; i++ {
|
||||
<-fetchStarted
|
||||
}
|
||||
|
||||
// All 3 fetches should be running in parallel
|
||||
assert.Equal(t, int32(3), fetchCount.Load(), "All three fetches should run in parallel")
|
||||
|
||||
// Release all fetches
|
||||
close(fetchComplete)
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// TestSingleflightCache_ContextCancellation tests context cancellation
|
||||
func TestSingleflightCache_ContextCancellation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewSingleflightCache(backend)
|
||||
|
||||
key := "cancel-key"
|
||||
fetchStarted := make(chan struct{})
|
||||
|
||||
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
|
||||
close(fetchStarted)
|
||||
// Simulate slow fetch
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
return []byte("value"), time.Minute, nil
|
||||
}
|
||||
|
||||
// Start first request with long timeout
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
ctx := context.Background()
|
||||
_, _ = cache.GetOrFetch(ctx, key, fetcher)
|
||||
}()
|
||||
|
||||
// Wait for fetch to start
|
||||
<-fetchStarted
|
||||
|
||||
// Start second request with short timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
_, err = cache.GetOrFetch(ctx, key, fetcher)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, context.DeadlineExceeded, err)
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// TestSingleflightCache_ErrorPropagation tests that errors are properly propagated
|
||||
func TestSingleflightCache_ErrorPropagation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewSingleflightCache(backend)
|
||||
|
||||
ctx := context.Background()
|
||||
key := "error-prop-key"
|
||||
expectedErr := errors.New("intentional error")
|
||||
|
||||
var fetchCount atomic.Int32
|
||||
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
|
||||
fetchCount.Add(1)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
return nil, 0, expectedErr
|
||||
}
|
||||
|
||||
// Launch multiple concurrent requests
|
||||
concurrency := 5
|
||||
var wg sync.WaitGroup
|
||||
errs := make([]error, concurrency)
|
||||
|
||||
for i := 0; i < concurrency; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
_, errs[idx] = cache.GetOrFetch(ctx, key, fetcher)
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Verify all requests got the same error
|
||||
for i := 0; i < concurrency; i++ {
|
||||
assert.Error(t, errs[i])
|
||||
assert.Equal(t, expectedErr, errs[i])
|
||||
}
|
||||
|
||||
// Verify fetcher was only called once
|
||||
assert.Equal(t, int32(1), fetchCount.Load())
|
||||
}
|
||||
|
||||
// TestSingleflightCache_PassthroughMethods tests that passthrough methods work
|
||||
func TestSingleflightCache_PassthroughMethods(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewSingleflightCache(backend)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Set", func(t *testing.T) {
|
||||
err := cache.Set(ctx, "set-key", []byte("set-value"), time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
val, _, exists, err := cache.Get(ctx, "set-key")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, []byte("set-value"), val)
|
||||
})
|
||||
|
||||
t.Run("Get", func(t *testing.T) {
|
||||
err := cache.Set(ctx, "get-key", []byte("get-value"), time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
val, ttl, exists, err := cache.Get(ctx, "get-key")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, []byte("get-value"), val)
|
||||
assert.Greater(t, ttl, time.Duration(0))
|
||||
})
|
||||
|
||||
t.Run("Delete", func(t *testing.T) {
|
||||
err := cache.Set(ctx, "delete-key", []byte("delete-value"), time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
deleted, err := cache.Delete(ctx, "delete-key")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, deleted)
|
||||
|
||||
exists, err := cache.Exists(ctx, "delete-key")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("Exists", func(t *testing.T) {
|
||||
exists, err := cache.Exists(ctx, "nonexistent")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
|
||||
err = cache.Set(ctx, "exists-key", []byte("value"), time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
exists, err = cache.Exists(ctx, "exists-key")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
})
|
||||
|
||||
t.Run("Clear", func(t *testing.T) {
|
||||
err := cache.Set(ctx, "clear-key", []byte("value"), time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = cache.Clear(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
exists, err := cache.Exists(ctx, "clear-key")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("Ping", func(t *testing.T) {
|
||||
err := cache.Ping(ctx)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// TestSingleflightCache_Stats tests statistics tracking
|
||||
func TestSingleflightCache_Stats(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewSingleflightCache(backend)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Make some calls
|
||||
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
return []byte("value"), time.Minute, nil
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 5; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, _ = cache.GetOrFetch(ctx, "stats-key", fetcher)
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
stats := cache.GetStats()
|
||||
|
||||
// Check singleflight stats exist
|
||||
assert.Contains(t, stats, "singleflight_total_calls")
|
||||
assert.Contains(t, stats, "singleflight_deduplicated")
|
||||
assert.Contains(t, stats, "singleflight_dedup_rate")
|
||||
assert.Contains(t, stats, "singleflight_inflight")
|
||||
|
||||
// Verify values
|
||||
assert.Equal(t, int64(5), stats["singleflight_total_calls"])
|
||||
assert.Equal(t, int64(4), stats["singleflight_deduplicated"])
|
||||
|
||||
// Also check underlying backend stats are included
|
||||
assert.Contains(t, stats, "hits")
|
||||
assert.Contains(t, stats, "misses")
|
||||
}
|
||||
|
||||
// TestSingleflightCache_ResetStats tests stats reset
|
||||
func TestSingleflightCache_ResetStats(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewSingleflightCache(backend)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
|
||||
return []byte("value"), time.Minute, nil
|
||||
}
|
||||
|
||||
// Make some calls
|
||||
_, _ = cache.GetOrFetch(ctx, "key1", fetcher)
|
||||
_, _ = cache.GetOrFetch(ctx, "key2", fetcher)
|
||||
|
||||
stats := cache.GetStats()
|
||||
assert.Greater(t, stats["singleflight_total_calls"].(int64), int64(0))
|
||||
|
||||
// Reset stats
|
||||
cache.ResetStats()
|
||||
|
||||
stats = cache.GetStats()
|
||||
assert.Equal(t, int64(0), stats["singleflight_total_calls"])
|
||||
assert.Equal(t, int64(0), stats["singleflight_deduplicated"])
|
||||
}
|
||||
|
||||
// TestSingleflightCache_GetBackend tests GetBackend method
|
||||
func TestSingleflightCache_GetBackend(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewSingleflightCache(backend)
|
||||
|
||||
assert.Equal(t, backend, cache.GetBackend())
|
||||
}
|
||||
|
||||
// BenchmarkSingleflightCache_Sequential benchmarks sequential access
|
||||
func BenchmarkSingleflightCache_Sequential(b *testing.B) {
|
||||
backend, _ := NewMemoryBackend(DefaultConfig())
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewSingleflightCache(backend)
|
||||
|
||||
ctx := context.Background()
|
||||
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
|
||||
return []byte("benchmark-value"), time.Minute, nil
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := fmt.Sprintf("key-%d", i%100)
|
||||
_, _ = cache.GetOrFetch(ctx, key, fetcher)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkSingleflightCache_Concurrent benchmarks concurrent access
|
||||
func BenchmarkSingleflightCache_Concurrent(b *testing.B) {
|
||||
backend, _ := NewMemoryBackend(DefaultConfig())
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewSingleflightCache(backend)
|
||||
|
||||
ctx := context.Background()
|
||||
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
|
||||
time.Sleep(time.Millisecond) // Simulate slow fetch
|
||||
return []byte("benchmark-value"), time.Minute, nil
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := 0
|
||||
for pb.Next() {
|
||||
key := fmt.Sprintf("key-%d", i%10) // Only 10 unique keys to force deduplication
|
||||
_, _ = cache.GetOrFetch(ctx, key, fetcher)
|
||||
i++
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkSingleflightCache_HighContention benchmarks high contention scenario
|
||||
func BenchmarkSingleflightCache_HighContention(b *testing.B) {
|
||||
backend, _ := NewMemoryBackend(DefaultConfig())
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewSingleflightCache(backend)
|
||||
|
||||
ctx := context.Background()
|
||||
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
|
||||
time.Sleep(10 * time.Millisecond) // Slow fetch to force queuing
|
||||
return []byte("benchmark-value"), time.Minute, nil
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
// All goroutines hit the same key
|
||||
_, _ = cache.GetOrFetch(ctx, "hot-key", fetcher)
|
||||
}
|
||||
})
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user