Compare commits

..

21 Commits

Author SHA1 Message Date
lukaszraczylo 413e4a1b7d LRU + cache conflicts prevention. (#104)
* LRU + cache conflicts prevention.

* Bugfix universalCache flooding ( issue #105 )

  1. Traefik cancels the context for old plugin instances
  2. Each plugin's Close() method is called
  3. The CacheInterfaceWrapper.Close() was calling cache.Close() on the shared singleton caches
  4. Each Close() triggered Clear() which logged "Cleared all items" at INFO level
2025-12-24 18:54:39 +00:00
lukaszraczylo 69e0d98c67 fixup! Add signing of the plugin on release. 2025-12-24 12:33:33 +00:00
lukaszraczylo 6d893df12b Add signing of the plugin on release. 2025-12-15 00:38:35 +00:00
lukaszraczylo 6efb78b7a8 Smarter approach to the cookies (#103)
* Smarter approach to the cookies

  - Single maxCookieSize = 1400 constant with clear documentation
  - Combined cookie storage for ~40-45% size reduction
  - Backward compatible migration from legacy cookies

* Tuneup the code.
2025-12-12 18:35:06 +00:00
lukaszraczylo d0b920c4f0 multiple realms fix (#102)
* Allow to use multiple realms

This change is a ressurection of PR #88 which can't be merged due to significant refactor of the codebase.

* Fix the autocleanup routine to handle multiple realms correctly, update tests.

* Metadata rediscovery when provider is unavailable for any reason during the start.

This one prevents the permanent 503 from the plugin when OIDC provider was for some reason unavailable during the start.
2025-12-10 13:07:22 +00:00
lukaszraczylo c474bbafd6 Cleanup [dec2025] (#101)
* Cleanup excessive comments.

* Remove leftovers hanging around from previous refactor

* Improve test coverage
2025-12-09 01:38:02 +00:00
lukaszraczylo 9126c74723 December 2025 Improvements - Azure AD, Internal Networks, Startup Race Condition (#100)
* Allow internal IPs for OIDC configuration via extra flag.

Addresses issue #97

* Allow for internal IPs in OIDC configuration.

Addresses issue #97.

* feat: Add allowPrivateIPAddresses config option for internal networks

Adds a new configuration option `allowPrivateIPAddresses` that allows
OIDC provider URLs to use private IP addresses (10.x.x.x, 172.16-31.x.x,
192.168.x.x). This is useful for internal deployments where Keycloak or
other OIDC providers run on private networks without DNS resolution.

Security considerations:
- Loopback addresses (127.0.0.1, localhost, ::1) remain blocked
- Link-local addresses (169.254.x.x) remain blocked
- Default is false (secure by default)

Fixes #97

* feat: Support non-email user identifiers for Azure AD

Add userIdentifierClaim configuration option to support Azure AD users
without email addresses. This allows using alternative JWT claims like
"sub", "oid", "upn", or "preferred_username" for user identification.

- Default behavior uses "email" claim (backward compatible)
- Falls back to "sub" claim if configured claim is missing
- allowedUsers matches against the configured claim value
- allowedUserDomains only applies when using email-based identification

Fixes #95

* Race condition on traefik pod startup

When the plugin initializes and calls GetMetadataWithRecovery():

1. Checks cache first (if metadata is cached, returns immediately)
2. Creates a retry executor with startup-optimized settings (10 attempts, 1s delays)
3. Attempts to fetch metadata from the OIDC provider
4. If the fetch fails with a retryable error (connection refused, EOF, TLS/certificate errors, Traefik default cert), it waits and retries
5. After 10 attempts or on a non-retryable error, returns the error

This allows the plugin to handle the race condition where:
- Traefik initializes the plugin before routes are established
- Traefik serves its default certificate before loading real ones
- The OIDC provider pod isn't fully ready yet

Fixes issue #90

* Race condition on traefik pod startup

When the plugin initializes and calls GetMetadataWithRecovery():

1. Checks cache first (if metadata is cached, returns immediately)
2. Creates a retry executor with startup-optimized settings (10 attempts, 1s delays)
3. Attempts to fetch metadata from the OIDC provider
4. If the fetch fails with a retryable error (connection refused, EOF, TLS/certificate errors, Traefik default cert), it waits and retries
5. After 10 attempts or on a non-retryable error, returns the error

This allows the plugin to handle the race condition where:
- Traefik initializes the plugin before routes are established
- Traefik serves its default certificate before loading real ones
- The OIDC provider pod isn't fully ready yet

Fixes issue #90

* Headers too big and 431 responses

Added new option `minimalHeaders` to reduce the size of forwarded headers from the auth middleware to backend services.

  - When minimalHeaders: false (default): All headers are forwarded as before
    - X-Forwarded-User (always set)
    - X-Auth-Request-Redirect
    - X-Auth-Request-User
    - X-Auth-Request-Token (the large ID token)
    - X-User-Groups, X-User-Roles (if configured)
  - When minimalHeaders: true: Reduces header overhead
    - X-Forwarded-User (always set)
    - X-User-Groups, X-User-Roles (still forwarded if configured)
    - Custom templated headers (still processed)
    - Skipped: X-Auth-Request-Token, X-Auth-Request-User, X-Auth-Request-Redirect

Fixes issues #64 and #86
2025-12-08 14:21:17 +00:00
lukaszraczylo a750c4f5b9 Size computation for allocation may overflow (#99)
* Size computation for allocation may overflow

Performing calculations involving the size of potentially large strings or slices can result in an overflow (for signed integer types) or a wraparound (for unsigned types). An overflow causes the result of the calculation to become negative, while a wraparound results in a small (positive) number.
2025-12-08 11:22:28 +00:00
lukaszraczylo 56051779ee Hotfix: goreleaser archive format. 2025-12-08 02:39:40 +00:00
lukaszraczylo 3f126d50f3 Force the v in the release tags and name. 2025-12-08 02:34:10 +00:00
lukaszraczylo 91f0fc9ab8 Switch to go releaser 2025-12-08 02:32:46 +00:00
lukaszraczylo 66b9ed0861 Reauthentication + redis fix
When introspection explicitly returns that a token is inactive/revoked/expired, the plugin now properly triggers re-authentication or refresh instead of falling back to ID token validation. This fixes the functional issue where users
weren't being redirected to re-authenticate.
Redis change ensures that when the caller's context is cancelled (e.g., the 200ms timeout in UniversalCache.Get()), the operation aborts quickly instead of continuing with retries.
2025-12-01 13:47:28 +00:00
lukaszraczylo e64fc7f730 Add redis support for distributed caching (#83)
* Add redis support for distributed caching

* Move towards the self-provided Redis connection pool and RESP protocol implementation.
Official redis client library won't work with yaegi.

* fixup! Move towards the self-provided Redis connection pool and RESP protocol implementation. Official redis client library won't work with yaegi.

* fixup! fixup! Move towards the self-provided Redis connection pool and RESP protocol implementation. Official redis client library won't work with yaegi.

* fixup! fixup! fixup! Move towards the self-provided Redis connection pool and RESP protocol implementation. Official redis client library won't work with yaegi.

* fixup! fixup! fixup! fixup! Move towards the self-provided Redis connection pool and RESP protocol implementation. Official redis client library won't work with yaegi.

* fixup! fixup! fixup! fixup! fixup! Move towards the self-provided Redis connection pool and RESP protocol implementation. Official redis client library won't work with yaegi.

* ... and another all nighter.

* fixup! ... and another all nighter.

* fixup! fixup! ... and another all nighter.

* fixup! fixup! fixup! ... and another all nighter.

* Resolve issue #85 by adding ability to set custom claims in JWT tokens

* Remove redundant validation in auth middleware ( issue #89 )

* Add ability to set cookie prefix for session cookies ( #87 )

* fixup! Add ability to set cookie prefix for session cookies ( #87 )

* Add ability to set cookie max age - issue #91

* Potential fix for code scanning alert no. 10: Size computation for allocation may overflow

Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>

* fixup! Merge main into 0.8.0-redis: resolve conflicts

---------

Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
2025-11-30 02:18:46 +00:00
lukaszraczylo 5fcbd54955 Add sharded cache and prevention of CPU spikes / locks (#96)
* Add sharded cache and prevention of CPU spikes / locks

* Add dynamic client registration with oidc provider

* Fix race condition introduced during the sharded cache implementation.

* Add page for traefikoidc.
2025-11-30 01:41:12 +00:00
lukaszraczylo e70cd1907c Create CNAME 2025-11-30 01:28:07 +00:00
lukaszraczylo e45b06c86d Fix markdown issues. 2025-10-17 14:40:50 +01:00
lukaszraczylo ae59a5e88a 0.7.10 (#80)
* Add ability to disable replay protection. - This is useful for runs with multiple traefik replicas to avoid false positives and tokens re-creation.
* Enhance the CI/CD pipelines
* Increase test coverage.
* Update vendored dependencies.
* Update behaviour on forceHTTPS as per issue #82
2025-10-16 10:56:28 +01:00
lukaszraczylo 79e9b164f9 release 0.7.9 (#78)
* Speed improvements.

After introduction of introspection the plugin became significantly slower.
This commit introduces several optimizations to bring the speed back up.

* Add relevant documentation and tests.
2025-10-13 10:43:35 +01:00
lukaszraczylo 93888e56d1 fixup! Multiple issues addressed (#76) 2025-10-09 00:56:53 +01:00
lukaszraczylo eff9bd7bd2 Multiple issues addressed (#76)
- Issue #74
- Issue #14
2025-10-09 00:44:03 +01:00
lukaszraczylo bde1db1c3b traefik plugin 0.7.7 (#73)
* Automatic discovery of the scopes.

Issue #61 raised very valid concerns about users configuring scopes that are not supported by the provider.
This change introduces automatic discovery of supported scopes by fetching the provider's discovery document and filtering out unsupported scopes.

Before:
User configures: scopes: ["openid", "profile", "email", "offline_access"]
Self-hosted GitLab: "The requested scope is invalid, unknown, or malformed"
Authentication:  FAILS

After:
User configures: scopes: ["openid", "profile", "email", "offline_access"]
Middleware checks discovery doc → offline_access not supported
Automatically filters to: ["openid", "profile", "email"]
Authentication:  SUCCEEDS

* Resolves issue #74 by enabling user to specify expected audience in the configuration.

* Fix flaky tests.
2025-10-08 11:44:00 +01:00
521 changed files with 143044 additions and 36885 deletions
+38
View File
@@ -0,0 +1,38 @@
# Code Owners for traefik-oidc
# These owners will be automatically requested for review when someone opens a PR
# Default owner for everything in the repo
* @lukaszraczylo
# Core authentication and middleware
/middleware/ @lukaszraczylo
/auth/ @lukaszraczylo
/handlers/ @lukaszraczylo
# OIDC providers
/internal/providers/ @lukaszraczylo
# Session management and security
/session/ @lukaszraczylo
/internal/security/ @lukaszraczylo
/security/ @lukaszraczylo
# Token management
/internal/token/ @lukaszraczylo
# Configuration
/config/ @lukaszraczylo
/.traefik.yml @lukaszraczylo
# GitHub Actions and CI/CD
/.github/ @lukaszraczylo
/.github/workflows/ @lukaszraczylo
/.golangci.yml @lukaszraczylo
# Documentation
/docs/ @lukaszraczylo
README.md @lukaszraczylo
# Dependencies
go.mod @lukaszraczylo
go.sum @lukaszraczylo
+123
View File
@@ -0,0 +1,123 @@
## Description
<!-- Provide a brief description of the changes in this PR -->
## Type of Change
<!-- Mark the relevant option with an "x" -->
- [ ] Bug fix (non-breaking change which fixes an issue)
- [ ] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
- [ ] Documentation update
- [ ] Performance improvement
- [ ] Code refactoring
- [ ] Security fix
- [ ] Provider-specific fix/enhancement
## Related Issues
<!-- Link to related issues using #issue_number -->
Fixes #
Related to #
## Changes Made
<!-- List the main changes made in this PR -->
-
-
-
## Provider Impact
<!-- If this affects specific OIDC providers, list them here -->
- [ ] Google
- [ ] Azure AD
- [ ] Auth0
- [ ] Okta
- [ ] Keycloak
- [ ] AWS Cognito
- [ ] GitLab
- [ ] GitHub
- [ ] Generic OIDC
- [ ] All providers
## Testing Performed
<!-- Describe the tests you ran to verify your changes -->
- [ ] Unit tests pass locally
- [ ] Integration tests pass locally
- [ ] Race detector shows no issues
- [ ] Memory leak tests pass
- [ ] Manual testing performed
### Test Configuration
<!-- Provide details about your test configuration if applicable -->
**Provider tested:**
**Go version:**
**Traefik version:**
## Security Considerations
<!-- Describe any security implications of these changes -->
- [ ] This PR does not introduce security vulnerabilities
- [ ] Security scanning has been performed
- [ ] Credentials/secrets are properly handled
- [ ] Input validation is implemented
## Performance Impact
<!-- Describe any performance implications -->
- [ ] No performance impact expected
- [ ] Performance improved (describe how)
- [ ] Performance may be affected (describe why and mitigation)
## Breaking Changes
<!-- If this is a breaking change, describe the impact and migration path -->
**Breaking changes:**
**Migration guide:**
## Checklist
<!-- Ensure all items are checked before requesting review -->
- [ ] My code follows the project's code style
- [ ] I have performed a self-review of my code
- [ ] I have commented my code, particularly in hard-to-understand areas
- [ ] I have made corresponding changes to the documentation
- [ ] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my feature works
- [ ] New and existing unit tests pass locally with my changes
- [ ] Any dependent changes have been merged and published
## Additional Context
<!-- Add any other context, screenshots, or information about the PR here -->
## Screenshots (if applicable)
<!-- Add screenshots to help explain your changes -->
---
**For Reviewers:**
Please verify:
- [ ] Code quality and style
- [ ] Test coverage is adequate
- [ ] Security implications reviewed
- [ ] Documentation is updated
- [ ] No performance regressions
+52
View File
@@ -0,0 +1,52 @@
version: 2
updates:
# Maintain dependencies for GitHub Actions
- package-ecosystem: "github-actions"
directory: "/"
schedule:
interval: "weekly"
day: "monday"
time: "09:00"
open-pull-requests-limit: 5
commit-message:
prefix: "chore(deps)"
include: "scope"
labels:
- "dependencies"
- "github-actions"
reviewers:
- "lukaszraczylo"
# Maintain Go module dependencies
- package-ecosystem: "gomod"
directory: "/"
schedule:
interval: "weekly"
day: "monday"
time: "09:00"
open-pull-requests-limit: 10
commit-message:
prefix: "chore(deps)"
include: "scope"
labels:
- "dependencies"
- "go"
reviewers:
- "lukaszraczylo"
# Group patch updates together
groups:
patch-updates:
patterns:
- "*"
update-types:
- "patch"
minor-updates:
patterns:
- "*"
update-types:
- "minor"
# Ignore certain dependencies if needed
ignore:
# Example: ignore specific versions
# - dependency-name: "github.com/example/package"
# versions: ["1.x", "2.x"]
+9
View File
@@ -0,0 +1,9 @@
# Ensure consistent line endings
* text=auto eol=lf
# GitHub Actions files should use LF
*.yml text eol=lf
*.yaml text eol=lf
# Shell scripts should use LF
*.sh text eol=lf
+225
View File
@@ -0,0 +1,225 @@
# GitHub Actions Workflows
This directory contains CI/CD workflows for the Traefik OIDC middleware.
## Workflows
### PR Validation (`pr-validation.yml`)
A comprehensive validation workflow that runs **all checks in parallel** for maximum speed and thorough testing.
**Triggered on:**
- Pull requests to `main` branch
- Pushes to `main` branch
**Parallel Jobs (20+ concurrent checks):**
#### Code Quality
- **Quick Checks** - Format, go vet, go mod verify
- **golangci-lint** - Comprehensive linting
- **Staticcheck** - Static analysis
#### Security
- **Gosec** - Security vulnerability scanning
- **Govulncheck** - Go vulnerability database check
- **CodeQL** - GitHub's code analysis
#### Testing
- **Race Detector** - Concurrent access bug detection
- **Coverage** - Test coverage with 75% threshold
- **Memory Leaks** - Goroutine and memory leak detection
- **Integration Tests** - Full integration test suite
- **Regression Tests** - Prevent previously fixed bugs
- **Security Edge Cases** - Security-specific scenarios
- **Session Tests** - Session management validation
- **Token Tests** - Token validation scenarios
- **CSRF Tests** - CSRF protection validation
#### Provider Testing (Matrix)
Tests run in parallel for each OIDC provider:
- Google
- Azure AD
- Auth0
- Okta
- Keycloak
- AWS Cognito
- GitLab
- GitHub
- Generic OIDC
#### Performance & Compatibility
- **Benchmarks** - Performance regression detection
- **Build Matrix** - linux/darwin × amd64/arm64
- **Go Versions** - Go 1.23 and 1.24 compatibility
#### Final Validation
- **All Checks Passed** - Ensures all jobs succeeded
## Workflow Features
### 🚀 Parallel Execution
All independent checks run simultaneously for fastest feedback (~5-10 minutes for full suite).
### 📊 Coverage Reporting
- Automatic PR comments with coverage statistics
- Per-package coverage breakdown
- 75% coverage threshold enforcement
### 🔒 Security First
- Multiple security scanners (gosec, govulncheck, CodeQL)
- SARIF report uploads for GitHub Security tab
- Security edge case testing
### 🎯 Comprehensive Testing
- Race condition detection
- Memory leak detection
- Provider-specific testing
- Integration and regression tests
### 📈 Performance Tracking
- Benchmark results stored as artifacts
- Performance regression detection
### ✅ Quality Gates
All checks must pass before PR can be merged:
- Code formatting and style
- Security vulnerabilities
- Test coverage threshold
- Race conditions
- Memory leaks
- Build success on all platforms
## Local Development
### Run checks locally before pushing:
```bash
# Format code
gofmt -s -w .
# Run linter
golangci-lint run
# Run tests with race detector
go test -race -timeout=15m -count=1 ./...
# Check coverage
go test -coverprofile=coverage.out ./...
go tool cover -func=coverage.out
# Run specific test suites
go test -v -run='.*Leak.*' ./... # Memory leak tests
go test -v -run='.*Integration.*' ./... # Integration tests
go test -v -run='.*Regression.*' ./... # Regression tests
# Run benchmarks
go test -bench=. -benchmem ./...
# Security scan
gosec ./...
govulncheck ./...
```
### Required Tools
Install these tools for local development:
```bash
# golangci-lint
go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
# staticcheck
go install honnef.co/go/tools/cmd/staticcheck@latest
# gosec
go install github.com/securego/gosec/v2/cmd/gosec@latest
# govulncheck
go install golang.org/x/vuln/cmd/govulncheck@latest
```
## Troubleshooting
### Workflow Fails
1. **Check job status** - Click on failed job for details
2. **Review logs** - Expand failed steps to see error messages
3. **Run locally** - Reproduce issue with local commands above
4. **Check coverage** - Ensure test coverage meets 75% threshold
### Coverage Below Threshold
Add tests to increase coverage:
```bash
# See which lines aren't covered
go test -coverprofile=coverage.out ./...
go tool cover -html=coverage.out
```
### Race Condition Detected
Run with race detector locally:
```bash
go test -race -v ./...
```
### Provider Test Failure
Test specific provider:
```bash
go test -v -run='.*Azure.*' ./internal/providers/...
```
## Performance Optimization
The workflow is optimized for speed:
- **Parallel execution** - All independent jobs run simultaneously
- **Go caching** - Dependencies cached between runs
- **Strategic ordering** - Quick checks run first for fast feedback
- **Fail-fast disabled** - Continue running all tests even if some fail
## Workflow Monitoring
### GitHub Actions Dashboard
Monitor workflow runs at: `https://github.com/{owner}/{repo}/actions`
### Status Badges
Add to README.md:
```markdown
![PR Validation](https://github.com/{owner}/{repo}/actions/workflows/pr-validation.yml/badge.svg)
```
### Notifications
Configure in repository settings:
- Settings → Notifications
- Choose email or Slack notifications for workflow failures
## Maintenance
### Update Go Version
Edit in workflow file:
```yaml
go-version: '1.24' # Update this
```
### Adjust Coverage Threshold
Edit in workflow file:
```yaml
THRESHOLD=75 # Adjust this value
```
### Add New Provider
Add to provider matrix:
```yaml
matrix:
provider:
- new_provider # Add here
```
## Additional Resources
- [GitHub Actions Documentation](https://docs.github.com/en/actions)
- [golangci-lint Configuration](../.golangci.yml)
- [Dependabot Configuration](../dependabot.yml)
- [PR Template](../PULL_REQUEST_TEMPLATE.md)
+23
View File
@@ -0,0 +1,23 @@
name: Pull Request
on:
pull_request:
branches:
- main
push:
branches:
- "**"
- "!main"
permissions:
contents: read
pull-requests: write
security-events: write
jobs:
pr-checks:
uses: lukaszraczylo/shared-actions/.github/workflows/go-pr.yaml@main
with:
go-version: "1.24.11"
coverage-threshold: 70
secrets: inherit
+23
View File
@@ -0,0 +1,23 @@
name: Release
on:
push:
branches:
- main
paths:
- "**.go"
- "go.mod"
- "go.sum"
workflow_dispatch:
permissions:
id-token: write
contents: write
packages: write
jobs:
release:
uses: lukaszraczylo/shared-actions/.github/workflows/go-release.yaml@main
with:
go-version: "1.24.11"
secrets: inherit
+2 -1
View File
@@ -1,2 +1,3 @@
docker/
.claude/
.claude/*.out
*.test
+192
View File
@@ -0,0 +1,192 @@
version: "2"
run:
go: "1.24"
modules-download-mode: readonly
tests: true
linters:
enable:
- bodyclose
- dupl
- goconst
- gocritic
- gocyclo
- goprintffuncname
- gosec
- misspell
- noctx
- nolintlint
- prealloc
- revive
- rowserrcheck
- sqlclosecheck
- unconvert
- unparam
- whitespace
disable:
- exhaustive
- funlen
- gocognit
- lll
- mnd
- testpackage
- wsl
settings:
dupl:
threshold: 200 # Allow intentional duplication in provider patterns and token management
errcheck:
check-type-assertions: true
check-blank: false # Allow explicit blank assignments (_ = ...) to ignore errors
exclude-functions:
- (io.Closer).Close
- (*database/sql.Rows).Close
- (*database/sql.Stmt).Close
- (io.Writer).Write
- (*net/http.ResponseWriter).Write
- fmt.Fprintf
- fmt.Fprint
- fmt.Fprintln
goconst:
min-len: 3
min-occurrences: 10 # Increased to reduce noise for standard OAuth2/OIDC strings
ignore-tests: true
gocritic:
# Using default enabled checks in v2
enabled-checks:
- appendCombine
- boolExprSimplify
- builtinShadow
- commentedOutCode
- emptyFallthrough
- equalFold
- hexLiteral
- indexAlloc
- initClause
- methodExprCall
- nestingReduce
- rangeExprCopy
- rangeValCopy
- stringXbytes
- typeAssertChain
- typeUnparen
- unlabelStmt
- yodaStyleExpr
gocyclo:
min-complexity: 30 # OAuth/OIDC flows are inherently complex; set higher for Yaegi compatibility
gosec:
excludes:
- G104
- G404
severity: medium
confidence: medium
govet:
disable:
- fieldalignment
- shadow
enable-all: true
misspell:
locale: US
ignore-rules:
- traefik
- oidc
- keycloak
nolintlint:
require-explanation: true
require-specific: true
allow-unused: false
prealloc:
simple: true
range-loops: true
for-loops: false
revive:
rules:
- name: blank-imports
- name: context-as-argument
- name: context-keys-type
- name: dot-imports
- name: error-return
- name: error-strings
- name: error-naming
- name: exported
- name: if-return
- name: increment-decrement
- name: var-naming
- name: var-declaration
- name: package-comments
- name: range
- name: receiver-naming
- name: time-naming
- name: unexported-return
- name: indent-error-flow
- name: errorf
- name: empty-block
- name: superfluous-else
- name: unused-parameter
- name: unreachable-code
- name: redefines-builtin-id
unparam:
check-exported: false
staticcheck:
checks:
- all
- -QF1001 # De Morgan's law - style preference, may affect Yaegi
- -QF1003 # Tagged switch - style preference, may affect Yaegi
- -QF1007 # Merge conditional assignment - style preference
- -QF1008 # Remove embedded field - may break Yaegi compatibility
- -QF1012 # Use fmt.Fprintf - style preference
- -ST1003 # Package name format - allowed for test packages
exclusions:
generated: lax
rules:
- linters:
- bodyclose
- dupl
- errcheck
- goconst
- gocyclo
- gosec
- noctx
- prealloc
- unparam
path: _test\.go
- linters:
- dupl
- gocyclo
path: test.*\.go
- linters:
- gocritic
- unused
path: mocks.*\.go
- linters:
- gosec
text: 'G404:'
- linters:
- all
path: vendor/
- linters:
- goconst
path: (.+)_test\.go
- linters:
- dupl
path: internal/providers/(auth0|keycloak|okta|google|azure|github|gitlab|cognito|generic)\.go
- linters:
- dupl
path: session\.go
- linters:
- dupl
path: session_chunk_manager\.go
text: "(extractJWTExpiration|extractJWTIssuedAt)"
paths:
- third_party$
- builtin$
- examples$
issues:
max-issues-per-linter: 0
max-same-issues: 0
uniq-by-line: true
formatters:
exclusions:
generated: lax
paths:
- third_party$
- builtin$
- examples$
+60
View File
@@ -0,0 +1,60 @@
version: 2
# Traefik plugins are source-only - no binary builds
# Traefik loads plugins via Yaegi interpreter at runtime
builds:
- skip: true
# Create source archive for GitHub releases
archives:
- formats: [tar.gz]
name_template: "{{ .ProjectName }}_v{{ .Version }}_source"
files:
- "*.go"
- "**/*.go"
- go.mod
- go.sum
- .traefik.yml
- LICENSE*
- README*
# Exclude test files and vendor from release archive
- "!**/*_test.go"
- "!vendor/**"
- "!docker/**"
- "!integration/**"
- "!regression/**"
- "!examples/**"
- "!docs/**"
checksum:
name_template: "{{ .ProjectName }}_v{{ .Version }}_checksums.txt"
algorithm: sha256
changelog:
sort: asc
filters:
exclude:
- "^docs:"
- "^test:"
- "^Merge"
- "^WIP"
- "^chore:"
release:
github:
owner: lukaszraczylo
name: traefikoidc
name_template: "v{{ .Version }}"
draft: false
prerelease: auto
signs:
- cmd: cosign
signature: "${artifact}.sigstore.json"
args:
- sign-blob
- "--bundle=${signature}"
- "${artifact}"
- "--yes"
artifacts: checksum
output: true
+774 -43
View File
@@ -31,6 +31,7 @@ summary: |
- Flexible configuration with multiple deployment scenarios
- Memory-efficient operation with automatic cleanup
- Extensive logging and debugging capabilities
- Redis cache support for multi-replica deployments with automatic failover
It supports various authentication scenarios including:
- Basic authentication with customizable callback and logout URLs
@@ -73,7 +74,16 @@ testData:
- admin
- developer
forceHTTPS: false # Forces the use of HTTPS for all URLs (default: true for security)
# Custom claim names for Auth0 and other providers with namespaced claims
roleClaimName: roles # JWT claim name for extracting user roles (default: "roles")
groupClaimName: groups # JWT claim name for extracting user groups (default: "groups")
userIdentifierClaim: email # JWT claim for user identification (default: "email", alternatives: "sub", "oid", "upn", "preferred_username")
# ⚠️ CRITICAL for TLS termination scenarios (AWS ALB, Cloud Load Balancers, etc.)
# When NOT specified in config: defaults to FALSE (Go zero value)
# When running behind load balancer that terminates TLS: MUST set to TRUE
# See: https://github.com/lukaszraczylo/traefikoidc/issues/82
forceHTTPS: true # Forces HTTPS scheme for redirect URIs (default when not specified: false)
logLevel: debug # Sets logging verbosity: debug, info, error (default: info)
rateLimit: 100 # Maximum number of requests per second (default: 100, minimum: 10)
@@ -84,25 +94,36 @@ testData:
- /metrics
headers: # Custom headers to set with templated values from claims and tokens
# NOTE: If you encounter "can't evaluate field AccessToken in type bool" errors,
# you may need to escape the templates. See the headers section in configuration below.
# NOTE: Use double curly braces to escape template expressions in YAML
# See the headers section in configuration below for details
- name: "X-User-Email"
value: "{{.Claims.email}}"
value: "{{{{.Claims.email}}}}"
- name: "X-User-ID"
value: "{{.Claims.sub}}"
value: "{{{{.Claims.sub}}}}"
- name: "Authorization"
value: "Bearer {{.AccessToken}}"
value: "Bearer {{{{.AccessToken}}}}"
- name: "X-User-Roles"
value: "{{range $i, $e := .Claims.roles}}{{if $i}},{{end}}{{$e}}{{end}}"
value: "{{{{range $i, $e := .Claims.roles}}}}{{{{if $i}}}},{{{{end}}}}{{{{$e}}}}{{{{end}}}}"
# Advanced parameters (usually discovered automatically from provider metadata)
revocationURL: https://accounts.google.com/revoke # Endpoint for revoking tokens
oidcEndSessionURL: https://accounts.google.com/logout # Provider's end session endpoint
enablePKCE: false # Enables PKCE (Proof Key for Code Exchange) for additional security
cookieDomain: "" # Explicit domain for session cookies (e.g., ".example.com" for multi-subdomain setups)
cookiePrefix: "" # Custom prefix for cookie names (e.g., "_oidc_myapp_" for session isolation between middleware instances)
sessionMaxAge: 86400 # Maximum session age in seconds (default: 86400 = 24 hours, 0 = use default)
overrideScopes: false # When true, replaces default scopes instead of appending (default: false)
refreshGracePeriodSeconds: 60 # Seconds before token expiry to attempt proactive refresh (default: 60)
# Auth0 / Custom API Audience Configuration
audience: "" # Custom audience for access token validation (default: clientID)
strictAudienceValidation: false # Reject sessions with audience mismatch (prevents token confusion attacks)
allowOpaqueTokens: false # Enable opaque (non-JWT) access token support via RFC 7662 introspection
requireTokenIntrospection: false # Force introspection for opaque tokens (requires introspection endpoint)
disableReplayDetection: false # Disable JTI replay detection for multi-replica deployments (default: false)
allowPrivateIPAddresses: false # Allow private IP addresses in provider URLs for internal networks (default: false)
minimalHeaders: false # Reduce forwarded headers to prevent 431 errors (default: false)
# Security Headers Configuration (enabled by default with 'default' profile)
securityHeaders:
enabled: true
@@ -115,11 +136,53 @@ testData:
- "https://*.example.com"
corsAllowCredentials: true
# Cross-origin policies
permissionsPolicy: "geolocation=(), camera=(), microphone=()"
crossOriginEmbedderPolicy: "require-corp"
crossOriginOpenerPolicy: "same-origin"
crossOriginResourcePolicy: "same-origin"
# Custom headers
customHeaders:
X-Custom-Header: "production"
X-API-Version: "v1"
# Example with Redis cache for multi-replica deployments
testDataWithRedis:
# Required OIDC parameters (same as standard configuration)
providerURL: https://auth.example.com
clientID: your-client-id
clientSecret: your-client-secret
callbackURL: /oauth2/callback
sessionEncryptionKey: your-64-character-encryption-key-at-least-32-bytes
# Standard optional parameters
logLevel: info
allowedUserDomains:
- company.com
# Redis cache configuration for multi-replica support
redis:
enabled: true # Enable Redis caching
address: "redis:6379" # Redis server address
password: "redis-password" # Redis authentication password
db: 0 # Redis database number (0-15)
keyPrefix: "traefikoidc:" # Prefix for all Redis keys
cacheMode: "hybrid" # Cache mode: redis, hybrid, or memory
poolSize: 20 # Maximum number of connections
connectTimeout: 5 # Connection timeout in seconds
readTimeout: 3 # Read operation timeout
writeTimeout: 3 # Write operation timeout
enableTLS: false # Use TLS for Redis connection
tlsSkipVerify: false # Skip TLS certificate verification
hybridL1Size: 500 # L1 cache size for hybrid mode
hybridL1MemoryMB: 10 # L1 memory limit for hybrid mode
enableCircuitBreaker: true # Enable circuit breaker
circuitBreakerThreshold: 5 # Failures before opening circuit
circuitBreakerTimeout: 60 # Timeout before retry (seconds)
enableHealthCheck: true # Enable periodic health checks
healthCheckInterval: 30 # Health check interval (seconds)
# --- Common Configuration Examples ---
#
# 🔒 HIGH-SECURITY CONFIGURATION
@@ -169,11 +232,11 @@ testData:
# corsAllowedOrigins: ["https://app.example.com"]
# corsAllowedMethods: ["GET", "POST", "PUT", "DELETE", "OPTIONS"]
# corsAllowedHeaders: ["Authorization", "Content-Type", "X-API-Key"]
# headers: # Custom headers with OIDC claims
# headers: # Custom headers with OIDC claims (use double curly braces)
# - name: "X-User-Email"
# value: "{{.Claims.email}}"
# value: "{{{{.Claims.email}}}}"
# - name: "X-User-ID"
# value: "{{.Claims.sub}}"
# value: "{{{{.Claims.sub}}}}"
# --- Provider Specific Configuration Examples ---
#
@@ -206,6 +269,8 @@ testData:
# allowedRolesAndGroups: # Corresponds to 'Token Claim Name' in Keycloak mappers
# - admin
# - editor
# # For internal Keycloak deployments with private IPs (Docker/Kubernetes internal):
# # allowPrivateIPAddresses: true # Enable for private IP addresses like 192.168.x.x, 10.x.x.x
# # Ensure Keycloak client mappers add 'email', 'roles', 'groups' etc. to the ID Token.
# # See README.md "Provider Configuration Recommendations" for Keycloak.
@@ -227,6 +292,26 @@ testData:
# - "AppRoleName"
# # See README.md "Provider Configuration Recommendations" for Azure AD.
# --- Azure AD Users Without Email Example (Issue #95) ---
# testDataAzureADNoEmail:
# providerURL: https://login.microsoftonline.com/your-tenant-id/v2.0
# clientID: your-azure-ad-client-id
# clientSecret: your-azure-ad-client-secret
# callbackURL: /oauth2/callback
# sessionEncryptionKey: "a-very-secure-key-at-least-32-bytes-long-for-azure"
# # Use 'sub' claim instead of 'email' for user identification
# userIdentifierClaim: sub # or "oid", "upn", "preferred_username"
# overrideScopes: true # Remove email scope if not needed
# scopes:
# - openid
# - profile
# - groups # For group-based access control
# # When using non-email identifiers, allowedUsers matches against the claim value
# allowedUsers:
# - "abc12345-6789-0abc-def0-123456789abc" # Azure AD user object ID (sub or oid claim)
# # NOTE: allowedUserDomains is ignored when userIdentifierClaim is not "email"
# # See: https://github.com/lukaszraczylo/traefikoidc/issues/95
# --- Google Workspace / Google Cloud Identity Example ---
# testDataGoogle:
# providerURL: https://accounts.google.com # Standard Google OIDC endpoint
@@ -312,6 +397,20 @@ testData:
# clientSecret: your-auth0-client-secret # Store securely
# callbackURL: /oauth2/callback
# sessionEncryptionKey: "a-very-secure-key-at-least-32-bytes-long-for-auth0"
#
# # Auth0 Audience Configuration (for custom APIs)
# # Scenario 1 (Recommended): Custom API with JWT access tokens
# audience: "https://my-api.example.com" # Your API identifier from Auth0 dashboard
# strictAudienceValidation: true # Enforce proper audience validation for security
#
# # Scenario 2 (Backward Compatible): Default audience (uses client_id)
# # audience: "" # Leave empty or omit to use client_id as audience
# # strictAudienceValidation: false # Allows fallback to ID token validation (logs warnings)
#
# # Scenario 3: Opaque (non-JWT) access tokens
# # allowOpaqueTokens: true # Enable opaque token support
# # requireTokenIntrospection: true # Require RFC 7662 token introspection
#
# scopes: # Defaults ["openid", "profile", "email"]. Add custom scopes if your Auth0 Rules/Actions require them.
# - read:custom_data # Example custom scope
# allowedRolesAndGroups: # Based on claims added via Auth0 Rules or Actions (e.g. namespaced claims)
@@ -319,7 +418,7 @@ testData:
# - editor
# # Use Auth0 Rules or Actions to add custom claims (roles, permissions) to the ID Token.
# # Ensure postLogoutRedirectURI is in Auth0 app's "Allowed Logout URLs".
# # See README.md "Provider Configuration Recommendations" for Auth0.
# # For detailed Auth0 audience configuration, see AUTH0_AUDIENCE_GUIDE.md
# --- Generic OIDC Provider Example ---
# testDataGenericOIDC:
@@ -448,9 +547,24 @@ configuration:
forceHTTPS:
type: boolean
description: |
Forces the use of HTTPS for all URLs.
This is recommended for security in production environments.
Default: true
Forces HTTPS scheme for redirect URIs regardless of request headers or TLS state.
⚠️ CRITICAL CONFIGURATION for TLS Termination Scenarios:
When running Traefik behind a load balancer that terminates TLS (AWS ALB,
Google Cloud Load Balancer, Azure Application Gateway, etc.), you MUST set
this to true. Without it, redirect URIs will use http:// instead of https://,
causing OAuth callback failures.
How it works:
- When true: Always uses https:// for redirect URIs (highest priority)
- When false: Detects scheme from X-Forwarded-Proto header or TLS state
- When NOT specified: Defaults to false (Go zero value for bool)
Default: false (when not specified in configuration)
Recommended: true (for production environments and TLS termination scenarios)
See: https://github.com/lukaszraczylo/traefikoidc/issues/82
required: false
rateLimit:
@@ -516,6 +630,38 @@ configuration:
items:
type: string
userIdentifierClaim:
type: string
description: |
Specifies the JWT claim to use as the user identifier for authentication and authorization.
This allows authentication for users without email addresses, such as Azure AD service
accounts or organizational accounts that don't have email attributes configured.
When set to a non-email claim (e.g., "sub", "oid", "upn"):
- AllowedUsers will match against this claim value instead of email
- AllowedUserDomains validation is skipped (domains only apply to email addresses)
- The session stores this identifier as the user's identity
- If the configured claim is missing, falls back to "sub" (required by OIDC spec)
Common values by provider:
- Default: "email" (standard email-based identification)
- Azure AD: "sub", "oid" (object ID), "upn" (User Principal Name), "preferred_username"
- Generic OIDC: "sub" (always present per OIDC specification)
- Keycloak: "sub", "preferred_username"
Example for Azure AD users without email:
```yaml
userIdentifierClaim: sub
allowedUsers:
- "abc123-user-object-id"
- "xyz789-another-user-id"
```
Default: "email"
See: https://github.com/lukaszraczylo/traefikoidc/issues/95
required: false
revocationURL:
type: string
description: |
@@ -549,28 +695,101 @@ configuration:
cookieDomain:
type: string
description: |
Explicit domain for session cookies. This is important for multi-subdomain setups
Explicit domain for session cookies. This is important for multi-subdomain setups
and reverse proxy deployments to ensure consistent cookie handling.
When set, all session cookies will use this domain. When not set, the domain
is auto-detected from the request headers (X-Forwarded-Host or Host).
Use a leading dot for subdomain-wide cookies (e.g., ".example.com" allows
cookies to be shared between app.example.com, api.example.com, etc.).
Use a specific domain for host-only cookies (e.g., "app.example.com" restricts
cookies to that exact domain).
This setting is crucial to prevent authentication issues like "CSRF token missing
in session" errors that can occur when cookies are created with inconsistent domains.
Examples:
- ".example.com" - Allows all subdomains to share cookies
- "app.example.com" - Restricts cookies to this specific host
Default: "" (auto-detected from request headers)
required: false
cookiePrefix:
type: string
description: |
Custom prefix for session cookie names. This is essential for running multiple
middleware instances with different authorization requirements on the same domain.
By default, all middleware instances use the same cookie names (_oidc_raczylo_m,
_oidc_raczylo_a, etc.), which means they share session state. When you have
multiple instances with different access restrictions (e.g., one for general users
and one for admins), this session sharing can lead to authorization bypass issues.
Setting a unique cookiePrefix for each middleware instance ensures complete
session isolation, preventing users authenticated via one middleware from
automatically gaining access to routes protected by a different middleware.
The prefix is prepended to all session cookie names:
- Main session cookie: {prefix}m
- Access token cookie: {prefix}a
- Refresh token cookie: {prefix}r
- ID token cookie: {prefix}id
Examples:
- "_oidc_userauth_" - For general user authentication middleware
- "_oidc_adminauth_" - For admin-only authentication middleware
- "_oidc_api_" - For API-specific authentication middleware
Security Note: Use different cookie prefixes AND different sessionEncryptionKey
values for each middleware instance to ensure complete isolation.
Default: "_oidc_raczylo_" (standard prefix for backward compatibility)
See: https://github.com/lukaszraczylo/traefikoidc/issues/87
required: false
sessionMaxAge:
type: integer
description: |
Maximum session age in seconds before requiring re-authentication.
This setting controls how long a user's authentication session remains valid
before they must authenticate again through the OIDC provider. The session
age is tracked from the initial authentication time (created_at).
When a session exceeds this age:
- The session is cleared and invalidated
- The user is redirected to re-authenticate
- All session cookies are removed
Use Cases:
- High-security applications: Use shorter durations (e.g., 3600 = 1 hour)
- Standard applications: Default 24 hours balances security and UX
- Long-lived sessions: Extend for applications accessed infrequently
(e.g., 604800 = 7 days, 2592000 = 30 days)
Security Considerations:
- Shorter sessions provide better security but require more frequent logins
- Longer sessions improve user experience but increase security risk
- Consider your application's security requirements and user access patterns
- This is independent of token refresh - tokens can be refreshed during the session
Common Values:
- 3600 (1 hour) - High security applications
- 28800 (8 hours) - Working day session
- 86400 (24 hours) - Default, balances security and convenience
- 604800 (7 days) - Weekly session for less frequently accessed apps
- 2592000 (30 days) - Monthly session for infrequently used applications
Default: 86400 (24 hours)
Minimum: 0 (uses default of 24 hours)
See: https://github.com/lukaszraczylo/traefikoidc/issues/91
required: false
overrideScopes:
type: boolean
description: |
@@ -588,16 +807,220 @@ configuration:
type: integer
description: |
The number of seconds before a token expires to attempt proactive refresh.
When a request is made and the access token will expire within this grace period,
the middleware will attempt to refresh the token proactively. This helps prevent
authentication interruptions for active users.
Setting this to 0 disables proactive refresh (tokens are only refreshed after expiry).
Default: 60 (1 minute before expiry)
required: false
audience:
type: string
description: |
Custom audience value for access token validation.
This configures the expected audience claim in access tokens. Per OAuth 2.0 and OIDC
specifications:
- ID tokens always have aud=client_id (per OIDC Core 1.0)
- Access tokens can have custom audiences (e.g., API identifiers)
Auth0 Scenarios:
1. Custom API audience (recommended): Set to your API identifier from Auth0
Example: "https://my-api.example.com"
Result: Access tokens will contain this audience
2. Default audience: Leave empty or omit (uses client_id)
Result: Access tokens may not contain client_id, triggering warnings
3. Opaque tokens: Use with allowOpaqueTokens=true for non-JWT tokens
When configured and different from client_id, the middleware automatically adds
the audience parameter to the authorize endpoint request.
Default: "" (uses client_id as audience)
See: AUTH0_AUDIENCE_GUIDE.md for detailed configuration
required: false
strictAudienceValidation:
type: boolean
description: |
Enforce strict audience validation for access tokens.
When enabled, sessions are rejected if access token validation fails due to
audience mismatch. This prevents falling back to ID token validation, addressing
potential token confusion attacks where tokens intended for different APIs could
be used to grant access.
Auth0 Scenario 2 Protection:
- When true: Rejects sessions with mismatched access token audience
- When false: Logs security warnings but allows fallback to ID token (backward compatible)
Security Recommendation:
- Production environments: Set to true for maximum security
- Development/testing: Can use false with monitoring of security warnings
This setting addresses security concerns where access tokens without proper
audience claims could bypass API-specific authorization checks.
Default: false (backward compatible)
See: https://github.com/lukaszraczylo/traefikoidc/issues/74
required: false
allowOpaqueTokens:
type: boolean
description: |
Enable acceptance of opaque (non-JWT) access tokens.
When enabled, the middleware accepts access tokens that are not in JWT format
(3-part base64 structure). Opaque tokens are validated using OAuth 2.0 Token
Introspection (RFC 7662) if the provider exposes an introspection endpoint.
Auth0 Scenario 3:
Some Auth0 configurations issue opaque access tokens when no default API is
configured. This setting allows those tokens to be validated.
Requirements:
- Provider must support introspection_endpoint in OIDC discovery
- Client must have appropriate introspection permissions
Validation Process:
1. Detects opaque token (not 3-part JWT structure)
2. Calls provider's introspection endpoint with client credentials
3. Validates response (active status, expiration, audience if present)
4. Caches result for 5 minutes or token expiry (whichever is shorter)
5. Falls back to ID token validation if introspection unavailable
(unless requireTokenIntrospection=true)
Default: false (only JWT access tokens accepted)
See: AUTH0_AUDIENCE_GUIDE.md for configuration examples
required: false
requireTokenIntrospection:
type: boolean
description: |
Require token introspection for all opaque access tokens.
When enabled with allowOpaqueTokens=true, opaque tokens are rejected if:
- Introspection endpoint is not available from provider metadata
- Introspection request fails
- Introspection response indicates token is not active
Security Levels:
- requireTokenIntrospection=true + allowOpaqueTokens=true:
Maximum security - rejects opaque tokens without successful introspection
- requireTokenIntrospection=false + allowOpaqueTokens=true:
Backward compatible - falls back to ID token validation if introspection fails
- requireTokenIntrospection=true + allowOpaqueTokens=false:
No effect - opaque tokens are already rejected
Recommended Configuration:
When accepting opaque tokens, always set this to true for maximum security:
```yaml
allowOpaqueTokens: true
requireTokenIntrospection: true
```
Default: false (allows fallback to ID token)
See: RFC 7662 OAuth 2.0 Token Introspection specification
required: false
disableReplayDetection:
type: boolean
description: |
Disable JTI-based replay attack detection for multi-replica deployments.
When running multiple Traefik replicas, each instance maintains its own in-memory
JTI (JWT Token ID) cache. This causes false positives when the same valid token
hits different replicas:
- Request → Replica A → JTI added to cache → OK
- Request → Replica B → JTI not in Replica B's cache → OK
- Request → Replica A again → JTI found → FALSE POSITIVE "replay detected"
Security Impact:
When disabled, the following validations remain active:
- RSA/ECDSA signature verification
- Token expiration (exp claim)
- Issuer validation (iss claim)
- Audience validation (aud claim)
- Not-before validation (nbf claim)
- Issued-at validation (iat claim)
Only the JTI replay check is skipped.
Recommendations:
- Single-instance deployment: false (default, enables replay protection)
- Multi-replica deployment: true (prevents false positives)
- Production with shared cache: false (use Redis/Memcached for shared JTI cache)
Default: false (replay detection enabled)
required: false
allowPrivateIPAddresses:
type: boolean
description: |
Allow private IP addresses in OIDC provider URLs for internal network deployments.
By default, the plugin blocks URLs containing private IP address ranges
(10.x.x.x, 172.16-31.x.x, 192.168.x.x) to prevent SSRF attacks and ensure
OIDC providers are publicly accessible.
Enable this option when:
- Your OIDC provider (e.g., Keycloak) runs on an internal network with private IPs
- You don't have DNS resolution available for internal services
- Your entire stack runs in a Docker network or Kubernetes cluster with private addressing
When enabled, the plugin will accept provider URLs like:
- https://192.168.1.100:8443/auth/realms/your-realm
- https://10.0.0.50:8080/realms/master
- https://172.16.0.10/auth
Security Warning:
Enabling this option reduces SSRF protection. Only use in trusted network
environments where the OIDC provider is known and controlled. Loopback
addresses (127.0.0.1, localhost, ::1) remain blocked even with this option enabled.
Default: false (private IPs are blocked for security)
See: https://github.com/lukaszraczylo/traefikoidc/issues/97
required: false
minimalHeaders:
type: boolean
description: |
Reduce forwarded headers to prevent "431 Request Header Fields Too Large" errors.
When enabled, the middleware only forwards the X-Forwarded-User header and skips
the larger authentication headers that can cause downstream services to reject
requests due to header size limits (typically 8KB).
Headers when disabled (default):
- X-Forwarded-User: User's email address (always set)
- X-Auth-Request-Redirect: Original request URI
- X-Auth-Request-User: User's email address
- X-Auth-Request-Token: Full ID token (can be very large with many claims)
- X-User-Groups: Comma-separated user groups (if configured)
- X-User-Roles: Comma-separated user roles (if configured)
Headers when enabled:
- X-Forwarded-User: User's email address (always set)
- X-User-Groups: Comma-separated user groups (if configured, still forwarded)
- X-User-Roles: Comma-separated user roles (if configured, still forwarded)
- Custom templated headers (still processed)
Use this option when:
- Downstream services return "431 Request Header Fields Too Large" errors
- Your ID tokens are large (many claims, long group lists)
- You don't need the full ID token forwarded to backend services
- You want to reduce request overhead
Default: false (all headers forwarded for backward compatibility)
See: https://github.com/lukaszraczylo/traefikoidc/issues/64
required: false
headers:
type: array
description: |
@@ -614,29 +1037,23 @@ configuration:
IMPORTANT: Template Escaping
If you encounter the error "can't evaluate field AccessToken in type bool" when
starting Traefik, this means Traefik is trying to evaluate the template expressions
before passing them to the plugin. To fix this, you need to escape the templates
using one of these methods:
before passing them to the plugin.
1. Use YAML literal style (recommended):
headers:
- name: "Authorization"
value: |
Bearer {{.AccessToken}}
SOLUTION: You must escape the template expressions using double curly braces:
2. Use single quotes:
headers:
- name: "Authorization"
value: 'Bearer {{.AccessToken}}'
headers:
- name: "Authorization"
value: "Bearer {{{{.AccessToken}}}}"
3. For inline double quotes, escape the braces:
headers:
- name: "Authorization"
value: "Bearer {{"{{.AccessToken}}"}}"
This is the only reliable method that works consistently. Here's why:
- The YAML parser converts {{{{ → {{ and }}}} → }}
- Result: Bearer {{.AccessToken}} reaches the Go template engine correctly
- Other methods (YAML literal style, single quotes) do NOT work reliably
Examples:
- name: "X-User-Email", value: "{{.Claims.email}}"
- name: "Authorization", value: "Bearer {{.AccessToken}}"
- name: "X-User-Roles", value: "{{range $i, $e := .Claims.roles}}{{if $i}},{{end}}{{$e}}{{end}}"
- name: "X-User-Email", value: "{{{{.Claims.email}}}}"
- name: "Authorization", value: "Bearer {{{{.AccessToken}}}}"
- name: "X-User-Roles", value: "{{{{range $i, $e := .Claims.roles}}}}{{{{if $i}}}},{{{{end}}}}{{{{$e}}}}{{{{end}}}}"
required: false
items:
type: object
@@ -899,3 +1316,317 @@ configuration:
Remove the X-Powered-By header to hide technology stack information.
Default: true
required: false
permissionsPolicy:
type: string
description: |
Permissions-Policy header to control browser feature permissions.
This header allows you to control which features and APIs can be used.
Examples:
- "geolocation=(), camera=(), microphone=()" (deny all)
- "geolocation=(self), camera=()" (allow geolocation for same origin only)
Common directives: accelerometer, camera, geolocation, gyroscope,
magnetometer, microphone, payment, usb
required: false
crossOriginEmbedderPolicy:
type: string
description: |
Cross-Origin-Embedder-Policy (COEP) header to prevent untrusted
resources from being loaded.
Options:
- "require-corp": Resources must explicitly grant permission
- "credentialless": Load without credentials for cross-origin resources
- "unsafe-none": No restrictions (default)
Required for certain browser features like SharedArrayBuffer.
required: false
crossOriginOpenerPolicy:
type: string
description: |
Cross-Origin-Opener-Policy (COOP) header to isolate browsing context
from cross-origin windows.
Options:
- "same-origin": Isolate from cross-origin documents
- "same-origin-allow-popups": Allow popups that don't set COOP
- "unsafe-none": No isolation (default)
Helps prevent cross-origin attacks and Spectre-like vulnerabilities.
required: false
crossOriginResourcePolicy:
type: string
description: |
Cross-Origin-Resource-Policy (CORP) header to control which origins
can load this resource.
Options:
- "same-origin": Only same-origin requests can load the resource
- "same-site": Only same-site requests can load the resource
- "cross-origin": Any origin can load the resource (default)
Prevents your resources from being embedded on other sites.
required: false
redis:
type: object
description: |
Optional Redis cache configuration for multi-replica deployments.
When running multiple Traefik instances, Redis provides shared caching to:
- Prevent JTI replay detection false positives across replicas
- Share token verification results between instances
- Maintain consistent session state across the cluster
- Improve performance by reducing redundant OIDC provider calls
Features:
- Automatic failover to memory-only mode when Redis is unavailable
- Circuit breaker pattern for resilience against Redis failures
- Health checking with automatic recovery
- Multiple cache modes: redis-only, hybrid (L1 memory + L2 Redis), memory-only
- Configurable timeouts and connection pooling
- TLS support for secure Redis connections
The middleware gracefully handles Redis failures by falling back to in-memory
caching, ensuring your authentication flow continues even during Redis outages.
Example configuration:
```yaml
redis:
enabled: true
address: "redis:6379"
cacheMode: "hybrid"
enableCircuitBreaker: true
```
required: false
properties:
enabled:
type: boolean
description: |
Enable Redis caching for distributed session and token management.
When enabled, the middleware will attempt to connect to Redis and use it
for shared state across multiple Traefik instances.
Default: false
required: false
address:
type: string
description: |
Redis server address in host:port format.
Examples:
- "redis:6379" (Docker/Kubernetes service)
- "localhost:6379" (local Redis)
- "redis.example.com:6380" (custom host/port)
- "redis-cluster.default.svc.cluster.local:6379" (Kubernetes)
Required when Redis is enabled.
required: false
password:
type: string
description: |
Password for Redis authentication.
Leave empty if Redis doesn't require authentication.
For Kubernetes deployments, you can use secret references:
urn:k8s:secret:namespace:secret-name:key
Default: "" (no authentication)
required: false
db:
type: integer
description: |
Redis database number to use (0-15).
Different databases can be used to isolate data between environments.
Default: 0
required: false
keyPrefix:
type: string
description: |
Prefix for all Redis keys created by this middleware.
Useful for:
- Avoiding key collisions with other applications
- Identifying keys for monitoring/debugging
- Supporting multiple environments in the same Redis instance
Default: "traefikoidc:"
required: false
cacheMode:
type: string
description: |
Determines the caching strategy:
- "redis": Redis-only caching. All cache operations go directly to Redis.
Best for: Consistent state across all replicas, minimal memory usage.
- "hybrid": Two-tier caching with in-memory L1 and Redis L2.
Best for: High performance with shared state, reduced Redis load.
L1 provides fast local cache, L2 provides shared state.
- "memory": Memory-only caching (Redis disabled even if configured).
Best for: Single instance deployments, development/testing.
Default: "redis" (when Redis is enabled)
required: false
enum:
- redis
- hybrid
- memory
poolSize:
type: integer
description: |
Maximum number of socket connections to Redis.
Higher values allow more concurrent operations but consume more resources.
Recommendations:
- Small deployments: 10-20
- Medium deployments: 20-50
- Large deployments: 50-100
Default: 10
required: false
connectTimeout:
type: integer
description: |
Timeout in seconds for establishing new connections to Redis.
Should be higher than network latency but low enough to fail fast.
Default: 5 seconds
required: false
readTimeout:
type: integer
description: |
Timeout in seconds for Redis read operations.
Includes the time to send the command, wait for Redis to process it,
and receive the response.
Default: 3 seconds
required: false
writeTimeout:
type: integer
description: |
Timeout in seconds for Redis write operations.
Should account for network latency and Redis persistence settings.
Default: 3 seconds
required: false
enableTLS:
type: boolean
description: |
Enable TLS encryption for Redis connections.
Required when connecting to Redis instances that enforce TLS,
such as AWS ElastiCache with encryption in transit.
Default: false
required: false
tlsSkipVerify:
type: boolean
description: |
Skip TLS certificate verification for Redis connections.
⚠️ WARNING: Only use in development environments.
This option bypasses certificate validation and should never be used
in production as it's vulnerable to man-in-the-middle attacks.
Default: false
required: false
hybridL1Size:
type: integer
description: |
Maximum number of items in the L1 (in-memory) cache for hybrid mode.
Controls how many cache entries are kept in local memory before eviction.
Only applies when cacheMode is "hybrid".
Default: 500
required: false
hybridL1MemoryMB:
type: integer
description: |
Maximum memory in megabytes for L1 cache in hybrid mode.
The cache will start evicting items when this limit is approached.
Only applies when cacheMode is "hybrid".
Default: 10 MB
required: false
enableCircuitBreaker:
type: boolean
description: |
Enable circuit breaker pattern for Redis connection failures.
When enabled, the middleware will:
1. Track Redis operation failures
2. Open the circuit after threshold failures (stop trying Redis)
3. Fall back to in-memory caching
4. Periodically attempt to reconnect (half-open state)
5. Resume Redis operations when connection recovers
This prevents cascading failures and improves resilience.
Default: true
required: false
circuitBreakerThreshold:
type: integer
description: |
Number of consecutive Redis failures before opening the circuit.
Lower values make the system more sensitive to Redis issues,
higher values tolerate more failures before switching to fallback.
Default: 5
required: false
circuitBreakerTimeout:
type: integer
description: |
Time in seconds to wait before attempting to close the circuit.
After this timeout, the circuit breaker will allow one test request
to Redis. If successful, normal operations resume.
Default: 60 seconds
required: false
enableHealthCheck:
type: boolean
description: |
Enable periodic health checks for Redis connection.
Health checks:
- Run in the background at regular intervals
- Detect Redis availability without affecting request processing
- Automatically reconnect when Redis becomes available
- Update circuit breaker state based on health status
Default: true
required: false
healthCheckInterval:
type: integer
description: |
Interval in seconds between Redis health checks.
Lower values detect issues faster but increase Redis load.
Higher values reduce overhead but delay failure detection.
Default: 30 seconds
required: false
+631 -84
View File
@@ -8,6 +8,8 @@ The Traefik OIDC middleware provides a complete OIDC authentication solution wit
- **Universal provider support**: Works with 9+ OIDC providers including Google, Azure AD, Auth0, Okta, Keycloak, AWS Cognito, GitLab, and more
- **Automatic provider detection**: Automatically detects and configures provider-specific settings
- **Dynamic Client Registration (RFC 7591)**: Automatic client registration with OIDC providers without manual pre-registration
- **Automatic scope filtering**: Intelligently filters OAuth scopes based on provider capabilities declared in OIDC discovery documents, preventing authentication failures with unsupported scopes
- **Security headers**: Comprehensive security headers with CORS, CSP, HSTS, and custom profiles
- **Domain restrictions**: Limit access to specific email domains or individual users
- **Role-based access control**: Restrict access based on roles and groups from OIDC claims
@@ -75,11 +77,24 @@ experimental:
plugins:
traefikoidc:
moduleName: github.com/lukaszraczylo/traefikoidc
version: v0.2.1 # Use the latest version
version: v0.7.10 # Use the latest version
```
2. Configure the middleware in your dynamic configuration (see examples below).
### Verifying Release Signatures
All release checksums are signed with [cosign](https://github.com/sigstore/cosign) using keyless signing. To verify:
```bash
# Download the checksum file and its sigstore bundle from the release
cosign verify-blob \
--certificate-identity-regexp "https://github.com/lukaszraczylo/traefikoidc/.*" \
--certificate-oidc-issuer "https://token.actions.githubusercontent.com" \
--bundle "traefikoidc_v<version>_checksums.txt.sigstore.json" \
traefikoidc_v<version>_checksums.txt
```
### Local Development with Docker Compose
For local development or testing, you can use the provided Docker Compose setup:
@@ -114,19 +129,46 @@ The middleware supports the following configuration options:
| `scopes` | OAuth 2.0 scopes to use for authentication | `["openid", "profile", "email"]` (always included by default) | `["roles", "custom_scope"]` (appended to defaults) |
| `overrideScopes` | When true, replaces default scopes with provided scopes instead of appending | `false` | `true` (use only the scopes explicitly provided) |
| `logLevel` | Sets the logging verbosity | `info` | `debug`, `info`, `error` |
| `forceHTTPS` | Forces the use of HTTPS for all URLs | `true` | `true`, `false` |
| `forceHTTPS` | Forces HTTPS scheme for redirect URIs (**REQUIRED** for TLS termination at load balancer like AWS ALB) | `false` (when not specified) | `true`, `false` |
| `rateLimit` | Sets the maximum number of requests per second | `100` | `500` |
| `excludedURLs` | Lists paths that bypass authentication | none | `["/health", "/metrics", "/public"]` |
| `allowedUserDomains` | Restricts access to specific email domains | none | `["company.com", "subsidiary.com"]` |
| `allowedUsers` | A list of specific email addresses that are allowed access | none | `["user1@example.com", "user2@another.org"]` |
| `allowedRolesAndGroups` | Restricts access to users with specific roles or groups | none | `["admin", "developer"]` |
| `roleClaimName` | JWT claim name for extracting user roles (supports namespaced claims for Auth0) | `"roles"` | `"https://myapp.com/roles"`, `"user_roles"` |
| `groupClaimName` | JWT claim name for extracting user groups (supports namespaced claims for Auth0) | `"groups"` | `"https://myapp.com/groups"`, `"user_groups"` |
| `userIdentifierClaim` | JWT claim to use as user identifier (for users without email, e.g., Azure AD service accounts) | `"email"` | `"sub"`, `"oid"`, `"upn"`, `"preferred_username"` |
| `revocationURL` | The endpoint for revoking tokens | auto-discovered | `https://accounts.google.com/revoke` |
| `oidcEndSessionURL` | The provider's end session endpoint | auto-discovered | `https://accounts.google.com/logout` |
| `enablePKCE` | Enables PKCE (Proof Key for Code Exchange) for authorization code flow | `false` | `true`, `false` |
| `refreshGracePeriodSeconds` | Seconds before token expiry to attempt proactive refresh | `60` | `120` |
| `cookieDomain` | Explicit domain for session cookies (important for multi-subdomain setups) | auto-detected | `.example.com`, `app.example.com` |
| `cookiePrefix` | Custom prefix for session cookie names (for isolating multiple middleware instances) | `_oidc_raczylo_` | `_oidc_userauth_`, `_oidc_admin_` |
| `sessionMaxAge` | Maximum session age in seconds before requiring re-authentication | `86400` (24 hours) | `3600` (1 hour), `604800` (7 days) |
| `audience` | Custom audience for access token validation (for Auth0 custom APIs, etc.) | `clientID` | `https://my-api.example.com` |
| `strictAudienceValidation` | Reject sessions with access token audience mismatch (prevents token confusion attacks) | `false` | `true` |
| `allowOpaqueTokens` | Enable opaque (non-JWT) access token support via RFC 7662 introspection | `false` | `true` |
| `requireTokenIntrospection` | Require introspection for opaque tokens (force validation, no fallback) | `false` | `true` |
| `headers` | Custom HTTP headers with templates that can access OIDC claims and tokens | none | See "Templated Headers" section |
| `securityHeaders` | Configure security headers including CSP, HSTS, CORS, and custom headers | enabled with default profile | See "Security Headers Configuration" section |
| `disableReplayDetection` | Disable JTI-based replay attack detection for multi-replica deployments | `false` | `true` |
| `allowPrivateIPAddresses` | Allow private IP addresses in provider URLs (for internal networks with Keycloak, etc.) | `false` | `true` |
| `minimalHeaders` | Reduce forwarded headers to prevent "431 Request Header Fields Too Large" errors | `false` | `true` |
| `redis` | Redis cache configuration for distributed deployments | disabled | See "Redis Cache" section |
> **⚠️ IMPORTANT - TLS Termination at Load Balancer:**
>
> If you're running Traefik behind a load balancer (AWS ALB, Google Cloud Load Balancer, Azure Application Gateway, etc.) that terminates TLS:
> - **You MUST set `forceHTTPS: true`** in your configuration
> - Without this setting, redirect URIs will use `http://` instead of `https://`, causing OAuth callback failures
> - This is especially critical for AWS ALB which may overwrite the `X-Forwarded-Proto` header
>
> **Default behavior:**
> - When `forceHTTPS` is **not specified** in your config → defaults to `false` (Go zero value)
> - When `forceHTTPS: true` is explicitly set → always uses `https://` for redirect URIs
> - When `forceHTTPS: false` is explicitly set → scheme detection based on headers/TLS
>
> See [GitHub Issue #82](https://github.com/lukaszraczylo/traefikoidc/issues/82) for details.
## Scope Configuration
@@ -201,6 +243,103 @@ scopes: []
The default append behavior ensures essential OIDC scopes are always present, while the override mode gives you complete control over the exact scopes requested from the provider.
## Auth0 Audience Validation & Security
The middleware provides comprehensive support for Auth0 audience validation to prevent token confusion attacks. Auth0 can issue tokens in three different scenarios, each requiring specific configuration.
### Understanding Token Audiences
Per OAuth 2.0 and OIDC specifications:
- **ID Tokens**: MUST have `aud = client_id` (OIDC Core 1.0 spec)
- **Access Tokens**: Can have custom audiences (e.g., API identifiers)
Proper audience validation prevents **token confusion attacks** where a token intended for one API is used to access another API.
### Auth0 Scenarios
#### Scenario 1: Custom API Audience ✅ (RECOMMENDED)
**Configuration:**
```yaml
audience: "https://my-api.example.com" # Your API identifier from Auth0
strictAudienceValidation: true # Enforce strict validation
```
**Result**: Fully secure, OIDC compliant with proper access token audience validation.
#### Scenario 2: Default Audience ⚠️ (USE WITH CAUTION)
**Configuration:**
```yaml
# audience not specified (defaults to client_id)
strictAudienceValidation: true # Recommended: reject mismatched tokens
```
**Behavior**: Access tokens may not contain client_id in audience, triggering security warnings. Set `strictAudienceValidation: true` to reject such sessions.
#### Scenario 3: Opaque Access Tokens ✅ (SUPPORTED)
**Configuration:**
```yaml
allowOpaqueTokens: true # Enable opaque token support
requireTokenIntrospection: true # Require introspection (recommended)
```
**Result**: Secure with OAuth 2.0 Token Introspection (RFC 7662).
### Security Configuration Options
| Option | Purpose | Recommended Value |
|--------|---------|-------------------|
| `audience` | Expected audience for access tokens | Your API identifier or leave empty |
| `strictAudienceValidation` | Reject sessions with audience mismatch | `true` for production |
| `allowOpaqueTokens` | Accept non-JWT access tokens | `true` if provider issues opaque tokens |
| `requireTokenIntrospection` | Force introspection for opaque tokens | `true` when `allowOpaqueTokens=true` |
### Complete Auth0 Configuration Examples
**Production Configuration (Scenario 1):**
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-auth0-secure
spec:
plugin:
traefikoidc:
providerURL: https://your-auth0-domain.auth0.com
clientID: your-auth0-client-id
clientSecret: your-auth0-client-secret
sessionEncryptionKey: your-secure-encryption-key-min-32-chars
callbackURL: /oauth2/callback
audience: "https://my-api.example.com"
strictAudienceValidation: true
allowedRolesAndGroups:
- "https://your-app.com/roles:admin"
- editor
```
**Opaque Token Configuration (Scenario 3):**
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-auth0-opaque
spec:
plugin:
traefikoidc:
providerURL: https://your-auth0-domain.auth0.com
clientID: your-auth0-client-id
clientSecret: your-auth0-client-secret
sessionEncryptionKey: your-secure-encryption-key-min-32-chars
callbackURL: /oauth2/callback
allowOpaqueTokens: true
requireTokenIntrospection: true
strictAudienceValidation: true
```
For detailed Auth0 configuration including all three scenarios, troubleshooting, and security best practices, see **[AUTH0_AUDIENCE_GUIDE.md](docs/AUTH0_AUDIENCE_GUIDE.md)**.
## Security Headers Configuration
The middleware includes comprehensive security headers support to protect your applications against common web vulnerabilities. Security headers are applied to all authenticated responses.
@@ -319,6 +458,10 @@ securityHeaders:
| `customHeaders` | Additional custom headers | `{}` | `{"X-Custom": "value"}` |
| `disableServerHeader` | Remove Server header | `true` | `true`, `false` |
| `disablePoweredByHeader` | Remove X-Powered-By header | `true` | `true`, `false` |
| `permissionsPolicy` | Permissions-Policy header | `` | `"geolocation=(), camera=(), microphone=()"` |
| `crossOriginEmbedderPolicy` | Cross-Origin-Embedder-Policy header | `` | `"require-corp"`, `"credentialless"`, `"unsafe-none"` |
| `crossOriginOpenerPolicy` | Cross-Origin-Opener-Policy header | `` | `"same-origin"`, `"same-origin-allow-popups"`, `"unsafe-none"` |
| `crossOriginResourcePolicy` | Cross-Origin-Resource-Policy header | `` | `"same-origin"`, `"same-site"`, `"cross-origin"` |
### CORS Wildcard Support
@@ -390,6 +533,316 @@ securityHeaders:
corsAllowedOrigins: ["http://localhost:*"]
```
### Multi-Replica Deployment Configuration
When running multiple Traefik replicas with the OIDC plugin, you may encounter false positive replay detection errors. Each replica maintains its own in-memory JTI (JWT Token ID) cache, causing legitimate token reuse to be flagged as replay attacks.
**Problem**: When the same valid token hits different replicas:
- Request → Replica A → JTI added to Replica A's cache ✓
- Request → Replica B → JTI NOT in Replica B's cache ✓
- Request → Replica A → ❌ **FALSE POSITIVE**: "token replay detected"
**Solution 1 (Simple)**: Disable replay detection for distributed deployments:
```yaml
disableReplayDetection: true # Disable JTI replay detection for multi-replica setups
```
**Solution 2 (Recommended)**: Use Redis cache backend for shared state (see [Redis Cache](#redis-cache-optional) section)
**Security Note**: When `disableReplayDetection: true`:
- ✅ Token signatures still validated
- ✅ Expiration still checked
- ✅ All other claims still verified
- ❌ JTI replay check **skipped**
**Example Configuration**:
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-multi-replica
namespace: traefik
spec:
plugin:
traefikoidc:
providerURL: https://accounts.google.com
clientID: your-client-id
clientSecret: your-client-secret
sessionEncryptionKey: your-secure-encryption-key-min-32-chars
callbackURL: /oauth2/callback
disableReplayDetection: true # Required for multi-replica deployments without Redis
```
**Recommendation**: For single-instance deployments, leave this setting at `false` (default) to maintain replay attack protection. For multi-replica deployments, use the Redis cache backend for proper replay detection across all instances.
## Redis Cache (Optional)
The plugin supports optional Redis caching for multi-replica deployments. This solves issues with JTI replay detection and session management when running multiple Traefik instances behind a load balancer.
> **✨ Yaegi Compatible**: Redis support is implemented using a pure-Go RESP protocol client that works seamlessly with Traefik's Yaegi interpreter (no `unsafe` package). Full Redis functionality is available for both dynamic plugin loading and pre-compiled deployments.
### Why Use Redis Cache?
When running multiple Traefik replicas, each instance maintains its own in-memory cache for:
- JTI (JWT Token ID) replay detection
- Session data
- Token metadata
Without a shared cache, you may experience:
- False positive replay detection errors
- Session inconsistencies between replicas
- Users needing to re-authenticate when hitting different instances
### Basic Configuration
Redis is configured through Traefik's dynamic configuration (YAML, labels, etc.):
```yaml
# Enable Redis cache in your middleware configuration
redis:
enabled: true
address: "localhost:6379"
password: "your-password" # Optional
db: 0
keyPrefix: "traefikoidc:"
```
### Configuration Priority
The plugin uses the following priority for Redis configuration:
1. **Traefik Dynamic Configuration** (PRIMARY) - Configure via YAML files or Docker/Kubernetes labels
2. **Environment Variables** (FALLBACK) - Used only when not set in Traefik config
This approach allows you to manage all settings through Traefik's configuration system while maintaining backward compatibility with environment variables.
### Configuration Options
| Parameter | Description | Default | Example |
|-----------|-------------|---------|---------|
| `enabled` | Enable Redis caching | `false` | `true` |
| `address` | Redis server address | - | `redis:6379` |
| `password` | Redis password | - | `YOUR_PASSWORD` |
| `db` | Database number | `0` | `1` |
| `keyPrefix` | Key prefix for namespacing | `traefikoidc:` | `myapp:` |
| `cacheMode` | Cache mode: `redis`, `hybrid`, `memory` | `redis` | `hybrid` |
| `poolSize` | Connection pool size | `10` | `20` |
| `connectTimeout` | Connection timeout (seconds) | `5` | `10` |
| `readTimeout` | Read timeout (seconds) | `3` | `5` |
| `writeTimeout` | Write timeout (seconds) | `3` | `5` |
| `enableTLS` | Enable TLS | `false` | `true` |
| `tlsSkipVerify` | Skip TLS verification | `false` | `true` |
| `enableCircuitBreaker` | Circuit breaker for failures | `true` | `true` |
| `circuitBreakerThreshold` | Failures before circuit opens | `5` | `10` |
| `circuitBreakerTimeout` | Circuit reset timeout (seconds) | `60` | `30` |
| `enableHealthCheck` | Periodic health checks | `true` | `true` |
| `healthCheckInterval` | Health check interval (seconds) | `30` | `60` |
### Environment Variables (Fallback)
If not configured through Traefik, these environment variables can be used as fallback:
- `REDIS_ENABLED` - Enable Redis cache
- `REDIS_ADDRESS` - Redis server address
- `REDIS_PASSWORD` - Redis password
- `REDIS_DB` - Database number
- `REDIS_KEY_PREFIX` - Key prefix
- `REDIS_CACHE_MODE` - Cache mode
- `REDIS_POOL_SIZE` - Connection pool size
- `REDIS_CONNECT_TIMEOUT` - Connection timeout
- `REDIS_READ_TIMEOUT` - Read timeout
- `REDIS_WRITE_TIMEOUT` - Write timeout
- `REDIS_ENABLE_TLS` - Enable TLS
- `REDIS_TLS_SKIP_VERIFY` - Skip TLS verification
### Cache Modes
The plugin supports three cache modes:
- **memory** (default): In-memory cache only, suitable for single-instance deployments
- **redis**: Redis-only cache, all data stored in Redis
- **hybrid**: Two-tier caching with local memory cache + Redis backend for optimal performance
### Example Configurations
#### Docker Compose with Redis
```yaml
services:
redis:
image: redis:alpine
command: redis-server --requirepass yourpassword
traefik:
image: traefik:v3.2
# ... rest of your Traefik configuration
labels:
# Configure the OIDC middleware with Redis
- "traefik.http.middlewares.oidc.plugin.traefikoidc.clientID=your-client-id"
- "traefik.http.middlewares.oidc.plugin.traefikoidc.clientSecret=your-secret"
- "traefik.http.middlewares.oidc.plugin.traefikoidc.providerURL=https://auth.example.com"
- "traefik.http.middlewares.oidc.plugin.traefikoidc.callbackURL=/oauth2/callback"
- "traefik.http.middlewares.oidc.plugin.traefikoidc.sessionEncryptionKey=your-64-char-key"
# Redis configuration via labels
- "traefik.http.middlewares.oidc.plugin.traefikoidc.redis.enabled=true"
- "traefik.http.middlewares.oidc.plugin.traefikoidc.redis.address=redis:6379"
- "traefik.http.middlewares.oidc.plugin.traefikoidc.redis.password=yourpassword"
- "traefik.http.middlewares.oidc.plugin.traefikoidc.redis.cacheMode=hybrid"
```
#### Kubernetes with Redis
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-with-redis
spec:
plugin:
traefikoidc:
providerURL: https://accounts.google.com
clientID: your-client-id
clientSecret: your-client-secret
sessionEncryptionKey: your-encryption-key
callbackURL: /oauth2/callback
redis:
enabled: true
address: "redis-service.redis-namespace:6379"
password: "urn:k8s:secret:redis-secret:password"
db: 0
keyPrefix: "traefikoidc"
cacheMode: "hybrid"
```
### Advanced Redis Configuration
See [Redis Cache Documentation](docs/REDIS_CACHE.md) for:
- Detailed architecture overview
- High availability setup with Redis Sentinel
- Redis Cluster configuration
- Performance tuning guidelines
- Monitoring and observability
- Troubleshooting guide
- Migration from memory-only cache
## Dynamic Client Registration (RFC 7591)
The middleware supports **OIDC Dynamic Client Registration** (RFC 7591), allowing automatic client registration with OIDC providers without manual pre-registration. This is useful for:
- **Multi-tenant deployments**: Automatically register clients per tenant
- **Development environments**: Quick setup without manual OAuth app creation
- **Self-service integrations**: Allow applications to self-register
### How It Works
1. When enabled, the middleware discovers the `registration_endpoint` from the provider's `.well-known/openid-configuration`
2. If no `clientID` is configured, it automatically registers a new client with the provider
3. The registered `client_id` and `client_secret` are cached and optionally persisted to a file
4. Subsequent requests use the registered credentials
### Configuration
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-dynamic-registration
namespace: traefik
spec:
plugin:
traefikoidc:
providerURL: https://your-oidc-provider.com
# clientID and clientSecret are NOT required when using DCR
sessionEncryptionKey: your-secure-encryption-key-min-32-chars
callbackURL: /oauth2/callback
dynamicClientRegistration:
enabled: true
# Optional: Initial access token for protected registration endpoints
initialAccessToken: "your-initial-access-token"
# Optional: Override the registration endpoint (auto-discovered by default)
registrationEndpoint: "https://your-provider.com/register"
# Optional: Persist credentials to file for reuse across restarts
persistCredentials: true
credentialsFile: "/tmp/oidc-client-credentials.json"
# Client metadata for registration
clientMetadata:
redirect_uris:
- "https://your-app.com/oauth2/callback"
client_name: "My Application"
application_type: "web"
grant_types:
- "authorization_code"
- "refresh_token"
response_types:
- "code"
token_endpoint_auth_method: "client_secret_basic"
contacts:
- "admin@your-app.com"
```
### DCR Configuration Parameters
| Parameter | Description | Required | Default |
|-----------|-------------|----------|---------|
| `enabled` | Enable dynamic client registration | Yes | `false` |
| `initialAccessToken` | Bearer token for protected registration endpoints | No | - |
| `registrationEndpoint` | Override auto-discovered registration endpoint | No | From discovery |
| `persistCredentials` | Save registered credentials to file | No | `false` |
| `credentialsFile` | Path to store/load credentials | No | `/tmp/oidc-client-credentials.json` |
| `clientMetadata.redirect_uris` | **REQUIRED** - Redirect URIs for OAuth flow | Yes | - |
| `clientMetadata.client_name` | Human-readable client name | No | - |
| `clientMetadata.application_type` | `web` or `native` | No | `web` |
| `clientMetadata.grant_types` | OAuth grant types | No | `["authorization_code", "refresh_token"]` |
| `clientMetadata.response_types` | OAuth response types | No | `["code"]` |
| `clientMetadata.token_endpoint_auth_method` | Authentication method | No | `client_secret_basic` |
| `clientMetadata.contacts` | Contact email addresses | No | - |
| `clientMetadata.logo_uri` | URL to client logo | No | - |
| `clientMetadata.client_uri` | URL to client homepage | No | - |
| `clientMetadata.policy_uri` | URL to privacy policy | No | - |
| `clientMetadata.tos_uri` | URL to terms of service | No | - |
| `clientMetadata.scope` | Space-separated scopes | No | - |
### Provider Support
DCR support varies by provider:
| Provider | DCR Support | Notes |
|----------|-------------|-------|
| Keycloak | ✅ Full | Enable in realm settings |
| Auth0 | ✅ Full | Requires Management API token |
| Okta | ✅ Full | Enable Dynamic Client Registration |
| Azure AD | ⚠️ Limited | App Registration API instead |
| Google | ❌ No | Manual registration required |
| AWS Cognito | ❌ No | Manual registration required |
### Security Considerations
1. **HTTPS Required**: Registration endpoints must use HTTPS (except localhost for development)
2. **Initial Access Token**: Recommended for production to prevent unauthorized registrations
3. **Credential Persistence**: If enabled, ensure the credentials file has appropriate permissions (0600)
4. **Secret Expiration**: Monitor `client_secret_expires_at` and handle rotation if needed
### Example: Keycloak with DCR
```yaml
dynamicClientRegistration:
enabled: true
clientMetadata:
redirect_uris:
- "https://myapp.example.com/oauth2/callback"
client_name: "My App - Production"
application_type: "web"
grant_types:
- "authorization_code"
- "refresh_token"
```
## Usage Examples
### Basic Configuration
@@ -568,6 +1021,87 @@ spec:
**Important**: The `cookieDomain` parameter is crucial when running behind a reverse proxy or when your application serves multiple subdomains. Without it, cookies may be created with inconsistent domains, leading to authentication issues like "CSRF token missing in session" errors.
### With Multiple Middleware Instances (Session Isolation)
When running multiple middleware instances with different authorization requirements (e.g., one for general users and one for admins), you must use different `cookiePrefix` values to prevent session sharing between instances:
```yaml
# Middleware for general user authentication
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-userauth
namespace: traefik
spec:
plugin:
traefikoidc:
providerURL: https://auth.example.com
clientID: your-client-id
clientSecret: your-client-secret
sessionEncryptionKey: user-key-at-least-32-bytes-long
callbackURL: /oauth2/callback
cookiePrefix: "_oidc_userauth_" # Unique prefix for this instance
---
# Middleware for admin authentication with stricter requirements
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-adminauth
namespace: traefik
spec:
plugin:
traefikoidc:
providerURL: https://auth.example.com
clientID: your-client-id
clientSecret: your-client-secret
sessionEncryptionKey: admin-key-at-least-32-bytes-long # Different encryption key
callbackURL: /oauth2/admin/callback # Different callback URL
cookiePrefix: "_oidc_adminauth_" # Different prefix for isolation
allowedUsers: # Restricted to specific admin users
- admin@example.com
- superadmin@example.com
```
**Security Note**: When running multiple instances, ensure you use:
1. **Different `cookiePrefix`** values to prevent cookie name collisions
2. **Different `sessionEncryptionKey`** values for complete session isolation
3. **Different `callbackURL`** paths to avoid routing conflicts
This configuration prevents authorization bypass issues where a user authenticated via the general middleware could access admin-protected routes. See [issue #87](https://github.com/lukaszraczylo/traefikoidc/issues/87) for more details.
### With Extended Session Duration
For applications that users access infrequently (weekly or monthly), you can extend the session duration beyond the default 24 hours to reduce authentication friction:
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-long-session
namespace: traefik
spec:
plugin:
traefikoidc:
providerURL: https://auth.example.com
clientID: your-client-id
clientSecret: your-client-secret
sessionEncryptionKey: your-key-at-least-32-bytes-long
callbackURL: /oauth2/callback
sessionMaxAge: 604800 # 7 days (in seconds)
# Other common values:
# 259200 - 3 days
# 604800 - 7 days
# 1209600 - 14 days
# 2592000 - 30 days
```
**Security Note**: Longer session durations improve user experience but increase security risk. Consider your application's security requirements:
- **High-security apps**: Use shorter sessions (3600 = 1 hour)
- **Standard apps**: Default 24 hours balances security and UX
- **Low-frequency access apps**: Extend to 7-30 days for better UX
See [issue #91](https://github.com/lukaszraczylo/traefikoidc/issues/91) for more details.
### With Custom Logging and Rate Limiting
```yaml
@@ -723,6 +1257,45 @@ spec:
- "AppRoleName" # Application role names
```
### Azure AD Configuration (Users Without Email)
For Azure AD users without email addresses (service accounts, organizational accounts without mail attributes):
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-azure-no-email
namespace: traefik
spec:
plugin:
traefikoidc:
providerURL: https://login.microsoftonline.com/your-tenant-id/v2.0
clientID: your-azure-ad-client-id
clientSecret: your-azure-ad-client-secret
sessionEncryptionKey: your-secure-encryption-key-min-32-chars
callbackURL: /oauth2/callback
logoutURL: /oauth2/logout
# Use 'sub' instead of 'email' for user identification
userIdentifierClaim: sub # Can also use: "oid", "upn", "preferred_username"
overrideScopes: true # Optional: Don't request email scope if not needed
scopes:
- openid
- profile
- groups
# When using non-email identifiers, allowedUsers matches against the claim value
allowedUsers:
- "abc12345-6789-0abc-def0-123456789abc" # Azure AD user object ID
- "def67890-1234-5678-90ab-cdef12345678"
# NOTE: allowedUserDomains is ignored when userIdentifierClaim is not "email"
```
> **Note**: When `userIdentifierClaim` is set to a non-email claim (like `sub`, `oid`, or `upn`), the `allowedUserDomains` configuration is ignored since domain-based validation only applies to email addresses. Use `allowedUsers` with the actual claim values instead.
### Auth0 Configuration
```yaml
@@ -740,14 +1313,26 @@ spec:
sessionEncryptionKey: your-secure-encryption-key-min-32-chars
callbackURL: /oauth2/callback
logoutURL: /oauth2/logout
# Audience configuration for custom APIs
audience: "https://my-api.example.com" # Your API identifier from Auth0
strictAudienceValidation: true # Enforce proper audience validation
scopes:
- read:custom_data # Custom scopes as needed
# Custom claim names for Auth0 namespaced claims
roleClaimName: "https://your-app.com/roles" # Auth0 requires namespaced custom claims
groupClaimName: "https://your-app.com/groups" # Must match claims added in Auth0 Actions
allowedRolesAndGroups:
- "https://your-app.com/roles:admin" # Namespaced claims from Actions
- admin # Will match "admin" in https://your-app.com/roles claim
- editor
postLogoutRedirectURI: /logged-out-page # Must be in Auth0 Allowed Logout URLs
```
**Note**: For detailed Auth0 audience configuration including opaque tokens and all security scenarios, see [AUTH0_AUDIENCE_GUIDE.md](docs/AUTH0_AUDIENCE_GUIDE.md).
### Okta Configuration
```yaml
@@ -797,8 +1382,12 @@ spec:
- admin
- editor
# Ensure Keycloak client mappers add necessary claims to ID Token
# For internal Keycloak deployments with private IPs (e.g., Docker network):
# allowPrivateIPAddresses: true
```
> **Internal Network Deployment**: If your Keycloak runs on an internal network with private IP addresses (e.g., `192.168.x.x`, `10.x.x.x`, `172.16-31.x.x`) and you don't have DNS resolution available, set `allowPrivateIPAddresses: true` to allow the plugin to connect to your Keycloak instance. See [Issue #97](https://github.com/lukaszraczylo/traefikoidc/issues/97) for details.
### AWS Cognito Configuration
```yaml
@@ -920,7 +1509,7 @@ services:
image: traefik:v3.2.1
command:
- "--experimental.plugins.traefikoidc.modulename=github.com/lukaszraczylo/traefikoidc"
- "--experimental.plugins.traefikoidc.version=v0.2.1"
- "--experimental.plugins.traefikoidc.version=v0.7.10"
volumes:
- /var/run/docker.sock:/var/run/docker.sock
- ./traefik-config/traefik.yml:/etc/traefik/traefik.yml
@@ -1027,58 +1616,6 @@ http:
{{range $i, $e := .Claims.roles}}{{if $i}},{{end}}{{$e}}{{end}}
```
## Advanced Configuration
### Session Management
The middleware uses encrypted cookies to manage user sessions. The `sessionEncryptionKey` must be at least 32 bytes long and should be kept secret.
### PKCE Support
The middleware supports PKCE (Proof Key for Code Exchange), which is an extension to the authorization code flow to prevent authorization code interception attacks. When enabled via the `enablePKCE` option, the middleware will generate a code verifier for each authentication request and derive a code challenge from it. The code verifier is stored in the user's session and sent during the token exchange process.
PKCE is recommended when:
- Your OIDC provider supports it (most modern providers do)
- You need an additional layer of security for the authorization code flow
- You're concerned about potential authorization code interception attacks
Note that not all OIDC providers support PKCE, so check your provider's documentation before enabling this feature.
### Session Duration and Token Refresh
This middleware aims to provide long-lived user sessions, typically up to 24 hours, by utilizing OIDC refresh tokens.
**How it works:**
- When a user authenticates, the middleware requests an access token and, if available, a refresh token from the OIDC provider.
- The access token usually has a short lifespan (e.g., 1 hour).
- Before the access token expires (controlled by `refreshGracePeriodSeconds`), the middleware uses the refresh token to obtain a new access token from the provider without requiring the user to log in again.
- This process repeats, allowing the session to remain valid for as long as the refresh token is valid (often 24 hours or more, depending on the provider).
**Provider-Specific Considerations (e.g., Google):**
- Some providers, like Google, issue short-lived access tokens (e.g., 1 hour) and require specific configurations for long-term sessions.
- To enable session extension beyond the initial token expiry with Google and similar providers, the middleware automatically includes the `offline_access` scope in the authentication request. This scope is necessary to obtain a refresh token.
- For Google specifically, the middleware also adds the `prompt=consent` parameter to the initial authorization request. This ensures Google issues a refresh token, which is crucial for extending the session.
- If a refresh attempt fails (e.g., the refresh token is revoked or expired), the user will be required to re-authenticate. The middleware includes enhanced error handling and logging for these scenarios.
- Ensure your OIDC provider is configured to issue refresh tokens and allows their use for extending sessions. Check your provider's documentation for details on refresh token validity periods.
### Google OAuth Compatibility Fix
The middleware includes a specific fix for Google's OAuth implementation, which differs from the standard OIDC specification in how it handles refresh tokens:
- **Issue**: Google does not support the standard `offline_access` scope for requesting refresh tokens and instead requires special parameters.
- **Automatic Solution**: The middleware detects Google as the provider based on the issuer URL and:
- Uses `access_type=offline` query parameter instead of the `offline_access` scope
- Adds `prompt=consent` to ensure refresh tokens are consistently issued
- Properly handles token refresh with Google's implementation
You do not need any special configuration to use Google OAuth - just set `providerURL` to `https://accounts.google.com` and the middleware will automatically apply the proper parameters.
For detailed information on the Google OAuth fix, see the [dedicated documentation](docs/google-oauth-fix.md).
### Token Caching and Blacklisting
The middleware automatically caches validated tokens to improve performance and maintains a blacklist of revoked tokens.
### Templated Headers
The middleware supports setting custom HTTP headers with values templated from OIDC claims and tokens. This allows you to pass authentication information to downstream services in a flexible, customized format.
@@ -1151,12 +1688,39 @@ headers:
When a user is authenticated, the middleware sets the following headers for downstream services:
- `X-Forwarded-User`: The user's email address
- `X-Forwarded-User`: The user's email address (always set)
- `X-User-Groups`: Comma-separated list of user groups (if available)
- `X-User-Roles`: Comma-separated list of user roles (if available)
- `X-Auth-Request-Redirect`: The original request URI
- `X-Auth-Request-User`: The user's email address
- `X-Auth-Request-Token`: The user's access token
- `X-Auth-Request-Token`: The user's ID token (can be large)
#### Minimal Headers Mode
If your downstream services return **"431 Request Header Fields Too Large"** errors, you can enable minimal headers mode to reduce header overhead:
```yaml
http:
middlewares:
my-auth:
plugin:
traefikoidc:
minimalHeaders: true
# ... other config
```
When `minimalHeaders: true` is set:
- **Only forwards**: `X-Forwarded-User`
- **Skips**: `X-Auth-Request-Token` (the full ID token - often the largest header), `X-Auth-Request-User`, `X-Auth-Request-Redirect`
- **Still forwards**: `X-User-Groups` and `X-User-Roles` (if configured)
- **Still processes**: Custom templated headers
This is particularly useful when:
- Your ID tokens are large (many claims, long group lists)
- Downstream services have limited header buffer sizes (default 8KB in many servers)
- You don't need the full token forwarded to backend services
See [GitHub Issue #64](https://github.com/lukaszraczylo/traefikoidc/issues/64) for details.
### Security Headers
@@ -1280,32 +1844,6 @@ GitLab supports OIDC for both GitLab.com and self-hosted instances.
* **Scopes**: Use `user:email`, `read:user` for basic profile access
* **Detection**: Auto-detected from `github.com` in issuer URL
### Azure AD (Microsoft Entra ID)
Azure AD generally works well with standard OIDC configurations.
* **ID Token Claims**: Azure AD typically includes standard claims like `email`, `name`, `preferred_username`, and `oid` (Object ID) in the ID Token by default when `openid profile email` scopes are requested.
* **Group Claims**: To include group claims in the ID Token, you need to configure this in the Azure AD application registration:
* Go to your App Registration -> Token configuration -> Add groups claim.
* You can choose which types of groups (Security groups, Directory roles, All groups) to include.
* Be aware of the "overage" issue: If a user is a member of too many groups, Azure AD will send a link to fetch groups instead of embedding them. This plugin currently expects group claims to be directly in the ID token. For users with many groups, consider alternative role/permission management strategies.
* The claim name for groups is typically `groups`.
* **Optional Claims**: You can add other optional claims via the "Token configuration" section of your App Registration. Ensure these are configured for the ID token.
* **Endpoints**: The `providerURL` should be `https://login.microsoftonline.com/{your-tenant-id}/v2.0`. The plugin will auto-discover the necessary endpoints.
* **Optimization**: Ensure your application manifest in Azure AD is configured for the desired token version (v1.0 or v2.0). This plugin works with v2.0 endpoints.
### Google Workspace / Google Cloud Identity
Google's OIDC implementation is well-supported.
* **Optimal Configuration**: The plugin automatically handles Google-specific requirements, such as using `access_type=offline` and `prompt=consent` to ensure refresh tokens are issued for long-lived sessions. You do not need to add `offline_access` to scopes.
* **ID Token Claims**: Google includes standard claims like `email`, `sub`, `name`, `given_name`, `family_name`, `picture` in the ID Token by default with `openid profile email` scopes.
* **Hosted Domain (hd claim)**: If you are using Google Workspace and want to restrict access to users within your organization's domain, Google includes an `hd` (hosted domain) claim in the ID Token. You can use this with the `allowedUserDomains` setting or for custom header logic.
* **Best Practices**:
* Use the `providerURL`: `https://accounts.google.com`.
* Ensure your OAuth consent screen in Google Cloud Console is configured correctly and published. For production, it should be "External" and in "Production" status. "Testing" status limits refresh token lifetime.
* Refer to the [Google OAuth Compatibility Fix](#google-oauth-compatibility-fix) section for more details on how the plugin handles Google's specifics.
### Auth0
Auth0 is generally OIDC compliant and works well.
@@ -1410,6 +1948,15 @@ logLevel: debug
- No refresh tokens (re-authentication required on expiry)
- Use only for GitHub API access, not user authentication
15. **Environment variable names containing "API" cause plugin failure** ([Issue #98](https://github.com/lukaszraczylo/traefikoidc/issues/98)):
- When using environment variable syntax like `${OIDC_ENCRYPTION_SECRET_API}` in Traefik configuration, the plugin fails with "invalid handler type: \<nil\>" error
- This is a **Traefik-side issue**, not a plugin bug. Traefik uses reserved environment variables starting with `TRAEFIK_API_*` for its internal API configuration, and the "API" substring in user-defined variable names may interfere with Traefik's environment variable processing
- **Workaround**: Avoid using "API" as a substring in environment variable names. Use alternatives like:
- `${OIDC_ENCRYPTION_SECRET_SVC}` instead of `${OIDC_ENCRYPTION_SECRET_API}`
- `${OIDC_ENCRYPTION_SECRET_SERVICE}`
- `${OIDC_ENCRYPTION_SECRET_BACKEND}`
- Any name that doesn't contain the literal substring "API"
### Provider Warnings and Recommendations
The middleware includes built-in warnings for provider-specific limitations. Check your logs for important notices about:
-308
View File
@@ -1,308 +0,0 @@
# Test Execution Guide
This guide explains how to run tests efficiently with the new test categorization and optimization system.
## Quick Start
### Fast Development Testing (Default - Target: < 30 seconds)
```bash
# Run quick smoke tests only
go test ./...
# Or explicitly run in short mode
go test ./... -short
```
### Extended Testing (Target: 2-5 minutes)
```bash
# Enable extended tests with more iterations and concurrency
RUN_EXTENDED_TESTS=1 go test ./...
# Or use the flag equivalent (if using test runner that supports it)
go test ./... -extended
```
### Long-Running Performance Tests (Target: 5-15 minutes)
```bash
# Enable comprehensive performance and stress tests
RUN_LONG_TESTS=1 go test ./...
```
### Full Stress Testing (Target: 10-30 minutes)
```bash
# Enable all stress tests with maximum parameters
RUN_STRESS_TESTS=1 go test ./...
```
## Test Categories
### 1. Quick Tests (Default)
- **Purpose**: Fast feedback during development
- **Duration**: < 30 seconds total
- **Features**:
- Basic functionality verification
- Limited iterations (1-3)
- Small data sets
- Minimal concurrency
- Essential memory leak checks
**Configuration**:
- Max Iterations: 3
- Max Concurrency: 5
- Memory Threshold: 2.0 MB
- Cache Size: 50
- Timeout: 10 seconds
### 2. Extended Tests
- **Purpose**: Comprehensive testing before commits
- **Duration**: 2-5 minutes
- **Features**:
- Increased test coverage
- More iterations (5-10)
- Medium concurrency tests
- Enhanced memory leak detection
**Configuration**:
- Max Iterations: 10
- Max Concurrency: 20
- Memory Threshold: 10.0 MB
- Cache Size: 200
- Timeout: 30 seconds
### 3. Long Tests
- **Purpose**: Performance validation and stress testing
- **Duration**: 5-15 minutes
- **Features**:
- High iteration counts (50-100)
- High concurrency scenarios
- Large data sets
- Comprehensive memory testing
**Configuration**:
- Max Iterations: 100
- Max Concurrency: 50
- Memory Threshold: 50.0 MB
- Cache Size: 1000
- Timeout: 60 seconds
### 4. Stress Tests
- **Purpose**: Maximum load testing and edge case validation
- **Duration**: 10-30 minutes
- **Features**:
- Extreme iteration counts (100-500)
- Maximum concurrency (100+)
- Large memory allocations
- Edge case combinations
**Configuration**:
- Max Iterations: 500
- Max Concurrency: 100
- Memory Threshold: 100.0 MB
- Cache Size: 2000
- Timeout: 120 seconds
## Environment Variables
### Test Execution Control
```bash
# Enable specific test types
export RUN_EXTENDED_TESTS=1 # Enable extended tests
export RUN_LONG_TESTS=1 # Enable long-running tests
export RUN_STRESS_TESTS=1 # Enable stress tests
# Disable specific features
export DISABLE_LEAK_DETECTION=1 # Skip memory leak detection
```
### Parameter Customization
```bash
# Customize concurrency limits
export TEST_MAX_CONCURRENCY=10 # Override max concurrent operations
# Customize iteration limits
export TEST_MAX_ITERATIONS=50 # Override max test iterations
# Customize memory thresholds
export TEST_MEMORY_THRESHOLD_MB=25.5 # Override memory growth limit (in MB)
```
## Test-Specific Behavior
### Memory Leak Tests
- **Quick Mode**: 1-3 iterations, small data sets, strict memory limits
- **Extended Mode**: 5-10 iterations, medium data sets, relaxed limits
- **Long Mode**: 50-100 iterations, large data sets, performance focus
- **Stress Mode**: 100-500 iterations, maximum data sets, stress focus
### Concurrency Tests
- **Quick Mode**: 2-5 concurrent operations, basic race detection
- **Extended Mode**: 10-20 concurrent operations, moderate stress
- **Long Mode**: 20-50 concurrent operations, high contention
- **Stress Mode**: 50-100+ concurrent operations, maximum stress
### Cache Tests
- **Quick Mode**: Small caches (50 items), basic operations
- **Extended Mode**: Medium caches (200 items), varied operations
- **Long Mode**: Large caches (1000 items), performance testing
- **Stress Mode**: Very large caches (2000+ items), stress testing
## Integration with CI/CD
### GitHub Actions Example
```yaml
# Quick tests for every push/PR
- name: Quick Tests
run: go test ./... -short
# Extended tests for main branch
- name: Extended Tests
if: github.ref == 'refs/heads/main'
run: RUN_EXTENDED_TESTS=1 go test ./...
# Nightly comprehensive testing
- name: Nightly Stress Tests
if: github.event_name == 'schedule'
run: RUN_STRESS_TESTS=1 go test ./...
```
### Local Development Workflow
```bash
# During active development
go test ./... -short
# Before committing
RUN_EXTENDED_TESTS=1 go test ./...
# Before major releases
RUN_LONG_TESTS=1 go test ./...
# Performance validation
RUN_STRESS_TESTS=1 go test ./...
```
## Performance Optimization Features
### Dynamic Test Scaling
The test system automatically adjusts parameters based on:
- Test mode (quick/extended/long/stress)
- Available resources
- Environment variables
- Previous test performance
### Memory Management
- **Garbage Collection**: Forced GC between test iterations
- **Memory Monitoring**: Real-time memory growth tracking
- **Leak Detection**: Goroutine and memory leak prevention
- **Resource Cleanup**: Automatic cleanup of test resources
### Timeout Management
- **Adaptive Timeouts**: Timeouts scale with test complexity
- **Graceful Degradation**: Tests adapt to slower environments
- **Early Termination**: Failed tests terminate quickly
## Troubleshooting
### Tests Taking Too Long
```bash
# Check if running in extended mode accidentally
echo $RUN_EXTENDED_TESTS $RUN_LONG_TESTS
# Force quick mode
unset RUN_EXTENDED_TESTS RUN_LONG_TESTS RUN_STRESS_TESTS
go test ./... -short
```
### Memory Issues
```bash
# Reduce memory limits for constrained environments
export TEST_MEMORY_THRESHOLD_MB=5.0
export TEST_MAX_CONCURRENCY=2
go test ./...
```
### Concurrency Issues
```bash
# Reduce concurrency for slower systems
export TEST_MAX_CONCURRENCY=5
export TEST_MAX_ITERATIONS=10
go test ./...
```
### Skip Specific Test Types
```bash
# Skip memory leak detection if problematic
export DISABLE_LEAK_DETECTION=1
go test ./...
```
## Benchmarking
### Running Benchmarks
```bash
# Quick benchmarks
go test -bench=. -short
# Extended benchmarks
RUN_EXTENDED_TESTS=1 go test -bench=.
# Memory profiling
go test -bench=. -memprofile=mem.prof
go tool pprof mem.prof
```
### Benchmark Categories
- **Basic Operations**: Set/Get performance
- **Concurrency**: Multi-threaded performance
- **Memory**: Allocation and cleanup performance
- **Cache**: Eviction and cleanup performance
## Best Practices
### For Developers
1. Always run quick tests during development (`go test ./... -short`)
2. Run extended tests before committing (`RUN_EXTENDED_TESTS=1 go test ./...`)
3. Use appropriate test categories for your use case
4. Monitor test execution time and adjust if needed
### For CI/CD
1. Use quick tests for fast feedback on PRs
2. Use extended tests for main branch validation
3. Use long tests for release validation
4. Use stress tests for nightly/weekly validation
### For Performance Testing
1. Use consistent environment variables
2. Run tests multiple times for statistical significance
3. Monitor both execution time and resource usage
4. Use profiling tools for detailed analysis
## Examples
### Daily Development
```bash
# Fast tests while coding
go test ./... -short
# Before git commit
RUN_EXTENDED_TESTS=1 go test ./...
```
### Release Testing
```bash
# Comprehensive validation
RUN_LONG_TESTS=1 go test ./...
# Stress testing
RUN_STRESS_TESTS=1 go test ./...
```
### Custom Configuration
```bash
# Custom limits for specific environment
export TEST_MAX_CONCURRENCY=8
export TEST_MAX_ITERATIONS=25
export TEST_MEMORY_THRESHOLD_MB=15.0
RUN_EXTENDED_TESTS=1 go test ./...
```
This test system provides flexible, scalable test execution that adapts to your development workflow and infrastructure constraints while maintaining comprehensive test coverage.
+1518
View File
File diff suppressed because it is too large Load Diff
-360
View File
@@ -1,360 +0,0 @@
// Package auth provides authentication-related functionality for the OIDC middleware.
package auth
import (
"fmt"
"net"
"net/http"
"net/url"
"strings"
"github.com/google/uuid"
)
// AuthHandler provides core authentication functionality for OIDC flows
type AuthHandler struct {
logger Logger
enablePKCE bool
isGoogleProv func() bool
isAzureProv func() bool
clientID string
authURL string
issuerURL string
scopes []string
overrideScopes bool
}
// Logger interface for dependency injection
type Logger interface {
Debugf(format string, args ...interface{})
Errorf(format string, args ...interface{})
}
// NewAuthHandler creates a new AuthHandler instance
func NewAuthHandler(logger Logger, enablePKCE bool, isGoogleProv, isAzureProv func() bool,
clientID, authURL, issuerURL string, scopes []string, overrideScopes bool) *AuthHandler {
return &AuthHandler{
logger: logger,
enablePKCE: enablePKCE,
isGoogleProv: isGoogleProv,
isAzureProv: isAzureProv,
clientID: clientID,
authURL: authURL,
issuerURL: issuerURL,
scopes: scopes,
overrideScopes: overrideScopes,
}
}
// InitiateAuthentication initiates the OIDC authentication flow.
// It generates CSRF tokens, nonce, PKCE parameters (if enabled), clears the session,
// stores authentication state, and redirects the user to the OIDC provider.
func (h *AuthHandler) InitiateAuthentication(rw http.ResponseWriter, req *http.Request,
session SessionData, redirectURL string,
generateNonce, generateCodeVerifier, deriveCodeChallenge func() (string, error)) {
h.logger.Debugf("Initiating new OIDC authentication flow for request: %s", req.URL.RequestURI())
const maxRedirects = 5
redirectCount := session.GetRedirectCount()
if redirectCount >= maxRedirects {
h.logger.Errorf("Maximum redirect limit (%d) exceeded, possible redirect loop detected", maxRedirects)
session.ResetRedirectCount()
http.Error(rw, "Authentication failed: Too many redirects", http.StatusLoopDetected)
return
}
session.IncrementRedirectCount()
csrfToken := uuid.NewString()
nonce, err := generateNonce()
if err != nil {
h.logger.Errorf("Failed to generate nonce: %v", err)
http.Error(rw, "Failed to generate nonce", http.StatusInternalServerError)
return
}
// Generate PKCE code verifier and challenge if PKCE is enabled
var codeVerifier, codeChallenge string
if h.enablePKCE {
codeVerifier, err = generateCodeVerifier()
if err != nil {
h.logger.Errorf("Failed to generate code verifier: %v", err)
http.Error(rw, "Failed to generate code verifier", http.StatusInternalServerError)
return
}
codeChallenge, err = deriveCodeChallenge()
if err != nil {
h.logger.Errorf("Failed to generate code challenge: %v", err)
http.Error(rw, "Failed to generate code challenge", http.StatusInternalServerError)
return
}
h.logger.Debugf("PKCE enabled, generated code challenge")
}
session.SetAuthenticated(false)
session.SetEmail("")
session.SetAccessToken("")
session.SetRefreshToken("")
session.SetIDToken("")
session.SetNonce("")
session.SetCodeVerifier("")
session.SetCSRF(csrfToken)
session.SetNonce(nonce)
if h.enablePKCE {
session.SetCodeVerifier(codeVerifier)
}
session.SetIncomingPath(req.URL.RequestURI())
h.logger.Debugf("Storing incoming path: %s", req.URL.RequestURI())
session.MarkDirty()
if err := session.Save(req, rw); err != nil {
h.logger.Errorf("Failed to save session before redirecting to provider: %v", err)
http.Error(rw, "Failed to save session", http.StatusInternalServerError)
return
}
h.logger.Debugf("Session saved before redirect. CSRF: %s, Nonce: %s",
csrfToken, nonce)
authURL := h.BuildAuthURL(redirectURL, csrfToken, nonce, codeChallenge)
h.logger.Debugf("Redirecting user to OIDC provider: %s", authURL)
http.Redirect(rw, req, authURL, http.StatusFound)
}
// BuildAuthURL constructs the OIDC provider authorization URL.
// It builds the URL with all necessary parameters including client_id, scopes,
// PKCE parameters, and provider-specific parameters for Google and Azure.
func (h *AuthHandler) BuildAuthURL(redirectURL, state, nonce, codeChallenge string) string {
params := url.Values{}
params.Set("client_id", h.clientID)
params.Set("response_type", "code")
params.Set("redirect_uri", redirectURL)
params.Set("state", state)
params.Set("nonce", nonce)
if h.enablePKCE && codeChallenge != "" {
params.Set("code_challenge", codeChallenge)
params.Set("code_challenge_method", "S256")
}
scopes := make([]string, len(h.scopes))
copy(scopes, h.scopes)
if h.isGoogleProv() {
params.Set("access_type", "offline")
h.logger.Debugf("Google OIDC provider detected, added access_type=offline for refresh tokens")
params.Set("prompt", "consent")
h.logger.Debugf("Google OIDC provider detected, added prompt=consent to ensure refresh tokens")
} else if h.isAzureProv() {
params.Set("response_mode", "query")
h.logger.Debugf("Azure AD provider detected, added response_mode=query")
hasOfflineAccess := false
for _, scope := range scopes {
if scope == "offline_access" {
hasOfflineAccess = true
break
}
}
if !h.overrideScopes || (h.overrideScopes && len(h.scopes) == 0) {
if !hasOfflineAccess {
scopes = append(scopes, "offline_access")
h.logger.Debugf("Azure AD provider: Added offline_access scope (overrideScopes: %t, user scopes count: %d)", h.overrideScopes, len(h.scopes))
}
} else {
h.logger.Debugf("Azure AD provider: User is overriding scopes (count: %d), offline_access not automatically added.", len(h.scopes))
}
} else {
if !h.overrideScopes || (h.overrideScopes && len(h.scopes) == 0) {
hasOfflineAccess := false
for _, scope := range scopes {
if scope == "offline_access" {
hasOfflineAccess = true
break
}
}
if !hasOfflineAccess {
scopes = append(scopes, "offline_access")
h.logger.Debugf("Standard provider: Added offline_access scope (overrideScopes: %t, user scopes count: %d)", h.overrideScopes, len(h.scopes))
}
} else {
h.logger.Debugf("Standard provider: User is overriding scopes (count: %d), offline_access not automatically added.", len(h.scopes))
}
}
if len(scopes) > 0 {
finalScopeString := strings.Join(scopes, " ")
params.Set("scope", finalScopeString)
h.logger.Debugf("AuthHandler.BuildAuthURL: Final scope string being sent to OIDC provider: %s", finalScopeString)
}
return h.buildURLWithParams(h.authURL, params)
}
// buildURLWithParams constructs a URL by combining a base URL with query parameters.
// It handles both relative and absolute URLs, validates URL security,
// and properly encodes query parameters.
func (h *AuthHandler) buildURLWithParams(baseURL string, params url.Values) string {
if baseURL != "" {
if strings.HasPrefix(baseURL, "http://") || strings.HasPrefix(baseURL, "https://") {
if err := h.validateURL(baseURL); err != nil {
h.logger.Errorf("URL validation failed for %s: %v", baseURL, err)
return ""
}
}
}
if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
issuerURLParsed, err := url.Parse(h.issuerURL)
if err != nil {
h.logger.Errorf("Could not parse issuerURL: %s. Error: %v", h.issuerURL, err)
return ""
}
baseURLParsed, err := url.Parse(baseURL)
if err != nil {
h.logger.Errorf("Could not parse baseURL: %s. Error: %v", baseURL, err)
return ""
}
resolvedURL := issuerURLParsed.ResolveReference(baseURLParsed)
if err := h.validateURL(resolvedURL.String()); err != nil {
h.logger.Errorf("Resolved URL validation failed for %s: %v", resolvedURL.String(), err)
return ""
}
resolvedURL.RawQuery = params.Encode()
return resolvedURL.String()
}
u, err := url.Parse(baseURL)
if err != nil {
h.logger.Errorf("Could not parse absolute baseURL: %s. Error: %v", baseURL, err)
return ""
}
if err := h.validateParsedURL(u); err != nil {
h.logger.Errorf("Parsed URL validation failed for %s: %v", baseURL, err)
return ""
}
u.RawQuery = params.Encode()
return u.String()
}
// validateURL performs security validation on URLs to prevent SSRF attacks.
// It checks for allowed schemes, validates hosts, and prevents access to private networks.
func (h *AuthHandler) validateURL(urlStr string) error {
if urlStr == "" {
return fmt.Errorf("empty URL")
}
u, err := url.Parse(urlStr)
if err != nil {
return fmt.Errorf("invalid URL format: %w", err)
}
return h.validateParsedURL(u)
}
// validateParsedURL validates a parsed URL structure for security.
// It checks schemes, hosts, and paths to prevent malicious URLs.
func (h *AuthHandler) validateParsedURL(u *url.URL) error {
allowedSchemes := map[string]bool{
"https": true,
"http": true,
}
if !allowedSchemes[u.Scheme] {
return fmt.Errorf("disallowed URL scheme: %s", u.Scheme)
}
if u.Scheme == "http" {
h.logger.Debugf("Warning: Using HTTP scheme for URL: %s", u.String())
}
if u.Host == "" {
return fmt.Errorf("missing host in URL")
}
if err := h.validateHost(u.Host); err != nil {
return fmt.Errorf("invalid host: %w", err)
}
if strings.Contains(u.Path, "..") {
return fmt.Errorf("path traversal detected in URL path")
}
return nil
}
// validateHost validates a hostname for security and reachability.
// It prevents access to private networks and localhost addresses.
func (h *AuthHandler) validateHost(host string) error {
if host == "" {
return fmt.Errorf("empty host")
}
// Strip port if present
if strings.Contains(host, ":") {
var err error
host, _, err = net.SplitHostPort(host)
if err != nil {
return fmt.Errorf("invalid host:port format: %w", err)
}
}
// Check for localhost variations
localhostVariations := []string{
"localhost", "127.0.0.1", "::1", "0.0.0.0",
}
for _, localhost := range localhostVariations {
if strings.EqualFold(host, localhost) {
return fmt.Errorf("localhost access not allowed: %s", host)
}
}
// Try to parse as IP address
if ip := net.ParseIP(host); ip != nil {
if ip.IsLoopback() {
return fmt.Errorf("loopback IP not allowed: %s", host)
}
if ip.IsPrivate() {
return fmt.Errorf("private IP not allowed: %s", host)
}
if ip.IsLinkLocalUnicast() {
return fmt.Errorf("link-local IP not allowed: %s", host)
}
if ip.IsMulticast() {
return fmt.Errorf("multicast IP not allowed: %s", host)
}
}
return nil
}
// SessionData interface for dependency injection
type SessionData interface {
GetRedirectCount() int
ResetRedirectCount()
IncrementRedirectCount()
SetAuthenticated(bool)
SetEmail(string)
SetAccessToken(string)
SetRefreshToken(string)
SetIDToken(string)
SetNonce(string)
SetCodeVerifier(string)
SetCSRF(string)
SetIncomingPath(string)
MarkDirty()
Save(req *http.Request, rw http.ResponseWriter) error
}
-599
View File
@@ -1,599 +0,0 @@
package auth
import (
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
)
// Test mocks
type mockLogger struct {
debugMessages []string
errorMessages []string
}
func (l *mockLogger) Debugf(format string, args ...interface{}) {
l.debugMessages = append(l.debugMessages, format)
}
func (l *mockLogger) Errorf(format string, args ...interface{}) {
l.errorMessages = append(l.errorMessages, format)
}
type mockSessionData struct {
authenticated bool
email string
accessToken string
refreshToken string
idToken string
csrf string
nonce string
codeVerifier string
incomingPath string
redirectCount int
saveError error
dirty bool
}
func (s *mockSessionData) GetRedirectCount() int { return s.redirectCount }
func (s *mockSessionData) ResetRedirectCount() { s.redirectCount = 0 }
func (s *mockSessionData) IncrementRedirectCount() { s.redirectCount++ }
func (s *mockSessionData) SetAuthenticated(auth bool) { s.authenticated = auth }
func (s *mockSessionData) SetEmail(email string) { s.email = email }
func (s *mockSessionData) SetAccessToken(token string) { s.accessToken = token }
func (s *mockSessionData) SetRefreshToken(token string) { s.refreshToken = token }
func (s *mockSessionData) SetIDToken(token string) { s.idToken = token }
func (s *mockSessionData) SetNonce(nonce string) { s.nonce = nonce }
func (s *mockSessionData) SetCodeVerifier(verifier string) { s.codeVerifier = verifier }
func (s *mockSessionData) SetCSRF(csrf string) { s.csrf = csrf }
func (s *mockSessionData) SetIncomingPath(path string) { s.incomingPath = path }
func (s *mockSessionData) MarkDirty() { s.dirty = true }
func (s *mockSessionData) Save(req *http.Request, rw http.ResponseWriter) error {
return s.saveError
}
// TestAuthHandler_NewAuthHandler tests the constructor
func TestAuthHandler_NewAuthHandler(t *testing.T) {
logger := &mockLogger{}
isGoogleProv := func() bool { return false }
isAzureProv := func() bool { return true }
scopes := []string{"openid", "profile", "email"}
handler := NewAuthHandler(logger, true, isGoogleProv, isAzureProv,
"test-client-id", "https://example.com/auth", "https://example.com",
scopes, false)
if handler == nil {
t.Fatal("Expected handler to be created, got nil")
}
if handler.logger != logger {
t.Error("Logger not set correctly")
}
if !handler.enablePKCE {
t.Error("PKCE should be enabled")
}
if handler.clientID != "test-client-id" {
t.Errorf("Expected clientID 'test-client-id', got '%s'", handler.clientID)
}
if handler.authURL != "https://example.com/auth" {
t.Errorf("Expected authURL 'https://example.com/auth', got '%s'", handler.authURL)
}
if handler.issuerURL != "https://example.com" {
t.Errorf("Expected issuerURL 'https://example.com', got '%s'", handler.issuerURL)
}
if len(handler.scopes) != 3 {
t.Errorf("Expected 3 scopes, got %d", len(handler.scopes))
}
if handler.overrideScopes {
t.Error("overrideScopes should be false")
}
}
// TestAuthHandler_InitiateAuthentication_MaxRedirects tests redirect limit enforcement
func TestAuthHandler_InitiateAuthentication_MaxRedirects(t *testing.T) {
logger := &mockLogger{}
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
"test-client", "https://example.com/auth", "https://example.com", []string{}, false)
session := &mockSessionData{redirectCount: 5} // At the limit
req := httptest.NewRequest("GET", "/test", nil)
rw := httptest.NewRecorder()
generateNonce := func() (string, error) { return "test-nonce", nil }
generateCodeVerifier := func() (string, error) { return "", nil }
deriveCodeChallenge := func() (string, error) { return "", nil }
handler.InitiateAuthentication(rw, req, session, "https://example.com/callback",
generateNonce, generateCodeVerifier, deriveCodeChallenge)
if rw.Code != http.StatusLoopDetected {
t.Errorf("Expected status %d, got %d", http.StatusLoopDetected, rw.Code)
}
body := rw.Body.String()
if !strings.Contains(body, "Too many redirects") {
t.Errorf("Expected 'Too many redirects' in response body, got '%s'", body)
}
if session.redirectCount != 0 {
t.Errorf("Expected redirect count to be reset, got %d", session.redirectCount)
}
if len(logger.errorMessages) == 0 {
t.Error("Expected error to be logged")
}
}
// TestAuthHandler_InitiateAuthentication_NonceGenerationError tests nonce generation failure
func TestAuthHandler_InitiateAuthentication_NonceGenerationError(t *testing.T) {
logger := &mockLogger{}
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
"test-client", "https://example.com/auth", "https://example.com", []string{}, false)
session := &mockSessionData{}
req := httptest.NewRequest("GET", "/test", nil)
rw := httptest.NewRecorder()
generateNonce := func() (string, error) { return "", &testError{"nonce generation failed"} }
generateCodeVerifier := func() (string, error) { return "", nil }
deriveCodeChallenge := func() (string, error) { return "", nil }
handler.InitiateAuthentication(rw, req, session, "https://example.com/callback",
generateNonce, generateCodeVerifier, deriveCodeChallenge)
if rw.Code != http.StatusInternalServerError {
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, rw.Code)
}
body := rw.Body.String()
if !strings.Contains(body, "Failed to generate nonce") {
t.Errorf("Expected 'Failed to generate nonce' in response body, got '%s'", body)
}
if len(logger.errorMessages) == 0 {
t.Error("Expected error to be logged")
}
}
// TestAuthHandler_InitiateAuthentication_PKCECodeVerifierError tests PKCE code verifier generation failure
func TestAuthHandler_InitiateAuthentication_PKCECodeVerifierError(t *testing.T) {
logger := &mockLogger{}
handler := NewAuthHandler(logger, true, func() bool { return false }, func() bool { return false },
"test-client", "https://example.com/auth", "https://example.com", []string{}, false)
session := &mockSessionData{}
req := httptest.NewRequest("GET", "/test", nil)
rw := httptest.NewRecorder()
generateNonce := func() (string, error) { return "test-nonce", nil }
generateCodeVerifier := func() (string, error) { return "", &testError{"code verifier generation failed"} }
deriveCodeChallenge := func() (string, error) { return "", nil }
handler.InitiateAuthentication(rw, req, session, "https://example.com/callback",
generateNonce, generateCodeVerifier, deriveCodeChallenge)
if rw.Code != http.StatusInternalServerError {
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, rw.Code)
}
body := rw.Body.String()
if !strings.Contains(body, "Failed to generate code verifier") {
t.Errorf("Expected 'Failed to generate code verifier' in response body, got '%s'", body)
}
if len(logger.errorMessages) == 0 {
t.Error("Expected error to be logged")
}
}
// TestAuthHandler_InitiateAuthentication_PKCECodeChallengeError tests PKCE code challenge derivation failure
func TestAuthHandler_InitiateAuthentication_PKCECodeChallengeError(t *testing.T) {
logger := &mockLogger{}
handler := NewAuthHandler(logger, true, func() bool { return false }, func() bool { return false },
"test-client", "https://example.com/auth", "https://example.com", []string{}, false)
session := &mockSessionData{}
req := httptest.NewRequest("GET", "/test", nil)
rw := httptest.NewRecorder()
generateNonce := func() (string, error) { return "test-nonce", nil }
generateCodeVerifier := func() (string, error) { return "test-verifier", nil }
deriveCodeChallenge := func() (string, error) { return "", &testError{"code challenge derivation failed"} }
handler.InitiateAuthentication(rw, req, session, "https://example.com/callback",
generateNonce, generateCodeVerifier, deriveCodeChallenge)
if rw.Code != http.StatusInternalServerError {
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, rw.Code)
}
body := rw.Body.String()
if !strings.Contains(body, "Failed to generate code challenge") {
t.Errorf("Expected 'Failed to generate code challenge' in response body, got '%s'", body)
}
if len(logger.errorMessages) == 0 {
t.Error("Expected error to be logged")
}
}
// TestAuthHandler_InitiateAuthentication_SessionSaveError tests session save failure
func TestAuthHandler_InitiateAuthentication_SessionSaveError(t *testing.T) {
logger := &mockLogger{}
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
"test-client", "https://example.com/auth", "https://example.com", []string{}, false)
session := &mockSessionData{saveError: &testError{"save failed"}}
req := httptest.NewRequest("GET", "/test?param=value", nil)
rw := httptest.NewRecorder()
generateNonce := func() (string, error) { return "test-nonce", nil }
generateCodeVerifier := func() (string, error) { return "", nil }
deriveCodeChallenge := func() (string, error) { return "", nil }
handler.InitiateAuthentication(rw, req, session, "https://example.com/callback",
generateNonce, generateCodeVerifier, deriveCodeChallenge)
if rw.Code != http.StatusInternalServerError {
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, rw.Code)
}
body := rw.Body.String()
if !strings.Contains(body, "Failed to save session") {
t.Errorf("Expected 'Failed to save session' in response body, got '%s'", body)
}
if len(logger.errorMessages) == 0 {
t.Error("Expected error to be logged")
}
// Verify session was prepared correctly before the save failure
if session.incomingPath != "/test?param=value" {
t.Errorf("Expected incoming path '/test?param=value', got '%s'", session.incomingPath)
}
if session.nonce != "test-nonce" {
t.Errorf("Expected nonce 'test-nonce', got '%s'", session.nonce)
}
if session.redirectCount != 1 {
t.Errorf("Expected redirect count 1, got %d", session.redirectCount)
}
}
// TestAuthHandler_InitiateAuthentication_Success tests successful authentication initiation
func TestAuthHandler_InitiateAuthentication_Success(t *testing.T) {
logger := &mockLogger{}
handler := NewAuthHandler(logger, true, func() bool { return false }, func() bool { return false },
"test-client", "https://example.com/auth", "https://example.com", []string{"openid", "email"}, false)
session := &mockSessionData{}
req := httptest.NewRequest("GET", "/protected/resource", nil)
rw := httptest.NewRecorder()
generateNonce := func() (string, error) { return "generated-nonce", nil }
generateCodeVerifier := func() (string, error) { return "generated-verifier", nil }
deriveCodeChallenge := func() (string, error) { return "generated-challenge", nil }
handler.InitiateAuthentication(rw, req, session, "https://example.com/callback",
generateNonce, generateCodeVerifier, deriveCodeChallenge)
// Should redirect
if rw.Code != http.StatusFound {
t.Errorf("Expected status %d, got %d", http.StatusFound, rw.Code)
}
location := rw.Header().Get("Location")
if location == "" {
t.Error("Expected Location header to be set")
}
// Parse the redirect URL to verify parameters
parsedURL, err := url.Parse(location)
if err != nil {
t.Fatalf("Failed to parse redirect URL: %v", err)
}
query := parsedURL.Query()
// Verify required parameters
if query.Get("client_id") != "test-client" {
t.Errorf("Expected client_id 'test-client', got '%s'", query.Get("client_id"))
}
if query.Get("response_type") != "code" {
t.Errorf("Expected response_type 'code', got '%s'", query.Get("response_type"))
}
if query.Get("redirect_uri") != "https://example.com/callback" {
t.Errorf("Expected redirect_uri 'https://example.com/callback', got '%s'", query.Get("redirect_uri"))
}
if query.Get("nonce") != "generated-nonce" {
t.Errorf("Expected nonce 'generated-nonce', got '%s'", query.Get("nonce"))
}
// Verify PKCE parameters
if query.Get("code_challenge") != "generated-challenge" {
t.Errorf("Expected code_challenge 'generated-challenge', got '%s'", query.Get("code_challenge"))
}
if query.Get("code_challenge_method") != "S256" {
t.Errorf("Expected code_challenge_method 'S256', got '%s'", query.Get("code_challenge_method"))
}
// Verify scope
scope := query.Get("scope")
if !strings.Contains(scope, "openid") || !strings.Contains(scope, "email") {
t.Errorf("Expected scope to contain 'openid' and 'email', got '%s'", scope)
}
// Verify session was updated correctly
if !session.dirty {
t.Error("Expected session to be marked dirty")
}
if session.incomingPath != "/protected/resource" {
t.Errorf("Expected incoming path '/protected/resource', got '%s'", session.incomingPath)
}
if session.nonce != "generated-nonce" {
t.Errorf("Expected session nonce 'generated-nonce', got '%s'", session.nonce)
}
if session.codeVerifier != "generated-verifier" {
t.Errorf("Expected session code verifier 'generated-verifier', got '%s'", session.codeVerifier)
}
// Verify session data was cleared
if session.authenticated {
t.Error("Expected session to not be authenticated")
}
if session.email != "" {
t.Errorf("Expected email to be cleared, got '%s'", session.email)
}
if session.accessToken != "" {
t.Errorf("Expected access token to be cleared, got '%s'", session.accessToken)
}
if session.idToken != "" {
t.Errorf("Expected ID token to be cleared, got '%s'", session.idToken)
}
}
// TestAuthHandler_BuildAuthURL_GoogleProvider tests Google-specific URL building
func TestAuthHandler_BuildAuthURL_GoogleProvider(t *testing.T) {
logger := &mockLogger{}
handler := NewAuthHandler(logger, false, func() bool { return true }, func() bool { return false },
"google-client", "https://accounts.google.com/oauth2/auth", "https://accounts.google.com",
[]string{"openid", "profile", "email"}, false)
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
parsedURL, err := url.Parse(authURL)
if err != nil {
t.Fatalf("Failed to parse auth URL: %v", err)
}
query := parsedURL.Query()
// Google-specific parameters
if query.Get("access_type") != "offline" {
t.Errorf("Expected access_type 'offline' for Google, got '%s'", query.Get("access_type"))
}
if query.Get("prompt") != "consent" {
t.Errorf("Expected prompt 'consent' for Google, got '%s'", query.Get("prompt"))
}
// Standard parameters should still be present
if query.Get("client_id") != "google-client" {
t.Errorf("Expected client_id 'google-client', got '%s'", query.Get("client_id"))
}
if query.Get("state") != "test-state" {
t.Errorf("Expected state 'test-state', got '%s'", query.Get("state"))
}
if query.Get("nonce") != "test-nonce" {
t.Errorf("Expected nonce 'test-nonce', got '%s'", query.Get("nonce"))
}
}
// TestAuthHandler_BuildAuthURL_AzureProvider tests Azure-specific URL building
func TestAuthHandler_BuildAuthURL_AzureProvider(t *testing.T) {
logger := &mockLogger{}
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return true },
"azure-client", "https://login.microsoftonline.com/tenant/oauth2/v2.0/authorize",
"https://login.microsoftonline.com/tenant/v2.0",
[]string{"openid", "profile", "email"}, false)
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
parsedURL, err := url.Parse(authURL)
if err != nil {
t.Fatalf("Failed to parse auth URL: %v", err)
}
query := parsedURL.Query()
// Azure-specific parameters
if query.Get("response_mode") != "query" {
t.Errorf("Expected response_mode 'query' for Azure, got '%s'", query.Get("response_mode"))
}
// Azure should add offline_access scope automatically
scope := query.Get("scope")
if !strings.Contains(scope, "offline_access") {
t.Errorf("Expected scope to contain 'offline_access' for Azure, got '%s'", scope)
}
}
// TestAuthHandler_BuildAuthURL_PKCEEnabled tests PKCE parameter inclusion
func TestAuthHandler_BuildAuthURL_PKCEEnabled(t *testing.T) {
logger := &mockLogger{}
handler := NewAuthHandler(logger, true, func() bool { return false }, func() bool { return false },
"pkce-client", "https://example.com/auth", "https://example.com",
[]string{"openid"}, false)
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "test-challenge")
parsedURL, err := url.Parse(authURL)
if err != nil {
t.Fatalf("Failed to parse auth URL: %v", err)
}
query := parsedURL.Query()
if query.Get("code_challenge") != "test-challenge" {
t.Errorf("Expected code_challenge 'test-challenge', got '%s'", query.Get("code_challenge"))
}
if query.Get("code_challenge_method") != "S256" {
t.Errorf("Expected code_challenge_method 'S256', got '%s'", query.Get("code_challenge_method"))
}
}
// TestAuthHandler_BuildAuthURL_PKCEDisabled tests when PKCE is disabled
func TestAuthHandler_BuildAuthURL_PKCEDisabled(t *testing.T) {
logger := &mockLogger{}
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
"no-pkce-client", "https://example.com/auth", "https://example.com",
[]string{"openid"}, false)
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "test-challenge")
parsedURL, err := url.Parse(authURL)
if err != nil {
t.Fatalf("Failed to parse auth URL: %v", err)
}
query := parsedURL.Query()
// PKCE parameters should not be included
if query.Get("code_challenge") != "" {
t.Errorf("Expected no code_challenge when PKCE disabled, got '%s'", query.Get("code_challenge"))
}
if query.Get("code_challenge_method") != "" {
t.Errorf("Expected no code_challenge_method when PKCE disabled, got '%s'", query.Get("code_challenge_method"))
}
}
// TestAuthHandler_BuildAuthURL_ScopeHandling tests various scope configurations
func TestAuthHandler_BuildAuthURL_ScopeHandling(t *testing.T) {
tests := []struct {
name string
scopes []string
overrideScopes bool
isAzure bool
expectedScopes []string
}{
{
name: "Basic scopes",
scopes: []string{"openid", "profile", "email"},
overrideScopes: false,
isAzure: false,
expectedScopes: []string{"openid", "profile", "email", "offline_access"},
},
{
name: "Azure with offline_access already present",
scopes: []string{"openid", "profile", "offline_access"},
overrideScopes: false,
isAzure: true,
expectedScopes: []string{"openid", "profile", "offline_access"},
},
{
name: "Azure auto-add offline_access",
scopes: []string{"openid", "profile"},
overrideScopes: false,
isAzure: true,
expectedScopes: []string{"openid", "profile", "offline_access"},
},
{
name: "Override scopes with empty array",
scopes: []string{},
overrideScopes: true,
isAzure: true,
expectedScopes: []string{"offline_access"},
},
{
name: "Override scopes prevents auto-add",
scopes: []string{"openid", "custom_scope"},
overrideScopes: true,
isAzure: true,
expectedScopes: []string{"openid", "custom_scope"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
logger := &mockLogger{}
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return tt.isAzure },
"test-client", "https://example.com/auth", "https://example.com",
tt.scopes, tt.overrideScopes)
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
parsedURL, err := url.Parse(authURL)
if err != nil {
t.Fatalf("Failed to parse auth URL: %v", err)
}
actualScope := parsedURL.Query().Get("scope")
actualScopes := strings.Split(actualScope, " ")
// Check each expected scope is present
for _, expectedScope := range tt.expectedScopes {
found := false
for _, actualScope := range actualScopes {
if actualScope == expectedScope {
found = true
break
}
}
if !found {
t.Errorf("Expected scope '%s' not found in '%s'", expectedScope, actualScope)
}
}
// Check no unexpected scopes are present
for _, actualScope := range actualScopes {
if actualScope == "" {
continue // Skip empty strings from split
}
found := false
for _, expectedScope := range tt.expectedScopes {
if actualScope == expectedScope {
found = true
break
}
}
if !found {
t.Errorf("Unexpected scope '%s' found in '%s'", actualScope, parsedURL.Query().Get("scope"))
}
}
})
}
}
// Test helper type for errors
type testError struct {
message string
}
func (e *testError) Error() string {
return e.message
}
-562
View File
@@ -1,562 +0,0 @@
package auth
import (
"net/url"
"strings"
"testing"
)
// TestAuthHandler_validateURL tests URL validation functionality
func TestAuthHandler_validateURL(t *testing.T) {
logger := &mockLogger{}
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
"test-client", "https://example.com/auth", "https://example.com", []string{}, false)
tests := []struct {
name string
url string
wantErr bool
errMsg string
}{
{
name: "Valid HTTPS URL",
url: "https://example.com/auth",
wantErr: false,
},
{
name: "Valid HTTP URL",
url: "http://example.com/auth",
wantErr: false,
},
{
name: "Empty URL",
url: "",
wantErr: true,
errMsg: "empty URL",
},
{
name: "Invalid URL format",
url: "not-a-url",
wantErr: true,
errMsg: "disallowed URL scheme",
},
{
name: "Disallowed scheme - javascript",
url: "javascript:alert('xss')",
wantErr: true,
errMsg: "disallowed URL scheme",
},
{
name: "Disallowed scheme - data",
url: "data:text/html,<script>alert('xss')</script>",
wantErr: true,
errMsg: "disallowed URL scheme",
},
{
name: "Disallowed scheme - file",
url: "file:///etc/passwd",
wantErr: true,
errMsg: "disallowed URL scheme",
},
{
name: "Disallowed scheme - ftp",
url: "ftp://example.com/file",
wantErr: true,
errMsg: "disallowed URL scheme",
},
{
name: "Missing host",
url: "https:///path",
wantErr: true,
errMsg: "missing host",
},
{
name: "Path traversal attempt",
url: "https://example.com/../../../etc/passwd",
wantErr: true,
errMsg: "path traversal detected",
},
{
name: "Path traversal in middle",
url: "https://example.com/path/../sensitive/file",
wantErr: true,
errMsg: "path traversal detected",
},
{
name: "Localhost attempt",
url: "https://localhost/auth",
wantErr: true,
errMsg: "localhost access not allowed",
},
{
name: "127.0.0.1 attempt",
url: "https://127.0.0.1/auth",
wantErr: true,
errMsg: "localhost access not allowed",
},
{
name: "IPv6 localhost attempt",
url: "https://[::1]/auth",
wantErr: true,
errMsg: "invalid host:port format",
},
{
name: "0.0.0.0 attempt",
url: "https://0.0.0.0/auth",
wantErr: true,
errMsg: "localhost access not allowed",
},
{
name: "Private IP - 192.168.x.x",
url: "https://192.168.1.1/auth",
wantErr: true,
errMsg: "private IP not allowed",
},
{
name: "Private IP - 10.x.x.x",
url: "https://10.0.0.1/auth",
wantErr: true,
errMsg: "private IP not allowed",
},
{
name: "Private IP - 172.16.x.x",
url: "https://172.16.0.1/auth",
wantErr: true,
errMsg: "private IP not allowed",
},
{
name: "Link-local IP",
url: "https://169.254.1.1/auth",
wantErr: true,
errMsg: "link-local IP not allowed",
},
{
name: "Multicast IP",
url: "https://224.0.0.1/auth",
wantErr: true,
errMsg: "multicast IP not allowed",
},
{
name: "Valid public IP",
url: "https://8.8.8.8/auth",
wantErr: false,
},
{
name: "Valid domain with port",
url: "https://example.com:8443/auth",
wantErr: false,
},
{
name: "localhost with case variation",
url: "https://LOCALHOST/auth",
wantErr: true,
errMsg: "localhost access not allowed",
},
{
name: "Invalid host:port format",
url: "https://example.com:notanumber/auth",
wantErr: true,
errMsg: "invalid URL format",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := handler.validateURL(tt.url)
if tt.wantErr {
if err == nil {
t.Errorf("validateURL() expected error but got none")
return
}
if !strings.Contains(err.Error(), tt.errMsg) {
t.Errorf("validateURL() error = %v, expected error containing %v", err, tt.errMsg)
}
} else {
if err != nil {
t.Errorf("validateURL() unexpected error = %v", err)
}
}
})
}
}
// TestAuthHandler_validateHost tests host validation specifically
func TestAuthHandler_validateHost(t *testing.T) {
logger := &mockLogger{}
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
"test-client", "https://example.com/auth", "https://example.com", []string{}, false)
tests := []struct {
name string
host string
wantErr bool
errMsg string
}{
{
name: "Valid hostname",
host: "example.com",
wantErr: false,
},
{
name: "Valid hostname with subdomain",
host: "api.example.com",
wantErr: false,
},
{
name: "Valid hostname with port",
host: "example.com:8080",
wantErr: false,
},
{
name: "Empty host",
host: "",
wantErr: true,
errMsg: "empty host",
},
{
name: "localhost",
host: "localhost",
wantErr: true,
errMsg: "localhost access not allowed",
},
{
name: "LOCALHOST (case insensitive)",
host: "LOCALHOST",
wantErr: true,
errMsg: "localhost access not allowed",
},
{
name: "localhost with port",
host: "localhost:8080",
wantErr: true,
errMsg: "localhost access not allowed",
},
{
name: "127.0.0.1",
host: "127.0.0.1",
wantErr: true,
errMsg: "localhost access not allowed",
},
{
name: "127.0.0.1 with port",
host: "127.0.0.1:8080",
wantErr: true,
errMsg: "localhost access not allowed",
},
{
name: "IPv6 localhost",
host: "::1",
wantErr: true,
errMsg: "invalid host:port format",
},
{
name: "0.0.0.0",
host: "0.0.0.0",
wantErr: true,
errMsg: "localhost access not allowed",
},
{
name: "Private IP 192.168.1.1",
host: "192.168.1.1",
wantErr: true,
errMsg: "private IP not allowed",
},
{
name: "Private IP 10.0.0.1",
host: "10.0.0.1",
wantErr: true,
errMsg: "private IP not allowed",
},
{
name: "Private IP 172.16.0.1",
host: "172.16.0.1",
wantErr: true,
errMsg: "private IP not allowed",
},
{
name: "Public IP 8.8.8.8",
host: "8.8.8.8",
wantErr: false,
},
{
name: "Link-local IP",
host: "169.254.1.1",
wantErr: true,
errMsg: "link-local IP not allowed",
},
{
name: "Multicast IP",
host: "224.0.0.1",
wantErr: true,
errMsg: "multicast IP not allowed",
},
{
name: "Invalid host:port format",
host: "example.com::",
wantErr: true,
errMsg: "invalid host:port format",
},
{
name: "Valid international domain",
host: "example.org",
wantErr: false,
},
{
name: "Valid ccTLD",
host: "example.co.uk",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := handler.validateHost(tt.host)
if tt.wantErr {
if err == nil {
t.Errorf("validateHost() expected error but got none")
return
}
if !strings.Contains(err.Error(), tt.errMsg) {
t.Errorf("validateHost() error = %v, expected error containing %v", err, tt.errMsg)
}
} else {
if err != nil {
t.Errorf("validateHost() unexpected error = %v", err)
}
}
})
}
}
// TestAuthHandler_buildURLWithParams tests URL building with parameters
func TestAuthHandler_buildURLWithParams(t *testing.T) {
logger := &mockLogger{}
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
"test-client", "https://example.com/auth", "https://example.com", []string{}, false)
tests := []struct {
name string
baseURL string
params url.Values
expected string
expectEmpty bool
}{
{
name: "Absolute HTTPS URL",
baseURL: "https://provider.com/auth",
params: url.Values{
"client_id": []string{"test-client"},
"response_type": []string{"code"},
},
expected: "https://provider.com/auth?client_id=test-client&response_type=code",
},
{
name: "Absolute HTTP URL",
baseURL: "http://provider.com/auth",
params: url.Values{
"state": []string{"test-state"},
},
expected: "http://provider.com/auth?state=test-state",
},
{
name: "Relative URL resolved against issuer",
baseURL: "/oauth2/authorize",
params: url.Values{
"scope": []string{"openid"},
},
expected: "https://example.com/oauth2/authorize?scope=openid",
},
{
name: "Root relative URL",
baseURL: "/auth",
params: url.Values{
"nonce": []string{"test-nonce"},
},
expected: "https://example.com/auth?nonce=test-nonce",
},
{
name: "Invalid absolute URL",
baseURL: "https://localhost/auth",
params: url.Values{},
expectEmpty: true, // Should return empty string due to validation failure
},
{
name: "Invalid relative URL when resolved",
baseURL: "/auth",
params: url.Values{},
expected: "", // Should be empty because issuer validation would be tested separately
expectEmpty: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := handler.buildURLWithParams(tt.baseURL, tt.params)
if tt.expectEmpty {
if result != "" {
t.Errorf("buildURLWithParams() expected empty string, got %v", result)
}
return
}
// For relative URLs, we expect them to be resolved against the issuer URL
if !strings.HasPrefix(tt.baseURL, "http") {
// Verify it starts with the issuer URL
if !strings.HasPrefix(result, handler.issuerURL) {
t.Errorf("buildURLWithParams() relative URL not resolved against issuer URL. Got %v", result)
}
}
// Parse the result to verify parameters
parsedURL, err := url.Parse(result)
if err != nil {
t.Fatalf("buildURLWithParams() produced invalid URL: %v", err)
}
// Verify all expected parameters are present
resultParams := parsedURL.Query()
for key, expectedValues := range tt.params {
actualValues := resultParams[key]
if len(actualValues) != len(expectedValues) {
t.Errorf("Parameter %s: expected %d values, got %d", key, len(expectedValues), len(actualValues))
continue
}
for i, expectedValue := range expectedValues {
if actualValues[i] != expectedValue {
t.Errorf("Parameter %s[%d]: expected %v, got %v", key, i, expectedValue, actualValues[i])
}
}
}
})
}
}
// TestAuthHandler_buildURLWithParams_ParameterEncoding tests proper parameter encoding
func TestAuthHandler_buildURLWithParams_ParameterEncoding(t *testing.T) {
logger := &mockLogger{}
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
"test-client", "https://example.com/auth", "https://example.com", []string{}, false)
// Test special characters that need encoding
params := url.Values{
"redirect_uri": []string{"https://example.com/callback?test=value&other=data"},
"state": []string{"state with spaces and & special chars"},
"scope": []string{"openid profile email"},
"special": []string{"value+with+plus&ampersand=equals"},
}
result := handler.buildURLWithParams("https://provider.com/auth", params)
parsedURL, err := url.Parse(result)
if err != nil {
t.Fatalf("Failed to parse result URL: %v", err)
}
// Verify parameters are correctly encoded/decoded
resultParams := parsedURL.Query()
expectedParams := map[string]string{
"redirect_uri": "https://example.com/callback?test=value&other=data",
"state": "state with spaces and & special chars",
"scope": "openid profile email",
"special": "value+with+plus&ampersand=equals",
}
for key, expectedValue := range expectedParams {
actualValue := resultParams.Get(key)
if actualValue != expectedValue {
t.Errorf("Parameter %s: expected %v, got %v", key, expectedValue, actualValue)
}
}
}
// TestAuthHandler_validateParsedURL tests validateParsedURL method
func TestAuthHandler_validateParsedURL(t *testing.T) {
logger := &mockLogger{}
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
"test-client", "https://example.com/auth", "https://example.com", []string{}, false)
tests := []struct {
name string
url string
wantErr bool
errMsg string
}{
{
name: "Valid HTTPS URL",
url: "https://example.com/path",
wantErr: false,
},
{
name: "Valid HTTP URL with warning",
url: "http://example.com/path",
wantErr: false, // Should not error but should log warning
},
{
name: "Invalid scheme",
url: "javascript:alert('xss')",
wantErr: true,
errMsg: "disallowed URL scheme",
},
{
name: "Missing host",
url: "https:///path",
wantErr: true,
errMsg: "missing host",
},
{
name: "Path traversal",
url: "https://example.com/path/../../../etc",
wantErr: true,
errMsg: "path traversal detected",
},
{
name: "Invalid host (private IP)",
url: "https://192.168.1.1/path",
wantErr: true,
errMsg: "invalid host",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parsedURL, err := url.Parse(tt.url)
if err != nil {
t.Fatalf("Failed to parse test URL: %v", err)
}
err = handler.validateParsedURL(parsedURL)
if tt.wantErr {
if err == nil {
t.Errorf("validateParsedURL() expected error but got none")
return
}
if !strings.Contains(err.Error(), tt.errMsg) {
t.Errorf("validateParsedURL() error = %v, expected error containing %v", err, tt.errMsg)
}
} else {
if err != nil {
t.Errorf("validateParsedURL() unexpected error = %v", err)
}
// Check for HTTP warning in debug logs
if parsedURL.Scheme == "http" && len(logger.debugMessages) > 0 {
found := false
for _, msg := range logger.debugMessages {
if strings.Contains(msg, "Warning: Using HTTP scheme") {
found = true
break
}
}
if !found {
t.Error("Expected HTTP scheme warning in debug logs")
}
}
}
})
}
}
+21 -15
View File
@@ -8,10 +8,6 @@ import (
"github.com/google/uuid"
)
// ============================================================================
// AUTHENTICATION FLOW
// ============================================================================
// validateRedirectCount checks if redirect limit is exceeded and handles the error
func (t *TraefikOidc) validateRedirectCount(session *SessionData, rw http.ResponseWriter, req *http.Request) error {
const maxRedirects = 5
@@ -47,7 +43,7 @@ func (t *TraefikOidc) generatePKCEParameters() (string, string, error) {
// prepareSessionForAuthentication clears existing session data and sets new authentication state
func (t *TraefikOidc) prepareSessionForAuthentication(session *SessionData, csrfToken, nonce, codeVerifier, incomingPath string) {
// Clear all existing session data
session.SetAuthenticated(false)
_ = session.SetAuthenticated(false) // Safe to ignore: clearing authentication state on new flow
session.SetEmail("")
session.SetAccessToken("")
session.SetRefreshToken("")
@@ -223,15 +219,25 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
return
}
email, _ := claims["email"].(string)
if email == "" {
t.logger.Errorf("Email claim missing or empty in token during callback")
t.sendErrorResponse(rw, req, "Authentication failed: Email missing in token", http.StatusInternalServerError)
return
// Extract user identifier from the configured claim (defaults to "email" for backward compatibility)
userIdentifier, _ := claims[t.userIdentifierClaim].(string)
if userIdentifier == "" {
// Try "sub" as fallback since it's required by OIDC spec
if t.userIdentifierClaim != "sub" {
userIdentifier, _ = claims["sub"].(string)
}
if userIdentifier == "" {
t.logger.Errorf("User identifier claim '%s' missing or empty in token during callback", t.userIdentifierClaim)
t.sendErrorResponse(rw, req, "Authentication failed: User identifier missing in token", http.StatusInternalServerError)
return
}
t.logger.Debugf("Configured claim '%s' not found, using 'sub' claim as fallback", t.userIdentifierClaim)
}
if !t.isAllowedDomain(email) {
t.logger.Errorf("Disallowed email domain during callback: %s", email)
t.sendErrorResponse(rw, req, "Authentication failed: Email domain not allowed", http.StatusForbidden)
// Validate user authorization
if !t.isAllowedUser(userIdentifier) {
t.logger.Errorf("User not authorized during callback: %s", userIdentifier)
t.sendErrorResponse(rw, req, "Authentication failed: User not authorized", http.StatusForbidden)
return
}
@@ -240,7 +246,7 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
t.sendErrorResponse(rw, req, "Failed to update session", http.StatusInternalServerError)
return
}
session.SetEmail(email)
session.SetEmail(userIdentifier) // SetEmail stores the user identifier (email or other claim)
session.SetIDToken(tokenResponse.IDToken)
session.SetAccessToken(tokenResponse.AccessToken)
session.SetRefreshToken(tokenResponse.RefreshToken)
@@ -276,7 +282,7 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
// - redirectURL: The callback URL to be used in the new authentication flow.
func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
t.logger.Debug("Handling expired token: Clearing session and initiating re-authentication.")
session.SetAuthenticated(false)
_ = session.SetAuthenticated(false) // Safe to ignore: clearing authentication on expired token
session.SetIDToken("")
session.SetAccessToken("")
session.SetRefreshToken("")
File diff suppressed because it is too large Load Diff
+101
View File
@@ -0,0 +1,101 @@
package traefikoidc
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestGeneratePKCEParameters tests the generatePKCEParameters method
func TestGeneratePKCEParameters(t *testing.T) {
t.Run("PKCE enabled - successful generation", func(t *testing.T) {
// Create a TraefikOidc instance with PKCE enabled
plugin := &TraefikOidc{
enablePKCE: true,
logger: NewLogger("debug"),
}
verifier, challenge, err := plugin.generatePKCEParameters()
require.NoError(t, err)
assert.NotEmpty(t, verifier, "code verifier should not be empty when PKCE is enabled")
assert.NotEmpty(t, challenge, "code challenge should not be empty when PKCE is enabled")
// Verify the challenge is derived from the verifier
expectedChallenge := deriveCodeChallenge(verifier)
assert.Equal(t, expectedChallenge, challenge, "challenge should match derived challenge from verifier")
})
t.Run("PKCE disabled - returns empty strings", func(t *testing.T) {
// Create a TraefikOidc instance with PKCE disabled
plugin := &TraefikOidc{
enablePKCE: false,
logger: NewLogger("debug"),
}
verifier, challenge, err := plugin.generatePKCEParameters()
require.NoError(t, err)
assert.Empty(t, verifier, "code verifier should be empty when PKCE is disabled")
assert.Empty(t, challenge, "code challenge should be empty when PKCE is disabled")
})
t.Run("PKCE enabled - generates different values each time", func(t *testing.T) {
plugin := &TraefikOidc{
enablePKCE: true,
logger: NewLogger("debug"),
}
verifier1, challenge1, err1 := plugin.generatePKCEParameters()
require.NoError(t, err1)
verifier2, challenge2, err2 := plugin.generatePKCEParameters()
require.NoError(t, err2)
assert.NotEqual(t, verifier1, verifier2, "verifiers should be different")
assert.NotEqual(t, challenge1, challenge2, "challenges should be different")
})
t.Run("PKCE enabled - verifier and challenge relationship", func(t *testing.T) {
plugin := &TraefikOidc{
enablePKCE: true,
logger: NewLogger("debug"),
}
verifier, challenge, err := plugin.generatePKCEParameters()
require.NoError(t, err)
// The challenge should always be derivable from the verifier
recalculatedChallenge := deriveCodeChallenge(verifier)
assert.Equal(t, challenge, recalculatedChallenge,
"challenge should always match the SHA256 hash of verifier")
})
t.Run("PKCE enabled - verifier meets RFC 7636 requirements", func(t *testing.T) {
plugin := &TraefikOidc{
enablePKCE: true,
logger: NewLogger("debug"),
}
verifier, _, err := plugin.generatePKCEParameters()
require.NoError(t, err)
// RFC 7636 requires verifier to be 43-128 characters
assert.GreaterOrEqual(t, len(verifier), 43, "verifier should be at least 43 characters")
assert.LessOrEqual(t, len(verifier), 128, "verifier should be at most 128 characters")
})
t.Run("PKCE enabled - challenge meets RFC 7636 requirements", func(t *testing.T) {
plugin := &TraefikOidc{
enablePKCE: true,
logger: NewLogger("debug"),
}
_, challenge, err := plugin.generatePKCEParameters()
require.NoError(t, err)
// SHA256 hash base64 encoded should be 43 characters
assert.Equal(t, 43, len(challenge), "S256 challenge should be exactly 43 characters")
})
}
+19 -16
View File
@@ -222,17 +222,16 @@ func (bt *BackgroundTask) run() {
// TaskCircuitBreaker implements circuit breaker pattern for background task creation
// It limits concurrent task execution and tracks failures to prevent system overload
type TaskCircuitBreaker struct {
state int32 // CircuitBreakerState
failureCount int32
lastFailureTime int64 // Unix timestamp
failureThreshold int32
timeout time.Duration
logger *Logger
// Concurrency limiting
concurrentTasks int32 // Current number of running tasks
maxConcurrent int32 // Maximum concurrent tasks allowed
activeTasks map[string]struct{} // Track active task names
tasksMu sync.RWMutex // Separate mutex for task tracking
activeTasks map[string]struct{}
lastFailureTime int64
timeout time.Duration
tasksMu sync.RWMutex
state int32
failureCount int32
failureThreshold int32
concurrentTasks int32
maxConcurrent int32
}
// NewTaskCircuitBreaker creates a new circuit breaker for background tasks
@@ -266,18 +265,21 @@ func (cb *TaskCircuitBreaker) CanCreateTask(taskName string) error {
max := atomic.LoadInt32(&cb.maxConcurrent)
// For cleanup tasks, be more restrictive (singleton-like behavior)
// However, allow distinct realm-specific tasks (e.g., singleton-metadata-refresh-abc123 vs singleton-metadata-refresh-def456)
if strings.Contains(taskName, "cleanup") || strings.Contains(taskName, "singleton") {
cb.tasksMu.RLock()
hasCleanupTask := false
hasSameTask := false
for activeTask := range cb.activeTasks {
if strings.Contains(activeTask, "cleanup") || strings.Contains(activeTask, "singleton") {
hasCleanupTask = true
// Only block if the EXACT same task is already running
// This allows realm-specific tasks like singleton-metadata-refresh-{hash} to run concurrently
if activeTask == taskName {
hasSameTask = true
break
}
}
cb.tasksMu.RUnlock()
if hasCleanupTask {
if hasSameTask {
return fmt.Errorf("cleanup/singleton task already running: %s", taskName)
}
}
@@ -377,9 +379,9 @@ func (cb *TaskCircuitBreaker) OnTaskFailure(taskName string, err error) {
// TaskRegistry maintains a registry of all active background tasks to prevent duplicates
type TaskRegistry struct {
tasks map[string]*BackgroundTask
mu sync.RWMutex
cb *TaskCircuitBreaker
logger *Logger
mu sync.RWMutex
}
// GlobalTaskRegistry is the singleton instance for managing all background tasks
@@ -538,7 +540,7 @@ func (tr *TaskRegistry) CreateSingletonTask(name string, interval time.Duration,
// Start the task if not already running
if !rm.IsTaskRunning(name) {
rm.StartBackgroundTask(name)
_ = rm.StartBackgroundTask(name) // Safe to ignore: task registration succeeded, start is best-effort
}
// Get the task from resource manager's internal registry
@@ -787,6 +789,7 @@ func (mm *TaskMemoryMonitor) ForceGC() (before, after TaskMemoryStats, err error
}
if mm.logger != nil {
// #nosec G115 -- heap allocation bytes fit in int64 for practical purposes
freed := int64(before.HeapAlloc) - int64(after.HeapAlloc)
mm.logger.Infof("Forced GC: freed %d bytes (%.2f MB)", freed, float64(freed)/(1024*1024))
}
+8 -7
View File
@@ -58,6 +58,7 @@ func TestAzureOIDCRegression(t *testing.T) {
tokenURL: "https://login.microsoftonline.com/tenant-id/oauth2/v2.0/token",
jwksURL: "https://login.microsoftonline.com/tenant-id/discovery/v2.0/keys",
clientID: "test-client-id",
audience: "test-client-id",
clientSecret: "test-client-secret",
scopes: []string{"openid", "profile", "email"},
refreshGracePeriod: 60 * time.Second,
@@ -78,7 +79,7 @@ func TestAzureOIDCRegression(t *testing.T) {
tOidc := &mockTraefikOidc{TraefikOidc: baseOidc}
// Initialize session manager
sessionManager, _ := NewSessionManager("test-encryption-key-32-bytes-long", false, "", mockLogger)
sessionManager, _ := NewSessionManager("test-encryption-key-32-bytes-long", false, "", "", 0, mockLogger)
tOidc.sessionManager = sessionManager
// Mock the JWT verification to avoid JWKS lookup issues
@@ -329,12 +330,12 @@ func TestValidateGoogleTokens(t *testing.T) {
ts.tOidc.refreshGracePeriod = 60 * time.Second
tests := []struct {
name string
setupSession func() *SessionData
name string
description string
expectedAuth bool
expectedRefresh bool
expectedExpired bool
description string
}{
{
name: "ValidGoogleTokens",
@@ -475,13 +476,13 @@ func TestIsUserAuthenticated(t *testing.T) {
ts.tOidc.refreshGracePeriod = 60 * time.Second
tests := []struct {
setupSession func() *SessionData
name string
providerType string
setupSession func() *SessionData
description string
expectedAuth bool
expectedRefresh bool
expectedExpired bool
description string
}{
{
name: "AzureProvider",
@@ -659,12 +660,12 @@ func TestValidateAzureTokensEdgeCases(t *testing.T) {
ts.tOidc.refreshGracePeriod = 60 * time.Second
tests := []struct {
name string
setupSession func() *SessionData
name string
description string
expectedAuth bool
expectedRefresh bool
expectedExpired bool
description string
}{
{
name: "UnauthenticatedWithRefreshToken",
+536
View File
@@ -0,0 +1,536 @@
package traefikoidc
import (
"context"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
// TestMemoryMonitorComprehensive tests memory monitor edge cases
func TestMemoryMonitorComprehensive(t *testing.T) {
t.Run("TriggerGC calls runtime GC", func(t *testing.T) {
thresholds := DefaultMemoryAlertThresholds()
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
// Should not panic
assert.NotPanics(t, func() {
monitor.TriggerGC()
})
})
t.Run("GetMemoryPressure returns pressure level", func(t *testing.T) {
thresholds := DefaultMemoryAlertThresholds()
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
// Initially should return None (no stats yet)
pressure := monitor.GetMemoryPressure()
assert.Equal(t, MemoryPressureNone, pressure)
// Collect stats to populate lastStats
monitor.GetCurrentStats()
// Now should return a valid pressure level
pressure = monitor.GetMemoryPressure()
assert.NotNil(t, pressure)
})
t.Run("StartMonitoring can be called", func(t *testing.T) {
ResetGlobalMemoryMonitor()
ResetGlobalTaskRegistry()
defer ResetGlobalMemoryMonitor()
defer ResetGlobalTaskRegistry()
thresholds := DefaultMemoryAlertThresholds()
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
// Start monitoring should not panic
assert.NotPanics(t, func() {
ctx := context.Background()
monitor.StartMonitoring(ctx, 100*time.Millisecond)
time.Sleep(GetTestDuration(50 * time.Millisecond))
})
// Clean up
monitor.StopMonitoring()
})
t.Run("StopMonitoring can be called safely", func(t *testing.T) {
ResetGlobalMemoryMonitor()
defer ResetGlobalMemoryMonitor()
thresholds := DefaultMemoryAlertThresholds()
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
// StopMonitoring should not panic even if not started
assert.NotPanics(t, func() {
monitor.StopMonitoring()
})
// Can be called multiple times safely
assert.NotPanics(t, func() {
monitor.StopMonitoring()
monitor.StopMonitoring()
})
})
t.Run("ResetGlobalMemoryMonitor resets singleton", func(t *testing.T) {
ResetGlobalMemoryMonitor()
defer ResetGlobalMemoryMonitor()
// Get initial instance
GetGlobalMemoryMonitor()
// Reset
ResetGlobalMemoryMonitor()
// Should be able to get a new instance
monitor := GetGlobalMemoryMonitor()
assert.NotNil(t, monitor)
// Clean up
monitor.StopMonitoring()
ResetGlobalMemoryMonitor()
})
t.Run("String method returns pressure name", func(t *testing.T) {
pressures := []struct {
name string
level MemoryPressureLevel
}{
{level: MemoryPressureNone, name: "None"},
{level: MemoryPressureLow, name: "Low"},
{level: MemoryPressureModerate, name: "Moderate"},
{level: MemoryPressureHigh, name: "High"},
{level: MemoryPressureCritical, name: "Critical"},
{level: MemoryPressureLevel(999), name: "Unknown"},
}
for _, p := range pressures {
assert.Equal(t, p.name, p.level.String(), "pressure level %d should return %s", p.level, p.name)
}
})
t.Run("GetCurrentStats collects statistics", func(t *testing.T) {
thresholds := DefaultMemoryAlertThresholds()
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
stats := monitor.GetCurrentStats()
assert.NotNil(t, stats)
assert.Greater(t, stats.HeapAllocBytes, uint64(0))
assert.Greater(t, stats.NumGoroutines, 0)
assert.NotZero(t, stats.Timestamp)
})
}
// TestBackgroundTaskRegistry tests background task registry edge cases
func TestBackgroundTaskRegistry(t *testing.T) {
t.Run("GetGlobalTaskRegistry returns singleton", func(t *testing.T) {
registry1 := GetGlobalTaskRegistry()
registry2 := GetGlobalTaskRegistry()
assert.Equal(t, registry1, registry2, "should return same instance")
})
t.Run("RegisterTask adds task to registry", func(t *testing.T) {
ResetGlobalTaskRegistry()
registry := GetGlobalTaskRegistry()
defer ResetGlobalTaskRegistry()
taskName := "test-register-task"
task := NewBackgroundTask(
taskName,
100*time.Millisecond,
func() {},
newNoOpLogger(),
)
err := registry.RegisterTask(taskName, task)
assert.NoError(t, err)
// Verify task was registered
_, exists := registry.GetTask(taskName)
assert.True(t, exists, "task should be registered")
// Clean up
task.Stop()
})
t.Run("CreateSingletonTask is idempotent", func(t *testing.T) {
ResetGlobalTaskRegistry()
registry := GetGlobalTaskRegistry()
defer ResetGlobalTaskRegistry()
taskName := "test-singleton-idempotent"
callCount := 0
var mu sync.Mutex
taskFunc := func() {
mu.Lock()
callCount++
mu.Unlock()
}
// First creation should succeed
task1, err1 := registry.CreateSingletonTask(
taskName,
100*time.Millisecond,
taskFunc,
newNoOpLogger(),
nil,
)
assert.NoError(t, err1)
assert.NotNil(t, task1)
// Second creation should also succeed (idempotent)
// Returns same task without error
task2, err2 := registry.CreateSingletonTask(
taskName,
100*time.Millisecond,
taskFunc,
newNoOpLogger(),
nil,
)
assert.NoError(t, err2, "CreateSingletonTask should be idempotent")
assert.NotNil(t, task2)
// Clean up
if task1 != nil {
task1.Stop()
}
})
t.Run("GetTaskCount returns active task count", func(t *testing.T) {
ResetGlobalTaskRegistry()
registry := GetGlobalTaskRegistry()
defer ResetGlobalTaskRegistry()
// Initially should be 0 or small number
initialCount := registry.GetTaskCount()
// Create a task
task := NewBackgroundTask(
"count-test-task",
100*time.Millisecond,
func() {},
newNoOpLogger(),
)
err := registry.RegisterTask("count-test-task", task)
assert.NoError(t, err)
// Count should increase
newCount := registry.GetTaskCount()
assert.Equal(t, initialCount+1, newCount)
// Clean up
task.Stop()
})
t.Run("StopAllTasks stops all tasks", func(t *testing.T) {
ResetGlobalTaskRegistry()
registry := GetGlobalTaskRegistry()
defer ResetGlobalTaskRegistry()
// Create multiple tasks
for i := 0; i < 3; i++ {
taskName := "multi-task-" + string(rune(i+'0'))
task := NewBackgroundTask(
taskName,
100*time.Millisecond,
func() {},
newNoOpLogger(),
)
registry.RegisterTask(taskName, task)
}
// Verify tasks were created
assert.GreaterOrEqual(t, registry.GetTaskCount(), 3)
// Stop all tasks
registry.StopAllTasks()
// Verify all tasks are removed
taskCount := registry.GetTaskCount()
assert.Equal(t, 0, taskCount, "all tasks should be stopped")
})
t.Run("ResetGlobalTaskRegistry clears registry", func(t *testing.T) {
ResetGlobalTaskRegistry()
registry := GetGlobalTaskRegistry()
// Create a task
task := NewBackgroundTask(
"reset-test-task",
100*time.Millisecond,
func() {},
newNoOpLogger(),
)
registry.RegisterTask("reset-test-task", task)
// Reset
ResetGlobalTaskRegistry()
// Get new registry
newRegistry := GetGlobalTaskRegistry()
assert.Equal(t, 0, newRegistry.GetTaskCount(), "new registry should be empty")
})
}
// TestBackgroundTaskLifecycle tests background task lifecycle
func TestBackgroundTaskLifecycle(t *testing.T) {
t.Run("Start begins task execution", func(t *testing.T) {
if testing.Short() {
t.Skip("Skipping background task test in short mode")
}
ResetGlobalTaskRegistry()
defer ResetGlobalTaskRegistry()
executed := false
var mu sync.Mutex
task := NewBackgroundTask(
"lifecycle-test",
50*time.Millisecond,
func() {
mu.Lock()
executed = true
mu.Unlock()
},
newNoOpLogger(),
)
// Start task
task.Start()
// Wait for execution
time.Sleep(GetTestDuration(100 * time.Millisecond))
// Stop task
task.Stop()
// Verify it executed
mu.Lock()
wasExecuted := executed
mu.Unlock()
assert.True(t, wasExecuted, "task should have executed")
})
t.Run("Stop halts task execution", func(t *testing.T) {
if testing.Short() {
t.Skip("Skipping background task test in short mode")
}
ResetGlobalTaskRegistry()
defer ResetGlobalTaskRegistry()
execCount := 0
var mu sync.Mutex
task := NewBackgroundTask(
"stop-test",
30*time.Millisecond,
func() {
mu.Lock()
execCount++
mu.Unlock()
},
newNoOpLogger(),
)
// Start task
task.Start()
// Let it run a few times
time.Sleep(GetTestDuration(100 * time.Millisecond))
// Stop task
task.Stop()
// Record count
mu.Lock()
countAfterStop := execCount
mu.Unlock()
// Wait more
time.Sleep(GetTestDuration(100 * time.Millisecond))
// Count should not increase
mu.Lock()
finalCount := execCount
mu.Unlock()
assert.Equal(t, countAfterStop, finalCount, "task should not execute after stop")
})
t.Run("Multiple Start calls are safe", func(t *testing.T) {
if testing.Short() {
t.Skip("Skipping background task test in short mode")
}
ResetGlobalTaskRegistry()
defer ResetGlobalTaskRegistry()
execCount := 0
var mu sync.Mutex
task := NewBackgroundTask(
"multi-start-test",
100*time.Millisecond,
func() {
mu.Lock()
execCount++
mu.Unlock()
},
newNoOpLogger(),
)
// Multiple starts should be safe
task.Start()
task.Start()
task.Start()
// Wait a bit
time.Sleep(GetTestDuration(50 * time.Millisecond))
// Stop task
task.Stop()
// Should have executed, but only one goroutine
mu.Lock()
count := execCount
mu.Unlock()
assert.GreaterOrEqual(t, count, 0, "task should have executed at least once")
})
t.Run("Multiple Stop calls are safe", func(t *testing.T) {
ResetGlobalTaskRegistry()
defer ResetGlobalTaskRegistry()
task := NewBackgroundTask(
"multi-stop-test",
100*time.Millisecond,
func() {},
newNoOpLogger(),
)
// Start and stop
task.Start()
time.Sleep(GetTestDuration(20 * time.Millisecond))
// Multiple stops should be safe
assert.NotPanics(t, func() {
task.Stop()
task.Stop()
task.Stop()
})
})
}
// TestMemoryMonitorIntegration tests memory monitor integration
func TestMemoryMonitorIntegration(t *testing.T) {
if testing.Short() {
t.Skip("Skipping memory monitor integration test in short mode")
}
t.Run("monitoring updates stats", func(t *testing.T) {
ResetGlobalMemoryMonitor()
ResetGlobalTaskRegistry()
defer ResetGlobalMemoryMonitor()
defer ResetGlobalTaskRegistry()
thresholds := DefaultMemoryAlertThresholds()
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
defer monitor.StopMonitoring()
// Start monitoring
ctx := context.Background()
monitor.StartMonitoring(ctx, 50*time.Millisecond)
// Wait for at least one check
time.Sleep(GetTestDuration(150 * time.Millisecond))
// Get pressure (should be a valid pressure level)
pressure := monitor.GetMemoryPressure()
assert.Contains(t, []MemoryPressureLevel{
MemoryPressureNone,
MemoryPressureLow,
MemoryPressureModerate,
MemoryPressureHigh,
MemoryPressureCritical,
}, pressure, "pressure should be a valid level")
// Stop monitoring
monitor.StopMonitoring()
})
t.Run("global memory monitor singleton", func(t *testing.T) {
ResetGlobalMemoryMonitor()
defer ResetGlobalMemoryMonitor()
monitor1 := GetGlobalMemoryMonitor()
monitor2 := GetGlobalMemoryMonitor()
assert.Equal(t, monitor1, monitor2, "should return same instance")
})
}
// TestMemoryStatsCollection tests memory statistics collection
func TestMemoryStatsCollection(t *testing.T) {
t.Run("GetCurrentStats returns valid data", func(t *testing.T) {
thresholds := DefaultMemoryAlertThresholds()
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
stats := monitor.GetCurrentStats()
assert.NotNil(t, stats)
assert.Greater(t, stats.HeapAllocBytes, uint64(0))
assert.Greater(t, stats.HeapSysBytes, uint64(0))
assert.Greater(t, stats.NumGoroutines, 0)
assert.False(t, stats.Timestamp.IsZero())
})
t.Run("Stats include memory pressure", func(t *testing.T) {
thresholds := DefaultMemoryAlertThresholds()
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
stats := monitor.GetCurrentStats()
// Should calculate and include pressure level
assert.NotNil(t, stats.MemoryPressure)
assert.Contains(t, []MemoryPressureLevel{
MemoryPressureNone,
MemoryPressureLow,
MemoryPressureModerate,
MemoryPressureHigh,
MemoryPressureCritical,
}, stats.MemoryPressure)
})
t.Run("TriggerGC reduces memory", func(t *testing.T) {
thresholds := DefaultMemoryAlertThresholds()
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
// Allocate some memory
_ = make([]byte, 1024*1024) // 1MB
// Get stats before GC
beforeStats := monitor.GetCurrentStats()
// Trigger GC
monitor.TriggerGC()
// Get stats after GC
afterStats := monitor.GetCurrentStats()
// After GC should have different stats
assert.NotEqual(t, beforeStats.LastGCTime, afterStats.LastGCTime)
})
}
+241
View File
@@ -0,0 +1,241 @@
package traefikoidc
import (
"fmt"
"sync"
"testing"
"time"
)
// =============================================================================
// UNIVERSAL CACHE BENCHMARKS
// =============================================================================
func BenchmarkCacheSet(b *testing.B) {
cache := NewUniversalCache(createTestCacheConfig())
defer cache.Close()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
cache.Set(fmt.Sprintf("key%d", i), fmt.Sprintf("value%d", i), 1*time.Hour)
i++
}
})
}
func BenchmarkCacheGet(b *testing.B) {
cache := NewUniversalCache(createTestCacheConfig())
defer cache.Close()
for i := 0; i < 1000; i++ {
cache.Set(fmt.Sprintf("key%d", i), fmt.Sprintf("value%d", i), 1*time.Hour)
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
cache.Get(fmt.Sprintf("key%d", i%1000))
i++
}
})
}
func BenchmarkCacheSetGet(b *testing.B) {
cache := NewUniversalCache(createTestCacheConfig())
defer cache.Close()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
key := fmt.Sprintf("key%d", i)
cache.Set(key, fmt.Sprintf("value%d", i), 1*time.Hour)
cache.Get(key)
i++
}
})
}
func BenchmarkCacheLRUEviction(b *testing.B) {
config := createTestCacheConfig()
config.MaxSize = 100
cache := NewUniversalCache(config)
defer cache.Close()
b.ResetTimer()
for i := 0; i < b.N; i++ {
cache.Set(fmt.Sprintf("key%d", i), fmt.Sprintf("value%d", i), 1*time.Hour)
}
}
func BenchmarkCacheConcurrent(b *testing.B) {
cache := NewUniversalCache(createTestCacheConfig())
defer cache.Close()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
switch i % 3 {
case 0:
cache.Set(fmt.Sprintf("key%d", i), fmt.Sprintf("value%d", i), 1*time.Hour)
case 1:
cache.Get(fmt.Sprintf("key%d", i))
case 2:
cache.Delete(fmt.Sprintf("key%d", i))
}
i++
}
})
}
// =============================================================================
// CACHE MANAGER BENCHMARKS
// =============================================================================
func BenchmarkCacheInterfaceWrapper_Set(b *testing.B) {
t := &testing.T{}
cm := getTestCacheManager(t)
cache := cm.GetSharedTokenBlacklist()
b.ResetTimer()
for i := 0; i < b.N; i++ {
cache.Set("benchmark-key", "benchmark-value", time.Hour)
}
}
func BenchmarkCacheInterfaceWrapper_Get(b *testing.B) {
t := &testing.T{}
cm := getTestCacheManager(t)
cache := cm.GetSharedTokenBlacklist()
cache.Set("benchmark-key", "benchmark-value", time.Hour)
b.ResetTimer()
for i := 0; i < b.N; i++ {
cache.Get("benchmark-key")
}
}
func BenchmarkCacheInterfaceWrapper_Delete(b *testing.B) {
t := &testing.T{}
cm := getTestCacheManager(t)
cache := cm.GetSharedTokenBlacklist()
b.ResetTimer()
for i := 0; i < b.N; i++ {
b.StopTimer()
key := fmt.Sprintf("benchmark-key-%d", i)
cache.Set(key, "value", time.Hour)
b.StartTimer()
cache.Delete(key)
}
}
// =============================================================================
// CACHE COMPATIBILITY BENCHMARKS
// =============================================================================
func BenchmarkNewBoundedCache(b *testing.B) {
for i := 0; i < b.N; i++ {
NewBoundedCache(1000)
}
}
func BenchmarkNewOptimizedCache(b *testing.B) {
for i := 0; i < b.N; i++ {
NewOptimizedCache()
}
}
func BenchmarkLRUStrategy_EstimateSize(b *testing.B) {
strategy := NewLRUStrategy(1000)
item := "test-item"
b.ResetTimer()
for i := 0; i < b.N; i++ {
strategy.EstimateSize(item)
}
}
// =============================================================================
// SHARDED CACHE BENCHMARKS
// =============================================================================
func BenchmarkShardedCache(b *testing.B) {
b.Run("Set", func(b *testing.B) {
cache := NewShardedCache(64, 100000)
b.ResetTimer()
for i := 0; i < b.N; i++ {
cache.Set(fmt.Sprintf("key-%d", i), i, 5*time.Minute)
}
})
b.Run("Get", func(b *testing.B) {
cache := NewShardedCache(64, 100000)
for i := 0; i < 10000; i++ {
cache.Set(fmt.Sprintf("key-%d", i), i, 5*time.Minute)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
cache.Get(fmt.Sprintf("key-%d", i%10000))
}
})
b.Run("ParallelSetGet", func(b *testing.B) {
cache := NewShardedCache(64, 100000)
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
key := fmt.Sprintf("key-%d", i)
cache.Set(key, i, 5*time.Minute)
cache.Get(key)
i++
}
})
})
}
// BenchmarkShardedVsGlobalMutex compares sharded cache with global mutex approach
func BenchmarkShardedVsGlobalMutex(b *testing.B) {
b.Run("ShardedCache64", func(b *testing.B) {
cache := NewShardedCache(64, 100000)
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
key := fmt.Sprintf("jti-%d", i%10000)
if !cache.Exists(key) {
cache.Set(key, true, 5*time.Minute)
}
i++
}
})
})
b.Run("GlobalMutexCache", func(b *testing.B) {
var mu sync.RWMutex
data := make(map[string]bool)
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
key := fmt.Sprintf("jti-%d", i%10000)
mu.RLock()
_, exists := data[key]
mu.RUnlock()
if !exists {
mu.Lock()
data[key] = true
mu.Unlock()
}
i++
}
})
})
}
+3 -3
View File
@@ -155,9 +155,9 @@ type CacheStrategy interface {
// CacheEntry for backward compatibility
type CacheEntry struct {
Key string
Value interface{}
ExpiresAt time.Time
Value interface{}
Key string
}
// Cache is an alias for backward compatibility
@@ -175,10 +175,10 @@ func NewOptimizedCacheWithConfig(config OptimizedCacheConfig) *CacheInterfaceWra
// ListNode for backward compatibility
type ListNode struct {
Key string
Value interface{}
Next *ListNode
Prev *ListNode
Key string
}
// NewFixedMetadataCache creates a metadata cache with fixed configuration
-369
View File
@@ -1,369 +0,0 @@
package traefikoidc
import (
"testing"
"time"
)
// TestNewBoundedCache tests creation of bounded cache
func TestNewBoundedCache(t *testing.T) {
maxSize := 500
cache := NewBoundedCache(maxSize)
if cache == nil {
t.Fatal("Expected cache to be created, got nil")
}
// Verify we can use basic operations
cache.Set("test-key", "test-value", time.Hour)
value, found := cache.Get("test-key")
if !found {
t.Error("Expected key to be found in cache")
}
if value != "test-value" {
t.Errorf("Expected 'test-value', got %v", value)
}
}
// TestDefaultUnifiedCacheConfig tests default configuration
func TestDefaultUnifiedCacheConfig(t *testing.T) {
config := DefaultUnifiedCacheConfig()
if config.Type != CacheTypeGeneral {
t.Errorf("Expected CacheTypeGeneral, got %v", config.Type)
}
if config.MaxSize != 500 {
t.Errorf("Expected MaxSize 500, got %d", config.MaxSize)
}
if config.MaxMemoryBytes != 64*1024*1024 {
t.Errorf("Expected MaxMemoryBytes 64MB, got %d", config.MaxMemoryBytes)
}
if config.CleanupInterval != 2*time.Minute {
t.Errorf("Expected CleanupInterval 2 minutes, got %v", config.CleanupInterval)
}
if config.Logger == nil {
t.Error("Expected Logger to be set")
}
}
// TestNewUnifiedCache tests unified cache creation
func TestNewUnifiedCache(t *testing.T) {
config := DefaultUnifiedCacheConfig()
cache := NewUnifiedCache(config)
if cache == nil {
t.Fatal("Expected cache to be created, got nil")
}
if cache.UniversalCache == nil {
t.Error("Expected UniversalCache to be set")
}
// Test basic operations
cache.Set("test-key", "test-value", time.Hour)
value, found := cache.Get("test-key")
if !found {
t.Error("Expected key to be found in cache")
}
if value != "test-value" {
t.Errorf("Expected 'test-value', got %v", value)
}
}
// TestUnifiedCache_SetMaxSize tests SetMaxSize method
func TestUnifiedCache_SetMaxSize(t *testing.T) {
config := DefaultUnifiedCacheConfig()
cache := NewUnifiedCache(config)
// Test setting max size
newSize := 1000
cache.SetMaxSize(newSize)
// We can't easily verify the size was set without exposing internal fields,
// but we can ensure the method doesn't panic
}
// TestNewCacheAdapter tests cache adapter creation
func TestNewCacheAdapter(t *testing.T) {
tests := []struct {
name string
cache interface{}
expectNil bool
description string
}{
{
name: "UniversalCache",
cache: NewUniversalCache(DefaultUnifiedCacheConfig()),
expectNil: false,
description: "Should create adapter for UniversalCache",
},
{
name: "UnifiedCache",
cache: NewUnifiedCache(DefaultUnifiedCacheConfig()),
expectNil: false,
description: "Should create adapter for UnifiedCache",
},
{
name: "Invalid cache type",
cache: "not-a-cache",
expectNil: true,
description: "Should return nil for invalid cache type",
},
{
name: "Nil cache",
cache: nil,
expectNil: true,
description: "Should return nil for nil cache",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
adapter := NewCacheAdapter(tt.cache)
if tt.expectNil {
if adapter != nil {
t.Errorf("Expected nil adapter, got %v", adapter)
}
} else {
if adapter == nil {
t.Error("Expected non-nil adapter")
}
// Test basic operations
adapter.Set("test", "value", time.Hour)
value, found := adapter.Get("test")
if !found {
t.Error("Expected key to be found")
}
if value != "value" {
t.Errorf("Expected 'value', got %v", value)
}
}
})
}
}
// TestNewOptimizedCache tests optimized cache creation
func TestNewOptimizedCache(t *testing.T) {
cache := NewOptimizedCache()
if cache == nil {
t.Fatal("Expected cache to be created, got nil")
}
// Verify it works with basic operations
cache.Set("test-key", "test-value", time.Hour)
value, found := cache.Get("test-key")
if !found {
t.Error("Expected key to be found in cache")
}
if value != "test-value" {
t.Errorf("Expected 'test-value', got %v", value)
}
}
// TestNewLRUStrategy tests LRU strategy creation
func TestNewLRUStrategy(t *testing.T) {
maxSize := 100
strategy := NewLRUStrategy(maxSize)
if strategy == nil {
t.Fatal("Expected strategy to be created, got nil")
}
lruStrategy, ok := strategy.(*LRUStrategy)
if !ok {
t.Fatal("Expected LRUStrategy type")
}
if lruStrategy.maxSize != maxSize {
t.Errorf("Expected maxSize %d, got %d", maxSize, lruStrategy.maxSize)
}
if lruStrategy.order == nil {
t.Error("Expected order list to be initialized")
}
if lruStrategy.elements == nil {
t.Error("Expected elements map to be initialized")
}
}
// TestLRUStrategy_Name tests strategy name
func TestLRUStrategy_Name(t *testing.T) {
strategy := NewLRUStrategy(100)
name := strategy.Name()
if name != "LRU" {
t.Errorf("Expected 'LRU', got %s", name)
}
}
// TestLRUStrategy_ShouldEvict tests eviction logic
func TestLRUStrategy_ShouldEvict(t *testing.T) {
strategy := NewLRUStrategy(100)
// LRU strategy always returns false for ShouldEvict
result := strategy.ShouldEvict("test-item", time.Now())
if result != false {
t.Error("Expected ShouldEvict to return false")
}
}
// TestLRUStrategy_OnAccess tests access callback
func TestLRUStrategy_OnAccess(t *testing.T) {
strategy := NewLRUStrategy(100)
// OnAccess should not panic
strategy.OnAccess("test-key", "test-value")
}
// TestLRUStrategy_OnRemove tests removal callback
func TestLRUStrategy_OnRemove(t *testing.T) {
strategy := NewLRUStrategy(100)
// OnRemove should not panic
strategy.OnRemove("test-key")
}
// TestLRUStrategy_EstimateSize tests size estimation
func TestLRUStrategy_EstimateSize(t *testing.T) {
strategy := NewLRUStrategy(100)
size := strategy.EstimateSize("test-item")
if size != 64 {
t.Errorf("Expected size 64, got %d", size)
}
}
// TestLRUStrategy_GetEvictionCandidate tests eviction candidate retrieval
func TestLRUStrategy_GetEvictionCandidate(t *testing.T) {
strategy := NewLRUStrategy(100)
key, found := strategy.GetEvictionCandidate()
if found {
t.Error("Expected no eviction candidate to be found")
}
if key != "" {
t.Errorf("Expected empty key, got %s", key)
}
}
// TestNewOptimizedCacheWithConfig tests optimized cache with custom config
func TestNewOptimizedCacheWithConfig(t *testing.T) {
config := UniversalCacheConfig{
Type: CacheTypeGeneral,
MaxSize: 1000,
MaxMemoryBytes: 128 * 1024 * 1024,
EnableMetrics: true,
Logger: GetSingletonNoOpLogger(),
}
cache := NewOptimizedCacheWithConfig(config)
if cache == nil {
t.Fatal("Expected cache to be created, got nil")
}
// Verify it works with basic operations
cache.Set("test-key", "test-value", time.Hour)
value, found := cache.Get("test-key")
if !found {
t.Error("Expected key to be found in cache")
}
if value != "test-value" {
t.Errorf("Expected 'test-value', got %v", value)
}
}
// TestNewFixedMetadataCache tests fixed metadata cache creation
func TestNewFixedMetadataCache(t *testing.T) {
cache := NewFixedMetadataCache()
if cache == nil {
t.Fatal("Expected cache to be created, got nil")
}
// Verify it works with proper metadata operations
metadata := &ProviderMetadata{
Issuer: "https://example.com",
AuthURL: "https://example.com/auth",
TokenURL: "https://example.com/token",
JWKSURL: "https://example.com/jwks",
}
err := cache.Set("test-provider", metadata, time.Hour)
if err != nil {
t.Errorf("Unexpected error setting metadata: %v", err)
}
// Test that the cache was created (basic verification)
// Note: We can't easily test Get without more complex setup
}
// TestNewDoublyLinkedList tests doubly linked list creation
func TestNewDoublyLinkedList(t *testing.T) {
list := NewDoublyLinkedList()
if list == nil {
t.Fatal("Expected list to be created, got nil")
}
// Test it's a proper list structure
if list.Len() != 0 {
t.Error("Expected empty list initially")
}
}
// TestDoublyLinkedList_PopFront tests front element removal
func TestDoublyLinkedList_PopFront(t *testing.T) {
list := NewDoublyLinkedList()
// Test popping from empty list
element := list.PopFront()
if element != nil {
t.Error("Expected nil when popping from empty list")
}
// Add an element and test popping
added := list.PushBack("test-value")
if added == nil {
t.Fatal("Expected element to be added")
}
popped := list.PopFront()
if popped == nil {
t.Error("Expected element to be popped")
}
if list.Len() != 0 {
t.Error("Expected list to be empty after popping")
}
}
// Benchmark tests for performance
func BenchmarkNewBoundedCache(b *testing.B) {
for i := 0; i < b.N; i++ {
NewBoundedCache(1000)
}
}
func BenchmarkNewOptimizedCache(b *testing.B) {
for i := 0; i < b.N; i++ {
NewOptimizedCache()
}
}
func BenchmarkLRUStrategy_EstimateSize(b *testing.B) {
strategy := NewLRUStrategy(1000)
item := "test-item"
b.ResetTimer()
for i := 0; i < b.N; i++ {
strategy.EstimateSize(item)
}
}
File diff suppressed because it is too large Load Diff
+57 -7
View File
@@ -21,10 +21,37 @@ var (
)
// GetGlobalCacheManager returns a singleton CacheManager instance
// Deprecated: Use GetGlobalCacheManagerWithConfig instead
func GetGlobalCacheManager(wg *sync.WaitGroup) *CacheManager {
return GetGlobalCacheManagerWithConfig(wg, nil)
}
// GetGlobalCacheManagerWithConfig returns a singleton CacheManager instance with optional Redis configuration
func GetGlobalCacheManagerWithConfig(wg *sync.WaitGroup, config *Config) *CacheManager {
cacheManagerInitOnce.Do(func() {
var redisConfig *RedisConfig
var logger *Logger
if config != nil {
logger = NewLogger(config.LogLevel)
// Initialize Redis config if not present
if config.Redis == nil {
config.Redis = &RedisConfig{}
}
// Apply environment variable fallbacks for fields not set in config
// This allows env vars to be used as optional overrides
config.Redis.ApplyEnvFallbacks()
// Apply defaults after env fallbacks
config.Redis.ApplyDefaults()
redisConfig = config.Redis
}
globalCacheManagerInstance = &CacheManager{
manager: GetUniversalCacheManager(nil),
manager: GetUniversalCacheManagerWithConfig(logger, redisConfig),
}
})
return globalCacheManagerInstance
@@ -34,7 +61,7 @@ func GetGlobalCacheManager(wg *sync.WaitGroup) *CacheManager {
func (cm *CacheManager) GetSharedTokenBlacklist() CacheInterface {
cm.mu.RLock()
defer cm.mu.RUnlock()
return &CacheInterfaceWrapper{cache: cm.manager.GetBlacklistCache()}
return &CacheInterfaceWrapper{cache: cm.manager.GetBlacklistCache(), managed: true}
}
// GetSharedTokenCache returns the shared token cache
@@ -61,6 +88,22 @@ func (cm *CacheManager) GetSharedJWKCache() JWKCacheInterface {
return &JWKCache{cache: cm.manager.GetJWKCache()}
}
// GetSharedIntrospectionCache returns the shared token introspection cache
// for caching OAuth 2.0 Token Introspection (RFC 7662) results
func (cm *CacheManager) GetSharedIntrospectionCache() CacheInterface {
cm.mu.RLock()
defer cm.mu.RUnlock()
return &CacheInterfaceWrapper{cache: cm.manager.GetIntrospectionCache(), managed: true}
}
// GetSharedTokenTypeCache returns the shared token type cache
// for caching token type detection results to improve performance
func (cm *CacheManager) GetSharedTokenTypeCache() CacheInterface {
cm.mu.RLock()
defer cm.mu.RUnlock()
return &CacheInterfaceWrapper{cache: cm.manager.GetTokenTypeCache(), managed: true}
}
// Close gracefully shuts down all cache components
func (cm *CacheManager) Close() error {
cm.mu.Lock()
@@ -78,12 +121,13 @@ func CleanupGlobalCacheManager() error {
// CacheInterfaceWrapper wraps UniversalCache to implement CacheInterface
type CacheInterfaceWrapper struct {
cache *UniversalCache
cache *UniversalCache
managed bool // If true, cache is managed globally and Close() is a no-op
}
// Set stores a value
func (c *CacheInterfaceWrapper) Set(key string, value interface{}, ttl time.Duration) {
c.cache.Set(key, value, ttl)
_ = c.cache.Set(key, value, ttl) // Safe to ignore: cache set failures are non-critical
}
// Get retrieves a value
@@ -106,11 +150,17 @@ func (c *CacheInterfaceWrapper) Cleanup() {
c.cache.Cleanup()
}
// Close shuts down the cache
// Close shuts down the cache if it's not managed globally.
// For managed caches (from UniversalCacheManager), this is a no-op to prevent log flooding
// when multiple plugin instances are closed during Traefik configuration reloads.
func (c *CacheInterfaceWrapper) Close() {
// Close the underlying cache to stop goroutines
if c.managed {
// Cache is managed globally by UniversalCacheManager, so we don't close it here.
return
}
// Standalone cache - close it properly to stop cleanup goroutines
if c.cache != nil {
c.cache.Close()
_ = c.cache.Close() // Safe to ignore: closing cache is best-effort during shutdown
}
}
-314
View File
@@ -1,314 +0,0 @@
package traefikoidc
import (
"fmt"
"sync"
"testing"
"time"
)
// Helper function to ensure we have a working cache manager for tests
func getTestCacheManager(t *testing.T) *CacheManager {
cm := GetGlobalCacheManager(&sync.WaitGroup{})
if cm == nil {
t.Fatal("Failed to get cache manager")
}
if cm.manager == nil {
t.Fatal("Cache manager has nil internal manager")
}
return cm
}
// TestCacheManager_Close tests cache manager close functionality
func TestCacheManager_Close(t *testing.T) {
// Get a fresh cache manager
wg := &sync.WaitGroup{}
cm := GetGlobalCacheManager(wg)
if cm == nil {
t.Fatal("Expected cache manager to be created")
}
// Test closing the cache manager
err := cm.Close()
if err != nil {
t.Errorf("Unexpected error closing cache manager: %v", err)
}
}
// TestCleanupGlobalCacheManager tests global cleanup
func TestCleanupGlobalCacheManager(t *testing.T) {
// Test cleanup when no instance exists (should not error)
originalInstance := globalCacheManagerInstance
globalCacheManagerInstance = nil
err := CleanupGlobalCacheManager()
if err != nil {
t.Errorf("Unexpected error during cleanup of nil instance: %v", err)
}
// Restore original instance
globalCacheManagerInstance = originalInstance
}
// TestCacheInterfaceWrapper_Delete tests delete functionality
func TestCacheInterfaceWrapper_Delete(t *testing.T) {
cm := getTestCacheManager(t)
cache := cm.GetSharedTokenBlacklist()
// Add an item
cache.Set("test-key", "test-value", time.Hour)
// Verify it exists
value, found := cache.Get("test-key")
if !found {
t.Fatal("Expected key to be found after setting")
}
if value != "test-value" {
t.Errorf("Expected 'test-value', got %v", value)
}
// Delete it
cache.Delete("test-key")
// Verify it's gone
_, found = cache.Get("test-key")
if found {
t.Error("Expected key to be deleted")
}
}
// TestCacheInterfaceWrapper_Size tests size functionality
func TestCacheInterfaceWrapper_Size(t *testing.T) {
cm := getTestCacheManager(t)
cache := cm.GetSharedTokenBlacklist()
// Clear cache first
cache.Clear()
// Check initial size
initialSize := cache.Size()
if initialSize != 0 {
t.Errorf("Expected initial size 0, got %d", initialSize)
}
// Add some items
cache.Set("key1", "value1", time.Hour)
cache.Set("key2", "value2", time.Hour)
// Check size increased
newSize := cache.Size()
if newSize != 2 {
t.Errorf("Expected size 2, got %d", newSize)
}
}
// TestCacheInterfaceWrapper_Clear tests clear functionality
func TestCacheInterfaceWrapper_Clear(t *testing.T) {
cm := getTestCacheManager(t)
cache := cm.GetSharedTokenBlacklist()
// Add some items
cache.Set("key1", "value1", time.Hour)
cache.Set("key2", "value2", time.Hour)
// Verify items exist
size := cache.Size()
if size != 2 {
t.Errorf("Expected 2 items before clear, got %d", size)
}
// Clear all
cache.Clear()
// Verify cache is empty
size = cache.Size()
if size != 0 {
t.Errorf("Expected 0 items after clear, got %d", size)
}
// Verify specific items are gone
_, found := cache.Get("key1")
if found {
t.Error("Expected key1 to be cleared")
}
_, found = cache.Get("key2")
if found {
t.Error("Expected key2 to be cleared")
}
}
// TestCacheInterfaceWrapper_Close tests wrapper close functionality
func TestCacheInterfaceWrapper_Close(t *testing.T) {
cm := getTestCacheManager(t)
cache := cm.GetSharedTokenBlacklist()
// Test close - should not panic
wrapper, ok := cache.(*CacheInterfaceWrapper)
if !ok {
t.Fatal("Expected CacheInterfaceWrapper")
}
wrapper.Close() // Should not panic
// Test close with nil cache
nilWrapper := &CacheInterfaceWrapper{cache: nil}
nilWrapper.Close() // Should not panic
}
// TestCacheInterfaceWrapper_GetStats tests stats functionality
func TestCacheInterfaceWrapper_GetStats(t *testing.T) {
cm := getTestCacheManager(t)
cache := cm.GetSharedTokenBlacklist()
wrapper, ok := cache.(*CacheInterfaceWrapper)
if !ok {
t.Fatal("Expected CacheInterfaceWrapper")
}
// Get stats
stats := wrapper.GetStats()
if stats == nil {
t.Error("Expected non-nil stats")
}
// Stats should be accessible (len() never returns negative values)
// Just verify it's accessible by checking it's not nil (already done above)
}
// TestCacheInterfaceWrapper_Cleanup tests cleanup functionality
func TestCacheInterfaceWrapper_Cleanup(t *testing.T) {
cm := getTestCacheManager(t)
cache := cm.GetSharedTokenBlacklist()
// Add an item that will expire quickly
cache.Set("expire-key", "expire-value", time.Millisecond)
// Wait for expiration
time.Sleep(10 * time.Millisecond)
// Trigger cleanup
cache.Cleanup()
// Item should be cleaned up
_, found := cache.Get("expire-key")
if found {
t.Error("Expected expired key to be cleaned up")
}
}
// TestCacheInterfaceWrapper_SetMaxSize tests max size setting
func TestCacheInterfaceWrapper_SetMaxSize(t *testing.T) {
cm := getTestCacheManager(t)
cache := cm.GetSharedTokenBlacklist()
// Test setting max size (should not panic)
cache.SetMaxSize(1000)
// We can't easily verify the size was set without exposing internals,
// but we can ensure the method doesn't panic
}
// TestGetSharedCaches tests getting shared cache instances
func TestGetSharedCaches(t *testing.T) {
cm := getTestCacheManager(t)
// Test getting shared token blacklist
blacklist := cm.GetSharedTokenBlacklist()
if blacklist == nil {
t.Error("Expected non-nil token blacklist")
}
// Test getting shared token cache
tokenCache := cm.GetSharedTokenCache()
if tokenCache == nil {
t.Error("Expected non-nil token cache")
}
// Test getting shared metadata cache
metadataCache := cm.GetSharedMetadataCache()
if metadataCache == nil {
t.Error("Expected non-nil metadata cache")
}
// Test getting shared JWK cache
jwkCache := cm.GetSharedJWKCache()
if jwkCache == nil {
t.Error("Expected non-nil JWK cache")
}
}
// TestConcurrentCacheAccess tests thread safety
func TestConcurrentCacheAccess(t *testing.T) {
cm := getTestCacheManager(t)
cache := cm.GetSharedTokenBlacklist()
var wg sync.WaitGroup
goroutines := 10
iterations := 10
// Concurrent operations
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
for j := 0; j < iterations; j++ {
key := fmt.Sprintf("key-%d-%d", id, j)
value := fmt.Sprintf("value-%d-%d", id, j)
cache.Set(key, value, time.Hour)
retrieved, found := cache.Get(key)
if found && retrieved != value {
t.Errorf("Concurrent access failed: expected %s, got %v", value, retrieved)
}
cache.Delete(key)
}
}(i)
}
wg.Wait()
}
// Benchmark tests for performance
func BenchmarkCacheInterfaceWrapper_Set(b *testing.B) {
t := &testing.T{}
cm := getTestCacheManager(t)
cache := cm.GetSharedTokenBlacklist()
b.ResetTimer()
for i := 0; i < b.N; i++ {
cache.Set("benchmark-key", "benchmark-value", time.Hour)
}
}
func BenchmarkCacheInterfaceWrapper_Get(b *testing.B) {
t := &testing.T{}
cm := getTestCacheManager(t)
cache := cm.GetSharedTokenBlacklist()
// Pre-populate cache
cache.Set("benchmark-key", "benchmark-value", time.Hour)
b.ResetTimer()
for i := 0; i < b.N; i++ {
cache.Get("benchmark-key")
}
}
func BenchmarkCacheInterfaceWrapper_Delete(b *testing.B) {
t := &testing.T{}
cm := getTestCacheManager(t)
cache := cm.GetSharedTokenBlacklist()
b.ResetTimer()
for i := 0; i < b.N; i++ {
b.StopTimer()
key := fmt.Sprintf("benchmark-key-%d", i)
cache.Set(key, "value", time.Hour)
b.StartTimer()
cache.Delete(key)
}
}
+1854
View File
File diff suppressed because it is too large Load Diff
-319
View File
@@ -1,319 +0,0 @@
// Package circuit_breaker provides circuit breaker implementation for resilience
package circuit_breaker
import (
"context"
"fmt"
"sync"
"sync/atomic"
"time"
)
// CircuitBreakerState represents the current state of a circuit breaker.
// The circuit breaker pattern prevents cascading failures by monitoring
// error rates and temporarily blocking requests to failing services.
type CircuitBreakerState int
// Circuit breaker states following the standard pattern:
// Closed: Normal operation, requests flow through
// Open: Circuit is tripped, requests are blocked
// HalfOpen: Testing state, limited requests allowed to test recovery
const (
// CircuitBreakerClosed allows all requests through (normal operation)
CircuitBreakerClosed CircuitBreakerState = iota
// CircuitBreakerOpen blocks all requests (service is failing)
CircuitBreakerOpen
// CircuitBreakerHalfOpen allows limited requests to test service recovery
CircuitBreakerHalfOpen
)
// String returns a string representation of the circuit breaker state
func (s CircuitBreakerState) String() string {
switch s {
case CircuitBreakerClosed:
return "closed"
case CircuitBreakerOpen:
return "open"
case CircuitBreakerHalfOpen:
return "half-open"
default:
return "unknown"
}
}
// Logger interface for dependency injection
type Logger interface {
Infof(format string, args ...interface{})
Errorf(format string, args ...interface{})
Debugf(format string, args ...interface{})
}
// BaseRecoveryMechanism interface for common functionality
type BaseRecoveryMechanism interface {
RecordRequest()
RecordSuccess()
RecordFailure()
GetBaseMetrics() map[string]interface{}
LogInfo(format string, args ...interface{})
LogError(format string, args ...interface{})
LogDebug(format string, args ...interface{})
}
// CircuitBreaker implements the circuit breaker pattern for external service calls.
// It monitors failure rates and automatically opens the circuit when failures
// exceed the threshold, preventing further requests until the service recovers.
type CircuitBreaker struct {
// baseRecovery provides common functionality
baseRecovery BaseRecoveryMechanism
// maxFailures is the threshold for opening the circuit
maxFailures int
// timeout is how long to wait before allowing requests in half-open state
timeout time.Duration
// resetTimeout is how long to wait before transitioning from open to half-open
resetTimeout time.Duration
// state tracks the current circuit breaker state
state CircuitBreakerState
// failures counts consecutive failures
failures int64
// lastFailureTime records when the last failure occurred
lastFailureTime time.Time
// mutex protects shared state
mutex sync.RWMutex
// logger for debugging and monitoring
logger Logger
}
// CircuitBreakerConfig holds configuration parameters for circuit breakers.
// These settings control when the circuit opens and how it recovers.
type CircuitBreakerConfig struct {
// MaxFailures is the number of failures before opening the circuit
MaxFailures int `json:"max_failures"`
// Timeout is how long to wait before trying to recover (open -> half-open)
Timeout time.Duration `json:"timeout"`
// ResetTimeout is how long to wait before fully closing the circuit
ResetTimeout time.Duration `json:"reset_timeout"`
}
// DefaultCircuitBreakerConfig returns sensible default configuration for circuit breakers.
// Configured for typical web service scenarios with moderate tolerance for failures.
func DefaultCircuitBreakerConfig() CircuitBreakerConfig {
return CircuitBreakerConfig{
MaxFailures: 2,
Timeout: 60 * time.Second,
ResetTimeout: 30 * time.Second,
}
}
// NewCircuitBreaker creates a new circuit breaker with the specified configuration.
// The circuit breaker starts in the closed state, allowing all requests through.
func NewCircuitBreaker(config CircuitBreakerConfig, logger Logger, baseRecovery BaseRecoveryMechanism) *CircuitBreaker {
return &CircuitBreaker{
baseRecovery: baseRecovery,
maxFailures: config.MaxFailures,
timeout: config.Timeout,
resetTimeout: config.ResetTimeout,
state: CircuitBreakerClosed,
logger: logger,
}
}
// ExecuteWithContext executes a function through the circuit breaker with context.
// It checks if requests are allowed, executes the function, and updates the circuit state
// based on the result. Implements the ErrorRecoveryMechanism interface.
func (cb *CircuitBreaker) ExecuteWithContext(ctx context.Context, fn func() error) error {
if cb.baseRecovery != nil {
cb.baseRecovery.RecordRequest()
}
if !cb.allowRequest() {
return fmt.Errorf("circuit breaker is open")
}
err := fn()
if err != nil {
cb.recordFailure()
if cb.baseRecovery != nil {
cb.baseRecovery.RecordFailure()
}
return err
}
cb.recordSuccess()
if cb.baseRecovery != nil {
cb.baseRecovery.RecordSuccess()
}
return nil
}
// Execute executes a function through the circuit breaker without context.
// This is provided for backward compatibility with existing code.
func (cb *CircuitBreaker) Execute(fn func() error) error {
return cb.ExecuteWithContext(context.Background(), fn)
}
// allowRequest determines whether to allow a request based on the circuit state.
// Handles state transitions from open to half-open based on timeout.
func (cb *CircuitBreaker) allowRequest() bool {
cb.mutex.Lock()
defer cb.mutex.Unlock()
now := time.Now()
switch cb.state {
case CircuitBreakerClosed:
return true
case CircuitBreakerOpen:
if now.Sub(cb.lastFailureTime) > cb.timeout {
cb.state = CircuitBreakerHalfOpen
if cb.logger != nil {
cb.logger.Infof("Circuit breaker transitioning to half-open state")
}
return true
}
return false
case CircuitBreakerHalfOpen:
return true
default:
return false
}
}
// recordFailure records a failure and potentially opens the circuit.
// Updates failure count and triggers state transitions when thresholds are exceeded.
func (cb *CircuitBreaker) recordFailure() {
cb.mutex.Lock()
defer cb.mutex.Unlock()
cb.failures++
cb.lastFailureTime = time.Now()
switch cb.state {
case CircuitBreakerClosed:
if cb.failures >= int64(cb.maxFailures) {
cb.state = CircuitBreakerOpen
if cb.baseRecovery != nil {
cb.baseRecovery.LogError("Circuit breaker opened after %d failures", cb.failures)
}
}
case CircuitBreakerHalfOpen:
cb.state = CircuitBreakerOpen
if cb.baseRecovery != nil {
cb.baseRecovery.LogError("Circuit breaker returned to open state after failure in half-open")
}
}
}
// recordSuccess records a successful request and potentially closes the circuit.
// Resets failure count and transitions from half-open to closed state on success.
func (cb *CircuitBreaker) recordSuccess() {
cb.mutex.Lock()
defer cb.mutex.Unlock()
switch cb.state {
case CircuitBreakerHalfOpen:
cb.failures = 0
cb.state = CircuitBreakerClosed
if cb.baseRecovery != nil {
cb.baseRecovery.LogInfo("Circuit breaker closed after successful request in half-open state")
}
case CircuitBreakerClosed:
cb.failures = 0
}
}
// GetState returns the current state of the circuit breaker.
// Thread-safe method for monitoring circuit breaker status.
func (cb *CircuitBreaker) GetState() CircuitBreakerState {
cb.mutex.RLock()
defer cb.mutex.RUnlock()
return cb.state
}
// Reset resets the circuit breaker to its initial closed state.
// Clears failure count and state, effectively recovering from any open state.
func (cb *CircuitBreaker) Reset() {
cb.mutex.Lock()
defer cb.mutex.Unlock()
cb.state = CircuitBreakerClosed
atomic.StoreInt64(&cb.failures, 0)
if cb.baseRecovery != nil {
cb.baseRecovery.LogInfo("Circuit breaker has been reset")
}
}
// IsAvailable returns whether the circuit breaker is currently allowing requests.
// This provides a quick way to check if the service is available.
func (cb *CircuitBreaker) IsAvailable() bool {
return cb.allowRequest()
}
// GetMetrics returns comprehensive metrics about the circuit breaker.
// Includes state information, failure counts, configuration, and base metrics.
func (cb *CircuitBreaker) GetMetrics() map[string]interface{} {
cb.mutex.RLock()
state := cb.state
failures := cb.failures
lastFailureTime := cb.lastFailureTime
cb.mutex.RUnlock()
var metrics map[string]interface{}
if cb.baseRecovery != nil {
metrics = cb.baseRecovery.GetBaseMetrics()
} else {
metrics = make(map[string]interface{})
}
metrics["state"] = state.String()
metrics["current_failures"] = failures
metrics["max_failures"] = cb.maxFailures
metrics["timeout"] = cb.timeout.String()
metrics["reset_timeout"] = cb.resetTimeout.String()
if !lastFailureTime.IsZero() {
metrics["last_failure_time"] = lastFailureTime
metrics["time_since_last_failure"] = time.Since(lastFailureTime).String()
}
return metrics
}
// GetFailureCount returns the current failure count
func (cb *CircuitBreaker) GetFailureCount() int64 {
cb.mutex.RLock()
defer cb.mutex.RUnlock()
return cb.failures
}
// GetLastFailureTime returns the time of the last failure
func (cb *CircuitBreaker) GetLastFailureTime() time.Time {
cb.mutex.RLock()
defer cb.mutex.RUnlock()
return cb.lastFailureTime
}
// IsOpen returns true if the circuit breaker is in open state
func (cb *CircuitBreaker) IsOpen() bool {
cb.mutex.RLock()
defer cb.mutex.RUnlock()
return cb.state == CircuitBreakerOpen
}
// IsClosed returns true if the circuit breaker is in closed state
func (cb *CircuitBreaker) IsClosed() bool {
cb.mutex.RLock()
defer cb.mutex.RUnlock()
return cb.state == CircuitBreakerClosed
}
// IsHalfOpen returns true if the circuit breaker is in half-open state
func (cb *CircuitBreaker) IsHalfOpen() bool {
cb.mutex.RLock()
defer cb.mutex.RUnlock()
return cb.state == CircuitBreakerHalfOpen
}
-981
View File
@@ -1,981 +0,0 @@
package circuit_breaker
import (
"context"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
)
// Mock implementations for testing
type mockLogger struct {
infoLogs []string
errorLogs []string
debugLogs []string
mu sync.RWMutex
}
func (m *mockLogger) Infof(format string, args ...interface{}) {
m.mu.Lock()
defer m.mu.Unlock()
m.infoLogs = append(m.infoLogs, fmt.Sprintf(format, args...))
}
func (m *mockLogger) Errorf(format string, args ...interface{}) {
m.mu.Lock()
defer m.mu.Unlock()
m.errorLogs = append(m.errorLogs, fmt.Sprintf(format, args...))
}
func (m *mockLogger) Debugf(format string, args ...interface{}) {
m.mu.Lock()
defer m.mu.Unlock()
m.debugLogs = append(m.debugLogs, fmt.Sprintf(format, args...))
}
func (m *mockLogger) getInfoLogs() []string {
m.mu.RLock()
defer m.mu.RUnlock()
result := make([]string, len(m.infoLogs))
copy(result, m.infoLogs)
return result
}
//lint:ignore U1000 May be needed for future error log verification tests
func (m *mockLogger) getErrorLogs() []string {
m.mu.RLock()
defer m.mu.RUnlock()
result := make([]string, len(m.errorLogs))
copy(result, m.errorLogs)
return result
}
//lint:ignore U1000 May be needed for future test isolation
func (m *mockLogger) reset() {
m.mu.Lock()
defer m.mu.Unlock()
m.infoLogs = nil
m.errorLogs = nil
m.debugLogs = nil
}
type mockBaseRecoveryMechanism struct {
requestCount int64
successCount int64
failureCount int64
infoLogs []string
errorLogs []string
debugLogs []string
baseMetrics map[string]interface{}
mu sync.RWMutex
}
func newMockBaseRecovery() *mockBaseRecoveryMechanism {
return &mockBaseRecoveryMechanism{
baseMetrics: make(map[string]interface{}),
}
}
func (m *mockBaseRecoveryMechanism) RecordRequest() {
atomic.AddInt64(&m.requestCount, 1)
}
func (m *mockBaseRecoveryMechanism) RecordSuccess() {
atomic.AddInt64(&m.successCount, 1)
}
func (m *mockBaseRecoveryMechanism) RecordFailure() {
atomic.AddInt64(&m.failureCount, 1)
}
func (m *mockBaseRecoveryMechanism) GetBaseMetrics() map[string]interface{} {
m.mu.RLock()
defer m.mu.RUnlock()
result := make(map[string]interface{})
for k, v := range m.baseMetrics {
result[k] = v
}
result["total_requests"] = atomic.LoadInt64(&m.requestCount)
result["total_successes"] = atomic.LoadInt64(&m.successCount)
result["total_failures"] = atomic.LoadInt64(&m.failureCount)
return result
}
func (m *mockBaseRecoveryMechanism) LogInfo(format string, args ...interface{}) {
m.mu.Lock()
defer m.mu.Unlock()
m.infoLogs = append(m.infoLogs, fmt.Sprintf(format, args...))
}
func (m *mockBaseRecoveryMechanism) LogError(format string, args ...interface{}) {
m.mu.Lock()
defer m.mu.Unlock()
m.errorLogs = append(m.errorLogs, fmt.Sprintf(format, args...))
}
func (m *mockBaseRecoveryMechanism) LogDebug(format string, args ...interface{}) {
m.mu.Lock()
defer m.mu.Unlock()
m.debugLogs = append(m.debugLogs, fmt.Sprintf(format, args...))
}
func (m *mockBaseRecoveryMechanism) getRequestCount() int64 {
return atomic.LoadInt64(&m.requestCount)
}
func (m *mockBaseRecoveryMechanism) getSuccessCount() int64 {
return atomic.LoadInt64(&m.successCount)
}
func (m *mockBaseRecoveryMechanism) getFailureCount() int64 {
return atomic.LoadInt64(&m.failureCount)
}
func (m *mockBaseRecoveryMechanism) getInfoLogs() []string {
m.mu.RLock()
defer m.mu.RUnlock()
result := make([]string, len(m.infoLogs))
copy(result, m.infoLogs)
return result
}
func (m *mockBaseRecoveryMechanism) getErrorLogs() []string {
m.mu.RLock()
defer m.mu.RUnlock()
result := make([]string, len(m.errorLogs))
copy(result, m.errorLogs)
return result
}
func TestCircuitBreakerState_String(t *testing.T) {
tests := []struct {
state CircuitBreakerState
expected string
}{
{CircuitBreakerClosed, "closed"},
{CircuitBreakerOpen, "open"},
{CircuitBreakerHalfOpen, "half-open"},
{CircuitBreakerState(999), "unknown"},
}
for _, tt := range tests {
t.Run(tt.expected, func(t *testing.T) {
result := tt.state.String()
if result != tt.expected {
t.Errorf("Expected %s, got %s", tt.expected, result)
}
})
}
}
func TestDefaultCircuitBreakerConfig(t *testing.T) {
config := DefaultCircuitBreakerConfig()
if config.MaxFailures != 2 {
t.Errorf("Expected MaxFailures to be 2, got %d", config.MaxFailures)
}
if config.Timeout != 60*time.Second {
t.Errorf("Expected Timeout to be 60s, got %v", config.Timeout)
}
if config.ResetTimeout != 30*time.Second {
t.Errorf("Expected ResetTimeout to be 30s, got %v", config.ResetTimeout)
}
}
func TestNewCircuitBreaker(t *testing.T) {
config := CircuitBreakerConfig{
MaxFailures: 3,
Timeout: 30 * time.Second,
ResetTimeout: 15 * time.Second,
}
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
if cb == nil {
t.Fatal("NewCircuitBreaker returned nil")
}
if cb.maxFailures != 3 {
t.Errorf("Expected maxFailures to be 3, got %d", cb.maxFailures)
}
if cb.timeout != 30*time.Second {
t.Errorf("Expected timeout to be 30s, got %v", cb.timeout)
}
if cb.resetTimeout != 15*time.Second {
t.Errorf("Expected resetTimeout to be 15s, got %v", cb.resetTimeout)
}
if cb.state != CircuitBreakerClosed {
t.Errorf("Expected initial state to be Closed, got %v", cb.state)
}
if cb.logger != logger {
t.Error("Expected logger to be set")
}
if cb.baseRecovery != baseRecovery {
t.Error("Expected baseRecovery to be set")
}
}
func TestCircuitBreaker_ExecuteWithContext_Success(t *testing.T) {
config := CircuitBreakerConfig{
MaxFailures: 2,
Timeout: time.Second,
ResetTimeout: time.Second,
}
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
callCount := 0
testFunc := func() error {
callCount++
return nil
}
ctx := context.Background()
err := cb.ExecuteWithContext(ctx, testFunc)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if callCount != 1 {
t.Errorf("Expected function to be called once, got %d", callCount)
}
if cb.GetState() != CircuitBreakerClosed {
t.Errorf("Expected state to remain Closed, got %v", cb.GetState())
}
if baseRecovery.getRequestCount() != 1 {
t.Errorf("Expected 1 request recorded, got %d", baseRecovery.getRequestCount())
}
if baseRecovery.getSuccessCount() != 1 {
t.Errorf("Expected 1 success recorded, got %d", baseRecovery.getSuccessCount())
}
}
func TestCircuitBreaker_ExecuteWithContext_Failure(t *testing.T) {
config := CircuitBreakerConfig{
MaxFailures: 2,
Timeout: time.Second,
ResetTimeout: time.Second,
}
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
testError := fmt.Errorf("test error")
testFunc := func() error {
return testError
}
ctx := context.Background()
err := cb.ExecuteWithContext(ctx, testFunc)
if err != testError {
t.Errorf("Expected test error, got %v", err)
}
if cb.GetState() != CircuitBreakerClosed {
t.Errorf("Expected state to remain Closed after single failure, got %v", cb.GetState())
}
if baseRecovery.getRequestCount() != 1 {
t.Errorf("Expected 1 request recorded, got %d", baseRecovery.getRequestCount())
}
if baseRecovery.getFailureCount() != 1 {
t.Errorf("Expected 1 failure recorded, got %d", baseRecovery.getFailureCount())
}
}
func TestCircuitBreaker_Execute(t *testing.T) {
config := CircuitBreakerConfig{
MaxFailures: 1,
Timeout: time.Second,
ResetTimeout: time.Second,
}
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
callCount := 0
testFunc := func() error {
callCount++
return nil
}
err := cb.Execute(testFunc)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if callCount != 1 {
t.Errorf("Expected function to be called once, got %d", callCount)
}
}
func TestCircuitBreaker_OpenAfterMaxFailures(t *testing.T) {
config := CircuitBreakerConfig{
MaxFailures: 2,
Timeout: time.Second,
ResetTimeout: time.Second,
}
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
testError := fmt.Errorf("test error")
testFunc := func() error {
return testError
}
ctx := context.Background()
// First failure
err := cb.ExecuteWithContext(ctx, testFunc)
if err != testError {
t.Errorf("Expected test error on first failure, got %v", err)
}
if cb.GetState() != CircuitBreakerClosed {
t.Errorf("Expected state to remain Closed after first failure, got %v", cb.GetState())
}
// Second failure - should open circuit
err = cb.ExecuteWithContext(ctx, testFunc)
if err != testError {
t.Errorf("Expected test error on second failure, got %v", err)
}
if cb.GetState() != CircuitBreakerOpen {
t.Errorf("Expected state to be Open after max failures, got %v", cb.GetState())
}
// Third attempt - should be blocked
callCount := 0
blockedFunc := func() error {
callCount++
return nil
}
err = cb.ExecuteWithContext(ctx, blockedFunc)
if err == nil {
t.Error("Expected error when circuit is open")
}
if callCount != 0 {
t.Errorf("Expected function not to be called when circuit is open, got %d calls", callCount)
}
}
func TestCircuitBreaker_HalfOpenTransition(t *testing.T) {
config := CircuitBreakerConfig{
MaxFailures: 1,
Timeout: 10 * time.Millisecond, // Very short for testing
ResetTimeout: time.Second,
}
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
// Trigger circuit opening
testError := fmt.Errorf("test error")
err := cb.ExecuteWithContext(context.Background(), func() error {
return testError
})
if err != testError {
t.Errorf("Expected test error, got %v", err)
}
if cb.GetState() != CircuitBreakerOpen {
t.Errorf("Expected state to be Open, got %v", cb.GetState())
}
// Wait for timeout
time.Sleep(15 * time.Millisecond)
// Next request should transition to half-open
callCount := 0
testFunc := func() error {
callCount++
return nil
}
err = cb.ExecuteWithContext(context.Background(), testFunc)
if err != nil {
t.Errorf("Expected no error in half-open state, got %v", err)
}
if callCount != 1 {
t.Errorf("Expected function to be called in half-open state, got %d calls", callCount)
}
if cb.GetState() != CircuitBreakerClosed {
t.Errorf("Expected state to be Closed after successful half-open request, got %v", cb.GetState())
}
}
func TestCircuitBreaker_HalfOpenFailureReturnsToOpen(t *testing.T) {
config := CircuitBreakerConfig{
MaxFailures: 1,
Timeout: 10 * time.Millisecond,
ResetTimeout: time.Second,
}
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
// Trigger circuit opening
testError := fmt.Errorf("test error")
_ = cb.ExecuteWithContext(context.Background(), func() error {
return testError
})
if cb.GetState() != CircuitBreakerOpen {
t.Errorf("Expected state to be Open, got %v", cb.GetState())
}
// Wait for timeout to allow half-open transition
time.Sleep(15 * time.Millisecond)
// First call should transition to half-open, but we'll force it by checking allowRequest
if !cb.allowRequest() {
t.Error("Expected allowRequest to return true after timeout")
}
if cb.GetState() != CircuitBreakerHalfOpen {
t.Errorf("Expected state to be HalfOpen, got %v", cb.GetState())
}
// Failure in half-open should return to open
err := cb.ExecuteWithContext(context.Background(), func() error {
return testError
})
if err != testError {
t.Errorf("Expected test error, got %v", err)
}
if cb.GetState() != CircuitBreakerOpen {
t.Errorf("Expected state to return to Open after half-open failure, got %v", cb.GetState())
}
}
func TestCircuitBreaker_Reset(t *testing.T) {
config := CircuitBreakerConfig{
MaxFailures: 1,
Timeout: time.Second,
ResetTimeout: time.Second,
}
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
// Trigger circuit opening
testError := fmt.Errorf("test error")
_ = cb.ExecuteWithContext(context.Background(), func() error {
return testError
})
if cb.GetState() != CircuitBreakerOpen {
t.Errorf("Expected state to be Open, got %v", cb.GetState())
}
// Reset circuit
cb.Reset()
if cb.GetState() != CircuitBreakerClosed {
t.Errorf("Expected state to be Closed after reset, got %v", cb.GetState())
}
if cb.GetFailureCount() != 0 {
t.Errorf("Expected failure count to be 0 after reset, got %d", cb.GetFailureCount())
}
// Should allow requests again
callCount := 0
err := cb.ExecuteWithContext(context.Background(), func() error {
callCount++
return nil
})
if err != nil {
t.Errorf("Expected no error after reset, got %v", err)
}
if callCount != 1 {
t.Errorf("Expected function to be called after reset, got %d calls", callCount)
}
}
func TestCircuitBreaker_IsAvailable(t *testing.T) {
config := CircuitBreakerConfig{
MaxFailures: 1,
Timeout: 10 * time.Millisecond,
ResetTimeout: time.Second,
}
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
// Initially available
if !cb.IsAvailable() {
t.Error("Expected circuit breaker to be available initially")
}
// Trigger opening
testError := fmt.Errorf("test error")
cb.ExecuteWithContext(context.Background(), func() error {
return testError
})
// Should not be available when open
if cb.IsAvailable() {
t.Error("Expected circuit breaker to be unavailable when open")
}
// Wait for timeout
time.Sleep(15 * time.Millisecond)
// Should be available again after timeout (half-open)
if !cb.IsAvailable() {
t.Error("Expected circuit breaker to be available after timeout")
}
}
func TestCircuitBreaker_StateCheckers(t *testing.T) {
config := CircuitBreakerConfig{
MaxFailures: 1,
Timeout: 10 * time.Millisecond,
ResetTimeout: time.Second,
}
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
// Initially closed
if !cb.IsClosed() {
t.Error("Expected circuit breaker to be closed initially")
}
if cb.IsOpen() {
t.Error("Expected circuit breaker not to be open initially")
}
if cb.IsHalfOpen() {
t.Error("Expected circuit breaker not to be half-open initially")
}
// Trigger opening
testError := fmt.Errorf("test error")
cb.ExecuteWithContext(context.Background(), func() error {
return testError
})
// Should be open
if cb.IsClosed() {
t.Error("Expected circuit breaker not to be closed when open")
}
if !cb.IsOpen() {
t.Error("Expected circuit breaker to be open")
}
if cb.IsHalfOpen() {
t.Error("Expected circuit breaker not to be half-open when open")
}
// Wait for timeout and trigger half-open
time.Sleep(15 * time.Millisecond)
cb.allowRequest() // This will transition to half-open
// Should be half-open
if cb.IsClosed() {
t.Error("Expected circuit breaker not to be closed when half-open")
}
if cb.IsOpen() {
t.Error("Expected circuit breaker not to be open when half-open")
}
if !cb.IsHalfOpen() {
t.Error("Expected circuit breaker to be half-open")
}
}
func TestCircuitBreaker_GetMetrics(t *testing.T) {
config := CircuitBreakerConfig{
MaxFailures: 2,
Timeout: 30 * time.Second,
ResetTimeout: 15 * time.Second,
}
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
baseRecovery.baseMetrics["custom_metric"] = "custom_value"
cb := NewCircuitBreaker(config, logger, baseRecovery)
// Record some activity
testError := fmt.Errorf("test error")
cb.ExecuteWithContext(context.Background(), func() error {
return testError
})
metrics := cb.GetMetrics()
// Check circuit breaker specific metrics
if metrics["state"] != "closed" {
t.Errorf("Expected state to be 'closed', got %v", metrics["state"])
}
if metrics["current_failures"] != int64(1) {
t.Errorf("Expected current_failures to be 1, got %v", metrics["current_failures"])
}
if metrics["max_failures"] != 2 {
t.Errorf("Expected max_failures to be 2, got %v", metrics["max_failures"])
}
if metrics["timeout"] != "30s" {
t.Errorf("Expected timeout to be '30s', got %v", metrics["timeout"])
}
if metrics["reset_timeout"] != "15s" {
t.Errorf("Expected reset_timeout to be '15s', got %v", metrics["reset_timeout"])
}
// Check base metrics are included
if metrics["total_requests"] != int64(1) {
t.Errorf("Expected total_requests to be 1, got %v", metrics["total_requests"])
}
if metrics["custom_metric"] != "custom_value" {
t.Errorf("Expected custom_metric to be 'custom_value', got %v", metrics["custom_metric"])
}
// Check failure time metrics
if _, exists := metrics["last_failure_time"]; !exists {
t.Error("Expected last_failure_time to exist")
}
if _, exists := metrics["time_since_last_failure"]; !exists {
t.Error("Expected time_since_last_failure to exist")
}
}
func TestCircuitBreaker_GetMetrics_NoBaseRecovery(t *testing.T) {
config := DefaultCircuitBreakerConfig()
logger := &mockLogger{}
cb := NewCircuitBreaker(config, logger, nil)
metrics := cb.GetMetrics()
// Should still have circuit breaker metrics
if metrics["state"] != "closed" {
t.Errorf("Expected state to be 'closed', got %v", metrics["state"])
}
if metrics["max_failures"] != 2 {
t.Errorf("Expected max_failures to be 2, got %v", metrics["max_failures"])
}
// Should not have base metrics
if _, exists := metrics["total_requests"]; exists {
t.Error("Expected total_requests not to exist without base recovery")
}
}
func TestCircuitBreaker_GetLastFailureTime(t *testing.T) {
config := DefaultCircuitBreakerConfig()
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
// Initially should be zero
if !cb.GetLastFailureTime().IsZero() {
t.Error("Expected last failure time to be zero initially")
}
// Record a failure
before := time.Now()
testError := fmt.Errorf("test error")
cb.ExecuteWithContext(context.Background(), func() error {
return testError
})
after := time.Now()
lastFailure := cb.GetLastFailureTime()
if lastFailure.IsZero() {
t.Error("Expected last failure time to be set after failure")
}
if lastFailure.Before(before) || lastFailure.After(after) {
t.Errorf("Expected last failure time to be between %v and %v, got %v",
before, after, lastFailure)
}
}
func TestCircuitBreaker_ExecuteWithoutBaseRecovery(t *testing.T) {
config := DefaultCircuitBreakerConfig()
logger := &mockLogger{}
cb := NewCircuitBreaker(config, logger, nil)
callCount := 0
testFunc := func() error {
callCount++
return nil
}
err := cb.ExecuteWithContext(context.Background(), testFunc)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if callCount != 1 {
t.Errorf("Expected function to be called once, got %d", callCount)
}
// Should work fine without base recovery
if cb.GetState() != CircuitBreakerClosed {
t.Errorf("Expected state to be Closed, got %v", cb.GetState())
}
}
func TestCircuitBreaker_ConcurrentAccess(t *testing.T) {
config := CircuitBreakerConfig{
MaxFailures: 10, // Higher threshold for concurrent test
Timeout: 100 * time.Millisecond,
ResetTimeout: 50 * time.Millisecond,
}
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
const numGoroutines = 10
const numOperations = 50
var wg sync.WaitGroup
successCount := int64(0)
errorCount := int64(0)
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
for j := 0; j < numOperations; j++ {
err := cb.ExecuteWithContext(context.Background(), func() error {
// Simulate some failures
if j%10 == 9 { // Every 10th operation fails
return fmt.Errorf("simulated error")
}
return nil
})
if err != nil {
atomic.AddInt64(&errorCount, 1)
} else {
atomic.AddInt64(&successCount, 1)
}
// Intermittently check state and metrics
if j%5 == 0 {
cb.GetState()
cb.GetMetrics()
cb.IsAvailable()
}
}
}(i)
}
wg.Wait()
// Verify we got both successes and errors
finalSuccessCount := atomic.LoadInt64(&successCount)
finalErrorCount := atomic.LoadInt64(&errorCount)
if finalSuccessCount == 0 {
t.Error("Expected some successful operations")
}
if finalErrorCount == 0 {
t.Error("Expected some failed operations")
}
totalOperations := finalSuccessCount + finalErrorCount
expectedMax := int64(numGoroutines * numOperations)
if totalOperations > expectedMax {
t.Errorf("Expected at most %d operations, got %d", expectedMax, totalOperations)
}
t.Logf("Concurrent test completed: %d successes, %d errors, final state: %v",
finalSuccessCount, finalErrorCount, cb.GetState())
}
func TestCircuitBreaker_StateTransitionLogging(t *testing.T) {
config := CircuitBreakerConfig{
MaxFailures: 1,
Timeout: 10 * time.Millisecond,
ResetTimeout: time.Second,
}
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
// Trigger circuit opening
testError := fmt.Errorf("test error")
cb.ExecuteWithContext(context.Background(), func() error {
return testError
})
// Check that error was logged when circuit opened
errorLogs := baseRecovery.getErrorLogs()
if len(errorLogs) == 0 {
t.Error("Expected error log when circuit breaker opened")
} else {
if !contains(errorLogs, "Circuit breaker opened after") {
t.Errorf("Expected circuit opening log, got %v", errorLogs)
}
}
// Wait and trigger half-open
time.Sleep(15 * time.Millisecond)
// Successful request should close circuit and log
cb.ExecuteWithContext(context.Background(), func() error {
return nil
})
// Check that success was logged when circuit closed
infoLogs := baseRecovery.getInfoLogs()
if len(infoLogs) == 0 {
t.Error("Expected info log when circuit breaker closed")
} else {
if !contains(infoLogs, "Circuit breaker closed after successful request") {
t.Errorf("Expected circuit closing log, got %v", infoLogs)
}
}
// Reset should also be logged
cb.Reset()
infoLogs = baseRecovery.getInfoLogs()
if !contains(infoLogs, "Circuit breaker has been reset") {
t.Errorf("Expected reset log, got %v", infoLogs)
}
}
func TestCircuitBreaker_LoggerTransitionLogging(t *testing.T) {
config := CircuitBreakerConfig{
MaxFailures: 1,
Timeout: 10 * time.Millisecond,
ResetTimeout: time.Second,
}
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
// Wait for timeout and check half-open transition logging
testError := fmt.Errorf("test error")
cb.ExecuteWithContext(context.Background(), func() error {
return testError
})
// Wait for timeout
time.Sleep(15 * time.Millisecond)
// Next allowRequest call should log transition to half-open
cb.allowRequest()
infoLogs := logger.getInfoLogs()
if len(infoLogs) == 0 {
t.Error("Expected info log for half-open transition")
} else {
if !contains(infoLogs, "Circuit breaker transitioning to half-open state") {
t.Errorf("Expected half-open transition log, got %v", infoLogs)
}
}
}
// Helper function to check if a slice contains a string with substring
func contains(slice []string, substr string) bool {
for _, s := range slice {
if len(s) >= len(substr) && s[:len(substr)] == substr {
return true
}
}
return false
}
// Benchmark tests
func BenchmarkCircuitBreaker_ExecuteWithContext_Success(b *testing.B) {
config := DefaultCircuitBreakerConfig()
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
testFunc := func() error {
return nil
}
ctx := context.Background()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
cb.ExecuteWithContext(ctx, testFunc)
}
})
}
func BenchmarkCircuitBreaker_ExecuteWithContext_Failure(b *testing.B) {
config := CircuitBreakerConfig{
MaxFailures: 1000, // High threshold to avoid opening during benchmark
Timeout: time.Second,
ResetTimeout: time.Second,
}
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
testError := fmt.Errorf("test error")
testFunc := func() error {
return testError
}
ctx := context.Background()
b.ResetTimer()
for i := 0; i < b.N; i++ {
cb.ExecuteWithContext(ctx, testFunc)
}
}
func BenchmarkCircuitBreaker_GetState(b *testing.B) {
config := DefaultCircuitBreakerConfig()
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
cb.GetState()
}
})
}
func BenchmarkCircuitBreaker_GetMetrics(b *testing.B) {
config := DefaultCircuitBreakerConfig()
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
// Add some activity
for i := 0; i < 100; i++ {
if i%2 == 0 {
cb.ExecuteWithContext(context.Background(), func() error { return nil })
} else {
cb.ExecuteWithContext(context.Background(), func() error { return fmt.Errorf("error") })
}
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
cb.GetMetrics()
}
}
File diff suppressed because it is too large Load Diff
-428
View File
@@ -1,428 +0,0 @@
// Package config provides configuration management for the OIDC middleware
package config
import (
"fmt"
"net/http"
"strconv"
"strings"
)
const (
minEncryptionKeyLength = 16
ConstSessionTimeout = 86400
)
//lint:ignore U1000 May be referenced for default exclusion patterns
var defaultExcludedURLs = map[string]struct{}{
"/favicon.ico": {},
"/robots.txt": {},
"/health": {},
"/.well-known/": {},
"/metrics": {},
"/ping": {},
"/api/": {},
"/static/": {},
"/assets/": {},
"/js/": {},
"/css/": {},
"/images/": {},
"/fonts/": {},
}
// Settings manages configuration and initialization for the OIDC middleware
type Settings struct {
logger Logger
}
// Logger interface for dependency injection
type Logger interface {
Debug(msg string)
Debugf(format string, args ...interface{})
Info(msg string)
Infof(format string, args ...interface{})
Error(msg string)
Errorf(format string, args ...interface{})
}
// Config represents the configuration for the OIDC middleware
type Config struct {
ProviderURL string `json:"providerUrl"`
ClientID string `json:"clientId"`
ClientSecret string `json:"clientSecret"`
CallbackURL string `json:"callbackUrl"`
LogoutURL string `json:"logoutUrl"`
PostLogoutRedirectURI string `json:"postLogoutRedirectUri"`
SessionEncryptionKey string `json:"sessionEncryptionKey"`
ForceHTTPS bool `json:"forceHttps"`
LogLevel string `json:"logLevel"`
Scopes []string `json:"scopes"`
OverrideScopes bool `json:"overrideScopes"`
AllowedUsers []string `json:"allowedUsers"`
AllowedUserDomains []string `json:"allowedUserDomains"`
AllowedRolesAndGroups []string `json:"allowedRolesAndGroups"`
ExcludedURLs []string `json:"excludedUrls"`
EnablePKCE bool `json:"enablePkce"`
RateLimit int `json:"rateLimit"`
RefreshGracePeriodSeconds int `json:"refreshGracePeriodSeconds"`
Headers []HeaderConfig `json:"headers"`
HTTPClient *http.Client `json:"-"`
CookieDomain string `json:"cookieDomain"`
SecurityHeaders *SecurityHeadersConfig `json:"securityHeaders,omitempty"`
}
// HeaderConfig represents header template configuration
type HeaderConfig struct {
Name string `json:"name"`
Value string `json:"value"`
}
// SecurityHeadersConfig configures security headers for the plugin
type SecurityHeadersConfig struct {
// Enable security headers (default: true)
Enabled bool `json:"enabled"`
// Security profile: "default", "strict", "development", "api", or "custom"
Profile string `json:"profile"`
// Content Security Policy
ContentSecurityPolicy string `json:"contentSecurityPolicy,omitempty"`
// HSTS settings
StrictTransportSecurity bool `json:"strictTransportSecurity"`
StrictTransportSecurityMaxAge int `json:"strictTransportSecurityMaxAge"` // seconds
StrictTransportSecuritySubdomains bool `json:"strictTransportSecuritySubdomains"`
StrictTransportSecurityPreload bool `json:"strictTransportSecurityPreload"`
// Frame options: "DENY", "SAMEORIGIN", or "ALLOW-FROM uri"
FrameOptions string `json:"frameOptions,omitempty"`
// Content type options (default: "nosniff")
ContentTypeOptions string `json:"contentTypeOptions,omitempty"`
// XSS protection (default: "1; mode=block")
XSSProtection string `json:"xssProtection,omitempty"`
// Referrer policy
ReferrerPolicy string `json:"referrerPolicy,omitempty"`
// Permissions policy
PermissionsPolicy string `json:"permissionsPolicy,omitempty"`
// Cross-origin settings
CrossOriginEmbedderPolicy string `json:"crossOriginEmbedderPolicy,omitempty"`
CrossOriginOpenerPolicy string `json:"crossOriginOpenerPolicy,omitempty"`
CrossOriginResourcePolicy string `json:"crossOriginResourcePolicy,omitempty"`
// CORS settings
CORSEnabled bool `json:"corsEnabled"`
CORSAllowedOrigins []string `json:"corsAllowedOrigins,omitempty"`
CORSAllowedMethods []string `json:"corsAllowedMethods,omitempty"`
CORSAllowedHeaders []string `json:"corsAllowedHeaders,omitempty"`
CORSAllowCredentials bool `json:"corsAllowCredentials"`
CORSMaxAge int `json:"corsMaxAge"` // seconds
// Custom headers (in addition to standard security headers)
CustomHeaders map[string]string `json:"customHeaders,omitempty"`
// Security features
DisableServerHeader bool `json:"disableServerHeader"`
DisablePoweredByHeader bool `json:"disablePoweredByHeader"`
}
// NewSettings creates a new Settings instance
func NewSettings(logger Logger) *Settings {
return &Settings{
logger: logger,
}
}
// CreateConfig creates a default configuration
func CreateConfig() *Config {
return &Config{
LogLevel: "INFO",
ForceHTTPS: true,
EnablePKCE: true,
RateLimit: 10,
RefreshGracePeriodSeconds: 60,
Scopes: []string{"openid", "profile", "email"},
Headers: []HeaderConfig{},
SecurityHeaders: createDefaultSecurityConfig(),
}
}
// createDefaultSecurityConfig creates a default security headers configuration
func createDefaultSecurityConfig() *SecurityHeadersConfig {
return &SecurityHeadersConfig{
Enabled: true,
Profile: "default",
// Default security headers
StrictTransportSecurity: true,
StrictTransportSecurityMaxAge: 31536000, // 1 year
StrictTransportSecuritySubdomains: true,
StrictTransportSecurityPreload: true,
FrameOptions: "DENY",
ContentTypeOptions: "nosniff",
XSSProtection: "1; mode=block",
ReferrerPolicy: "strict-origin-when-cross-origin",
// CORS disabled by default
CORSEnabled: false,
CORSAllowedMethods: []string{"GET", "POST", "OPTIONS"},
CORSAllowedHeaders: []string{"Authorization", "Content-Type"},
CORSAllowCredentials: false,
CORSMaxAge: 86400, // 24 hours
// Security features
DisableServerHeader: true,
DisablePoweredByHeader: true,
}
}
// ToInternalSecurityConfig converts plugin SecurityHeadersConfig to internal security config
func (c *SecurityHeadersConfig) ToInternalSecurityConfig() interface{} {
if c == nil || !c.Enabled {
return nil
}
// Create the internal security config structure
config := map[string]interface{}{
"DevelopmentMode": false,
}
// Apply profile-based defaults
switch strings.ToLower(c.Profile) {
case "strict":
applyStrictProfile(config)
case "development":
applyDevelopmentProfile(config)
case "api":
applyAPIProfile(config)
case "custom":
// No defaults, use only what's explicitly configured
default: // "default"
applyDefaultProfile(config)
}
// Override with explicit configuration
if c.ContentSecurityPolicy != "" {
config["ContentSecurityPolicy"] = c.ContentSecurityPolicy
}
// HSTS configuration
if c.StrictTransportSecurity {
config["StrictTransportSecurityMaxAge"] = c.StrictTransportSecurityMaxAge
config["StrictTransportSecuritySubdomains"] = c.StrictTransportSecuritySubdomains
config["StrictTransportSecurityPreload"] = c.StrictTransportSecurityPreload
}
// Frame options
if c.FrameOptions != "" {
config["FrameOptions"] = c.FrameOptions
}
// Content type and XSS protection
if c.ContentTypeOptions != "" {
config["ContentTypeOptions"] = c.ContentTypeOptions
}
if c.XSSProtection != "" {
config["XSSProtection"] = c.XSSProtection
}
// Referrer and permissions policies
if c.ReferrerPolicy != "" {
config["ReferrerPolicy"] = c.ReferrerPolicy
}
if c.PermissionsPolicy != "" {
config["PermissionsPolicy"] = c.PermissionsPolicy
}
// Cross-origin policies
if c.CrossOriginEmbedderPolicy != "" {
config["CrossOriginEmbedderPolicy"] = c.CrossOriginEmbedderPolicy
}
if c.CrossOriginOpenerPolicy != "" {
config["CrossOriginOpenerPolicy"] = c.CrossOriginOpenerPolicy
}
if c.CrossOriginResourcePolicy != "" {
config["CrossOriginResourcePolicy"] = c.CrossOriginResourcePolicy
}
// CORS configuration
config["CORSEnabled"] = c.CORSEnabled
if len(c.CORSAllowedOrigins) > 0 {
config["CORSAllowedOrigins"] = c.CORSAllowedOrigins
}
if len(c.CORSAllowedMethods) > 0 {
config["CORSAllowedMethods"] = c.CORSAllowedMethods
}
if len(c.CORSAllowedHeaders) > 0 {
config["CORSAllowedHeaders"] = c.CORSAllowedHeaders
}
config["CORSAllowCredentials"] = c.CORSAllowCredentials
if c.CORSMaxAge > 0 {
config["CORSMaxAge"] = c.CORSMaxAge
}
// Custom headers
if len(c.CustomHeaders) > 0 {
config["CustomHeaders"] = c.CustomHeaders
}
// Security features
config["DisableServerHeader"] = c.DisableServerHeader
config["DisablePoweredByHeader"] = c.DisablePoweredByHeader
return config
}
// applyDefaultProfile applies default security settings
func applyDefaultProfile(config map[string]interface{}) {
config["ContentSecurityPolicy"] = "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'; img-src 'self' data: https:; font-src 'self' data:; connect-src 'self'; frame-ancestors 'none';"
config["FrameOptions"] = "DENY"
config["ContentTypeOptions"] = "nosniff"
config["XSSProtection"] = "1; mode=block"
config["ReferrerPolicy"] = "strict-origin-when-cross-origin"
config["PermissionsPolicy"] = "geolocation=(), microphone=(), camera=(), payment=(), usb=()"
config["CrossOriginEmbedderPolicy"] = "require-corp"
config["CrossOriginOpenerPolicy"] = "same-origin"
config["CrossOriginResourcePolicy"] = "same-origin"
}
// applyStrictProfile applies strict security settings
func applyStrictProfile(config map[string]interface{}) {
config["ContentSecurityPolicy"] = "default-src 'none'; script-src 'self'; style-src 'self'; img-src 'self'; font-src 'self'; connect-src 'self'; frame-ancestors 'none'; base-uri 'self'; form-action 'self';"
config["FrameOptions"] = "DENY"
config["ContentTypeOptions"] = "nosniff"
config["XSSProtection"] = "1; mode=block"
config["ReferrerPolicy"] = "strict-origin-when-cross-origin"
config["PermissionsPolicy"] = "geolocation=(), microphone=(), camera=(), payment=(), usb=(), magnetometer=(), gyroscope=(), speaker=()"
config["CrossOriginEmbedderPolicy"] = "require-corp"
config["CrossOriginOpenerPolicy"] = "same-origin"
config["CrossOriginResourcePolicy"] = "same-site"
}
// applyDevelopmentProfile applies development-friendly settings
func applyDevelopmentProfile(config map[string]interface{}) {
config["ContentSecurityPolicy"] = "default-src 'self' 'unsafe-inline' 'unsafe-eval'; img-src 'self' data: https: http:; connect-src 'self' ws: wss:;"
config["FrameOptions"] = "SAMEORIGIN"
config["ContentTypeOptions"] = "nosniff"
config["XSSProtection"] = "1; mode=block"
config["ReferrerPolicy"] = "strict-origin-when-cross-origin"
config["CrossOriginOpenerPolicy"] = "unsafe-none"
config["CrossOriginResourcePolicy"] = "cross-origin"
config["DevelopmentMode"] = true
}
// applyAPIProfile applies API-friendly settings
func applyAPIProfile(config map[string]interface{}) {
config["ContentSecurityPolicy"] = "default-src 'none'; frame-ancestors 'none';"
config["FrameOptions"] = "DENY"
config["ContentTypeOptions"] = "nosniff"
config["XSSProtection"] = "1; mode=block"
config["ReferrerPolicy"] = "strict-origin-when-cross-origin"
config["CrossOriginResourcePolicy"] = "cross-origin"
}
// GetSecurityHeadersApplier returns a function that applies security headers
func (c *Config) GetSecurityHeadersApplier() func(http.ResponseWriter, *http.Request) {
if c.SecurityHeaders == nil || !c.SecurityHeaders.Enabled {
return nil
}
// This would need to import the internal security package
// For now, return a basic implementation
return func(rw http.ResponseWriter, req *http.Request) {
headers := rw.Header()
// Apply basic security headers based on configuration
if c.SecurityHeaders.FrameOptions != "" {
headers.Set("X-Frame-Options", c.SecurityHeaders.FrameOptions)
}
if c.SecurityHeaders.ContentTypeOptions != "" {
headers.Set("X-Content-Type-Options", c.SecurityHeaders.ContentTypeOptions)
}
if c.SecurityHeaders.XSSProtection != "" {
headers.Set("X-XSS-Protection", c.SecurityHeaders.XSSProtection)
}
if c.SecurityHeaders.ReferrerPolicy != "" {
headers.Set("Referrer-Policy", c.SecurityHeaders.ReferrerPolicy)
}
if c.SecurityHeaders.ContentSecurityPolicy != "" {
headers.Set("Content-Security-Policy", c.SecurityHeaders.ContentSecurityPolicy)
}
// HSTS for HTTPS
if (req.TLS != nil || req.Header.Get("X-Forwarded-Proto") == "https") && c.SecurityHeaders.StrictTransportSecurity {
hstsValue := fmt.Sprintf("max-age=%d", c.SecurityHeaders.StrictTransportSecurityMaxAge)
if c.SecurityHeaders.StrictTransportSecuritySubdomains {
hstsValue += "; includeSubDomains"
}
if c.SecurityHeaders.StrictTransportSecurityPreload {
hstsValue += "; preload"
}
headers.Set("Strict-Transport-Security", hstsValue)
}
// CORS headers
if c.SecurityHeaders.CORSEnabled {
origin := req.Header.Get("Origin")
if origin != "" && isOriginAllowed(origin, c.SecurityHeaders.CORSAllowedOrigins) {
headers.Set("Access-Control-Allow-Origin", origin)
}
if len(c.SecurityHeaders.CORSAllowedMethods) > 0 {
headers.Set("Access-Control-Allow-Methods", strings.Join(c.SecurityHeaders.CORSAllowedMethods, ", "))
}
if len(c.SecurityHeaders.CORSAllowedHeaders) > 0 {
headers.Set("Access-Control-Allow-Headers", strings.Join(c.SecurityHeaders.CORSAllowedHeaders, ", "))
}
if c.SecurityHeaders.CORSAllowCredentials {
headers.Set("Access-Control-Allow-Credentials", "true")
}
if c.SecurityHeaders.CORSMaxAge > 0 {
headers.Set("Access-Control-Max-Age", strconv.Itoa(c.SecurityHeaders.CORSMaxAge))
}
}
// Custom headers
for name, value := range c.SecurityHeaders.CustomHeaders {
headers.Set(name, value)
}
// Remove server headers
if c.SecurityHeaders.DisableServerHeader {
headers.Del("Server")
}
if c.SecurityHeaders.DisablePoweredByHeader {
headers.Del("X-Powered-By")
}
}
}
// isOriginAllowed checks if an origin is in the allowed list
func isOriginAllowed(origin string, allowedOrigins []string) bool {
for _, allowed := range allowedOrigins {
if origin == allowed || allowed == "*" {
return true
}
// Simple wildcard matching for subdomains
if strings.Contains(allowed, "*") {
if strings.HasPrefix(allowed, "https://*.") {
domain := strings.TrimPrefix(allowed, "https://*.")
if strings.HasSuffix(origin, "."+domain) || origin == "https://"+domain {
return true
}
}
if strings.HasPrefix(allowed, "http://*.") {
domain := strings.TrimPrefix(allowed, "http://*.")
if strings.HasSuffix(origin, "."+domain) || origin == "http://"+domain {
return true
}
}
}
}
return false
}
+116
View File
@@ -0,0 +1,116 @@
package traefikoidc
import (
"encoding/json"
)
// REDACTED is the placeholder value for sensitive information
const REDACTED = "[REDACTED]"
// MarshalJSON implements custom JSON marshalling to redact sensitive fields
// Rewritten without type aliases for yaegi compatibility
func (c Config) MarshalJSON() ([]byte, error) {
// Build a map manually to avoid type alias issues with yaegi
result := make(map[string]interface{})
// Copy public fields
result["providerURL"] = c.ProviderURL
result["clientID"] = c.ClientID
result["callbackURL"] = c.CallbackURL
result["logoutURL"] = c.LogoutURL
result["postLogoutRedirectURI"] = c.PostLogoutRedirectURI
result["scopes"] = c.Scopes
result["forceHTTPS"] = c.ForceHTTPS
result["logLevel"] = c.LogLevel
result["rateLimit"] = c.RateLimit
result["excludedURLs"] = c.ExcludedURLs
result["allowedUserDomains"] = c.AllowedUserDomains
result["allowedUsers"] = c.AllowedUsers
result["allowedRolesAndGroups"] = c.AllowedRolesAndGroups
// Redact sensitive fields
result["clientSecret"] = REDACTED
result["sessionEncryptionKey"] = REDACTED
// Handle Redis config
if c.Redis != nil {
redisMap := make(map[string]interface{})
redisMap["enabled"] = c.Redis.Enabled
redisMap["address"] = c.Redis.Address
redisMap["password"] = REDACTED
redisMap["db"] = c.Redis.DB
redisMap["poolSize"] = c.Redis.PoolSize
redisMap["cacheMode"] = c.Redis.CacheMode
result["redis"] = redisMap
}
return json.Marshal(result)
}
// MarshalYAML implements custom YAML marshalling to redact sensitive fields
// Rewritten without type aliases for yaegi compatibility
func (c Config) MarshalYAML() (interface{}, error) {
// Build a map manually to avoid type alias issues with yaegi
result := make(map[string]interface{})
// Copy public fields
result["providerURL"] = c.ProviderURL
result["clientID"] = c.ClientID
result["callbackURL"] = c.CallbackURL
result["logoutURL"] = c.LogoutURL
result["postLogoutRedirectURI"] = c.PostLogoutRedirectURI
result["scopes"] = c.Scopes
result["forceHTTPS"] = c.ForceHTTPS
result["logLevel"] = c.LogLevel
result["rateLimit"] = c.RateLimit
result["excludedURLs"] = c.ExcludedURLs
result["allowedUserDomains"] = c.AllowedUserDomains
result["allowedUsers"] = c.AllowedUsers
result["allowedRolesAndGroups"] = c.AllowedRolesAndGroups
// Redact sensitive fields
result["clientSecret"] = REDACTED
result["sessionEncryptionKey"] = REDACTED
// Handle Redis config
if c.Redis != nil {
redisMap := make(map[string]interface{})
redisMap["enabled"] = c.Redis.Enabled
redisMap["address"] = c.Redis.Address
redisMap["password"] = REDACTED
redisMap["db"] = c.Redis.DB
redisMap["poolSize"] = c.Redis.PoolSize
redisMap["cacheMode"] = c.Redis.CacheMode
result["redis"] = redisMap
}
return result, nil
}
// MarshalJSON for RedisConfig to redact sensitive fields
// Rewritten without type aliases for yaegi compatibility
func (r RedisConfig) MarshalJSON() ([]byte, error) {
result := make(map[string]interface{})
result["enabled"] = r.Enabled
result["address"] = r.Address
result["password"] = REDACTED
result["db"] = r.DB
result["poolSize"] = r.PoolSize
result["cacheMode"] = r.CacheMode
return json.Marshal(result)
}
// MarshalYAML for RedisConfig to redact sensitive fields
// Rewritten without type aliases for yaegi compatibility
func (r RedisConfig) MarshalYAML() (interface{}, error) {
result := make(map[string]interface{})
result["enabled"] = r.Enabled
result["address"] = r.Address
result["password"] = REDACTED
result["db"] = r.DB
result["poolSize"] = r.PoolSize
result["cacheMode"] = r.CacheMode
return result, nil
}
File diff suppressed because it is too large Load Diff
+8 -8
View File
@@ -18,7 +18,7 @@ func TestCSRFTokenSessionManagement(t *testing.T) {
// Test that CSRF tokens persist through the authentication flow
t.Run("CSRF_Token_Persists_After_Selective_Clear", func(t *testing.T) {
// Create a session manager
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", NewLogger("debug"))
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", "", 0, NewLogger("debug"))
require.NoError(t, err)
// Create initial request
@@ -90,7 +90,7 @@ func TestCSRFTokenSessionManagement(t *testing.T) {
// Test that marking session as dirty forces save
t.Run("Mark_Dirty_Forces_Session_Save", func(t *testing.T) {
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", NewLogger("debug"))
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", "", 0, NewLogger("debug"))
require.NoError(t, err)
req := httptest.NewRequest("GET", "http://example.com/test", nil)
@@ -126,7 +126,7 @@ func TestCSRFTokenSessionManagement(t *testing.T) {
// Test Azure-specific session handling
t.Run("Azure_Session_Cookie_Configuration", func(t *testing.T) {
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", NewLogger("debug"))
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", "", 0, NewLogger("debug"))
require.NoError(t, err)
// Simulate Azure callback scenario
@@ -158,7 +158,7 @@ func TestCSRFTokenSessionManagement(t *testing.T) {
// Test session continuity through auth flow
t.Run("Session_Continuity_Through_Auth_Flow", func(t *testing.T) {
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", NewLogger("debug"))
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", "", 0, NewLogger("debug"))
require.NoError(t, err)
// Step 1: Initial request
@@ -199,7 +199,7 @@ func TestCSRFTokenSessionManagement(t *testing.T) {
// Test large token handling doesn't affect CSRF
t.Run("Large_Tokens_Dont_Affect_CSRF", func(t *testing.T) {
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", NewLogger("debug"))
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", "", 0, NewLogger("debug"))
require.NoError(t, err)
req := httptest.NewRequest("GET", "http://example.com/test", nil)
@@ -262,7 +262,7 @@ func TestAuthFlowWithoutExternalDependencies(t *testing.T) {
// We can't fully initialize TraefikOidc without network access,
// but we can test the session management directly
sessionManager, err := NewSessionManager(plugin.SessionEncryptionKey, plugin.ForceHTTPS, "", NewLogger(plugin.LogLevel))
sessionManager, err := NewSessionManager(plugin.SessionEncryptionKey, plugin.ForceHTTPS, "", "", 0, NewLogger(plugin.LogLevel))
require.NoError(t, err)
t.Run("Session_Created_On_Protected_Request", func(t *testing.T) {
@@ -291,7 +291,7 @@ func TestAuthFlowWithoutExternalDependencies(t *testing.T) {
// TestRegressionLoginLoop specifically tests the fix for issue #53
func TestRegressionLoginLoop(t *testing.T) {
// This test verifies that the specific changes made to fix the login loop work correctly
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", NewLogger("debug"))
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", "", 0, NewLogger("debug"))
require.NoError(t, err)
// Simulate the exact flow that was causing the login loop
@@ -392,7 +392,7 @@ func TestRegressionLoginLoop(t *testing.T) {
// TestCSRFValidationTiming tests timing-sensitive CSRF validation scenarios
func TestCSRFValidationTiming(t *testing.T) {
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", NewLogger("debug"))
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", "", 0, NewLogger("debug"))
require.NoError(t, err)
t.Run("Rapid_Redirect_Maintains_CSRF", func(t *testing.T) {
+364
View File
@@ -0,0 +1,364 @@
//go:build !yaegi
package traefikoidc
import (
"testing"
)
// TestCustomClaimNames_DefaultBehavior tests backward compatibility with default claim names
func TestCustomClaimNames_DefaultBehavior(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
// Explicitly set defaults to test backward compatibility
ts.tOidc.roleClaimName = "roles"
ts.tOidc.groupClaimName = "groups"
// Test that when no custom claim names are configured, it uses defaults "roles" and "groups"
claims := map[string]interface{}{
"groups": []interface{}{"admin", "users"},
"roles": []interface{}{"editor", "viewer"},
}
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
if err != nil {
t.Fatalf("Failed to create test token: %v", err)
}
groups, roles, err := ts.tOidc.extractGroupsAndRoles(token)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if !stringSliceEqual(groups, []string{"admin", "users"}) {
t.Errorf("Expected groups [admin users], got %v", groups)
}
if !stringSliceEqual(roles, []string{"editor", "viewer"}) {
t.Errorf("Expected roles [editor viewer], got %v", roles)
}
}
// TestCustomClaimNames_Auth0Namespaced tests Auth0-style namespaced claims
func TestCustomClaimNames_Auth0Namespaced(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
// Configure custom claim names for Auth0
ts.tOidc.roleClaimName = "https://myapp.com/roles"
ts.tOidc.groupClaimName = "https://myapp.com/groups"
// Create token with Auth0-style namespaced claims
claims := map[string]interface{}{
"https://myapp.com/groups": []interface{}{"admin", "users"},
"https://myapp.com/roles": []interface{}{"editor", "viewer"},
}
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
if err != nil {
t.Fatalf("Failed to create test token: %v", err)
}
groups, roles, err := ts.tOidc.extractGroupsAndRoles(token)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if !stringSliceEqual(groups, []string{"admin", "users"}) {
t.Errorf("Expected groups [admin users], got %v", groups)
}
if !stringSliceEqual(roles, []string{"editor", "viewer"}) {
t.Errorf("Expected roles [editor viewer], got %v", roles)
}
}
// TestCustomClaimNames_CustomSimpleNames tests custom simple claim names
func TestCustomClaimNames_CustomSimpleNames(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
// Configure custom simple claim names
ts.tOidc.roleClaimName = "user_roles"
ts.tOidc.groupClaimName = "user_groups"
// Create token with custom claim names
claims := map[string]interface{}{
"user_groups": []interface{}{"engineering", "product"},
"user_roles": []interface{}{"developer", "manager"},
}
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
if err != nil {
t.Fatalf("Failed to create test token: %v", err)
}
groups, roles, err := ts.tOidc.extractGroupsAndRoles(token)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if !stringSliceEqual(groups, []string{"engineering", "product"}) {
t.Errorf("Expected groups [engineering product], got %v", groups)
}
if !stringSliceEqual(roles, []string{"developer", "manager"}) {
t.Errorf("Expected roles [developer manager], got %v", roles)
}
}
// TestCustomClaimNames_MissingClaims tests behavior when custom claims are missing
func TestCustomClaimNames_MissingClaims(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
// Configure custom claim names
ts.tOidc.roleClaimName = "custom_roles"
ts.tOidc.groupClaimName = "custom_groups"
// Create token WITHOUT the custom claims
claims := map[string]interface{}{
"sub": "user123",
"email": "user@example.com",
}
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
if err != nil {
t.Fatalf("Failed to create test token: %v", err)
}
groups, roles, err := ts.tOidc.extractGroupsAndRoles(token)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
// Should return empty slices, not error
if len(groups) != 0 {
t.Errorf("Expected empty groups, got %v", groups)
}
if len(roles) != 0 {
t.Errorf("Expected empty roles, got %v", roles)
}
}
// TestCustomClaimNames_MalformedClaims tests error handling for malformed claims
func TestCustomClaimNames_MalformedRoleClaim(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
// Configure custom claim names
ts.tOidc.roleClaimName = "custom_roles"
// Create token with malformed role claim (not an array)
claims := map[string]interface{}{
"custom_roles": "this-should-be-an-array",
}
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
if err != nil {
t.Fatalf("Failed to create test token: %v", err)
}
_, _, err = ts.tOidc.extractGroupsAndRoles(token)
if err == nil {
t.Error("Expected error for malformed role claim, got nil")
}
// Check error message contains the custom claim name
expectedError := "custom_roles claim is not an array"
if err.Error() != expectedError {
t.Errorf("Expected error '%s', got '%s'", expectedError, err.Error())
}
}
// TestCustomClaimNames_MalformedGroupClaim tests error handling for malformed group claims
func TestCustomClaimNames_MalformedGroupClaim(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
// Configure custom claim names
ts.tOidc.groupClaimName = "custom_groups"
// Create token with malformed group claim (not an array)
claims := map[string]interface{}{
"custom_groups": 12345, // Not an array
}
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
if err != nil {
t.Fatalf("Failed to create test token: %v", err)
}
_, _, err = ts.tOidc.extractGroupsAndRoles(token)
if err == nil {
t.Error("Expected error for malformed group claim, got nil")
}
// Check error message contains the custom claim name
expectedError := "custom_groups claim is not an array"
if err.Error() != expectedError {
t.Errorf("Expected error '%s', got '%s'", expectedError, err.Error())
}
}
// TestCustomClaimNames_PartialConfiguration tests when only one claim name is customized
func TestCustomClaimNames_OnlyRoleCustomized(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
// Configure only role claim name (group uses default)
ts.tOidc.roleClaimName = "https://myapp.com/roles"
ts.tOidc.groupClaimName = "groups" // default
// Create token with mixed claim names
claims := map[string]interface{}{
"groups": []interface{}{"admin"},
"https://myapp.com/roles": []interface{}{"editor"},
}
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
if err != nil {
t.Fatalf("Failed to create test token: %v", err)
}
groups, roles, err := ts.tOidc.extractGroupsAndRoles(token)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if !stringSliceEqual(groups, []string{"admin"}) {
t.Errorf("Expected groups [admin], got %v", groups)
}
if !stringSliceEqual(roles, []string{"editor"}) {
t.Errorf("Expected roles [editor], got %v", roles)
}
}
// TestCustomClaimNames_OnlyGroupCustomized tests when only group claim name is customized
func TestCustomClaimNames_OnlyGroupCustomized(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
// Configure only group claim name (role uses default)
ts.tOidc.roleClaimName = "roles" // default
ts.tOidc.groupClaimName = "https://myapp.com/groups"
// Create token with mixed claim names
claims := map[string]interface{}{
"roles": []interface{}{"viewer"},
"https://myapp.com/groups": []interface{}{"users"},
}
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
if err != nil {
t.Fatalf("Failed to create test token: %v", err)
}
groups, roles, err := ts.tOidc.extractGroupsAndRoles(token)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if !stringSliceEqual(groups, []string{"users"}) {
t.Errorf("Expected groups [users], got %v", groups)
}
if !stringSliceEqual(roles, []string{"viewer"}) {
t.Errorf("Expected roles [viewer], got %v", roles)
}
}
// TestCustomClaimNames_EmptyArrays tests extraction with empty claim arrays
func TestCustomClaimNames_EmptyArrays(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
// Configure custom claim names
ts.tOidc.roleClaimName = "https://myapp.com/roles"
ts.tOidc.groupClaimName = "https://myapp.com/groups"
// Create token with empty arrays
claims := map[string]interface{}{
"https://myapp.com/groups": []interface{}{},
"https://myapp.com/roles": []interface{}{},
}
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
if err != nil {
t.Fatalf("Failed to create test token: %v", err)
}
groups, roles, err := ts.tOidc.extractGroupsAndRoles(token)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if len(groups) != 0 {
t.Errorf("Expected empty groups, got %v", groups)
}
if len(roles) != 0 {
t.Errorf("Expected empty roles, got %v", roles)
}
}
// TestCustomClaimNames_NonStringElements tests handling of non-string elements in claim arrays
func TestCustomClaimNames_NonStringInRoleArray(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
// Configure custom claim names
ts.tOidc.roleClaimName = "custom_roles"
// Create token with mixed-type array (should skip non-string elements)
claims := map[string]interface{}{
"custom_roles": []interface{}{"role1", 12345, "role2", true},
}
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
if err != nil {
t.Fatalf("Failed to create test token: %v", err)
}
_, roles, err := ts.tOidc.extractGroupsAndRoles(token)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
// Should only extract string elements
if !stringSliceEqual(roles, []string{"role1", "role2"}) {
t.Errorf("Expected roles [role1 role2], got %v", roles)
}
}
// TestCustomClaimNames_NonStringInGroupArray tests handling of non-string elements in group arrays
func TestCustomClaimNames_NonStringInGroupArray(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
// Configure custom claim names
ts.tOidc.groupClaimName = "custom_groups"
// Create token with mixed-type array (should skip non-string elements)
claims := map[string]interface{}{
"custom_groups": []interface{}{"group1", nil, "group2", 3.14},
}
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
if err != nil {
t.Fatalf("Failed to create test token: %v", err)
}
groups, _, err := ts.tOidc.extractGroupsAndRoles(token)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
// Should only extract string elements
if !stringSliceEqual(groups, []string{"group1", "group2"}) {
t.Errorf("Expected groups [group1 group2], got %v", groups)
}
}
+424
View File
@@ -0,0 +1,424 @@
# Auth0 Audience Validation Guide
## Overview
This guide explains how to configure audience validation for Auth0 and other OIDC providers that support custom API audiences. It covers three common Auth0 scenarios and how to configure the middleware for maximum security.
## Table of Contents
1. [Understanding Audiences](#understanding-audiences)
2. [The Three Auth0 Scenarios](#the-three-auth0-scenarios)
3. [Configuration Options](#configuration-options)
4. [Security Recommendations](#security-recommendations)
5. [Troubleshooting](#troubleshooting)
---
## Understanding Audiences
### What is an Audience?
The **audience** (`aud`) claim in a JWT identifies the intended recipient of the token. Per OAuth 2.0 and OIDC specifications:
- **ID Tokens**: MUST have `aud = client_id` (per OIDC Core 1.0 spec)
- **Access Tokens**: Can have custom audiences (e.g., API identifiers)
### Why Does This Matter?
Proper audience validation prevents **token confusion attacks** where a token intended for one API is used to access another API.
---
## The Three Auth0 Scenarios
### Scenario 1: Custom API Audience ✅ **RECOMMENDED**
**Configuration:**
```yaml
audience: "https://my-api.example.com" # Your API identifier from Auth0
```
**What Happens:**
1. Authorization request includes `audience` parameter
2. Auth0 issues:
- **ID Token**: `aud = client_id`
- **Access Token**: `aud = ["https://issuer/userinfo", "https://my-api.example.com"]`
3. Middleware validates:
- ID tokens against `client_id`
- Access tokens against custom audience
**Result:** ✅ Fully secure, OIDC compliant
---
### Scenario 2: Default Audience (No Custom API) ⚠️ **USE WITH CAUTION**
**Configuration:**
```yaml
# audience not specified (defaults to client_id)
```
**What Happens:**
1. Authorization request WITHOUT `audience` parameter
2. Auth0 issues:
- **ID Token**: `aud = client_id`
- **Access Token**: `aud = ["https://issuer/userinfo", "default_api"]` (no `client_id`)
3. Access token validation fails (audience mismatch)
4. Middleware falls back to ID token validation
**Security Warning:**
```
⚠️⚠️⚠️ SECURITY WARNING: Falling back to ID token validation despite access token audience mismatch!
⚠️ This could allow tokens intended for different APIs to grant access
⚠️ Set strictAudienceValidation=true to enforce proper audience validation
⚠️ See: https://github.com/lukaszraczylo/traefikoidc/issues/74
```
**Recommended Fix:**
```yaml
strictAudienceValidation: true # Reject sessions with audience mismatch
```
**Result:**
- Default: ⚠️ Works but logs security warnings
- With strict mode: ✅ Secure (rejects mismatched tokens)
---
### Scenario 3: Opaque Access Tokens ✅ **SUPPORTED**
**Configuration:**
```yaml
allowOpaqueTokens: true # Enable opaque token support
requireTokenIntrospection: true # Require introspection (recommended)
```
**What Happens:**
1. Auth0 issues opaque (non-JWT) access token
2. Middleware detects opaque token (not 3 parts separated by dots)
3. Uses OAuth 2.0 Token Introspection (RFC 7662) to validate
4. Falls back to ID token if introspection unavailable (unless `requireTokenIntrospection=true`)
**Requirements:**
- Provider must support `introspection_endpoint` in OIDC discovery
- Client must have introspection permissions
**Result:** ✅ Secure with introspection, ⚠️ risky without
---
## Configuration Options
### Audience Settings
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `audience` | string | `client_id` | Expected audience for access tokens |
**Example:**
```yaml
# .traefik.yml
http:
middlewares:
oidc-auth:
plugin:
traefikoidc:
audience: "https://my-api.example.com"
```
---
### Security Mode Settings
#### `strictAudienceValidation`
**Type:** boolean
**Default:** `false`
**Recommended:** `true` for production
**What it does:**
- When `true`: Rejects sessions if access token audience doesn't match (prevents Scenario 2)
- When `false`: Logs warnings but allows fallback to ID token (backward compatible)
**Example:**
```yaml
strictAudienceValidation: true
```
**When to use:**
- ✅ Always use in production environments
- ✅ When you have custom API audiences configured in Auth0
- ⚠️ May break existing deployments relying on Scenario 2 behavior
---
#### `allowOpaqueTokens`
**Type:** boolean
**Default:** `false`
**What it does:**
- When `true`: Accepts opaque (non-JWT) access tokens
- When `false`: Only accepts JWT access tokens
**Example:**
```yaml
allowOpaqueTokens: true
```
**When to use:**
- ✅ When Auth0 issues opaque tokens (no default API configured)
- ✅ When using Auth0 Management API tokens
- ⚠️ Requires introspection endpoint for security
---
#### `requireTokenIntrospection`
**Type:** boolean
**Default:** `false`
**Recommended:** `true` when `allowOpaqueTokens=true`
**What it does:**
- When `true`: Rejects opaque tokens if introspection fails or endpoint unavailable
- When `false`: Falls back to ID token validation for opaque tokens
**Example:**
```yaml
allowOpaqueTokens: true
requireTokenIntrospection: true
```
**When to use:**
- ✅ Always use when `allowOpaqueTokens=true` for maximum security
- ⚠️ Requires provider to expose introspection endpoint
---
## Security Recommendations
### Recommended Configuration for Auth0
**For APIs with custom audiences (Scenario 1):**
```yaml
audience: "https://my-api.example.com"
strictAudienceValidation: true
allowOpaqueTokens: false
```
**For default Auth0 setup (Scenario 2):**
```yaml
# Don't set audience (defaults to client_id)
strictAudienceValidation: true # Enforce proper configuration
```
**For opaque tokens (Scenario 3):**
```yaml
allowOpaqueTokens: true
requireTokenIntrospection: true
strictAudienceValidation: true
```
### Security Best Practices
1.**Always set `strictAudienceValidation: true` in production**
2.**Configure custom API audiences in Auth0 dashboard**
3.**Use `requireTokenIntrospection: true` if accepting opaque tokens**
4.**Monitor logs for security warnings**
5.**Don't rely on Scenario 2 fallback behavior**
---
## Troubleshooting
### "Access token validation failed due to audience mismatch"
**Symptom:**
```
⚠️ SCENARIO 2 DETECTED: Access token validation failed due to audience mismatch
```
**Cause:** Access token audience doesn't match configured audience
**Solutions:**
1. **Configure correct audience:**
```yaml
audience: "https://your-api-identifier" # From Auth0 API settings
```
2. **Update Auth0 authorization request:**
- Ensure `audience` parameter is included in authorize URL
- Middleware automatically adds this when `audience != client_id`
3. **Accept the behavior (not recommended):**
```yaml
strictAudienceValidation: false # Logs warnings but allows
```
---
### "Opaque token detected but allowOpaqueTokens=false"
**Symptom:**
```
⚠️ Opaque access token detected but allowOpaqueTokens=false
```
**Cause:** Auth0 issued non-JWT access token but middleware not configured to accept them
**Solutions:**
1. **Enable opaque tokens:**
```yaml
allowOpaqueTokens: true
requireTokenIntrospection: true
```
2. **Configure Auth0 to issue JWT access tokens:**
- Create an API in Auth0 dashboard
- Set API identifier as `audience` in configuration
---
### "Introspection endpoint not available"
**Symptom:**
```
⚠️ Opaque tokens enabled but no introspection endpoint available from provider
```
**Cause:** Auth0 provider metadata doesn't include `introspection_endpoint`
**Solutions:**
1. **Check provider discovery:**
```bash
curl https://YOUR_DOMAIN/.well-known/openid-configuration
```
Look for `introspection_endpoint`
2. **Disable required introspection (less secure):**
```yaml
allowOpaqueTokens: true
requireTokenIntrospection: false # Falls back to ID token
```
3. **Use JWT access tokens instead** (recommended)
---
### "Token introspection required but endpoint not available"
**Symptom:**
```
❌ SECURITY: Opaque token rejected (introspection required but failed)
```
**Cause:** `requireTokenIntrospection=true` but provider doesn't support it
**Solutions:**
1. **Disable required introspection:**
```yaml
requireTokenIntrospection: false
```
2. **Configure Auth0 to issue JWT tokens** (better solution)
---
## Advanced Topics
### Token Type Detection
The middleware uses a sophisticated 6-step detection algorithm:
1. **RFC 9068 `typ` header**: `at+jwt` → Access Token
2. **Explicit type claims**: `token_use`, `token_type`
3. **`scope` claim**: Present → Access Token
4. **`nonce` claim**: Present → ID Token (OIDC spec)
5. **Audience check**: `aud == client_id` only → ID Token
6. **Default**: Access Token
### OAuth 2.0 Token Introspection (RFC 7662)
When opaque tokens are detected:
1. Middleware calls provider's `introspection_endpoint`
2. Authenticates using client credentials
3. Receives response with `active` status and claims
4. Caches result for 5 minutes (configurable via TTL)
5. Validates expiration, not-before, and audience if present
**Cache behavior:**
- Cache key: Token hash
- TTL: 5 minutes or token expiry (whichever is shorter)
- Reduces introspection requests for frequently used tokens
---
## Reference Links
- [GitHub Issue #74](https://github.com/lukaszraczylo/traefikoidc/issues/74) - Original Auth0 audience discussion
- [OIDC Core 1.0 Spec](https://openid.net/specs/openid-connect-core-1_0.html) - ID Token requirements
- [OAuth 2.0 RFC 6749](https://datatracker.ietf.org/doc/html/rfc6749) - OAuth 2.0 specification
- [RFC 7662](https://datatracker.ietf.org/doc/html/rfc7662) - OAuth 2.0 Token Introspection
- [RFC 9068](https://datatracker.ietf.org/doc/html/rfc9068) - JWT Access Token Profile
- [Auth0 API Authorization](https://auth0.com/docs/secure/tokens/access-tokens) - Auth0 audience documentation
---
## Migration Guide
### From Previous Versions
**If you're upgrading from a version without these features:**
1. **No action required for default behavior** - backward compatible
2. **Recommended: Enable strict mode gradually**
```yaml
# Step 1: Enable and monitor logs
strictAudienceValidation: false # Default
# Step 2: After confirming no warnings, enable
strictAudienceValidation: true
```
3. **For opaque tokens: Enable explicitly**
```yaml
allowOpaqueTokens: true
requireTokenIntrospection: true
```
### Testing Your Configuration
1. **Check logs for warnings:**
```bash
# Look for Scenario 2 warnings
grep "SCENARIO 2 DETECTED" /var/log/traefik.log
# Look for opaque token warnings
grep "Opaque" /var/log/traefik.log
```
2. **Test with curl:**
```bash
# Get token from Auth0
ACCESS_TOKEN="your_access_token"
# Test request
curl -H "Authorization: Bearer $ACCESS_TOKEN" \
https://your-app.example.com/api
```
3. **Monitor for security warnings in production logs**
---
## Support
For issues or questions:
- GitHub Issues: https://github.com/lukaszraczylo/traefikoidc/issues
- Security issues: See SECURITY.md for responsible disclosure
---
**Last Updated:** 2025-01-09
**Version:** 0.7.8+
+1
View File
@@ -0,0 +1 @@
traefikoidc.raczylo.com
+456
View File
@@ -0,0 +1,456 @@
# Configuration Reference
Complete reference for all Traefik OIDC middleware configuration options.
## Table of Contents
- [Required Parameters](#required-parameters)
- [Optional Parameters](#optional-parameters)
- [Security Options](#security-options)
- [Session Management](#session-management)
- [Access Control](#access-control)
- [Headers Configuration](#headers-configuration)
- [Security Headers](#security-headers)
- [Scope Configuration](#scope-configuration)
- [Advanced Options](#advanced-options)
---
## Required Parameters
| Parameter | Type | Description | Example |
|-----------|------|-------------|---------|
| `providerURL` | string | Base URL of the OIDC provider | `https://accounts.google.com` |
| `clientID` | string | OAuth 2.0 client identifier | `1234567890.apps.googleusercontent.com` |
| `clientSecret` | string | OAuth 2.0 client secret | `your-client-secret` |
| `sessionEncryptionKey` | string | Key for encrypting session data (min 32 bytes) | `your-32-byte-encryption-key-here` |
| `callbackURL` | string | Path where provider redirects after authentication | `/oauth2/callback` |
### Basic Configuration Example
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-auth
spec:
plugin:
traefikoidc:
providerURL: https://accounts.google.com
clientID: your-client-id.apps.googleusercontent.com
clientSecret: your-client-secret
sessionEncryptionKey: your-32-byte-encryption-key-here
callbackURL: /oauth2/callback
```
---
## Optional Parameters
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `logoutURL` | string | `callbackURL + "/logout"` | Path for logout requests |
| `postLogoutRedirectURI` | string | `/` | Redirect URL after logout |
| `logLevel` | string | `info` | Logging verbosity (`debug`, `info`, `error`) |
| `forceHTTPS` | bool | `false` | Force HTTPS for redirect URIs |
| `rateLimit` | int | `100` | Maximum requests per second |
| `excludedURLs` | []string | none | Paths that bypass authentication |
| `revocationURL` | string | auto-discovered | Token revocation endpoint |
| `oidcEndSessionURL` | string | auto-discovered | Provider's end session endpoint |
| `enablePKCE` | bool | `false` | Enable PKCE for authorization code flow |
| `minimalHeaders` | bool | `false` | Reduce forwarded headers |
### TLS Termination at Load Balancer
If running Traefik behind a load balancer (AWS ALB, Google Cloud LB, Azure App Gateway) that terminates TLS:
```yaml
forceHTTPS: true # Required for correct redirect URIs
```
Without this setting, redirect URIs will use `http://` instead of `https://`, causing OAuth callback failures.
---
## Security Options
### Audience Validation
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `audience` | string | `clientID` | Expected audience for access token validation |
| `strictAudienceValidation` | bool | `false` | Reject sessions with audience mismatch |
| `allowOpaqueTokens` | bool | `false` | Enable opaque token support via RFC 7662 |
| `requireTokenIntrospection` | bool | `false` | Require introspection for opaque tokens |
#### Production Security Configuration
```yaml
audience: "https://my-api.example.com"
strictAudienceValidation: true
```
#### Opaque Token Support
```yaml
allowOpaqueTokens: true
requireTokenIntrospection: true
strictAudienceValidation: true
```
### Other Security Options
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `disableReplayDetection` | bool | `false` | Disable JTI-based replay attack detection |
| `allowPrivateIPAddresses` | bool | `false` | Allow private IPs in provider URLs |
---
## Session Management
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `sessionMaxAge` | int | `86400` (24h) | Maximum session age in seconds |
| `refreshGracePeriodSeconds` | int | `60` | Seconds before expiry to attempt refresh |
| `cookieDomain` | string | auto-detected | Domain for session cookies |
| `cookiePrefix` | string | `_oidc_raczylo_` | Prefix for cookie names |
### Multi-Subdomain Setup
```yaml
cookieDomain: .example.com # Share cookies across subdomains
```
### Multiple Middleware Instances
When running multiple middleware instances with different authorization requirements, use unique prefixes:
```yaml
# User authentication middleware
---
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-userauth
spec:
plugin:
traefikoidc:
cookiePrefix: "_oidc_userauth_"
sessionEncryptionKey: user-encryption-key-min-32-bytes
# ... other config
---
# Admin authentication middleware
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-adminauth
spec:
plugin:
traefikoidc:
cookiePrefix: "_oidc_adminauth_"
sessionEncryptionKey: admin-encryption-key-min-32-bytes
allowedUsers:
- admin@example.com
# ... other config
```
### Extended Session Duration
```yaml
sessionMaxAge: 604800 # 7 days
# Common values:
# 3600 - 1 hour (high security)
# 86400 - 1 day (default)
# 259200 - 3 days
# 604800 - 7 days
# 2592000 - 30 days
```
---
## Access Control
### User Restrictions
| Parameter | Type | Description |
|-----------|------|-------------|
| `allowedUserDomains` | []string | Restrict to specific email domains |
| `allowedUsers` | []string | Specific email addresses allowed |
| `allowedRolesAndGroups` | []string | Required roles or groups |
| `roleClaimName` | string | JWT claim for roles (default: `roles`) |
| `groupClaimName` | string | JWT claim for groups (default: `groups`) |
| `userIdentifierClaim` | string | Claim for user ID (default: `email`) |
### Domain Restriction
```yaml
allowedUserDomains:
- company.com
- subsidiary.com
```
### Specific User Access
```yaml
allowedUsers:
- user@example.com
- contractor@external.org
```
### Role-Based Access Control
```yaml
allowedRolesAndGroups:
- admin
- developer
roleClaimName: "https://myapp.com/roles" # For namespaced claims (Auth0)
```
### Access Control Logic
- If only `allowedUsers` is set: Only specified emails can access
- If only `allowedUserDomains` is set: Only specified domains can access
- If both are set: Access granted if email is in `allowedUsers` OR domain is in `allowedUserDomains`
- If neither is set: Any authenticated user can access
### Users Without Email (Azure AD)
For Azure AD service accounts or users without email:
```yaml
userIdentifierClaim: sub # Options: sub, oid, upn, preferred_username
allowedUsers:
- "abc12345-6789-0abc-def0-123456789abc" # User object ID
```
---
## Headers Configuration
### Default Headers
The middleware sets these headers for downstream services:
| Header | Description |
|--------|-------------|
| `X-Forwarded-User` | User's email address |
| `X-User-Groups` | Comma-separated user groups |
| `X-User-Roles` | Comma-separated user roles |
| `X-Auth-Request-Redirect` | Original request URI |
| `X-Auth-Request-User` | User's email address |
| `X-Auth-Request-Token` | User's ID token |
### Minimal Headers Mode
For "431 Request Header Fields Too Large" errors:
```yaml
minimalHeaders: true # Only forwards X-Forwarded-User
```
### Custom Templated Headers
```yaml
headers:
- name: "X-User-Email"
value: "{{{{.Claims.email}}}}"
- name: "X-User-ID"
value: "{{{{.Claims.sub}}}}"
- name: "Authorization"
value: "Bearer {{{{.AccessToken}}}}"
- name: "X-User-Roles"
value: "{{{{range $i, $e := .Claims.roles}}}}{{{{if $i}}}},{{{{end}}}}{{{{$e}}}}{{{{end}}}}"
```
**Template Variables:**
- `{{.Claims.field}}` - ID token claims
- `{{.AccessToken}}` - Raw access token
- `{{.IdToken}}` - Raw ID token
- `{{.RefreshToken}}` - Raw refresh token
**Important:** Use double curly braces (`{{{{` and `}}}}`) to escape templates in YAML.
---
## Security Headers
### Security Profiles
| Profile | Use Case | Security Level |
|---------|----------|----------------|
| `default` | Standard web apps | High |
| `strict` | Maximum security | Very High |
| `development` | Local development | Medium |
| `api` | API endpoints | High |
| `custom` | Custom requirements | Configurable |
### Basic Configuration
```yaml
securityHeaders:
enabled: true
profile: "default"
```
### API with CORS
```yaml
securityHeaders:
enabled: true
profile: "api"
corsEnabled: true
corsAllowedOrigins:
- "https://your-frontend.com"
- "https://*.example.com"
corsAllowCredentials: true
```
### Custom Security Configuration
```yaml
securityHeaders:
enabled: true
profile: "custom"
# Content Security Policy
contentSecurityPolicy: "default-src 'self'; script-src 'self'"
# HSTS
strictTransportSecurity: true
strictTransportSecurityMaxAge: 31536000
strictTransportSecuritySubdomains: true
strictTransportSecurityPreload: true
# Frame and Content Protection
frameOptions: "DENY"
contentTypeOptions: "nosniff"
xssProtection: "1; mode=block"
referrerPolicy: "strict-origin-when-cross-origin"
# CORS
corsEnabled: true
corsAllowedOrigins: ["https://app.example.com"]
corsAllowedMethods: ["GET", "POST", "PUT", "DELETE", "OPTIONS"]
corsAllowedHeaders: ["Authorization", "Content-Type"]
corsAllowCredentials: true
corsMaxAge: 86400
# Custom Headers
customHeaders:
X-Custom-Header: "value"
# Server Identification
disableServerHeader: true
disablePoweredByHeader: true
```
### CORS Origin Patterns
```yaml
corsAllowedOrigins:
- "https://example.com" # Exact match
- "https://*.example.com" # Subdomain wildcard
- "http://localhost:*" # Port wildcard (development)
```
---
## Scope Configuration
### Default Behavior (Append Mode)
```yaml
scopes:
- roles
- custom_scope
# Result: ["openid", "profile", "email", "roles", "custom_scope"]
```
### Override Mode
```yaml
overrideScopes: true
scopes:
- openid
- profile
- custom_scope
# Result: ["openid", "profile", "custom_scope"]
```
---
## Advanced Options
### Dynamic Client Registration (RFC 7591)
```yaml
dynamicClientRegistration:
enabled: true
initialAccessToken: "your-token" # Optional
persistCredentials: true
credentialsFile: "/tmp/oidc-credentials.json"
clientMetadata:
redirect_uris:
- "https://your-app.com/oauth2/callback"
client_name: "My Application"
application_type: "web"
grant_types:
- "authorization_code"
- "refresh_token"
```
### Multi-Replica Deployment
Without Redis, disable replay detection:
```yaml
disableReplayDetection: true
```
With Redis (recommended):
```yaml
redis:
enabled: true
address: "redis:6379"
cacheMode: "hybrid"
```
See [REDIS.md](REDIS.md) for complete Redis configuration.
---
## Kubernetes Secrets
Reference secrets instead of hardcoding sensitive values:
```yaml
providerURL: urn:k8s:secret:oidc-secret:ISSUER
clientID: urn:k8s:secret:oidc-secret:CLIENT_ID
clientSecret: urn:k8s:secret:oidc-secret:SECRET
```
Create the secret:
```bash
kubectl create secret generic oidc-secret \
--from-literal=ISSUER=https://accounts.google.com \
--from-literal=CLIENT_ID=your-client-id \
--from-literal=SECRET=your-client-secret \
-n traefik
```
---
## Environment Variable Naming
**Important:** Avoid using "API" as a substring in environment variable names when using `${VAR}` syntax in Traefik configuration. Traefik reserves `TRAEFIK_API_*` variables and the substring may cause conflicts.
```yaml
# Bad - may cause issues
sessionEncryptionKey: ${OIDC_SECRET_API}
# Good
sessionEncryptionKey: ${OIDC_SECRET_SVC}
```
+455
View File
@@ -0,0 +1,455 @@
# Development Guide
Guide for local development, testing, and contributing to the Traefik OIDC middleware.
## Table of Contents
- [Prerequisites](#prerequisites)
- [Local Development Setup](#local-development-setup)
- [Running Tests](#running-tests)
- [Test Categories](#test-categories)
- [CI/CD Pipeline](#cicd-pipeline)
- [Code Quality](#code-quality)
- [Contributing](#contributing)
---
## Prerequisites
- **Go 1.23+** for plugin compilation
- **Docker & Docker Compose** for local testing
- **OIDC Provider** credentials (Google, Azure, etc.)
### Required Development Tools
```bash
# golangci-lint (comprehensive linting)
go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
# staticcheck (static analysis)
go install honnef.co/go/tools/cmd/staticcheck@latest
# gosec (security scanning)
go install github.com/securego/gosec/v2/cmd/gosec@latest
# govulncheck (vulnerability scanning)
go install golang.org/x/vuln/cmd/govulncheck@latest
```
---
## Local Development Setup
### Docker Compose Environment
The repository includes a Docker Compose setup for testing the plugin locally.
#### 1. Host Configuration
Add to `/etc/hosts`:
```bash
127.0.0.1 hello.localhost
127.0.0.1 traefik.localhost
```
#### 2. Plugin Configuration
The plugin is loaded using Traefik's **local plugins mode**:
- Plugin source: Parent directory (`../`)
- Mount path: `/plugins-local/src/github.com/lukaszraczylo/traefikoidc`
- Configuration: `experimental.localPlugins` in `traefik.yml`
#### 3. OIDC Provider Setup
Edit `docker/dynamic.yml` with your provider details:
**Google:**
```yaml
http:
middlewares:
oidc-auth:
plugin:
traefikoidc:
providerURL: "https://accounts.google.com"
clientID: "your-client-id.apps.googleusercontent.com"
clientSecret: "your-google-client-secret"
sessionEncryptionKey: "your-32-character-encryption-key"
callbackURL: "/oauth2/callback"
logoutURL: "/oauth2/logout"
scopes:
- "openid"
- "email"
- "profile"
```
**Azure AD:**
```yaml
http:
middlewares:
oidc-auth:
plugin:
traefikoidc:
providerURL: "https://login.microsoftonline.com/your-tenant-id/v2.0"
clientID: "your-azure-client-id"
clientSecret: "your-azure-client-secret"
sessionEncryptionKey: "your-32-character-encryption-key"
callbackURL: "/oauth2/callback"
scopes:
- "openid"
- "email"
- "profile"
```
#### 4. Start Environment
```bash
cd docker
docker-compose up -d
```
#### 5. Test Plugin
- **Protected App**: http://hello.localhost (redirects to OIDC)
- **Traefik Dashboard**: http://traefik.localhost:8080
### Development Workflow
1. **Edit plugin code** in the project root
2. **Build and test** (optional syntax check):
```bash
go mod tidy
go build .
go test ./...
```
3. **Restart Traefik** to reload plugin:
```bash
docker-compose restart traefik
```
4. **Test changes** at http://hello.localhost
### Debugging
**View plugin logs:**
```bash
docker-compose logs -f traefik | grep traefikoidc
```
**Check plugin loading:**
```bash
docker-compose logs traefik | grep -i plugin
```
**Verify plugin directory:**
```bash
docker-compose exec traefik ls -la /plugins-local/src/github.com/lukaszraczylo/traefikoidc/
```
---
## Running Tests
### Quick Start
```bash
# Fast development testing (< 30 seconds)
go test ./... -short
# Standard tests with race detector
go test -race -timeout=15m ./...
# With coverage report
go test -coverprofile=coverage.out ./...
go tool cover -func=coverage.out
```
### Test Modes
| Mode | Command | Duration | Use Case |
|------|---------|----------|----------|
| Quick | `go test ./... -short` | < 30s | During development |
| Extended | `RUN_EXTENDED_TESTS=1 go test ./...` | 2-5 min | Before commits |
| Long | `RUN_LONG_TESTS=1 go test ./...` | 5-15 min | Release validation |
| Stress | `RUN_STRESS_TESTS=1 go test ./...` | 10-30 min | Performance testing |
### Environment Variables
```bash
# Enable specific test types
export RUN_EXTENDED_TESTS=1
export RUN_LONG_TESTS=1
export RUN_STRESS_TESTS=1
# Disable specific features
export DISABLE_LEAK_DETECTION=1
# Customize test parameters
export TEST_MAX_CONCURRENCY=10
export TEST_MAX_ITERATIONS=50
export TEST_MEMORY_THRESHOLD_MB=25.5
```
---
## Test Categories
### Quick Tests (Default)
- Basic functionality verification
- Limited iterations (1-3)
- Small data sets
- Essential memory leak checks
**Configuration:**
- Max Iterations: 3
- Max Concurrency: 5
- Memory Threshold: 2.0 MB
- Timeout: 10 seconds
### Extended Tests
- Comprehensive testing before commits
- More iterations (5-10)
- Enhanced memory leak detection
**Configuration:**
- Max Iterations: 10
- Max Concurrency: 20
- Memory Threshold: 10.0 MB
- Timeout: 30 seconds
### Long Tests
- Performance validation
- High iteration counts (50-100)
- Large data sets
**Configuration:**
- Max Iterations: 100
- Max Concurrency: 50
- Memory Threshold: 50.0 MB
- Timeout: 60 seconds
### Stress Tests
- Maximum load testing
- Edge case validation
- Extreme parameters
**Configuration:**
- Max Iterations: 500
- Max Concurrency: 100
- Memory Threshold: 100.0 MB
- Timeout: 120 seconds
### Running Specific Test Suites
```bash
# Memory leak tests
go test -v -run='.*Leak.*' ./...
# Integration tests
go test -v -run='.*Integration.*' ./...
# Regression tests
go test -v -run='.*Regression.*' ./...
# Provider-specific tests
go test -v -run='.*Azure.*' ./...
go test -v -run='.*Google.*' ./...
```
### Benchmarks
```bash
# Quick benchmarks
go test -bench=. -short
# Extended benchmarks
RUN_EXTENDED_TESTS=1 go test -bench=.
# Memory profiling
go test -bench=. -memprofile=mem.prof
go tool pprof mem.prof
```
---
## CI/CD Pipeline
The repository uses GitHub Actions for comprehensive validation with 20+ parallel checks.
### Triggered On
- Pull requests to `main` branch
- Pushes to `main` branch
### Parallel Jobs
#### Code Quality (3 checks)
- **Format & Basic Checks** - gofmt, go vet, go mod
- **golangci-lint** - 30+ linters
- **Staticcheck** - Advanced static analysis
#### Security (3 checks)
- **Gosec** - Security vulnerability scanning
- **Govulncheck** - Go vulnerability database
- **CodeQL** - GitHub's semantic code analysis
#### Testing (9 suites)
- Race Detector
- Coverage (75% threshold)
- Memory Leaks
- Integration Tests
- Regression Tests
- Security Edge Cases
- Session Tests
- Token Tests
- CSRF Tests
#### Provider Testing (9 providers)
Tests run in parallel for:
- Google
- Azure AD
- Auth0
- Okta
- Keycloak
- AWS Cognito
- GitLab
- GitHub
- Generic OIDC
#### Performance & Build (3 checks)
- Benchmarks
- Multi-platform Build (linux/darwin x amd64/arm64)
- Go Version Compatibility (Go 1.23 & 1.24)
### Quality Gates
All PRs must pass:
- All parallel checks
- 75% test coverage minimum
- Zero security vulnerabilities
- No race conditions
- No memory leaks
- All providers tested
- Builds on all platforms
---
## Code Quality
### Pre-Commit Checklist
```bash
# Run before every commit
gofmt -s -w . && \
go mod tidy && \
golangci-lint run && \
go test -race -short ./... && \
echo "Ready to commit!"
```
### Local Validation
```bash
# Format code
gofmt -s -w .
# Run linter
golangci-lint run
# Static analysis
staticcheck ./...
# Security scan
gosec ./...
# Vulnerability check
govulncheck ./...
# Tests with race detector
go test -race -timeout=15m -count=1 ./...
# Coverage report
go test -coverprofile=coverage.out ./...
go tool cover -func=coverage.out
# View coverage in browser
go tool cover -html=coverage.out
```
### Troubleshooting
**Coverage Below Threshold:**
```bash
go test -coverprofile=coverage.out ./...
go tool cover -html=coverage.out # See uncovered lines
```
**Race Condition Found:**
```bash
go test -race -v -run=TestName ./...
```
**Linter Errors:**
```bash
golangci-lint run -v
golangci-lint run --fix # Auto-fix some issues
```
**Provider Test Fails:**
```bash
go test -v -run='.*Azure.*' ./...
```
---
## Contributing
### Development Guidelines
1. **Memory Management**: Ensure all goroutines can be cancelled and resources are bounded
2. **Testing**: Add tests for new features, including memory leak tests where appropriate
3. **Race Conditions**: Run tests with `-race` flag to detect race conditions
4. **Documentation**: Update README and configuration files for new options
### Pull Request Template
PRs should include:
- Description of changes
- Type of change (bug fix, feature, breaking change, etc.)
- Related issues
- Provider impact (which providers are affected)
- Testing performed
- Security considerations
- Performance impact
- Breaking changes (if any)
### Checklist
Before submitting:
- [ ] Code follows project style
- [ ] Self-review completed
- [ ] Tests added for new functionality
- [ ] All tests pass locally
- [ ] Documentation updated
- [ ] No new warnings generated
### Code Owners
The repository uses CODEOWNERS for automatic PR reviewer assignment based on file paths.
### Dependabot
Automated dependency updates run weekly (Mondays 9 AM) with security updates prioritized.
---
## Additional Resources
- [golangci-lint Rules](.golangci.yml)
- [PR Template](.github/PULL_REQUEST_TEMPLATE.md)
- [Workflow Documentation](.github/workflows/README.md)
- [GitHub Actions Documentation](https://docs.github.com/en/actions)
+582
View File
@@ -0,0 +1,582 @@
# OIDC Provider Configuration Guide
Configuration reference for each supported OIDC provider.
## Table of Contents
- [Provider Support Matrix](#provider-support-matrix)
- [Google](#google)
- [Microsoft Azure AD](#microsoft-azure-ad)
- [Auth0](#auth0)
- [Okta](#okta)
- [Keycloak](#keycloak)
- [AWS Cognito](#aws-cognito)
- [GitLab](#gitlab)
- [GitHub](#github)
- [Generic OIDC](#generic-oidc)
- [Automatic Scope Filtering](#automatic-scope-filtering)
---
## Provider Support Matrix
| Provider | OIDC Support | Refresh Tokens | Auto-Detection | ID Tokens |
|----------|-------------|----------------|----------------|-----------|
| Google | Full | Yes | `accounts.google.com` | Yes |
| Azure AD | Full | Yes | `login.microsoftonline.com` | Yes |
| Auth0 | Full | Yes | `*.auth0.com` | Yes |
| Okta | Full | Yes | `*.okta.com` | Yes |
| Keycloak | Full | Yes | `/auth/realms/` path | Yes |
| AWS Cognito | Full | Yes | `cognito-idp.*.amazonaws.com` | Yes |
| GitLab | Full | Yes | `gitlab.com` | Yes |
| GitHub | OAuth 2.0 Only | No | `github.com` | No |
| Generic | Full | Yes | Any OIDC endpoint | Yes |
---
## Google
### Provider URL
```yaml
providerURL: "https://accounts.google.com"
```
### Configuration
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-google
spec:
plugin:
traefikoidc:
providerURL: "https://accounts.google.com"
clientID: "your-id.apps.googleusercontent.com"
clientSecret: "your-client-secret"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "your-32-char-encryption-key-here"
scopes:
- openid
- email
- profile
allowedUserDomains:
- "your-gsuite-domain.com" # Optional: Workspace restriction
forceHttps: true
enablePkce: true
```
### Google-Specific Features
- **Automatic offline access**: Middleware adds `access_type=offline` and `prompt=consent`
- **Scope filtering**: Automatically removes unsupported `offline_access` scope
- **Workspace domains**: Restrict to specific Google Workspace domains via `hd` claim
### Google Cloud Console Setup
1. Go to [Google Cloud Console](https://console.cloud.google.com/)
2. Create or select a project
3. Navigate to APIs & Services > Credentials
4. Create OAuth 2.0 Client ID (Web application)
5. Add authorized redirect URI: `https://your-domain.com/oauth2/callback`
6. Configure OAuth consent screen (must be "Published" for production)
---
## Microsoft Azure AD
### Provider URL
```yaml
# Single tenant
providerURL: "https://login.microsoftonline.com/{tenant-id}/v2.0"
# Multi-tenant
providerURL: "https://login.microsoftonline.com/common/v2.0"
```
### Basic Configuration
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-azure
spec:
plugin:
traefikoidc:
providerURL: "https://login.microsoftonline.com/common/v2.0"
clientID: "your-azure-client-id"
clientSecret: "your-azure-client-secret"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "your-32-char-encryption-key-here"
scopes:
- openid
- profile
- email
- offline_access
allowedRolesAndGroups:
- "App.Users"
- "Admin.Group"
forceHttps: true
```
### With Application ID URI (API Access)
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-azure-api
spec:
plugin:
traefikoidc:
providerURL: "https://login.microsoftonline.com/common/v2.0"
clientID: "your-azure-client-id"
clientSecret: "your-azure-client-secret"
audience: "api://your-azure-client-id" # Application ID URI
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "your-32-char-encryption-key-here"
forceHttps: true
```
### Users Without Email
```yaml
userIdentifierClaim: sub # Options: sub, oid, upn, preferred_username
allowedUsers:
- "user-object-id-1"
- "user-object-id-2"
```
### Azure AD Setup
1. Go to [Azure Portal](https://portal.azure.com/)
2. Navigate to Azure Active Directory > App registrations
3. Create new registration
4. Add redirect URI: `https://your-domain.com/oauth2/callback`
5. Create client secret in Certificates & secrets
6. Configure Token Configuration for group claims
---
## Auth0
### Provider URL
```yaml
providerURL: "https://your-domain.auth0.com"
```
### Basic Configuration
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-auth0
spec:
plugin:
traefikoidc:
providerURL: "https://your-domain.auth0.com"
clientID: "your-auth0-client-id"
clientSecret: "your-auth0-client-secret"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "your-32-char-encryption-key-here"
scopes:
- openid
- profile
- email
- offline_access
postLogoutRedirectUri: "https://your-app.com"
forceHttps: true
enablePkce: true
```
### With Custom API Audience
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-auth0-api
spec:
plugin:
traefikoidc:
providerURL: "https://your-domain.auth0.com"
clientID: "your-auth0-client-id"
clientSecret: "your-auth0-client-secret"
audience: "https://api.your-domain.com" # API identifier
strictAudienceValidation: true
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "your-32-char-encryption-key-here"
roleClaimName: "https://your-app.com/roles" # Namespaced claim
groupClaimName: "https://your-app.com/groups"
allowedRolesAndGroups:
- admin
- editor
```
### Auth0 Action for Custom Claims
```javascript
exports.onExecutePostLogin = async (event, api) => {
const namespace = 'https://your-app.com/';
if (event.authorization) {
api.idToken.setCustomClaim(namespace + 'roles', event.authorization.roles);
api.idToken.setCustomClaim('email', event.user.email);
}
};
```
### Auth0 Setup
1. Go to [Auth0 Dashboard](https://manage.auth0.com/)
2. Create Regular Web Application
3. Configure Allowed Callback URLs: `https://your-domain.com/oauth2/callback`
4. Configure Allowed Logout URLs: `https://your-domain.com/oauth2/logout`
5. Enable OIDC Conformant in Advanced Settings
6. Create API in APIs section for custom audiences
See [AUTH0_AUDIENCE_GUIDE.md](AUTH0_AUDIENCE_GUIDE.md) for detailed audience configuration.
---
## Okta
### Provider URL
```yaml
providerURL: "https://your-domain.okta.com"
# Or with custom authorization server:
providerURL: "https://your-domain.okta.com/oauth2/default"
```
### Configuration
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-okta
spec:
plugin:
traefikoidc:
providerURL: "https://your-domain.okta.com"
clientID: "your-okta-client-id"
clientSecret: "your-okta-client-secret"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "your-32-char-encryption-key-here"
scopes:
- openid
- profile
- email
- groups
- offline_access
allowedRolesAndGroups:
- admin
- "Everyone"
forceHttps: true
enablePkce: true
```
### Okta Setup
1. Access Okta Admin Console
2. Go to Applications > Create App Integration
3. Select OIDC - OpenID Connect > Web Application
4. Configure Sign-in redirect URIs: `https://your-domain.com/oauth2/callback`
5. Configure Sign-out redirect URIs: `https://your-domain.com/oauth2/logout`
6. Enable Authorization Code and Refresh Token grant types
7. Configure Groups claim in authorization server
---
## Keycloak
### Provider URL
```yaml
providerURL: "https://keycloak.your-domain.com/realms/{realm-name}"
```
### Configuration
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-keycloak
spec:
plugin:
traefikoidc:
providerURL: "https://keycloak.company.com/realms/your-realm"
clientID: "your-keycloak-client-id"
clientSecret: "your-keycloak-client-secret"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "your-32-char-encryption-key-here"
scopes:
- openid
- profile
- email
- roles
- groups
- offline_access
allowedRolesAndGroups:
- admin
- editor
forceHttps: true
enablePkce: true
```
### Internal Network Deployment
For private IP addresses (Docker networks, Kubernetes):
```yaml
providerURL: "https://192.168.1.100:8443/realms/your-realm"
allowPrivateIPAddresses: true # Required for private IPs
```
### Keycloak Client Setup
1. Access Keycloak Admin Console
2. Select your realm
3. Go to Clients > Create client
4. Set Client Protocol: openid-connect
5. Set Access Type: confidential
6. Add Valid Redirect URIs: `https://your-domain.com/oauth2/callback`
7. Generate client secret in Credentials tab
8. Configure mappers to add claims to ID Token:
- Email: User Property mapper with "Add to ID token" enabled
- Roles: User Client Role mapper with "Add to ID token" enabled
- Groups: Group Membership mapper with "Add to ID token" enabled
See [KEYCLOAK_SETUP_GUIDE.md](KEYCLOAK_SETUP_GUIDE.md) for detailed step-by-step setup instructions, mapper configuration, troubleshooting, and performance optimization.
---
## AWS Cognito
### Provider URL
```yaml
providerURL: "https://cognito-idp.{region}.amazonaws.com/{user-pool-id}"
```
### Configuration
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-cognito
spec:
plugin:
traefikoidc:
providerURL: "https://cognito-idp.us-east-1.amazonaws.com/us-east-1_ABCDEF123"
clientID: "your-cognito-client-id"
clientSecret: "your-cognito-client-secret"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "your-32-char-encryption-key-here"
scopes:
- openid
- profile
- email
- aws.cognito.signin.user.admin
allowedRolesAndGroups:
- admin
- users
forceHttps: true
```
### AWS Cognito Setup
1. Create Cognito User Pool
2. Create App Client with OIDC scopes
3. Configure App Client settings:
- Callback URLs: `https://your-domain.com/oauth2/callback`
- Sign out URLs: `https://your-domain.com/oauth2/logout`
- OAuth flows: Authorization code grant
4. Configure hosted UI domain (optional)
5. Set up groups for role-based access
---
## GitLab
### Provider URL
```yaml
# GitLab.com
providerURL: "https://gitlab.com"
# Self-hosted
providerURL: "https://gitlab.your-company.com"
```
### Configuration
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-gitlab
spec:
plugin:
traefikoidc:
providerURL: "https://gitlab.com"
clientID: "your-gitlab-application-id"
clientSecret: "your-gitlab-application-secret"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "your-32-char-encryption-key-here"
scopes:
- openid
- profile
- email
# Note: GitLab doesn't require offline_access scope
# Refresh tokens are issued automatically with openid
allowedRolesAndGroups:
- developers
- maintainers
forceHttps: true
enablePkce: true
```
### GitLab Setup
1. Go to GitLab Settings > Applications
2. Create new application
3. Add scopes: `openid`, `profile`, `email`
4. Set redirect URI: `https://your-domain.com/oauth2/callback`
5. Save and note Application ID and Secret
---
## GitHub
### Provider URL
```yaml
providerURL: "https://github.com"
```
### Configuration
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oauth-github
spec:
plugin:
traefikoidc:
providerURL: "https://github.com/login/oauth"
clientID: "your-github-client-id"
clientSecret: "your-github-client-secret"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "your-32-char-encryption-key-here"
scopes:
- user:email
- read:user
allowedUsers:
- "github-username"
forceHttps: true
```
### Limitations
- **OAuth 2.0 only** - Not OpenID Connect
- **No ID tokens** - Only access tokens for API calls
- **No refresh tokens** - Users must re-authenticate on expiry
- **No standard claims** - User info requires API calls
Use GitHub only for API access, not for user authentication with claims.
### GitHub Setup
1. Go to GitHub Settings > Developer settings > OAuth Apps
2. Create new OAuth App
3. Set Authorization callback URL: `https://your-domain.com/oauth2/callback`
4. Note Client ID and generate Client Secret
---
## Generic OIDC
For any OIDC-compliant provider not listed above.
### Configuration
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-generic
spec:
plugin:
traefikoidc:
providerURL: "https://oidc.your-provider.com"
clientID: "your-client-id"
clientSecret: "your-client-secret"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "your-32-char-encryption-key-here"
scopes:
- openid
- profile
- email
forceHttps: true
enablePkce: true
```
### Requirements
- Provider must expose `.well-known/openid-configuration` endpoint
- Must support authorization code flow
- ID tokens must contain required claims (email, sub, etc.)
---
## Automatic Scope Filtering
The middleware automatically filters OAuth scopes based on the provider's declared capabilities.
### How It Works
1. Fetches provider's `.well-known/openid-configuration`
2. Extracts `scopes_supported` field
3. Filters requested scopes to only include supported ones
4. Falls back to all requested scopes if provider doesn't declare supported scopes
### Example: Self-Hosted GitLab
Self-hosted GitLab may reject `offline_access` scope:
```yaml
scopes:
- openid
- profile
- email
- offline_access # Will be automatically filtered out if unsupported
```
The middleware will:
1. Read GitLab's discovery document
2. Detect `offline_access` is NOT in `scopes_supported`
3. Filter it out automatically
4. Authentication succeeds
### Logging
```
INFO: ScopeFilter: Filtered unsupported scopes: [offline_access]
DEBUG: ScopeFilter: Final filtered scopes: [openid profile email]
```
### Troubleshooting
If a provider rejects scopes even after filtering:
1. Check the provider's discovery document: `curl https://provider/.well-known/openid-configuration`
2. Use `overrideScopes: true` with only supported scopes
3. Review middleware debug logs for filtering decisions
-770
View File
@@ -1,770 +0,0 @@
# Provider-Specific Configuration Guide
This guide covers the configuration requirements and best practices for each supported OIDC provider.
## Table of Contents
- [Google](#google)
- [Microsoft Azure AD](#microsoft-azure-ad)
- [Auth0](#auth0)
- [GitHub](#github)
- [GitLab](#gitlab)
- [AWS Cognito](#aws-cognito)
- [Keycloak](#keycloak)
- [Okta](#okta)
- [Generic OIDC](#generic-oidc)
---
## Google
### Provider URL
```yaml
providerUrl: "https://accounts.google.com"
```
### Required Configuration
```yaml
clientId: "your-google-client-id.apps.googleusercontent.com"
clientSecret: "your-google-client-secret"
callbackUrl: "https://your-domain.com/auth/callback"
scopes: ["openid", "profile", "email"]
```
### Google-Specific Features
- **Automatic offline access**: Google provider automatically adds `access_type=offline` and `prompt=consent`
- **Scope filtering**: Automatically removes `offline_access` scope (not used by Google)
- **Refresh token support**: Fully supported
- **Domain restrictions**: Can restrict by Google Workspace domains
### Example Configuration
```yaml
# Traefik dynamic configuration
http:
middlewares:
google-oidc:
plugin:
traefik-oidc:
providerUrl: "https://accounts.google.com"
clientId: "123456789-abcdef.apps.googleusercontent.com"
clientSecret: "GOCSPX-your-client-secret"
callbackUrl: "https://app.example.com/auth/callback"
logoutUrl: "https://app.example.com/auth/logout"
scopes: ["openid", "profile", "email"]
allowedUserDomains: ["example.com", "company.org"]
forceHttps: true
enablePkce: true
```
### Google OAuth Console Setup
1. Go to [Google Cloud Console](https://console.cloud.google.com/)
2. Create or select a project
3. Enable Google+ API
4. Create OAuth 2.0 credentials
5. Add authorized redirect URIs: `https://your-domain.com/auth/callback`
---
## Microsoft Azure AD
### Provider URL
```yaml
# For Azure AD (single tenant)
providerUrl: "https://login.microsoftonline.com/{tenant-id}/v2.0"
# For Azure AD (multi-tenant)
providerUrl: "https://login.microsoftonline.com/common/v2.0"
```
### Required Configuration
```yaml
clientId: "your-azure-application-id"
clientSecret: "your-azure-client-secret"
callbackUrl: "https://your-domain.com/auth/callback"
scopes: ["openid", "profile", "email", "offline_access"]
```
### Azure-Specific Features
- **Response mode**: Automatically adds `response_mode=query`
- **Offline access**: Requires `offline_access` scope for refresh tokens
- **Access token validation**: Supports both JWT and opaque access tokens
- **Tenant isolation**: Can restrict to specific Azure AD tenants
### Example Configuration
```yaml
http:
middlewares:
azure-oidc:
plugin:
traefik-oidc:
providerUrl: "https://login.microsoftonline.com/common/v2.0"
clientId: "12345678-1234-1234-1234-123456789abc"
clientSecret: "your-azure-client-secret"
callbackUrl: "https://app.example.com/auth/callback"
logoutUrl: "https://app.example.com/auth/logout"
postLogoutRedirectUri: "https://app.example.com"
scopes: ["openid", "profile", "email", "offline_access"]
allowedRolesAndGroups: ["App.Users", "Admin.Group"]
forceHttps: true
```
### Azure App Registration Setup
1. Go to [Azure Portal](https://portal.azure.com/)
2. Navigate to "Azure Active Directory" > "App registrations"
3. Create new registration
4. Add redirect URI: `https://your-domain.com/auth/callback`
5. Create client secret in "Certificates & secrets"
6. Configure API permissions for required scopes
---
## Auth0
### Provider URL
```yaml
providerUrl: "https://your-domain.auth0.com"
```
### Required Configuration
```yaml
clientId: "your-auth0-client-id"
clientSecret: "your-auth0-client-secret"
callbackUrl: "https://your-domain.com/auth/callback"
scopes: ["openid", "profile", "email", "offline_access"]
```
### Auth0-Specific Features
- **Custom domains**: Supports Auth0 custom domains
- **Rules and hooks**: Leverages Auth0's extensibility
- **Social connections**: Works with Auth0's social identity providers
- **Offline access**: Requires `offline_access` scope
### Example Configuration
```yaml
http:
middlewares:
auth0-oidc:
plugin:
traefik-oidc:
providerUrl: "https://company.auth0.com"
clientId: "abcdef123456789"
clientSecret: "your-auth0-client-secret"
callbackUrl: "https://app.example.com/auth/callback"
logoutUrl: "https://app.example.com/auth/logout"
postLogoutRedirectUri: "https://app.example.com"
scopes: ["openid", "profile", "email", "offline_access"]
allowedUsers: ["user@example.com", "admin@company.com"]
forceHttps: true
enablePkce: true
```
### Auth0 Application Setup
1. Go to [Auth0 Dashboard](https://manage.auth0.com/)
2. Create new application (Regular Web Application)
3. Configure allowed callback URLs: `https://your-domain.com/auth/callback`
4. Configure allowed logout URLs: `https://your-domain.com/auth/logout`
5. Enable OIDC Conformant in Advanced Settings
---
## GitHub
### Provider URL
```yaml
providerUrl: "https://github.com"
```
### Required Configuration
```yaml
clientId: "your-github-client-id"
clientSecret: "your-github-client-secret"
callbackUrl: "https://your-domain.com/auth/callback"
scopes: ["read:user", "user:email"]
```
### GitHub-Specific Features
- **Organization membership**: Can restrict by GitHub organization
- **Team membership**: Can restrict by specific teams
- **Limited OIDC**: GitHub has limited OIDC support
- **Email verification**: Requires verified email addresses
### Example Configuration
```yaml
http:
middlewares:
github-oidc:
plugin:
traefik-oidc:
providerUrl: "https://github.com"
clientId: "Iv1.abcdef123456"
clientSecret: "your-github-client-secret"
callbackUrl: "https://app.example.com/auth/callback"
logoutUrl: "https://app.example.com/auth/logout"
scopes: ["read:user", "user:email"]
allowedUsers: ["octocat", "github-user"]
forceHttps: true
```
### GitHub OAuth App Setup
1. Go to GitHub Settings > Developer settings > OAuth Apps
2. Create new OAuth App
3. Set Authorization callback URL: `https://your-domain.com/auth/callback`
4. Note the Client ID and generate Client Secret
---
## GitLab
### Provider URL
```yaml
# GitLab.com
providerUrl: "https://gitlab.com"
# Self-hosted GitLab
providerUrl: "https://gitlab.your-company.com"
```
### Required Configuration
```yaml
clientId: "your-gitlab-application-id"
clientSecret: "your-gitlab-application-secret"
callbackUrl: "https://your-domain.com/auth/callback"
scopes: ["openid", "profile", "email"]
```
### GitLab-Specific Features
- **Self-hosted support**: Works with self-hosted GitLab instances
- **Group membership**: Can restrict by GitLab groups
- **Project access**: Can validate project permissions
- **Offline access**: Supports refresh tokens with `offline_access`
### Example Configuration
```yaml
http:
middlewares:
gitlab-oidc:
plugin:
traefik-oidc:
providerUrl: "https://gitlab.com"
clientId: "abcdef123456789"
clientSecret: "your-gitlab-application-secret"
callbackUrl: "https://app.example.com/auth/callback"
logoutUrl: "https://app.example.com/auth/logout"
scopes: ["openid", "profile", "email", "offline_access"]
allowedRolesAndGroups: ["developers", "maintainers"]
forceHttps: true
enablePkce: true
```
### GitLab Application Setup
1. Go to GitLab Settings > Applications
2. Create new application
3. Add scopes: `openid`, `profile`, `email`
4. Set redirect URI: `https://your-domain.com/auth/callback`
5. Save and note the Application ID and Secret
---
## AWS Cognito
### Provider URL
```yaml
providerUrl: "https://cognito-idp.{region}.amazonaws.com/{user-pool-id}"
```
### Required Configuration
```yaml
clientId: "your-cognito-app-client-id"
clientSecret: "your-cognito-app-client-secret" # If app client has secret
callbackUrl: "https://your-domain.com/auth/callback"
scopes: ["openid", "profile", "email"]
```
### Cognito-Specific Features
- **User pools**: Integrates with Cognito User Pools
- **Custom attributes**: Supports custom user attributes
- **Groups**: Can validate Cognito user group membership
- **Regional endpoints**: Requires region-specific URLs
### Example Configuration
```yaml
http:
middlewares:
cognito-oidc:
plugin:
traefik-oidc:
providerUrl: "https://cognito-idp.us-east-1.amazonaws.com/us-east-1_ABCDEF123"
clientId: "1234567890abcdefghij"
clientSecret: "your-cognito-client-secret"
callbackUrl: "https://app.example.com/auth/callback"
logoutUrl: "https://app.example.com/auth/logout"
scopes: ["openid", "profile", "email"]
allowedRolesAndGroups: ["admin", "users"]
forceHttps: true
```
### AWS Cognito Setup
1. Create Cognito User Pool
2. Create App Client with OIDC scopes
3. Configure App Client settings:
- Callback URLs: `https://your-domain.com/auth/callback`
- Sign out URLs: `https://your-domain.com/auth/logout`
- OAuth flows: Authorization code grant
4. Configure hosted UI domain (optional)
---
## Keycloak
### Provider URL
```yaml
providerUrl: "https://keycloak.your-company.com/realms/{realm-name}"
```
### Required Configuration
```yaml
clientId: "your-keycloak-client-id"
clientSecret: "your-keycloak-client-secret"
callbackUrl: "https://your-domain.com/auth/callback"
scopes: ["openid", "profile", "email"]
```
### Keycloak-Specific Features
- **Realm support**: Multi-realm deployments
- **Custom mappers**: Rich claim mapping capabilities
- **Role-based access**: Fine-grained role management
- **Offline access**: Full refresh token support
### Example Configuration
```yaml
http:
middlewares:
keycloak-oidc:
plugin:
traefik-oidc:
providerUrl: "https://keycloak.company.com/realms/employees"
clientId: "traefik-app"
clientSecret: "your-keycloak-client-secret"
callbackUrl: "https://app.example.com/auth/callback"
logoutUrl: "https://app.example.com/auth/logout"
postLogoutRedirectUri: "https://app.example.com"
scopes: ["openid", "profile", "email", "offline_access"]
allowedRolesAndGroups: ["app-users", "administrators"]
forceHttps: true
enablePkce: true
```
### Keycloak Client Setup
1. Access Keycloak Admin Console
2. Select appropriate realm
3. Create new client:
- Client Protocol: openid-connect
- Access Type: confidential
- Valid Redirect URIs: `https://your-domain.com/auth/callback`
4. Configure client scopes and mappers
5. Generate client secret in Credentials tab
---
## Okta
### Provider URL
```yaml
providerUrl: "https://your-domain.okta.com"
```
### Required Configuration
```yaml
clientId: "your-okta-client-id"
clientSecret: "your-okta-client-secret"
callbackUrl: "https://your-domain.com/auth/callback"
scopes: ["openid", "profile", "email", "offline_access"]
```
### Okta-Specific Features
- **Custom authorization servers**: Supports custom auth servers
- **Group claims**: Rich group membership information
- **Universal Directory**: Integrates with Okta's user store
- **Offline access**: Requires `offline_access` scope
### Example Configuration
```yaml
http:
middlewares:
okta-oidc:
plugin:
traefik-oidc:
providerUrl: "https://company.okta.com"
clientId: "0oa123456789abcdef"
clientSecret: "your-okta-client-secret"
callbackUrl: "https://app.example.com/auth/callback"
logoutUrl: "https://app.example.com/auth/logout"
postLogoutRedirectUri: "https://app.example.com"
scopes: ["openid", "profile", "email", "offline_access"]
allowedRolesAndGroups: ["Everyone", "Administrators"]
forceHttps: true
enablePkce: true
```
### Okta Application Setup
1. Access Okta Admin Console
2. Go to Applications > Create App Integration
3. Select OIDC - OpenID Connect
4. Choose Web Application
5. Configure:
- Sign-in redirect URIs: `https://your-domain.com/auth/callback`
- Sign-out redirect URIs: `https://your-domain.com/auth/logout`
- Grant types: Authorization Code, Refresh Token
6. Assign users or groups
---
## Generic OIDC
### Provider URL
```yaml
providerUrl: "https://your-oidc-provider.com"
```
### Required Configuration
```yaml
clientId: "your-client-id"
clientSecret: "your-client-secret"
callbackUrl: "https://your-domain.com/auth/callback"
scopes: ["openid", "profile", "email"]
```
### Generic Features
- **Standards compliance**: Works with any OIDC-compliant provider
- **Auto-discovery**: Uses `.well-known/openid-configuration` endpoint
- **Flexible scopes**: Supports custom scope requirements
- **Custom claims**: Works with provider-specific claims
### Example Configuration
```yaml
http:
middlewares:
generic-oidc:
plugin:
traefik-oidc:
providerUrl: "https://oidc.your-provider.com"
clientId: "your-client-id"
clientSecret: "your-client-secret"
callbackUrl: "https://app.example.com/auth/callback"
logoutUrl: "https://app.example.com/auth/logout"
scopes: ["openid", "profile", "email"]
forceHttps: true
enablePkce: true
```
---
## Common Configuration Options
### Security Settings
```yaml
# Force HTTPS (recommended for production)
forceHttps: true
# Enable PKCE (recommended for security)
enablePkce: true
# Session encryption key (32+ characters)
sessionEncryptionKey: "your-very-long-encryption-key-here"
```
### Access Control
```yaml
# Restrict by email addresses
allowedUsers: ["user1@example.com", "user2@example.com"]
# Restrict by email domains
allowedUserDomains: ["company.com", "partner.org"]
# Restrict by roles/groups (provider-specific)
allowedRolesAndGroups: ["admin", "users", "developers"]
```
### URLs and Endpoints
```yaml
# OAuth callback URL (must match provider config)
callbackUrl: "https://your-domain.com/auth/callback"
# Logout endpoint
logoutUrl: "https://your-domain.com/auth/logout"
# Post-logout redirect (optional)
postLogoutRedirectUri: "https://your-domain.com"
# URLs to exclude from authentication
excludedUrls: ["/health", "/metrics", "/public"]
```
### Advanced Settings
```yaml
# Override default scopes
overrideScopes: true
scopes: ["openid", "custom_scope"]
# Rate limiting (requests per second)
rateLimit: 10
# Token refresh grace period (seconds)
refreshGracePeriodSeconds: 60
# Cookie domain (for subdomain sharing)
cookieDomain: ".example.com"
# Custom headers to inject
headers:
- name: "X-User-Email"
value: "{{.Claims.email}}"
- name: "X-User-Name"
value: "{{.Claims.name}}"
```
---
## Troubleshooting
### Common Issues
1. **Invalid redirect URI**
- Ensure callback URL exactly matches provider configuration
- Check for HTTP vs HTTPS mismatches
2. **Scope errors**
- Verify required scopes are configured in provider
- Some providers require specific scopes for refresh tokens
3. **Token validation failures**
- Check provider URL format and accessibility
- Verify `.well-known/openid-configuration` endpoint is reachable
4. **Session issues**
- Ensure session encryption key is properly configured
- Check cookie domain settings for subdomain scenarios
### Debug Mode
Enable debug logging to troubleshoot configuration issues:
```yaml
logLevel: "debug"
```
This will provide detailed logs of the authentication flow and help identify configuration problems.
---
## Security Headers Configuration
The plugin includes comprehensive security headers support to protect your applications against common web vulnerabilities.
### Default Security Headers
By default, the plugin applies these security headers:
- `X-Frame-Options: DENY` - Prevents clickjacking
- `X-Content-Type-Options: nosniff` - Prevents MIME sniffing
- `X-XSS-Protection: 1; mode=block` - Enables XSS protection
- `Referrer-Policy: strict-origin-when-cross-origin` - Controls referrer information
- `Strict-Transport-Security` - Forces HTTPS (when HTTPS is detected)
### Security Profiles
Choose from predefined security profiles or create custom configurations:
#### Default Profile (Recommended)
```yaml
securityHeaders:
enabled: true
profile: "default"
```
#### Strict Profile (Maximum Security)
```yaml
securityHeaders:
enabled: true
profile: "strict"
# Additional strict CSP and cross-origin policies
```
#### Development Profile (Local Development)
```yaml
securityHeaders:
enabled: true
profile: "development"
# Relaxed policies for local development
```
#### API Profile (API Endpoints)
```yaml
securityHeaders:
enabled: true
profile: "api"
corsEnabled: true
corsAllowedOrigins: ["https://your-frontend.com"]
```
### Custom Security Configuration
For complete control, use the custom profile:
```yaml
securityHeaders:
enabled: true
profile: "custom"
# Content Security Policy
contentSecurityPolicy: "default-src 'self'; script-src 'self' 'unsafe-inline'"
# HSTS Configuration
strictTransportSecurity: true
strictTransportSecurityMaxAge: 31536000 # 1 year
strictTransportSecuritySubdomains: true
strictTransportSecurityPreload: true
# Frame and content protection
frameOptions: "DENY" # or "SAMEORIGIN", "ALLOW-FROM uri"
contentTypeOptions: "nosniff"
xssProtection: "1; mode=block"
referrerPolicy: "strict-origin-when-cross-origin"
# Permissions policy (feature policy)
permissionsPolicy: "geolocation=(), microphone=(), camera=()"
# Cross-origin policies
crossOriginEmbedderPolicy: "require-corp"
crossOriginOpenerPolicy: "same-origin"
crossOriginResourcePolicy: "same-origin"
# CORS configuration
corsEnabled: true
corsAllowedOrigins:
- "https://app.example.com"
- "https://*.api.example.com"
corsAllowedMethods: ["GET", "POST", "PUT", "DELETE", "OPTIONS"]
corsAllowedHeaders: ["Authorization", "Content-Type", "X-Requested-With"]
corsAllowCredentials: true
corsMaxAge: 86400 # 24 hours
# Custom headers
customHeaders:
X-Custom-Header: "custom-value"
X-API-Version: "v1"
# Server identification
disableServerHeader: true
disablePoweredByHeader: true
```
### Complete Example with Security Headers
Here's a complete configuration example for Google OIDC with custom security headers:
```yaml
# Traefik dynamic configuration
http:
middlewares:
secure-google-oidc:
plugin:
traefik-oidc:
# OIDC Configuration
providerUrl: "https://accounts.google.com"
clientId: "123456789-abcdef.apps.googleusercontent.com"
clientSecret: "GOCSPX-your-client-secret"
callbackUrl: "https://your-domain.com/auth/callback"
sessionEncryptionKey: "your-32-character-encryption-key-here"
# Domain restrictions
allowedUserDomains: ["your-company.com"]
# Security Headers
securityHeaders:
enabled: true
profile: "strict"
corsEnabled: true
corsAllowedOrigins:
- "https://your-frontend.com"
- "https://*.your-domain.com"
corsAllowCredentials: true
customHeaders:
X-Company: "YourCompany"
X-Environment: "production"
routers:
secure-app:
rule: "Host(`your-domain.com`)"
middlewares:
- secure-google-oidc
service: your-app-service
tls:
certResolver: letsencrypt
```
### CORS Configuration Details
For applications with frontend-backend separation, configure CORS properly:
#### Simple CORS (Single Origin)
```yaml
securityHeaders:
corsEnabled: true
corsAllowedOrigins: ["https://app.example.com"]
corsAllowCredentials: true
```
#### Wildcard Subdomains
```yaml
securityHeaders:
corsEnabled: true
corsAllowedOrigins: ["https://*.example.com"]
corsAllowCredentials: true
```
#### Development with Multiple Ports
```yaml
securityHeaders:
profile: "development"
corsEnabled: true
corsAllowedOrigins:
- "http://localhost:*"
- "http://127.0.0.1:*"
```
### Security Best Practices
1. **Always use HTTPS in production**
- Set `forceHttps: true`
- Configure proper TLS certificates
2. **Implement proper CSP**
- Start with strict policy
- Add exceptions only when necessary
- Test thoroughly
3. **Configure CORS restrictively**
- Only allow necessary origins
- Use specific domains instead of wildcards when possible
4. **Enable HSTS**
- Use long max-age values (1 year minimum)
- Include subdomains when appropriate
5. **Monitor security headers**
- Use browser developer tools to verify headers
- Test with security scanning tools
- Regularly review and update policies
### Testing Security Headers
Use browser developer tools or online tools to verify your security headers:
1. **Browser DevTools**: Check Network tab → Response Headers
2. **Online scanners**: Use securityheaders.com or observatory.mozilla.org
3. **Command line**: Use `curl -I https://your-domain.com`
Example verification:
```bash
curl -I https://your-domain.com
# Should show security headers in response
```
+546
View File
@@ -0,0 +1,546 @@
# Redis Cache for Distributed Deployments
Redis cache support for multi-replica Traefik deployments with shared state.
## Table of Contents
- [Overview](#overview)
- [Why Use Redis Cache?](#why-use-redis-cache)
- [Configuration](#configuration)
- [Cache Modes](#cache-modes)
- [Deployment Examples](#deployment-examples)
- [Performance Tuning](#performance-tuning)
- [Monitoring](#monitoring)
- [Troubleshooting](#troubleshooting)
- [Migration Guide](#migration-guide)
---
## Overview
The Redis cache feature provides distributed caching for the Traefik OIDC plugin, enabling seamless operation across multiple Traefik instances.
### Key Features
- **Distributed JTI Replay Detection**: Prevents token replay attacks across all instances
- **Shared Session Management**: Consistent user sessions across replicas
- **Circuit Breaker**: Automatic fallback to memory cache during Redis outages
- **Health Checking**: Continuous monitoring of Redis connectivity
- **Flexible Cache Modes**: Memory, Redis, or hybrid caching strategies
- **Pure-Go Implementation**: Yaegi-compatible, works with dynamic plugin loading
### Architecture
```
┌──────────────┐ ┌──────────────┐ ┌──────────────┐
│ Traefik #1 │ │ Traefik #2 │ │ Traefik #3
│ (Plugin) │ │ (Plugin) │ │ (Plugin) │
└──────┬───────┘ └──────┬───────┘ └──────┬───────┘
│ │ │
└────────────────────┼────────────────────┘
┌──────▼──────┐
│ Redis │
│ (Shared │
│ Cache) │
└─────────────┘
```
---
## Why Use Redis Cache?
### The Problem
When running multiple Traefik instances without shared cache:
1. **False Positive Replay Detection**
- User authenticates → Token stored in Instance A's JTI cache
- Next request → Load balancer routes to Instance B
- Instance B doesn't have the JTI → Falsely detects replay attack
2. **Session Inconsistency**
- User session created on Instance A
- Subsequent request routed to Instance B
- Instance B has no knowledge of the session
3. **Token Metadata Fragmentation**
- Token refresh happens on Instance A
- Other instances continue using old tokens
### The Solution
Redis provides centralized cache that all instances share, ensuring:
- **Consistent Authentication**: All instances share authentication state
- **True Replay Detection**: JTI cache shared across all instances
- **Seamless Scaling**: Add/remove instances without affecting sessions
- **High Availability**: Circuit breaker with automatic fallback
---
## Configuration
### Basic Configuration
```yaml
redis:
enabled: true
address: "redis:6379"
password: "your-password" # Optional
db: 0
keyPrefix: "traefikoidc:"
cacheMode: "hybrid"
```
### All Configuration Options
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `enabled` | bool | `false` | Enable Redis caching |
| `address` | string | - | Redis server address (`host:port`) |
| `password` | string | - | Redis password (optional) |
| `db` | int | `0` | Redis database number (0-15) |
| `keyPrefix` | string | `traefikoidc:` | Prefix for all Redis keys |
| `cacheMode` | string | `redis` | Cache mode: `memory`, `redis`, `hybrid` |
| `poolSize` | int | `10` | Connection pool size |
| `connectTimeout` | int | `5` | Connection timeout (seconds) |
| `readTimeout` | int | `3` | Read timeout (seconds) |
| `writeTimeout` | int | `3` | Write timeout (seconds) |
| `enableTLS` | bool | `false` | Enable TLS for connections |
| `tlsSkipVerify` | bool | `false` | Skip TLS certificate verification |
| `enableCircuitBreaker` | bool | `true` | Enable circuit breaker |
| `circuitBreakerThreshold` | int | `5` | Failures before circuit opens |
| `circuitBreakerTimeout` | int | `60` | Circuit reset timeout (seconds) |
| `enableHealthCheck` | bool | `true` | Enable periodic health checks |
| `healthCheckInterval` | int | `30` | Health check interval (seconds) |
| `hybridL1Size` | int | `500` | Max items in L1 cache (hybrid mode) |
| `hybridL1MemoryMB` | int64 | `10` | Max memory for L1 cache in MB |
### Environment Variables (Fallback)
If not configured through Traefik, these environment variables are used:
```bash
REDIS_ENABLED=true
REDIS_ADDRESS=redis:6379
REDIS_PASSWORD=your-password
REDIS_DB=0
REDIS_KEY_PREFIX=traefikoidc:
REDIS_CACHE_MODE=hybrid
REDIS_POOL_SIZE=10
REDIS_CONNECT_TIMEOUT=5
REDIS_READ_TIMEOUT=3
REDIS_WRITE_TIMEOUT=3
REDIS_ENABLE_TLS=false
REDIS_TLS_SKIP_VERIFY=false
```
---
## Cache Modes
### Memory Mode (Default without Redis)
```yaml
redis:
cacheMode: "memory"
```
- Uses only in-memory cache
- Suitable for single-instance deployments
- No Redis dependency
- Fastest performance
### Redis Mode
```yaml
redis:
enabled: true
address: "redis:6379"
cacheMode: "redis"
```
- All operations go directly to Redis
- Ensures consistency across replicas
- Slightly higher latency
### Hybrid Mode (Recommended)
```yaml
redis:
enabled: true
address: "redis:6379"
cacheMode: "hybrid"
```
Two-tier caching strategy:
```
┌─────────────────────────────────────────┐
│ Client Request │
└────────────────┬────────────────────────┘
┌────────────────┐
│ Local Cache │ ← L1 Cache (Fast)
│ (Memory) │
└────────┬───────┘
│ Miss
┌────────────────┐
│ Remote Cache │ ← L2 Cache (Shared)
│ (Redis) │
└────────────────┘
```
**Read Path:**
1. Check local memory cache (L1)
2. On miss, check Redis (L2)
3. On hit in Redis, populate L1
4. Return value
**Write Path:**
1. Write to Redis (L2) for durability
2. Write to local cache (L1) for speed
### Performance Comparison
| Operation | Memory Mode | Redis Mode | Hybrid Mode |
|-----------|------------|------------|-------------|
| Read (p50) | 0.1ms | 2ms | 0.2ms |
| Read (p99) | 0.5ms | 10ms | 5ms |
| Write (p50) | 0.2ms | 3ms | 3ms |
| Throughput | 100k/s | 20k/s | 80k/s |
---
## Deployment Examples
### Docker Compose
```yaml
version: '3.8'
services:
redis:
image: redis:7-alpine
command: redis-server --requirepass ${REDIS_PASSWORD}
volumes:
- redis-data:/data
healthcheck:
test: ["CMD", "redis-cli", "--raw", "incr", "ping"]
interval: 30s
timeout: 3s
retries: 3
traefik:
image: traefik:v3.2
deploy:
replicas: 3
labels:
- "traefik.http.middlewares.oidc.plugin.traefikoidc.redis.enabled=true"
- "traefik.http.middlewares.oidc.plugin.traefikoidc.redis.address=redis:6379"
- "traefik.http.middlewares.oidc.plugin.traefikoidc.redis.password=${REDIS_PASSWORD}"
- "traefik.http.middlewares.oidc.plugin.traefikoidc.redis.cacheMode=hybrid"
depends_on:
redis:
condition: service_healthy
volumes:
redis-data:
```
### Kubernetes
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-with-redis
spec:
plugin:
traefikoidc:
providerURL: https://accounts.google.com
clientID: your-client-id
clientSecret: your-client-secret
sessionEncryptionKey: your-encryption-key
callbackURL: /oauth2/callback
redis:
enabled: true
address: "redis-service.redis-namespace:6379"
password: "urn:k8s:secret:redis-secret:password"
db: 0
keyPrefix: "traefikoidc:"
cacheMode: "hybrid"
poolSize: 20
enableCircuitBreaker: true
circuitBreakerThreshold: 5
```
### AWS ElastiCache
```yaml
redis:
enabled: true
address: "your-cache.abc123.cache.amazonaws.com:6379"
cacheMode: "hybrid"
enableTLS: true
password: "your-elasticache-auth-token"
```
---
## Performance Tuning
### Connection Pool Sizing
```yaml
redis:
poolSize: 20 # Formula: 2 * CPU cores * replicas
# For 4 cores, 3 replicas: poolSize = 24
```
### TTL Strategy
The plugin automatically sets TTLs based on token lifetimes:
- **JTI Cache**: Matches token lifetime (typically 1 hour)
- **Session**: Matches `sessionMaxAge` configuration
- **Token Metadata**: 5 minutes (short-lived)
### Redis Server Configuration
```bash
# Recommended Redis settings for cache
maxmemory 512mb
maxmemory-policy allkeys-lru # Evict least recently used
# For cache data, disable persistence for better performance
save ""
appendonly no
```
### Hybrid Mode Tuning
```yaml
redis:
cacheMode: "hybrid"
hybridL1Size: 500 # Max items in local cache
hybridL1MemoryMB: 10 # Max memory for local cache
```
---
## Monitoring
### Key Metrics
- **Cache hit rate** (target: >90% for hybrid mode)
- **Redis latency** (target: <10ms p99)
- **Circuit breaker state**
- **Connection pool utilization
### Redis Commands for Monitoring
```bash
# Monitor commands in real-time
redis-cli MONITOR
# Check slow queries
redis-cli SLOWLOG GET 10
# Memory usage
redis-cli INFO memory
# Key statistics
redis-cli DBSIZE
# List keys with prefix
redis-cli --scan --pattern "traefikoidc:*"
# Check key TTL
redis-cli TTL "traefikoidc:session:abc123"
```
### Health Check Endpoint
The plugin provides health information including:
```json
{
"status": "healthy",
"cache": {
"mode": "hybrid",
"redis": {
"connected": true,
"latency": "2ms"
},
"circuit_breaker": {
"state": "closed",
"failures": 0
}
}
}
```
---
## Troubleshooting
### Connection Refused
**Symptoms:** `dial tcp: connection refused`
**Solutions:**
1. Verify Redis is running: `redis-cli ping`
2. Check network connectivity: `telnet redis-host 6379`
3. Verify address configuration
### Authentication Failure
**Symptoms:** `NOAUTH Authentication required`
**Solutions:**
1. Set Redis password in configuration
2. Verify password is correct
### Circuit Breaker Open
**Symptoms:** `Circuit breaker is open`, falling back to memory
**Solutions:**
1. Check Redis health: `redis-cli INFO server`
2. Review network latency: `redis-cli --latency`
3. Adjust circuit breaker thresholds if needed
### High Memory Usage
**Symptoms:** Redis memory constantly growing, OOM errors
**Solutions:**
1. Configure eviction policy:
```bash
CONFIG SET maxmemory 512mb
CONFIG SET maxmemory-policy allkeys-lru
```
2. Review key count: `redis-cli DBSIZE`
3. Check for large keys: `redis-cli --bigkeys`
### Inconsistent Cache State
**Symptoms:** Different responses from different replicas
**Solutions:**
1. Verify all instances use the same Redis address
2. Check cache mode consistency across instances
3. Verify time synchronization on all hosts
---
## Migration Guide
### From Memory-Only to Redis
#### Phase 1: Preparation
1. Deploy Redis infrastructure
2. Test Redis connectivity
3. Configure monitoring
#### Phase 2: Gradual Rollout
1. Enable Redis on one instance:
```yaml
redis:
enabled: true
address: "redis:6379"
cacheMode: "hybrid"
```
2. Monitor for errors
3. Gradually enable on more instances
#### Phase 3: Full Migration
1. Enable Redis on all instances
2. Remove `disableReplayDetection: true` if set
3. Monitor for issues
### Rollback Plan
If issues occur:
1. Set `redis.enabled: false`
2. Plugin falls back to memory cache automatically
3. Investigate and resolve issues
### Migration Checklist
- [ ] Redis deployed and accessible
- [ ] Redis password configured
- [ ] Network connectivity verified
- [ ] Monitoring configured
- [ ] Backup plan prepared
- [ ] Test environment validated
- [ ] Gradual rollout planned
---
## Best Practices
### Security
- Always use Redis password authentication
- Enable TLS for production deployments
- Use network segmentation (private subnets)
- Rotate Redis passwords regularly
### High Availability
- Use Redis Sentinel or Cluster for HA
- Configure appropriate circuit breaker thresholds
- Implement proper health checks
- Use connection pooling
### Performance
- Use hybrid cache mode for best performance
- Monitor cache hit rates
- Size Redis memory appropriately
- Disable persistence for cache-only usage
### Operations
- Implement comprehensive monitoring
- Set up alerting for circuit breaker state
- Document Redis configuration
- Test failover scenarios
---
## FAQ
### Is Redis required?
No, Redis is optional. The plugin works with in-memory cache for single-instance deployments.
### What happens if Redis goes down?
The circuit breaker opens after threshold failures, and the plugin falls back to in-memory cache. It periodically attempts to reconnect.
### Which cache mode should I use?
For production multi-replica deployments, use `hybrid` mode for best performance and consistency.
### How much memory does Redis need?
Depends on active sessions and token sizes:
- Small (1-1000 users): 128MB
- Medium (1000-10000 users): 256-512MB
- Large (10000+ users): 1GB+
### Can I use managed Redis services?
Yes, the plugin works with AWS ElastiCache, Azure Cache for Redis, Google Cloud Memorystore, and Redis Enterprise Cloud.
### Is data encrypted in Redis?
Session data is encrypted before storing using `sessionEncryptionKey`. Additionally, you can enable TLS for Redis connections.
+390
View File
@@ -0,0 +1,390 @@
# Testing Guide
Comprehensive testing infrastructure for traefikoidc.
## Overview
| Metric | Value |
|--------|-------|
| Test files | 99 |
| Lines of test code | ~65,500 |
| Code coverage | 71.0% |
| Race conditions | None (all pass with `-race`) |
## Running Tests
```bash
# Run all tests
go test ./...
# Run with race detection
go test -race ./...
# Run with coverage
go test -cover ./...
# Run specific test suite
go test -v -run "TokenValidationSuite" .
# Run edge case tests
go test -v -run "ClockSkewEdgeCasesSuite|UnicodeClaimsSuite" .
```
## Test Infrastructure
### Directory Structure
```
internal/testutil/
├── compat.go # Re-exports for main package access
├── mocks/
│ ├── interfaces.go # JWKCache, TokenExchanger, TokenVerifier, etc.
│ ├── session.go # SessionManager, SessionData
│ ├── cache.go # Cache, TokenCache, Blacklist
│ └── interfaces_test.go # Mock verification tests
├── fixtures/
│ └── tokens.go # JWT token generation fixtures
└── servers/
├── oidc.go # Mock OIDC server factory
└── oidc_test.go # Server tests
```
### Test Suites
| Suite | File | Description |
|-------|------|-------------|
| TokenValidationSuite | `token_validation_suite_test.go` | Token validation happy path and error cases |
| JWKCacheTestSuite | `token_validation_suite_test.go` | JWK cache behavior tests |
| TokenExchangerTestSuite | `token_validation_suite_test.go` | Token exchange scenarios |
| ClockSkewEdgeCasesSuite | `edge_cases_suite_test.go` | Expiry boundary testing |
| UnicodeClaimsSuite | `edge_cases_suite_test.go` | Unicode/emoji handling in claims |
| LargeClaimsSuite | `edge_cases_suite_test.go` | Large data handling (100s of claims) |
| URLPathEdgeCasesSuite | `edge_cases_suite_test.go` | URL parsing edge cases |
| ConcurrencyEdgeCasesSuite | `edge_cases_suite_test.go` | Concurrent token validation |
| ExampleTestSuite | `testutil_example_test.go` | Example demonstrating patterns |
| AuthFlowBehaviourSuite | `auth_flow_behaviour_test.go` | Authentication flow behavior tests |
| SessionBehaviourSuite | `session_behaviour_test.go` | Session management behavior tests |
| EnhancedMocksSuite | `enhanced_mocks_suite_test.go` | Enhanced mock usage demonstration |
## Mock Types
The project provides two mocking patterns:
### State-Based Mocks (Basic)
Located in `main_test.go`, `mocks_test.go`. Simple mocks that store data in struct fields.
| Mock | Interface | Description |
|------|-----------|-------------|
| `MockJWKCache` | `JWKCacheInterface` | Simple state-based mock with JWKS/Err fields |
| `MockTokenVerifier` | `TokenVerifier` | Function-based mock for token verification |
| `MockTokenExchanger` | `TokenExchanger` | Function-based mock for token exchange |
| `MockOAuthProvider` | `http.Handler` | Full HTTP handler mock for OAuth provider simulation |
| `MockSessionManager` | `SessionManager` | State-based mock for session management |
| `MockHTTPClient` | N/A | Mock HTTP client with customizable responses |
**Usage:**
```go
mock := &MockJWKCache{
JWKS: &JWKSet{Keys: []JWK{jwk}},
Err: nil,
}
tOidc := &TraefikOidc{
jwkCache: mock,
// ...
}
```
### Enhanced State-Based Mocks (with Call Tracking)
Located in `enhanced_mocks_test.go`. State-based mocks with built-in call tracking and assertion helpers.
| Mock | Interface | Description |
|------|-----------|-------------|
| `EnhancedMockJWKCache` | `JWKCacheInterface` | State-based with call tracking |
| `EnhancedMockTokenVerifier` | `TokenVerifier` | State-based with call tracking |
| `EnhancedMockTokenExchanger` | `TokenExchanger` | State-based with call tracking |
| `EnhancedMockCacheInterface` | `CacheInterface` | Functional cache with call tracking |
**Usage:**
```go
mock := &EnhancedMockJWKCache{
JWKS: &JWKSet{Keys: []JWK{jwk}},
}
// Make calls
result, err := mock.GetJWKS(ctx, "https://example.com/jwks", nil)
// Verify calls were made
mock.AssertGetJWKSCalled(t)
mock.AssertGetJWKSCalledWith(t, "https://example.com/jwks")
mock.AssertGetJWKSCallCount(t, 1)
// Access call details
s.Equal(1, mock.GetJWKSCallCount())
```
**Features:**
- Track all calls with parameters and timestamps
- Built-in assertion helpers using testify
- Thread-safe for concurrent tests
- `Reset()` method to clear state between tests
- `LastCall()` to inspect most recent call
### Testify-Based Mocks
Located in `testify_mocks_test.go`. Mocks using testify's `.On()/.Return()` pattern for behavior verification.
| Mock | Interface | Description |
|------|-----------|-------------|
| `TestifyJWKCache` | `JWKCacheInterface` | Testify mock with `.On()/.Return()` |
| `TestifyTokenVerifier` | `TokenVerifier` | Testify mock for token verification |
| `TestifyTokenExchanger` | `TokenExchanger` | Testify mock for token exchange |
| `TestifyCacheInterface` | `CacheInterface` | Testify mock for cache operations |
| `TestifyHTTPClient` | N/A | Testify mock for HTTP client |
| `TestifyRoundTripper` | `http.RoundTripper` | Testify mock for HTTP transport |
**Usage:**
```go
mock := &TestifyJWKCache{}
mock.On("GetJWKS", mock.Anything, "https://example.com/jwks", mock.Anything).
Return(&JWKSet{Keys: []JWK{jwk}}, nil)
// After test
mock.AssertExpectations(t)
```
### Testutil Package Mocks
Located in `internal/testutil/mocks/`. Generic mocks for testing the test infrastructure itself.
```go
import "github.com/lukaszraczylo/traefikoidc/internal/testutil"
mock := testutil.NewJWKCacheMock()
mock.On("GetJWKS", mock.Anything, mock.Anything, mock.Anything).
Return(&mocks.JWKSet{Keys: []mocks.JWK{{Kty: "RSA"}}}, nil)
```
### Choosing the Right Mock
| Use Case | Recommended Mock |
|----------|-----------------|
| Simple return values only | Basic state-based (`MockJWKCache`) |
| Return values + verify calls made | Enhanced state-based (`EnhancedMockJWKCache`) |
| Complex call expectations | Testify-based (`TestifyJWKCache`) |
| Verify call order/sequence | Testify-based |
| HTTP endpoint simulation | `MockOAuthProvider` |
| New testify suite tests | Enhanced or Testify-based |
**Decision Guide:**
1. **Basic State-Based**: Use when you only need to control return values and don't care about verifying interactions.
2. **Enhanced State-Based**: Use when you want to verify calls were made with specific parameters, but prefer simpler setup than testify's `.On()/.Return()` pattern.
3. **Testify-Based**: Use when you need complex behavior like different returns per call, strict call ordering, or detailed expectation matching.
## Token Fixtures
The `testutil.TokenFixture` generates JWT tokens for testing:
```go
fixture, err := testutil.NewTokenFixture()
// Valid token with default claims
token, _ := fixture.ValidToken(nil)
// Token with custom claims
token, _ := fixture.ValidToken(map[string]interface{}{
"email": "test@example.com",
"roles": []string{"admin"},
})
// Expired token
token, _ := fixture.ExpiredToken()
// Token with specific roles/groups
token, _ := fixture.TokenWithRoles([]string{"admin", "user"})
token, _ := fixture.TokenWithGroups([]string{"developers"})
// Token with clock skew
token, _ := fixture.TokenWithSkew(-2 * time.Minute) // expired 2 min ago
token, _ := fixture.TokenWithSkew(5 * time.Minute) // expires in 5 min
// Token missing specific claims
token, _ := fixture.TokenMissingClaim("email", "sub")
// Malformed token
token := fixture.MalformedToken() // "not.a.valid.jwt"
// Get JWKS for verification
jwks := fixture.GetJWKS()
```
## Mock OIDC Server
The `testutil.OIDCServer` provides a fully functional mock OIDC provider:
```go
// Default configuration
server := testutil.NewOIDCServer(nil)
defer server.Close()
// Custom configuration
config := testutil.DefaultServerConfig()
config.Issuer = "https://custom-issuer.com"
config.TokenError = &testutil.OIDCError{
Error: "invalid_grant",
Description: "Authorization code expired",
}
server := testutil.NewOIDCServer(config)
// Provider-specific configurations
googleConfig := testutil.GoogleServerConfig()
azureConfig := testutil.AzureServerConfig()
auth0Config := testutil.Auth0ServerConfig()
keycloakConfig := testutil.KeycloakServerConfig()
// Behavior configurations
slowConfig := testutil.SlowServerConfig(100 * time.Millisecond)
rateLimitedConfig := testutil.RateLimitedServerConfig(5) // Limit after 5 requests
```
### Server Endpoints
| Endpoint | Description |
|----------|-------------|
| `/.well-known/openid-configuration` | OIDC discovery document |
| `/authorize` | Authorization endpoint |
| `/token` | Token exchange endpoint |
| `/jwks` | JSON Web Key Set |
| `/userinfo` | User information endpoint |
| `/introspect` | Token introspection |
| `/revoke` | Token revocation |
| `/logout` | End session endpoint |
### Request Tracking
```go
server := testutil.NewOIDCServer(nil)
// Make requests...
count := server.GetRequestCount()
requests := server.GetRequests()
server.Reset() // Clear tracking
```
## Writing Test Suites
### Basic Suite Structure
```go
type MyTestSuite struct {
suite.Suite
fixture *testutil.TokenFixture
tOidc *TraefikOidc
}
func (s *MyTestSuite) SetupSuite() {
var err error
s.fixture, err = testutil.NewTokenFixture()
s.Require().NoError(err)
}
func (s *MyTestSuite) SetupTest() {
// Per-test setup
s.tOidc = &TraefikOidc{
issuerURL: s.fixture.Issuer,
// ...
}
}
func (s *MyTestSuite) TearDownTest() {
// Per-test cleanup
}
func (s *MyTestSuite) TestSomething() {
token, err := s.fixture.ValidToken(nil)
s.Require().NoError(err)
err = s.tOidc.VerifyToken(token)
s.NoError(err)
}
func TestMyTestSuite(t *testing.T) {
suite.Run(t, new(MyTestSuite))
}
```
### Table-Driven Tests
```go
func (s *MyTestSuite) TestClockSkewEdgeCases() {
testCases := []struct {
name string
skew time.Duration
shouldPass bool
}{
{"valid_token", 5 * time.Minute, true},
{"expired_within_tolerance", -1 * time.Minute, true},
{"expired_beyond_tolerance", -10 * time.Minute, false},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
token, err := s.fixture.TokenWithSkew(tc.skew)
s.Require().NoError(err)
err = s.tOidc.VerifyToken(token)
if tc.shouldPass {
s.NoError(err)
} else {
s.Error(err)
}
})
}
}
```
## Test Categories
### Happy Path Tests
Test the expected successful scenarios:
- Valid token verification
- Successful token exchange
- Session creation and retrieval
- Cache operations
### Error Case Tests
Test failure scenarios:
- Expired tokens
- Invalid signatures
- Wrong issuer/audience
- Network failures
- Rate limiting
### Edge Case Tests
Test boundary conditions:
- Clock skew tolerance boundaries
- Unicode/emoji in claims
- Very large claim values
- Concurrent access
- Special characters in URLs
## Best Practices
1. **Use fixtures for token generation** - Don't manually construct JWTs
2. **Use mock servers for integration tests** - Test against realistic OIDC behavior
3. **Always run with `-race`** - Catch concurrency issues early
4. **Use testify assertions** - Better error messages and cleaner code
5. **Clean up resources** - Use `t.Cleanup()` or `TearDownTest()`
6. **Test edge cases systematically** - Use table-driven tests
-163
View File
@@ -1,163 +0,0 @@
# Google OAuth Integration Fix
## Problem Overview
The Traefik OIDC plugin encountered an authentication issue when using Google as an OAuth provider. Authentication would fail with the following error:
```
Some requested scopes were invalid. {valid=[openid, https://www.googleapis.com/auth/userinfo.email, https://www.googleapis.com/auth/userinfo.profile], invalid=[offline_access]}
```
This occurred because Google's OAuth implementation differs from the standard OIDC specification in how it handles refresh tokens and offline access.
## Technical Details of the Issue
### Standard OIDC Provider Behavior
Most OpenID Connect (OIDC) providers follow the standard specification, where:
- To obtain a refresh token, clients include the `offline_access` scope in their authorization request
- This allows authenticated sessions to persist beyond the initial access token expiration
### Google's Non-Standard Approach
Google's OAuth implementation deviates from the standard by:
1. Not supporting the `offline_access` scope, instead rejecting it as an invalid scope
2. Requiring the `access_type=offline` query parameter for requesting refresh tokens
3. Needing the `prompt=consent` parameter to consistently issue refresh tokens (especially for repeat authentications)
This difference caused the plugin to fail when configured for Google OAuth, as it was using a standard approach that didn't work with Google's implementation.
## Solution Implementation
The fix involved modifying the authentication flow to specifically handle Google providers:
1. **Google Provider Detection**: Added code to detect if the OIDC provider is Google based on the issuer URL:
```go
// Check if we're dealing with a Google OIDC provider
isGoogleProvider := strings.Contains(t.issuerURL, "google") ||
strings.Contains(t.issuerURL, "accounts.google.com")
```
2. **Provider-Specific Auth URL Building**: Modified the `buildAuthURL` function to handle Google and non-Google providers differently:
```go
// Handle offline access differently for Google vs other providers
if isGoogleProvider {
// For Google, use access_type=offline parameter instead of offline_access scope
params.Set("access_type", "offline")
t.logger.Debug("Google OIDC provider detected, added access_type=offline for refresh tokens")
// Add prompt=consent for Google to ensure refresh token is issued
params.Set("prompt", "consent")
t.logger.Debug("Google OIDC provider detected, added prompt=consent to ensure refresh tokens")
} else {
// For non-Google providers, use the offline_access scope
hasOfflineAccess := false
for _, scope := range scopes {
if scope == "offline_access" {
hasOfflineAccess = true
break
}
}
if !hasOfflineAccess {
scopes = append(scopes, "offline_access")
}
}
```
3. **Token Refresh Enhancement**: Improved the token refresh logic to better handle Google's behavior, particularly when refresh tokens aren't returned in refresh responses (as Google often uses the same refresh token for multiple requests).
## Why This Approach Works
This solution aligns with Google's OAuth 2.0 documentation which specifies:
1. **Access Type Parameter**: Google's [OAuth 2.0 documentation](https://developers.google.com/identity/protocols/oauth2/web-server#offline) states that to request a refresh token, applications must include `access_type=offline` in the authorization request.
2. **Prompt Parameter**: The [`prompt=consent`](https://developers.google.com/identity/protocols/oauth2/web-server#forceapprovalprompt) parameter forces the consent screen to appear, ensuring a refresh token is issued even if the user has previously granted access.
3. **Scope Validation**: Google strictly validates scopes and rejects non-standard ones like `offline_access`, instead relying on the `access_type` parameter to indicate whether a refresh token should be issued.
By adapting to these Google-specific requirements, the OIDC plugin can now seamlessly work with both standard OIDC providers and Google's OAuth implementation.
## Testing and Verification
Comprehensive tests were implemented to verify the solution:
1. **Provider Detection Test**: Ensures the code correctly identifies Google providers and applies the appropriate parameters.
2. **Auth URL Parameter Tests**: Verifies that:
- For Google providers: `access_type=offline` and `prompt=consent` are included; `offline_access` scope is NOT included
- For non-Google providers: `offline_access` scope IS included; `access_type` parameter is NOT added
3. **Token Refresh Tests**: Validates that Google's token refresh process works correctly, including the preservation of refresh tokens when Google doesn't return a new one.
4. **Integration Test**: Tests the complete authentication flow with a mocked Google provider to ensure all components work together seamlessly.
Sample test case (simplified):
```go
t.Run("Google provider detection adds required parameters", func(t *testing.T) {
// Test buildAuthURL to ensure it adds access_type=offline and prompt=consent for Google
authURL := tOidc.buildAuthURL("https://example.com/callback", "state123", "nonce123", "")
// Check that access_type=offline was added (not offline_access scope for Google)
if !strings.Contains(authURL, "access_type=offline") {
t.Errorf("access_type=offline not added to Google auth URL: %s", authURL)
}
// Verify offline_access scope is NOT included for Google providers
if strings.Contains(authURL, "offline_access") {
t.Errorf("offline_access scope incorrectly added to Google auth URL: %s", authURL)
}
// Check that prompt=consent was added
if !strings.Contains(authURL, "prompt=consent") {
t.Errorf("prompt=consent not added to Google auth URL: %s", authURL)
}
})
```
## Usage Guidance for Developers
When configuring the Traefik OIDC middleware for Google:
1. **Provider URL**: Use `https://accounts.google.com` as the `providerURL` value
2. **Client Configuration**: Create OAuth 2.0 credentials in the Google Cloud Console:
- Configure the authorized redirect URI to match your `callbackURL` setting
- Ensure your OAuth consent screen is properly configured (especially if you want long-lived refresh tokens)
3. **Configuration Example**:
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-google
namespace: traefik
spec:
plugin:
traefikoidc:
providerURL: https://accounts.google.com
clientID: your-google-client-id.apps.googleusercontent.com
clientSecret: your-google-client-secret
sessionEncryptionKey: your-secure-encryption-key-min-32-chars
callbackURL: /oauth2/callback
scopes:
- openid
- email
- profile
# Note: DO NOT manually add offline_access scope for Google
# The middleware handles this automatically and correctly
```
4. **Troubleshooting**: If sessions still expire prematurely with Google (typically after 1 hour):
- Ensure your Google Cloud OAuth consent screen is set to "External" and "Production" mode (not "Testing" mode, which limits refresh token validity)
- Review your application logs with `logLevel: debug` to check for refresh token errors
- Verify you're using a version of the middleware that includes this fix
## Conclusion
This fix ensures that the Traefik OIDC plugin works seamlessly with Google's OAuth implementation without requiring users to make provider-specific configuration changes. The middleware now intelligently adapts to the provider's requirements, making it more robust and user-friendly while maintaining compatibility with the standard OIDC specification for other providers.
+1373
View File
File diff suppressed because it is too large Load Diff
+540
View File
@@ -0,0 +1,540 @@
// Package traefikoidc provides OIDC authentication middleware for Traefik
package traefikoidc
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"strings"
"sync"
"time"
)
// ClientRegistrationResponse represents the response from a successful client registration (RFC 7591)
type ClientRegistrationResponse struct {
SubjectType string `json:"subject_type,omitempty"`
LogoURI string `json:"logo_uri,omitempty"`
RegistrationAccessToken string `json:"registration_access_token,omitempty"`
RegistrationClientURI string `json:"registration_client_uri,omitempty"`
Scope string `json:"scope,omitempty"`
TokenEndpointAuthMethod string `json:"token_endpoint_auth_method,omitempty"`
TOSURI string `json:"tos_uri,omitempty"`
PolicyURI string `json:"policy_uri,omitempty"`
ClientSecret string `json:"client_secret,omitempty"`
ApplicationType string `json:"application_type,omitempty"`
ClientID string `json:"client_id"`
ClientName string `json:"client_name,omitempty"`
JWKSURI string `json:"jwks_uri,omitempty"`
ClientURI string `json:"client_uri,omitempty"`
Contacts []string `json:"contacts,omitempty"`
GrantTypes []string `json:"grant_types,omitempty"`
ResponseTypes []string `json:"response_types,omitempty"`
RedirectURIs []string `json:"redirect_uris,omitempty"`
ClientSecretExpiresAt int64 `json:"client_secret_expires_at,omitempty"`
ClientIDIssuedAt int64 `json:"client_id_issued_at,omitempty"`
}
// ClientRegistrationError represents an error response from client registration (RFC 7591)
type ClientRegistrationError struct {
Error string `json:"error"`
ErrorDescription string `json:"error_description,omitempty"`
}
// DynamicClientRegistrar handles OIDC Dynamic Client Registration (RFC 7591)
type DynamicClientRegistrar struct {
httpClient *http.Client
logger *Logger
config *DynamicClientRegistrationConfig
registrationResponse *ClientRegistrationResponse
providerURL string
mu sync.RWMutex
}
// NewDynamicClientRegistrar creates a new dynamic client registrar
func NewDynamicClientRegistrar(
httpClient *http.Client,
logger *Logger,
dcrConfig *DynamicClientRegistrationConfig,
providerURL string,
) *DynamicClientRegistrar {
if logger == nil {
logger = GetSingletonNoOpLogger()
}
return &DynamicClientRegistrar{
httpClient: httpClient,
logger: logger,
config: dcrConfig,
providerURL: providerURL,
}
}
// RegisterClient performs dynamic client registration with the OIDC provider
// It first attempts to load existing credentials from a file if persistence is enabled,
// then registers a new client if no valid credentials exist.
func (r *DynamicClientRegistrar) RegisterClient(ctx context.Context, registrationEndpoint string) (*ClientRegistrationResponse, error) {
if r.config == nil || !r.config.Enabled {
return nil, fmt.Errorf("dynamic client registration is not enabled")
}
// Try to load existing credentials if persistence is enabled
if r.config.PersistCredentials {
if resp, err := r.loadCredentials(); err == nil && resp != nil {
// Check if credentials are still valid (not expired)
if r.areCredentialsValid(resp) {
r.logger.Info("Loaded existing client credentials from file")
r.mu.Lock()
r.registrationResponse = resp
r.mu.Unlock()
return resp, nil
}
r.logger.Info("Existing credentials expired or invalid, registering new client")
}
}
// Determine registration endpoint
endpoint := registrationEndpoint
if r.config.RegistrationEndpoint != "" {
endpoint = r.config.RegistrationEndpoint
}
if endpoint == "" {
return nil, fmt.Errorf("no registration endpoint available: provider does not support dynamic client registration or endpoint not configured")
}
// Validate the endpoint URL
if !strings.HasPrefix(endpoint, "https://") {
// Allow http only for localhost/development
if !strings.HasPrefix(endpoint, "http://localhost") && !strings.HasPrefix(endpoint, "http://127.0.0.1") {
return nil, fmt.Errorf("registration endpoint must use HTTPS for security")
}
r.logger.Infof("Warning: using insecure HTTP for registration endpoint (development only): %s", endpoint)
}
// Build registration request
reqBody, err := r.buildRegistrationRequest()
if err != nil {
return nil, fmt.Errorf("failed to build registration request: %w", err)
}
r.logger.Debugf("Registering client at endpoint: %s", endpoint)
// Create HTTP request
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(reqBody))
if err != nil {
return nil, fmt.Errorf("failed to create registration request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
// Add Initial Access Token if provided
if r.config.InitialAccessToken != "" {
req.Header.Set("Authorization", "Bearer "+r.config.InitialAccessToken)
}
// Execute request
resp, err := r.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("registration request failed: %w", err)
}
defer resp.Body.Close()
// Read response body
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) // 1MB limit
if err != nil {
return nil, fmt.Errorf("failed to read registration response: %w", err)
}
// Handle error responses
if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK {
var regError ClientRegistrationError
if jsonErr := json.Unmarshal(body, &regError); jsonErr == nil && regError.Error != "" {
return nil, fmt.Errorf("registration failed: %s - %s", regError.Error, regError.ErrorDescription)
}
return nil, fmt.Errorf("registration failed with status %d: %s", resp.StatusCode, string(body))
}
// Parse successful response
var regResp ClientRegistrationResponse
if err := json.Unmarshal(body, &regResp); err != nil {
return nil, fmt.Errorf("failed to parse registration response: %w", err)
}
// Validate response
if regResp.ClientID == "" {
return nil, fmt.Errorf("registration response missing client_id")
}
r.logger.Infof("Successfully registered client with ID: %s", regResp.ClientID)
// Cache the response
r.mu.Lock()
r.registrationResponse = &regResp
r.mu.Unlock()
// Persist credentials if enabled
if r.config.PersistCredentials {
if err := r.saveCredentials(&regResp); err != nil {
r.logger.Errorf("Failed to persist client credentials: %v", err)
// Don't fail registration if persistence fails
}
}
return &regResp, nil
}
// buildRegistrationRequest creates the JSON request body for client registration
func (r *DynamicClientRegistrar) buildRegistrationRequest() ([]byte, error) {
metadata := r.config.ClientMetadata
if metadata == nil {
metadata = &ClientRegistrationMetadata{}
}
// Build request object
reqData := make(map[string]interface{})
// Required: redirect_uris
if len(metadata.RedirectURIs) > 0 {
reqData["redirect_uris"] = metadata.RedirectURIs
} else {
return nil, fmt.Errorf("redirect_uris is required for client registration")
}
// Optional fields - only include if set
if len(metadata.ResponseTypes) > 0 {
reqData["response_types"] = metadata.ResponseTypes
} else {
// Default to authorization code flow
reqData["response_types"] = []string{"code"}
}
if len(metadata.GrantTypes) > 0 {
reqData["grant_types"] = metadata.GrantTypes
} else {
// Default grant types for authorization code flow
reqData["grant_types"] = []string{"authorization_code", "refresh_token"}
}
if metadata.ApplicationType != "" {
reqData["application_type"] = metadata.ApplicationType
}
if len(metadata.Contacts) > 0 {
reqData["contacts"] = metadata.Contacts
}
if metadata.ClientName != "" {
reqData["client_name"] = metadata.ClientName
}
if metadata.LogoURI != "" {
reqData["logo_uri"] = metadata.LogoURI
}
if metadata.ClientURI != "" {
reqData["client_uri"] = metadata.ClientURI
}
if metadata.PolicyURI != "" {
reqData["policy_uri"] = metadata.PolicyURI
}
if metadata.TOSURI != "" {
reqData["tos_uri"] = metadata.TOSURI
}
if metadata.JWKSURI != "" {
reqData["jwks_uri"] = metadata.JWKSURI
}
if metadata.SubjectType != "" {
reqData["subject_type"] = metadata.SubjectType
}
if metadata.TokenEndpointAuthMethod != "" {
reqData["token_endpoint_auth_method"] = metadata.TokenEndpointAuthMethod
} else {
// Default to client_secret_basic for confidential clients
reqData["token_endpoint_auth_method"] = "client_secret_basic"
}
if metadata.DefaultMaxAge > 0 {
reqData["default_max_age"] = metadata.DefaultMaxAge
}
if metadata.RequireAuthTime {
reqData["require_auth_time"] = metadata.RequireAuthTime
}
if len(metadata.DefaultACRValues) > 0 {
reqData["default_acr_values"] = metadata.DefaultACRValues
}
if metadata.Scope != "" {
reqData["scope"] = metadata.Scope
}
return json.Marshal(reqData)
}
// GetCachedResponse returns the cached registration response
func (r *DynamicClientRegistrar) GetCachedResponse() *ClientRegistrationResponse {
r.mu.RLock()
defer r.mu.RUnlock()
return r.registrationResponse
}
// areCredentialsValid checks if the cached credentials are still valid
func (r *DynamicClientRegistrar) areCredentialsValid(resp *ClientRegistrationResponse) bool {
if resp == nil || resp.ClientID == "" {
return false
}
// Check if secret has expired
if resp.ClientSecretExpiresAt > 0 {
expiresAt := time.Unix(resp.ClientSecretExpiresAt, 0)
// Add 5 minute buffer before expiration
if time.Now().Add(5 * time.Minute).After(expiresAt) {
return false
}
}
return true
}
// credentialsFilePath returns the path for storing credentials
func (r *DynamicClientRegistrar) credentialsFilePath() string {
if r.config.CredentialsFile != "" {
return r.config.CredentialsFile
}
return "/tmp/oidc-client-credentials.json"
}
// saveCredentials persists client credentials to a file
func (r *DynamicClientRegistrar) saveCredentials(resp *ClientRegistrationResponse) error {
filePath := r.credentialsFilePath()
data, err := json.MarshalIndent(resp, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal credentials: %w", err)
}
// Write with restrictive permissions (owner read/write only)
if err := os.WriteFile(filePath, data, 0600); err != nil {
return fmt.Errorf("failed to write credentials file: %w", err)
}
r.logger.Debugf("Saved client credentials to %s", filePath)
return nil
}
// loadCredentials loads client credentials from a file
func (r *DynamicClientRegistrar) loadCredentials() (*ClientRegistrationResponse, error) {
filePath := r.credentialsFilePath()
// #nosec G304 -- path is constructed from trusted config values via credentialsFilePath()
data, err := os.ReadFile(filePath)
if err != nil {
if os.IsNotExist(err) {
return nil, nil // No credentials file exists
}
return nil, fmt.Errorf("failed to read credentials file: %w", err)
}
var resp ClientRegistrationResponse
if err := json.Unmarshal(data, &resp); err != nil {
return nil, fmt.Errorf("failed to parse credentials file: %w", err)
}
return &resp, nil
}
// UpdateClientRegistration updates an existing client registration using RFC 7592
// This requires the registration_client_uri and registration_access_token from the original registration
func (r *DynamicClientRegistrar) UpdateClientRegistration(ctx context.Context) (*ClientRegistrationResponse, error) {
r.mu.RLock()
cachedResp := r.registrationResponse
r.mu.RUnlock()
if cachedResp == nil {
return nil, fmt.Errorf("no existing registration to update")
}
if cachedResp.RegistrationClientURI == "" || cachedResp.RegistrationAccessToken == "" {
return nil, fmt.Errorf("registration management not supported: missing registration_client_uri or registration_access_token")
}
// Build update request
reqBody, err := r.buildRegistrationRequest()
if err != nil {
return nil, fmt.Errorf("failed to build update request: %w", err)
}
// Create HTTP request
req, err := http.NewRequestWithContext(ctx, http.MethodPut, cachedResp.RegistrationClientURI, bytes.NewReader(reqBody))
if err != nil {
return nil, fmt.Errorf("failed to create update request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Bearer "+cachedResp.RegistrationAccessToken)
// Execute request
resp, err := r.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("update request failed: %w", err)
}
defer resp.Body.Close()
// Read response body
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil {
return nil, fmt.Errorf("failed to read update response: %w", err)
}
// Handle error responses
if resp.StatusCode != http.StatusOK {
var regError ClientRegistrationError
if jsonErr := json.Unmarshal(body, &regError); jsonErr == nil && regError.Error != "" {
return nil, fmt.Errorf("update failed: %s - %s", regError.Error, regError.ErrorDescription)
}
return nil, fmt.Errorf("update failed with status %d: %s", resp.StatusCode, string(body))
}
// Parse successful response
var regResp ClientRegistrationResponse
if err := json.Unmarshal(body, &regResp); err != nil {
return nil, fmt.Errorf("failed to parse update response: %w", err)
}
// Update cache
r.mu.Lock()
r.registrationResponse = &regResp
r.mu.Unlock()
// Persist updated credentials if enabled
if r.config.PersistCredentials {
if err := r.saveCredentials(&regResp); err != nil {
r.logger.Errorf("Failed to persist updated credentials: %v", err)
}
}
r.logger.Infof("Successfully updated client registration for client ID: %s", regResp.ClientID)
return &regResp, nil
}
// ReadClientRegistration reads the current client registration using RFC 7592
func (r *DynamicClientRegistrar) ReadClientRegistration(ctx context.Context) (*ClientRegistrationResponse, error) {
r.mu.RLock()
cachedResp := r.registrationResponse
r.mu.RUnlock()
if cachedResp == nil {
return nil, fmt.Errorf("no existing registration to read")
}
if cachedResp.RegistrationClientURI == "" || cachedResp.RegistrationAccessToken == "" {
return nil, fmt.Errorf("registration management not supported: missing registration_client_uri or registration_access_token")
}
// Create HTTP request
req, err := http.NewRequestWithContext(ctx, http.MethodGet, cachedResp.RegistrationClientURI, nil)
if err != nil {
return nil, fmt.Errorf("failed to create read request: %w", err)
}
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Bearer "+cachedResp.RegistrationAccessToken)
// Execute request
resp, err := r.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("read request failed: %w", err)
}
defer resp.Body.Close()
// Read response body
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
// Handle error responses
if resp.StatusCode != http.StatusOK {
var regError ClientRegistrationError
if jsonErr := json.Unmarshal(body, &regError); jsonErr == nil && regError.Error != "" {
return nil, fmt.Errorf("read failed: %s - %s", regError.Error, regError.ErrorDescription)
}
return nil, fmt.Errorf("read failed with status %d: %s", resp.StatusCode, string(body))
}
// Parse successful response
var regResp ClientRegistrationResponse
if err := json.Unmarshal(body, &regResp); err != nil {
return nil, fmt.Errorf("failed to parse read response: %w", err)
}
return &regResp, nil
}
// DeleteClientRegistration deletes the client registration using RFC 7592
func (r *DynamicClientRegistrar) DeleteClientRegistration(ctx context.Context) error {
r.mu.RLock()
cachedResp := r.registrationResponse
r.mu.RUnlock()
if cachedResp == nil {
return fmt.Errorf("no existing registration to delete")
}
if cachedResp.RegistrationClientURI == "" || cachedResp.RegistrationAccessToken == "" {
return fmt.Errorf("registration management not supported: missing registration_client_uri or registration_access_token")
}
// Create HTTP request
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, cachedResp.RegistrationClientURI, nil)
if err != nil {
return fmt.Errorf("failed to create delete request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+cachedResp.RegistrationAccessToken)
// Execute request
resp, err := r.httpClient.Do(req)
if err != nil {
return fmt.Errorf("delete request failed: %w", err)
}
defer resp.Body.Close()
// Handle error responses (204 No Content is success)
if resp.StatusCode != http.StatusNoContent && resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
var regError ClientRegistrationError
if jsonErr := json.Unmarshal(body, &regError); jsonErr == nil && regError.Error != "" {
return fmt.Errorf("delete failed: %s - %s", regError.Error, regError.ErrorDescription)
}
return fmt.Errorf("delete failed with status %d: %s", resp.StatusCode, string(body))
}
// Clear cache
r.mu.Lock()
r.registrationResponse = nil
r.mu.Unlock()
// Remove credentials file if persistence is enabled
if r.config.PersistCredentials {
filePath := r.credentialsFilePath()
if err := os.Remove(filePath); err != nil && !os.IsNotExist(err) {
r.logger.Errorf("Failed to remove credentials file: %v", err)
}
}
r.logger.Info("Successfully deleted client registration")
return nil
}
File diff suppressed because it is too large Load Diff
+620
View File
@@ -0,0 +1,620 @@
package traefikoidc
import (
"context"
"encoding/base64"
"math/big"
"net/http"
"net/http/httptest"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/lukaszraczylo/traefikoidc/internal/testutil"
"github.com/stretchr/testify/suite"
"golang.org/x/time/rate"
)
// ClockSkewEdgeCasesSuite tests clock skew tolerance scenarios
type ClockSkewEdgeCasesSuite struct {
suite.Suite
fixture *testutil.TokenFixture
tOidc *TraefikOidc
}
func (s *ClockSkewEdgeCasesSuite) SetupSuite() {
var err error
s.fixture, err = testutil.NewTokenFixture()
s.Require().NoError(err)
}
func (s *ClockSkewEdgeCasesSuite) SetupTest() {
// Create JWK for the test key
jwk := JWK{
Kty: "RSA",
Kid: s.fixture.KeyID,
Alg: "RS256",
N: base64.RawURLEncoding.EncodeToString(s.fixture.RSAPublicKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString(bigIntToBytes(big.NewInt(int64(s.fixture.RSAPublicKey.E)))),
}
jwkCache := &MockJWKCache{
JWKS: &JWKSet{Keys: []JWK{jwk}},
Err: nil,
}
tokenBlacklist := NewCache()
tokenCacheInternal := NewCache()
tokenCache := &TokenCache{}
if tokenCache.cache == nil {
if wrapper, ok := tokenCacheInternal.(*CacheInterfaceWrapper); ok {
tokenCache.cache = wrapper.cache
}
}
logger := NewLogger("error") // Reduce noise
s.tOidc = &TraefikOidc{
issuerURL: s.fixture.Issuer,
clientID: s.fixture.Audience,
audience: s.fixture.Audience,
clientSecret: "test-client-secret",
roleClaimName: "roles",
groupClaimName: "groups",
userIdentifierClaim: "email",
jwkCache: jwkCache,
jwksURL: "https://test-jwks-url.com",
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
tokenBlacklist: tokenBlacklist,
tokenCache: tokenCache,
logger: logger,
httpClient: &http.Client{Timeout: 10 * time.Second},
extractClaimsFunc: extractClaims,
initComplete: make(chan struct{}),
goroutineWG: &sync.WaitGroup{},
ctx: context.Background(),
}
close(s.tOidc.initComplete)
s.tOidc.tokenVerifier = s.tOidc
s.tOidc.jwtVerifier = s.tOidc
s.T().Cleanup(func() {
if s.tOidc.tokenBlacklist != nil {
s.tOidc.tokenBlacklist.Close()
}
if s.tOidc.tokenCache != nil && s.tOidc.tokenCache.cache != nil {
s.tOidc.tokenCache.cache.Close()
}
})
}
func (s *ClockSkewEdgeCasesSuite) TestExactlyAtExpiry() {
token, err := s.fixture.TokenWithSkew(0)
s.Require().NoError(err)
// Token at exact expiry - behavior is implementation-defined
err = s.tOidc.VerifyToken(token)
s.T().Logf("Exact expiry result: %v", err)
}
func (s *ClockSkewEdgeCasesSuite) TestOneSecondBeforeExpiry() {
token, err := s.fixture.TokenWithSkew(1 * time.Second)
s.Require().NoError(err)
err = s.tOidc.VerifyToken(token)
s.NoError(err, "Token should be valid 1 second before expiry")
}
func (s *ClockSkewEdgeCasesSuite) TestOneSecondAfterExpiry() {
token, err := s.fixture.TokenWithSkew(-1 * time.Second)
s.Require().NoError(err)
err = s.tOidc.VerifyToken(token)
// With default 2-minute clock skew tolerance, 1 second past expiry should still be valid
s.NoError(err, "Token 1 second past expiry should be valid within clock skew tolerance")
}
func (s *ClockSkewEdgeCasesSuite) TestWithinSkewTolerance() {
// Most implementations allow 5-minute clock skew
token, err := s.fixture.TokenWithSkew(-4 * time.Minute)
s.Require().NoError(err)
err = s.tOidc.VerifyToken(token)
// May pass or fail depending on implementation
s.T().Logf("4-minute expired token result: %v", err)
}
func (s *ClockSkewEdgeCasesSuite) TestBeyondSkewTolerance() {
token, err := s.fixture.TokenWithSkew(-10 * time.Minute)
s.Require().NoError(err)
err = s.tOidc.VerifyToken(token)
s.Error(err, "Token should be invalid 10 minutes after expiry")
}
func TestClockSkewEdgeCasesSuite(t *testing.T) {
suite.Run(t, new(ClockSkewEdgeCasesSuite))
}
// UnicodeClaimsSuite tests Unicode handling in JWT claims
type UnicodeClaimsSuite struct {
suite.Suite
fixture *testutil.TokenFixture
tOidc *TraefikOidc
}
func (s *UnicodeClaimsSuite) SetupSuite() {
var err error
s.fixture, err = testutil.NewTokenFixture()
s.Require().NoError(err)
}
func (s *UnicodeClaimsSuite) SetupTest() {
jwk := JWK{
Kty: "RSA",
Kid: s.fixture.KeyID,
Alg: "RS256",
N: base64.RawURLEncoding.EncodeToString(s.fixture.RSAPublicKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString(bigIntToBytes(big.NewInt(int64(s.fixture.RSAPublicKey.E)))),
}
jwkCache := &MockJWKCache{
JWKS: &JWKSet{Keys: []JWK{jwk}},
Err: nil,
}
tokenBlacklist := NewCache()
tokenCacheInternal := NewCache()
tokenCache := &TokenCache{}
if tokenCache.cache == nil {
if wrapper, ok := tokenCacheInternal.(*CacheInterfaceWrapper); ok {
tokenCache.cache = wrapper.cache
}
}
logger := NewLogger("error")
s.tOidc = &TraefikOidc{
issuerURL: s.fixture.Issuer,
clientID: s.fixture.Audience,
audience: s.fixture.Audience,
clientSecret: "test-client-secret",
roleClaimName: "roles",
groupClaimName: "groups",
userIdentifierClaim: "email",
jwkCache: jwkCache,
jwksURL: "https://test-jwks-url.com",
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
tokenBlacklist: tokenBlacklist,
tokenCache: tokenCache,
logger: logger,
httpClient: &http.Client{Timeout: 10 * time.Second},
extractClaimsFunc: extractClaims,
initComplete: make(chan struct{}),
goroutineWG: &sync.WaitGroup{},
ctx: context.Background(),
}
close(s.tOidc.initComplete)
s.tOidc.tokenVerifier = s.tOidc
s.tOidc.jwtVerifier = s.tOidc
s.T().Cleanup(func() {
if s.tOidc.tokenBlacklist != nil {
s.tOidc.tokenBlacklist.Close()
}
if s.tOidc.tokenCache != nil && s.tOidc.tokenCache.cache != nil {
s.tOidc.tokenCache.cache.Close()
}
})
}
func (s *UnicodeClaimsSuite) TestUnicodeEmail() {
token, err := s.fixture.TokenWithEmail("用户@example.com")
s.Require().NoError(err)
err = s.tOidc.VerifyToken(token)
s.NoError(err, "Unicode email should be handled correctly")
}
func (s *UnicodeClaimsSuite) TestUnicodeName() {
token, err := s.fixture.TokenWithCustomClaims(map[string]interface{}{
"name": "田中太郎",
})
s.Require().NoError(err)
err = s.tOidc.VerifyToken(token)
s.NoError(err, "Unicode name should be handled correctly")
}
func (s *UnicodeClaimsSuite) TestEmojiInClaims() {
token, err := s.fixture.TokenWithCustomClaims(map[string]interface{}{
"name": "Test User 😀",
})
s.Require().NoError(err)
err = s.tOidc.VerifyToken(token)
s.NoError(err, "Emoji in claims should be handled correctly")
}
func (s *UnicodeClaimsSuite) TestRTLText() {
token, err := s.fixture.TokenWithCustomClaims(map[string]interface{}{
"name": "مستخدم اختبار",
})
s.Require().NoError(err)
err = s.tOidc.VerifyToken(token)
s.NoError(err, "RTL text should be handled correctly")
}
func (s *UnicodeClaimsSuite) TestMixedScripts() {
token, err := s.fixture.TokenWithCustomClaims(map[string]interface{}{
"name": "Test 测试 テスト",
"roles": []string{"admin", "管理者", "管理员"},
})
s.Require().NoError(err)
err = s.tOidc.VerifyToken(token)
s.NoError(err, "Mixed scripts should be handled correctly")
}
func TestUnicodeClaimsSuite(t *testing.T) {
suite.Run(t, new(UnicodeClaimsSuite))
}
// LargeClaimsSuite tests large claim values
type LargeClaimsSuite struct {
suite.Suite
fixture *testutil.TokenFixture
tOidc *TraefikOidc
}
func (s *LargeClaimsSuite) SetupSuite() {
var err error
s.fixture, err = testutil.NewTokenFixture()
s.Require().NoError(err)
}
func (s *LargeClaimsSuite) SetupTest() {
jwk := JWK{
Kty: "RSA",
Kid: s.fixture.KeyID,
Alg: "RS256",
N: base64.RawURLEncoding.EncodeToString(s.fixture.RSAPublicKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString(bigIntToBytes(big.NewInt(int64(s.fixture.RSAPublicKey.E)))),
}
jwkCache := &MockJWKCache{
JWKS: &JWKSet{Keys: []JWK{jwk}},
Err: nil,
}
tokenBlacklist := NewCache()
tokenCacheInternal := NewCache()
tokenCache := &TokenCache{}
if tokenCache.cache == nil {
if wrapper, ok := tokenCacheInternal.(*CacheInterfaceWrapper); ok {
tokenCache.cache = wrapper.cache
}
}
logger := NewLogger("error")
s.tOidc = &TraefikOidc{
issuerURL: s.fixture.Issuer,
clientID: s.fixture.Audience,
audience: s.fixture.Audience,
clientSecret: "test-client-secret",
roleClaimName: "roles",
groupClaimName: "groups",
userIdentifierClaim: "email",
jwkCache: jwkCache,
jwksURL: "https://test-jwks-url.com",
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
tokenBlacklist: tokenBlacklist,
tokenCache: tokenCache,
logger: logger,
httpClient: &http.Client{Timeout: 10 * time.Second},
extractClaimsFunc: extractClaims,
initComplete: make(chan struct{}),
goroutineWG: &sync.WaitGroup{},
ctx: context.Background(),
}
close(s.tOidc.initComplete)
s.tOidc.tokenVerifier = s.tOidc
s.tOidc.jwtVerifier = s.tOidc
s.T().Cleanup(func() {
if s.tOidc.tokenBlacklist != nil {
s.tOidc.tokenBlacklist.Close()
}
if s.tOidc.tokenCache != nil && s.tOidc.tokenCache.cache != nil {
s.tOidc.tokenCache.cache.Close()
}
})
}
func (s *LargeClaimsSuite) TestManyRoles() {
roles := make([]string, 100)
for i := 0; i < 100; i++ {
roles[i] = strings.Repeat("role", 10) + string(rune('A'+i%26))
}
token, err := s.fixture.TokenWithRoles(roles)
s.Require().NoError(err)
err = s.tOidc.VerifyToken(token)
s.NoError(err, "Token with 100 roles should be handled")
}
func (s *LargeClaimsSuite) TestManyGroups() {
groups := make([]string, 50)
for i := 0; i < 50; i++ {
groups[i] = strings.Repeat("group", 5) + string(rune('A'+i%26))
}
token, err := s.fixture.TokenWithGroups(groups)
s.Require().NoError(err)
err = s.tOidc.VerifyToken(token)
s.NoError(err, "Token with 50 groups should be handled")
}
func (s *LargeClaimsSuite) TestLongEmail() {
// RFC 5321 allows up to 254 characters
localPart := strings.Repeat("a", 64)
domain := strings.Repeat("b", 63) + ".com"
email := localPart + "@" + domain
token, err := s.fixture.TokenWithEmail(email)
s.Require().NoError(err)
err = s.tOidc.VerifyToken(token)
s.NoError(err, "Token with long email should be handled")
}
func (s *LargeClaimsSuite) TestLongSubject() {
longSub := strings.Repeat("subject", 100)
token, err := s.fixture.TokenWithCustomClaims(map[string]interface{}{
"sub": longSub,
})
s.Require().NoError(err)
err = s.tOidc.VerifyToken(token)
s.NoError(err, "Token with long subject should be handled")
}
func TestLargeClaimsSuite(t *testing.T) {
suite.Run(t, new(LargeClaimsSuite))
}
// URLPathEdgeCasesSuite tests URL handling edge cases
type URLPathEdgeCasesSuite struct {
suite.Suite
}
func (s *URLPathEdgeCasesSuite) TestVeryLongPath() {
longPath := "/" + strings.Repeat("segment/", 100)
req := httptest.NewRequest("GET", longPath, nil)
s.NotNil(req)
s.Contains(req.URL.Path, "segment")
}
func (s *URLPathEdgeCasesSuite) TestSpecialCharactersInPath() {
paths := []string{
"/path%20with%20spaces",
"/path/with/日本語",
"/path?query=value&another=test",
"/path#fragment",
"/path/../traversal",
"/path/./current",
}
for _, path := range paths {
s.Run(path, func() {
req := httptest.NewRequest("GET", path, nil)
s.NotNil(req)
})
}
}
func (s *URLPathEdgeCasesSuite) TestEmptyPath() {
req := httptest.NewRequest("GET", "/", nil)
s.Equal("/", req.URL.Path)
}
func (s *URLPathEdgeCasesSuite) TestDoubleSlashes() {
req := httptest.NewRequest("GET", "//double//slashes//", nil)
s.NotNil(req)
}
func TestURLPathEdgeCasesSuite(t *testing.T) {
suite.Run(t, new(URLPathEdgeCasesSuite))
}
// ConcurrencyEdgeCasesSuite tests concurrency scenarios
type ConcurrencyEdgeCasesSuite struct {
suite.Suite
fixture *testutil.TokenFixture
tOidc *TraefikOidc
}
func (s *ConcurrencyEdgeCasesSuite) SetupSuite() {
var err error
s.fixture, err = testutil.NewTokenFixture()
s.Require().NoError(err)
}
func (s *ConcurrencyEdgeCasesSuite) SetupTest() {
jwk := JWK{
Kty: "RSA",
Kid: s.fixture.KeyID,
Alg: "RS256",
N: base64.RawURLEncoding.EncodeToString(s.fixture.RSAPublicKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString(bigIntToBytes(big.NewInt(int64(s.fixture.RSAPublicKey.E)))),
}
jwkCache := &MockJWKCache{
JWKS: &JWKSet{Keys: []JWK{jwk}},
Err: nil,
}
tokenBlacklist := NewCache()
tokenCacheInternal := NewCache()
tokenCache := &TokenCache{}
if tokenCache.cache == nil {
if wrapper, ok := tokenCacheInternal.(*CacheInterfaceWrapper); ok {
tokenCache.cache = wrapper.cache
}
}
logger := NewLogger("error")
s.tOidc = &TraefikOidc{
issuerURL: s.fixture.Issuer,
clientID: s.fixture.Audience,
audience: s.fixture.Audience,
clientSecret: "test-client-secret",
roleClaimName: "roles",
groupClaimName: "groups",
userIdentifierClaim: "email",
jwkCache: jwkCache,
jwksURL: "https://test-jwks-url.com",
limiter: rate.NewLimiter(rate.Every(time.Second), 100), // Higher limit for concurrency tests
tokenBlacklist: tokenBlacklist,
tokenCache: tokenCache,
logger: logger,
httpClient: &http.Client{Timeout: 10 * time.Second},
extractClaimsFunc: extractClaims,
initComplete: make(chan struct{}),
goroutineWG: &sync.WaitGroup{},
ctx: context.Background(),
}
close(s.tOidc.initComplete)
s.tOidc.tokenVerifier = s.tOidc
s.tOidc.jwtVerifier = s.tOidc
s.T().Cleanup(func() {
if s.tOidc.tokenBlacklist != nil {
s.tOidc.tokenBlacklist.Close()
}
if s.tOidc.tokenCache != nil && s.tOidc.tokenCache.cache != nil {
s.tOidc.tokenCache.cache.Close()
}
})
}
func (s *ConcurrencyEdgeCasesSuite) TestConcurrentTokenValidation() {
token, err := s.fixture.ValidToken(nil)
s.Require().NoError(err)
const goroutines = 50
var wg sync.WaitGroup
errors := make(chan error, goroutines)
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
if err := s.tOidc.VerifyToken(token); err != nil {
errors <- err
}
}()
}
wg.Wait()
close(errors)
var errCount int
for err := range errors {
s.T().Logf("Concurrent error: %v", err)
errCount++
}
s.Equal(0, errCount, "All concurrent validations should succeed")
}
func (s *ConcurrencyEdgeCasesSuite) TestConcurrentDifferentTokens() {
const goroutines = 20
var wg sync.WaitGroup
errors := make(chan error, goroutines)
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
token, err := s.fixture.TokenWithCustomClaims(map[string]interface{}{
"custom": idx,
})
if err != nil {
errors <- err
return
}
if err := s.tOidc.VerifyToken(token); err != nil {
errors <- err
}
}(i)
}
wg.Wait()
close(errors)
var errCount int
for err := range errors {
s.T().Logf("Concurrent different token error: %v", err)
errCount++
}
s.Equal(0, errCount, "All concurrent different token validations should succeed")
}
func (s *ConcurrencyEdgeCasesSuite) TestConcurrentMixedValidInvalid() {
validToken, err := s.fixture.ValidToken(nil)
s.Require().NoError(err)
expiredToken, err := s.fixture.ExpiredToken()
s.Require().NoError(err)
const goroutines = 40
var wg sync.WaitGroup
validCount := int32(0)
expiredCount := int32(0)
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
var token string
if idx%2 == 0 {
token = validToken
} else {
token = expiredToken
}
err := s.tOidc.VerifyToken(token)
if idx%2 == 0 {
if err == nil {
atomic.AddInt32(&validCount, 1)
}
} else {
if err != nil {
atomic.AddInt32(&expiredCount, 1)
}
}
}(i)
}
wg.Wait()
s.T().Logf("Valid passed: %d, Expired rejected: %d", validCount, expiredCount)
}
func TestConcurrencyEdgeCasesSuite(t *testing.T) {
suite.Run(t, new(ConcurrencyEdgeCasesSuite))
}
+258
View File
@@ -0,0 +1,258 @@
package traefikoidc
import (
"context"
"errors"
"testing"
"time"
"github.com/stretchr/testify/suite"
)
// EnhancedMocksSuite demonstrates improved state-based mocks with call tracking
type EnhancedMocksSuite struct {
suite.Suite
}
func (s *EnhancedMocksSuite) TestEnhancedJWKCacheCallTracking() {
mock := &EnhancedMockJWKCache{
JWKS: &JWKSet{Keys: []JWK{{Kid: "test-key"}}},
}
// Make some calls
result, err := mock.GetJWKS(context.Background(), "https://example.com/jwks", nil)
s.NoError(err)
s.NotNil(result)
// Another call with different URL
_, _ = mock.GetJWKS(context.Background(), "https://other.com/jwks", nil)
// Verify calls were tracked
s.Equal(2, mock.GetJWKSCallCount())
mock.AssertGetJWKSCalled(s.T())
mock.AssertGetJWKSCalledWith(s.T(), "https://example.com/jwks")
mock.AssertGetJWKSCallCount(s.T(), 2)
}
func (s *EnhancedMocksSuite) TestEnhancedJWKCacheWithError() {
expectedErr := errors.New("network error")
mock := &EnhancedMockJWKCache{
Err: expectedErr,
}
result, err := mock.GetJWKS(context.Background(), "https://example.com/jwks", nil)
s.Nil(result)
s.Equal(expectedErr, err)
mock.AssertGetJWKSCalled(s.T())
}
func (s *EnhancedMocksSuite) TestEnhancedJWKCacheReset() {
mock := &EnhancedMockJWKCache{
JWKS: &JWKSet{Keys: []JWK{{Kid: "test-key"}}},
}
_, _ = mock.GetJWKS(context.Background(), "https://example.com/jwks", nil)
s.Equal(1, mock.GetJWKSCallCount())
mock.Reset()
s.Equal(0, mock.GetJWKSCallCount())
s.Nil(mock.JWKS)
}
func (s *EnhancedMocksSuite) TestEnhancedTokenVerifierCallTracking() {
mock := &EnhancedMockTokenVerifier{
Err: nil, // Valid tokens
}
// Verify a token
err := mock.VerifyToken("test-token-1")
s.NoError(err)
// Verify another token
err = mock.VerifyToken("test-token-2")
s.NoError(err)
// Check tracking
s.Equal(2, mock.GetVerifyTokenCallCount())
mock.AssertVerifyTokenCalled(s.T())
mock.AssertVerifyTokenCalledWith(s.T(), "test-token-1")
// Check last call
lastCall := mock.LastCall()
s.NotNil(lastCall)
s.Equal("test-token-2", lastCall.Token)
}
func (s *EnhancedMocksSuite) TestEnhancedTokenVerifierWithDynamicFunc() {
callCount := 0
mock := &EnhancedMockTokenVerifier{
VerifyFunc: func(token string) error {
callCount++
if token == "invalid" {
return errors.New("invalid token")
}
return nil
},
}
// Valid token
err := mock.VerifyToken("valid-token")
s.NoError(err)
// Invalid token
err = mock.VerifyToken("invalid")
s.Error(err)
s.Equal(2, callCount)
s.Equal(2, mock.GetVerifyTokenCallCount())
}
func (s *EnhancedMocksSuite) TestEnhancedTokenExchangerCallTracking() {
mock := &EnhancedMockTokenExchanger{
ExchangeResponse: &TokenResponse{
AccessToken: "access-token",
RefreshToken: "refresh-token",
ExpiresIn: 3600,
},
RefreshResponse: &TokenResponse{
AccessToken: "new-access-token",
ExpiresIn: 3600,
},
}
// Exchange code
resp, err := mock.ExchangeCodeForToken(context.Background(), "authorization_code", "auth-code", "https://redirect.com", "verifier")
s.NoError(err)
s.Equal("access-token", resp.AccessToken)
// Refresh token
resp, err = mock.GetNewTokenWithRefreshToken("refresh-token")
s.NoError(err)
s.Equal("new-access-token", resp.AccessToken)
// Revoke token
err = mock.RevokeTokenWithProvider("access-token", "access_token")
s.NoError(err)
// Check tracking
mock.AssertExchangeCalled(s.T())
mock.AssertExchangeCalledWith(s.T(), "authorization_code")
mock.AssertRefreshCalled(s.T())
mock.AssertRevokeCalled(s.T())
s.Equal(1, mock.GetExchangeCallCount())
s.Equal(1, mock.GetRefreshCallCount())
s.Equal(1, mock.GetRevokeCallCount())
// Check last exchange call details
lastExchange := mock.LastExchangeCall()
s.NotNil(lastExchange)
s.Equal("authorization_code", lastExchange.GrantType)
s.Equal("auth-code", lastExchange.CodeOrToken)
s.Equal("https://redirect.com", lastExchange.RedirectURL)
}
func (s *EnhancedMocksSuite) TestEnhancedTokenExchangerWithErrors() {
mock := &EnhancedMockTokenExchanger{
ExchangeErr: errors.New("invalid_grant"),
RefreshErr: errors.New("refresh_expired"),
RevokeErr: errors.New("revoke_failed"),
}
_, err := mock.ExchangeCodeForToken(context.Background(), "authorization_code", "code", "", "")
s.Error(err)
s.Contains(err.Error(), "invalid_grant")
_, err = mock.GetNewTokenWithRefreshToken("token")
s.Error(err)
s.Contains(err.Error(), "refresh_expired")
err = mock.RevokeTokenWithProvider("token", "access_token")
s.Error(err)
s.Contains(err.Error(), "revoke_failed")
}
func (s *EnhancedMocksSuite) TestEnhancedCacheCallTracking() {
mock := NewEnhancedMockCache()
// Set some values
mock.Set("key1", "value1", 5*time.Minute)
mock.Set("key2", "value2", 10*time.Minute)
// Get values
val, found := mock.Get("key1")
s.True(found)
s.Equal("value1", val)
_, found = mock.Get("nonexistent")
s.False(found)
// Delete
mock.Delete("key1")
// Verify tracking
mock.AssertSetCalled(s.T(), "key1")
mock.AssertSetCalled(s.T(), "key2")
mock.AssertGetCalled(s.T(), "key1")
mock.AssertGetCalled(s.T(), "nonexistent")
mock.AssertDeleteCalled(s.T(), "key1")
s.Equal(2, mock.SetCallCount())
s.Equal(2, mock.GetCallCount())
}
func (s *EnhancedMocksSuite) TestEnhancedCacheActualStorage() {
mock := NewEnhancedMockCache()
// The enhanced mock actually stores data
mock.Set("key", "value", time.Hour)
s.Equal(1, mock.Size())
val, found := mock.Get("key")
s.True(found)
s.Equal("value", val)
mock.Delete("key")
s.Equal(0, mock.Size())
_, found = mock.Get("key")
s.False(found)
}
func (s *EnhancedMocksSuite) TestEnhancedCacheClear() {
mock := NewEnhancedMockCache()
mock.Set("key1", "value1", time.Hour)
mock.Set("key2", "value2", time.Hour)
s.Equal(2, mock.Size())
mock.Clear()
s.Equal(0, mock.Size())
}
func (s *EnhancedMocksSuite) TestConcurrentAccess() {
mock := &EnhancedMockJWKCache{
JWKS: &JWKSet{Keys: []JWK{{Kid: "test-key"}}},
}
// Concurrent calls should be safe
done := make(chan bool)
for i := 0; i < 10; i++ {
go func() {
_, _ = mock.GetJWKS(context.Background(), "https://example.com/jwks", nil)
done <- true
}()
}
for i := 0; i < 10; i++ {
<-done
}
s.Equal(10, mock.GetJWKSCallCount())
}
func TestEnhancedMocksSuite(t *testing.T) {
suite.Run(t, new(EnhancedMocksSuite))
}
+577
View File
@@ -0,0 +1,577 @@
package traefikoidc
import (
"context"
"net/http"
"sync"
"sync/atomic"
"time"
"github.com/stretchr/testify/assert"
)
// EnhancedMockJWKCache is an improved state-based mock with call tracking
type EnhancedMockJWKCache struct {
Err error
JWKS *JWKSet
GetJWKSCalls []JWKSCall
mu sync.RWMutex
getJWKSCallsMu sync.Mutex
CleanupCalls int32
CloseCalls int32
}
// JWKSCall records parameters from a GetJWKS call
type JWKSCall struct {
Timestamp time.Time
URL string
}
func (m *EnhancedMockJWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error) {
m.getJWKSCallsMu.Lock()
m.GetJWKSCalls = append(m.GetJWKSCalls, JWKSCall{
URL: jwksURL,
Timestamp: time.Now(),
})
m.getJWKSCallsMu.Unlock()
m.mu.RLock()
defer m.mu.RUnlock()
return m.JWKS, m.Err
}
func (m *EnhancedMockJWKCache) Cleanup() {
atomic.AddInt32(&m.CleanupCalls, 1)
m.mu.Lock()
defer m.mu.Unlock()
m.JWKS = nil
m.Err = nil
}
func (m *EnhancedMockJWKCache) Close() {
atomic.AddInt32(&m.CloseCalls, 1)
}
// Assertion helpers
// AssertGetJWKSCalled verifies GetJWKS was called
func (m *EnhancedMockJWKCache) AssertGetJWKSCalled(t assert.TestingT) bool {
m.getJWKSCallsMu.Lock()
defer m.getJWKSCallsMu.Unlock()
return assert.NotEmpty(t, m.GetJWKSCalls, "GetJWKS should have been called")
}
// AssertGetJWKSCalledWith verifies GetJWKS was called with specific URL
func (m *EnhancedMockJWKCache) AssertGetJWKSCalledWith(t assert.TestingT, expectedURL string) bool {
m.getJWKSCallsMu.Lock()
defer m.getJWKSCallsMu.Unlock()
for _, call := range m.GetJWKSCalls {
if call.URL == expectedURL {
return true
}
}
return assert.Fail(t, "GetJWKS was not called with URL: "+expectedURL)
}
// AssertGetJWKSCallCount verifies the number of GetJWKS calls
func (m *EnhancedMockJWKCache) AssertGetJWKSCallCount(t assert.TestingT, expected int) bool {
m.getJWKSCallsMu.Lock()
defer m.getJWKSCallsMu.Unlock()
return assert.Equal(t, expected, len(m.GetJWKSCalls), "GetJWKS call count mismatch")
}
// GetJWKSCallCount returns the number of GetJWKS calls
func (m *EnhancedMockJWKCache) GetJWKSCallCount() int {
m.getJWKSCallsMu.Lock()
defer m.getJWKSCallsMu.Unlock()
return len(m.GetJWKSCalls)
}
// Reset clears all state and call tracking
func (m *EnhancedMockJWKCache) Reset() {
m.mu.Lock()
m.JWKS = nil
m.Err = nil
m.mu.Unlock()
m.getJWKSCallsMu.Lock()
m.GetJWKSCalls = nil
m.getJWKSCallsMu.Unlock()
atomic.StoreInt32(&m.CleanupCalls, 0)
atomic.StoreInt32(&m.CloseCalls, 0)
}
// EnhancedMockTokenVerifier is an improved state-based mock with call tracking
type EnhancedMockTokenVerifier struct {
Err error
VerifyFunc func(token string) error
VerifyCalls []TokenVerifyCall
mu sync.RWMutex
verifyCallsMu sync.Mutex
}
// TokenVerifyCall records parameters from a VerifyToken call
type TokenVerifyCall struct {
Timestamp time.Time
Result error
Token string
}
func (m *EnhancedMockTokenVerifier) VerifyToken(token string) error {
var result error
m.mu.RLock()
if m.VerifyFunc != nil {
result = m.VerifyFunc(token)
} else {
result = m.Err
}
m.mu.RUnlock()
m.verifyCallsMu.Lock()
m.VerifyCalls = append(m.VerifyCalls, TokenVerifyCall{
Token: token,
Timestamp: time.Now(),
Result: result,
})
m.verifyCallsMu.Unlock()
return result
}
// Assertion helpers
// AssertVerifyTokenCalled verifies VerifyToken was called
func (m *EnhancedMockTokenVerifier) AssertVerifyTokenCalled(t assert.TestingT) bool {
m.verifyCallsMu.Lock()
defer m.verifyCallsMu.Unlock()
return assert.NotEmpty(t, m.VerifyCalls, "VerifyToken should have been called")
}
// AssertVerifyTokenCalledWith verifies VerifyToken was called with specific token
func (m *EnhancedMockTokenVerifier) AssertVerifyTokenCalledWith(t assert.TestingT, expectedToken string) bool {
m.verifyCallsMu.Lock()
defer m.verifyCallsMu.Unlock()
for _, call := range m.VerifyCalls {
if call.Token == expectedToken {
return true
}
}
return assert.Fail(t, "VerifyToken was not called with expected token")
}
// AssertVerifyTokenCallCount verifies the number of VerifyToken calls
func (m *EnhancedMockTokenVerifier) AssertVerifyTokenCallCount(t assert.TestingT, expected int) bool {
m.verifyCallsMu.Lock()
defer m.verifyCallsMu.Unlock()
return assert.Equal(t, expected, len(m.VerifyCalls), "VerifyToken call count mismatch")
}
// GetVerifyTokenCallCount returns the number of VerifyToken calls
func (m *EnhancedMockTokenVerifier) GetVerifyTokenCallCount() int {
m.verifyCallsMu.Lock()
defer m.verifyCallsMu.Unlock()
return len(m.VerifyCalls)
}
// LastCall returns the most recent VerifyToken call
func (m *EnhancedMockTokenVerifier) LastCall() *TokenVerifyCall {
m.verifyCallsMu.Lock()
defer m.verifyCallsMu.Unlock()
if len(m.VerifyCalls) == 0 {
return nil
}
return &m.VerifyCalls[len(m.VerifyCalls)-1]
}
// Reset clears all state and call tracking
func (m *EnhancedMockTokenVerifier) Reset() {
m.mu.Lock()
m.Err = nil
m.VerifyFunc = nil
m.mu.Unlock()
m.verifyCallsMu.Lock()
m.VerifyCalls = nil
m.verifyCallsMu.Unlock()
}
// EnhancedMockTokenExchanger is an improved state-based mock with call tracking
type EnhancedMockTokenExchanger struct {
RefreshErr error
RevokeErr error
ExchangeErr error
ExchangeCodeFunc func(ctx context.Context, grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error)
RefreshResponse *TokenResponse
ExchangeResponse *TokenResponse
RefreshTokenFunc func(refreshToken string) (*TokenResponse, error)
RevokeTokenFunc func(token, tokenType string) error
ExchangeCalls []ExchangeCall
RefreshCalls []RefreshCall
RevokeCalls []RevokeCall
mu sync.RWMutex
exchangeCallsMu sync.Mutex
refreshCallsMu sync.Mutex
revokeCallsMu sync.Mutex
}
// ExchangeCall records parameters from an ExchangeCodeForToken call
type ExchangeCall struct {
Timestamp time.Time
GrantType string
CodeOrToken string
RedirectURL string
CodeVerifier string
}
// RefreshCall records parameters from a GetNewTokenWithRefreshToken call
type RefreshCall struct {
Timestamp time.Time
RefreshToken string
}
// RevokeCall records parameters from a RevokeTokenWithProvider call
type RevokeCall struct {
Timestamp time.Time
Token string
TokenType string
}
func (m *EnhancedMockTokenExchanger) ExchangeCodeForToken(ctx context.Context, grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error) {
m.exchangeCallsMu.Lock()
m.ExchangeCalls = append(m.ExchangeCalls, ExchangeCall{
GrantType: grantType,
CodeOrToken: codeOrToken,
RedirectURL: redirectURL,
CodeVerifier: codeVerifier,
Timestamp: time.Now(),
})
m.exchangeCallsMu.Unlock()
m.mu.RLock()
defer m.mu.RUnlock()
if m.ExchangeCodeFunc != nil {
return m.ExchangeCodeFunc(ctx, grantType, codeOrToken, redirectURL, codeVerifier)
}
return m.ExchangeResponse, m.ExchangeErr
}
func (m *EnhancedMockTokenExchanger) GetNewTokenWithRefreshToken(refreshToken string) (*TokenResponse, error) {
m.refreshCallsMu.Lock()
m.RefreshCalls = append(m.RefreshCalls, RefreshCall{
RefreshToken: refreshToken,
Timestamp: time.Now(),
})
m.refreshCallsMu.Unlock()
m.mu.RLock()
defer m.mu.RUnlock()
if m.RefreshTokenFunc != nil {
return m.RefreshTokenFunc(refreshToken)
}
return m.RefreshResponse, m.RefreshErr
}
func (m *EnhancedMockTokenExchanger) RevokeTokenWithProvider(token, tokenType string) error {
m.revokeCallsMu.Lock()
m.RevokeCalls = append(m.RevokeCalls, RevokeCall{
Token: token,
TokenType: tokenType,
Timestamp: time.Now(),
})
m.revokeCallsMu.Unlock()
m.mu.RLock()
defer m.mu.RUnlock()
if m.RevokeTokenFunc != nil {
return m.RevokeTokenFunc(token, tokenType)
}
return m.RevokeErr
}
// Assertion helpers
// AssertExchangeCalled verifies ExchangeCodeForToken was called
func (m *EnhancedMockTokenExchanger) AssertExchangeCalled(t assert.TestingT) bool {
m.exchangeCallsMu.Lock()
defer m.exchangeCallsMu.Unlock()
return assert.NotEmpty(t, m.ExchangeCalls, "ExchangeCodeForToken should have been called")
}
// AssertExchangeCalledWith verifies ExchangeCodeForToken was called with specific grant type
func (m *EnhancedMockTokenExchanger) AssertExchangeCalledWith(t assert.TestingT, grantType string) bool {
m.exchangeCallsMu.Lock()
defer m.exchangeCallsMu.Unlock()
for _, call := range m.ExchangeCalls {
if call.GrantType == grantType {
return true
}
}
return assert.Fail(t, "ExchangeCodeForToken was not called with grant type: "+grantType)
}
// AssertRefreshCalled verifies GetNewTokenWithRefreshToken was called
func (m *EnhancedMockTokenExchanger) AssertRefreshCalled(t assert.TestingT) bool {
m.refreshCallsMu.Lock()
defer m.refreshCallsMu.Unlock()
return assert.NotEmpty(t, m.RefreshCalls, "GetNewTokenWithRefreshToken should have been called")
}
// AssertRevokeCalled verifies RevokeTokenWithProvider was called
func (m *EnhancedMockTokenExchanger) AssertRevokeCalled(t assert.TestingT) bool {
m.revokeCallsMu.Lock()
defer m.revokeCallsMu.Unlock()
return assert.NotEmpty(t, m.RevokeCalls, "RevokeTokenWithProvider should have been called")
}
// GetExchangeCallCount returns the number of ExchangeCodeForToken calls
func (m *EnhancedMockTokenExchanger) GetExchangeCallCount() int {
m.exchangeCallsMu.Lock()
defer m.exchangeCallsMu.Unlock()
return len(m.ExchangeCalls)
}
// GetRefreshCallCount returns the number of GetNewTokenWithRefreshToken calls
func (m *EnhancedMockTokenExchanger) GetRefreshCallCount() int {
m.refreshCallsMu.Lock()
defer m.refreshCallsMu.Unlock()
return len(m.RefreshCalls)
}
// GetRevokeCallCount returns the number of RevokeTokenWithProvider calls
func (m *EnhancedMockTokenExchanger) GetRevokeCallCount() int {
m.revokeCallsMu.Lock()
defer m.revokeCallsMu.Unlock()
return len(m.RevokeCalls)
}
// LastExchangeCall returns the most recent ExchangeCodeForToken call
func (m *EnhancedMockTokenExchanger) LastExchangeCall() *ExchangeCall {
m.exchangeCallsMu.Lock()
defer m.exchangeCallsMu.Unlock()
if len(m.ExchangeCalls) == 0 {
return nil
}
return &m.ExchangeCalls[len(m.ExchangeCalls)-1]
}
// Reset clears all state and call tracking
func (m *EnhancedMockTokenExchanger) Reset() {
m.mu.Lock()
m.ExchangeResponse = nil
m.ExchangeErr = nil
m.RefreshResponse = nil
m.RefreshErr = nil
m.RevokeErr = nil
m.ExchangeCodeFunc = nil
m.RefreshTokenFunc = nil
m.RevokeTokenFunc = nil
m.mu.Unlock()
m.exchangeCallsMu.Lock()
m.ExchangeCalls = nil
m.exchangeCallsMu.Unlock()
m.refreshCallsMu.Lock()
m.RefreshCalls = nil
m.refreshCallsMu.Unlock()
m.revokeCallsMu.Lock()
m.RevokeCalls = nil
m.revokeCallsMu.Unlock()
}
// EnhancedMockCacheInterface is an improved state-based mock for CacheInterface
type EnhancedMockCacheInterface struct {
data map[string]cacheEntry
GetCalls []CacheGetCall
SetCalls []CacheSetCall
DeleteCalls []string
maxSize int
mu sync.RWMutex
getCalls sync.Mutex
setCalls sync.Mutex
deleteCalls sync.Mutex
}
type cacheEntry struct {
value any
ttl time.Duration
}
// CacheGetCall records parameters from a Get call
type CacheGetCall struct {
Timestamp time.Time
Key string
Found bool
}
// CacheSetCall records parameters from a Set call
type CacheSetCall struct {
Timestamp time.Time
Value any
Key string
TTL time.Duration
}
// NewEnhancedMockCache creates a new enhanced cache mock
func NewEnhancedMockCache() *EnhancedMockCacheInterface {
return &EnhancedMockCacheInterface{
data: make(map[string]cacheEntry),
maxSize: 1000,
}
}
func (m *EnhancedMockCacheInterface) Set(key string, value any, ttl time.Duration) {
m.setCalls.Lock()
m.SetCalls = append(m.SetCalls, CacheSetCall{
Key: key,
Value: value,
TTL: ttl,
Timestamp: time.Now(),
})
m.setCalls.Unlock()
m.mu.Lock()
m.data[key] = cacheEntry{value: value, ttl: ttl}
m.mu.Unlock()
}
func (m *EnhancedMockCacheInterface) Get(key string) (any, bool) {
m.mu.RLock()
entry, found := m.data[key]
m.mu.RUnlock()
m.getCalls.Lock()
m.GetCalls = append(m.GetCalls, CacheGetCall{
Key: key,
Found: found,
Timestamp: time.Now(),
})
m.getCalls.Unlock()
if found {
return entry.value, true
}
return nil, false
}
func (m *EnhancedMockCacheInterface) Delete(key string) {
m.deleteCalls.Lock()
m.DeleteCalls = append(m.DeleteCalls, key)
m.deleteCalls.Unlock()
m.mu.Lock()
delete(m.data, key)
m.mu.Unlock()
}
func (m *EnhancedMockCacheInterface) SetMaxSize(size int) {
m.mu.Lock()
m.maxSize = size
m.mu.Unlock()
}
func (m *EnhancedMockCacheInterface) Size() int {
m.mu.RLock()
defer m.mu.RUnlock()
return len(m.data)
}
func (m *EnhancedMockCacheInterface) Clear() {
m.mu.Lock()
m.data = make(map[string]cacheEntry)
m.mu.Unlock()
}
func (m *EnhancedMockCacheInterface) Cleanup() {
// No-op for mock
}
func (m *EnhancedMockCacheInterface) Close() {
// No-op for mock
}
func (m *EnhancedMockCacheInterface) GetStats() map[string]any {
m.mu.RLock()
defer m.mu.RUnlock()
return map[string]any{
"size": len(m.data),
"max_size": m.maxSize,
}
}
// Assertion helpers
// AssertGetCalled verifies Get was called with specific key
func (m *EnhancedMockCacheInterface) AssertGetCalled(t assert.TestingT, key string) bool {
m.getCalls.Lock()
defer m.getCalls.Unlock()
for _, call := range m.GetCalls {
if call.Key == key {
return true
}
}
return assert.Fail(t, "Get was not called with key: "+key)
}
// AssertSetCalled verifies Set was called with specific key
func (m *EnhancedMockCacheInterface) AssertSetCalled(t assert.TestingT, key string) bool {
m.setCalls.Lock()
defer m.setCalls.Unlock()
for _, call := range m.SetCalls {
if call.Key == key {
return true
}
}
return assert.Fail(t, "Set was not called with key: "+key)
}
// AssertDeleteCalled verifies Delete was called with specific key
func (m *EnhancedMockCacheInterface) AssertDeleteCalled(t assert.TestingT, key string) bool {
m.deleteCalls.Lock()
defer m.deleteCalls.Unlock()
for _, k := range m.DeleteCalls {
if k == key {
return true
}
}
return assert.Fail(t, "Delete was not called with key: "+key)
}
// GetCallCount returns the number of Get calls
func (m *EnhancedMockCacheInterface) GetCallCount() int {
m.getCalls.Lock()
defer m.getCalls.Unlock()
return len(m.GetCalls)
}
// SetCallCount returns the number of Set calls
func (m *EnhancedMockCacheInterface) SetCallCount() int {
m.setCalls.Lock()
defer m.setCalls.Unlock()
return len(m.SetCalls)
}
// Reset clears all state and call tracking
func (m *EnhancedMockCacheInterface) Reset() {
m.mu.Lock()
m.data = make(map[string]cacheEntry)
m.mu.Unlock()
m.getCalls.Lock()
m.GetCalls = nil
m.getCalls.Unlock()
m.setCalls.Lock()
m.SetCalls = nil
m.setCalls.Unlock()
m.deleteCalls.Lock()
m.DeleteCalls = nil
m.deleteCalls.Unlock()
}
+150 -38
View File
@@ -2,10 +2,14 @@ package traefikoidc
import (
"context"
"crypto/x509"
"errors"
"fmt"
"io"
"math"
"math/rand/v2"
"net"
"strings"
"sync"
"sync/atomic"
"time"
@@ -123,8 +127,10 @@ func (b *BaseRecoveryMechanism) GetBaseMetrics() map[string]interface{} {
metrics["seconds_since_last_success"] = time.Since(b.lastSuccessTime).Seconds()
}
if metrics["total_requests"].(int64) > 0 {
successRate := float64(metrics["total_successes"].(int64)) / float64(metrics["total_requests"].(int64))
totalReq, _ := metrics["total_requests"].(int64) // Safe to ignore: type assertion with fallback
totalSucc, _ := metrics["total_successes"].(int64) // Safe to ignore: type assertion with fallback
if totalReq > 0 {
successRate := float64(totalSucc) / float64(totalReq)
metrics["success_rate"] = successRate
} else {
metrics["success_rate"] = 1.0
@@ -409,6 +415,31 @@ func DefaultRetryConfig() RetryConfig {
}
}
// MetadataFetchRetryConfig returns retry configuration optimized for OIDC metadata
// fetching during startup. Uses more aggressive retry settings to handle the race
// condition where Traefik initializes the plugin before routes are fully established,
// or before TLS certificates are properly loaded.
// See: https://github.com/lukaszraczylo/traefikoidc/issues/90
func MetadataFetchRetryConfig() RetryConfig {
return RetryConfig{
MaxAttempts: 10, // More attempts for startup scenarios
InitialDelay: 1 * time.Second, // 1 second between attempts as suggested
MaxDelay: 10 * time.Second, // Cap at 10 seconds
BackoffFactor: 1.5, // Gentler backoff for startup
EnableJitter: true, // Prevent thundering herd
RetryableErrors: []string{
"connection refused",
"timeout",
"temporary failure",
"network unreachable",
"EOF",
"certificate",
"x509",
"tls",
},
}
}
// RetryExecutor implements retry logic with exponential backoff and jitter.
// It automatically retries failed operations based on configurable error patterns
// and uses exponential backoff to avoid overwhelming failing services.
@@ -485,11 +516,29 @@ func (re *RetryExecutor) Execute(ctx context.Context, fn func() error) error {
// isRetryableError checks if an error should trigger a retry
// isRetryableError determines if an error should trigger a retry attempt.
// Checks error message against configured retryable error patterns.
// Also handles startup-specific errors like Traefik default certificate errors
// and EOF errors that occur during service initialization.
func (re *RetryExecutor) isRetryableError(err error) bool {
if err == nil {
return false
}
// Check for Traefik default certificate error (startup race condition)
// See: https://github.com/lukaszraczylo/traefikoidc/issues/90
if isTraefikDefaultCertError(err) {
return true
}
// Check for EOF errors (common during startup when services aren't ready)
if isEOFError(err) {
return true
}
// Check for certificate errors (transient during startup)
if isCertificateError(err) {
return true
}
errStr := err.Error()
for _, retryableErr := range re.config.RetryableErrors {
@@ -536,6 +585,7 @@ func (re *RetryExecutor) calculateDelay(attempt int) time.Duration {
delay = float64(re.config.MaxDelay)
}
// #nosec G404 -- math/rand is acceptable for jitter timing, not security-sensitive
if re.config.EnableJitter {
jitter := delay * 0.1 * (2.0*rand.Float64() - 1.0)
delay += jitter
@@ -592,14 +642,10 @@ func (e *HTTPError) Error() string {
// OIDCError represents OIDC-specific errors with context information.
// It provides structured error reporting for authentication and authorization failures.
type OIDCError struct {
// Code identifies the specific error type
Code string
// Message provides a human-readable description
Message string
// Context contains additional error context (e.g., provider, session details)
Cause error
Context map[string]interface{}
// Cause is the underlying error that caused this error
Cause error
Code string
Message string
}
// Error returns the string representation of the OIDC error.
@@ -619,14 +665,10 @@ func (e *OIDCError) Unwrap() error {
// SessionError represents session-related errors with context.
// Used for session management, validation, and storage errors.
type SessionError struct {
// Operation describes what session operation failed
Cause error
Operation string
// Message provides a human-readable description
Message string
// SessionID identifies the session (if available)
Message string
SessionID string
// Cause is the underlying error that caused this error
Cause error
}
// Error returns the string representation of the session error.
@@ -646,14 +688,10 @@ func (e *SessionError) Unwrap() error {
// TokenError represents token-related errors with validation context.
// Used for JWT validation, token refresh, and token format errors.
type TokenError struct {
// TokenType identifies the type of token (id_token, access_token, refresh_token)
Cause error
TokenType string
// Reason describes why the token is invalid
Reason string
// Message provides a human-readable description
Message string
// Cause is the underlying error that caused this error
Cause error
Reason string
Message string
}
// Error returns the string representation of the token error.
@@ -715,24 +753,15 @@ func NewTokenError(tokenType, reason, message string, cause error) *TokenError {
// It provides fallback mechanisms when primary services are unavailable and monitors
// service health to automatically recover when services become available again.
type GracefulDegradation struct {
// BaseRecoveryMechanism provides common functionality
*BaseRecoveryMechanism
// fallbacks stores service-specific fallback implementations
fallbacks map[string]func() (interface{}, error)
// healthChecks stores service health check functions
healthChecks map[string]func() bool
// degradedServices tracks which services are currently degraded
fallbacks map[string]func() (interface{}, error)
healthChecks map[string]func() bool
degradedServices map[string]time.Time
// config contains graceful degradation configuration
config GracefulDegradationConfig
// mutex protects shared state
mutex sync.RWMutex
// healthCheckTask manages background health checking
healthCheckTask *BackgroundTask
// stopChan signals shutdown
stopChan chan struct{}
// shutdownOnce ensures shutdown happens only once
shutdownOnce sync.Once
healthCheckTask *BackgroundTask
stopChan chan struct{}
config GracefulDegradationConfig
mutex sync.RWMutex
shutdownOnce sync.Once
}
// GracefulDegradationConfig holds configuration for graceful degradation behavior.
@@ -1085,3 +1114,86 @@ func containsSubstring(s, substr string) bool {
}
return false
}
// isTraefikDefaultCertError detects when Traefik is serving its default self-signed
// certificate during cold-start, before the real certificates are loaded.
// This manifests as an x509.HostnameError where one of the certificate's DNS names
// ends with "traefik.default" (the default Traefik certificate pattern).
// See: https://github.com/lukaszraczylo/traefikoidc/issues/90
func isTraefikDefaultCertError(err error) bool {
if err == nil {
return false
}
var hostnameErr x509.HostnameError
if errors.As(err, &hostnameErr) {
if hostnameErr.Certificate != nil {
for _, name := range hostnameErr.Certificate.DNSNames {
if strings.HasSuffix(name, "traefik.default") {
return true
}
}
}
}
return false
}
// isEOFError checks if an error is an EOF error, which can occur during
// connection establishment when the remote end closes unexpectedly.
// This is common during service startup when endpoints aren't fully ready.
func isEOFError(err error) bool {
if err == nil {
return false
}
// Check for direct EOF
if errors.Is(err, io.EOF) {
return true
}
// Check for unexpected EOF
if errors.Is(err, io.ErrUnexpectedEOF) {
return true
}
// Check error message for EOF patterns (wrapped errors)
errStr := err.Error()
return strings.Contains(errStr, "EOF") || strings.Contains(errStr, "unexpected EOF")
}
// isCertificateError checks if an error is related to TLS certificate validation.
// These errors are often transient during startup when services are still initializing.
func isCertificateError(err error) bool {
if err == nil {
return false
}
// Check for x509 certificate errors
var certInvalidErr x509.CertificateInvalidError
var hostnameErr x509.HostnameError
var unknownAuthErr x509.UnknownAuthorityError
if errors.As(err, &certInvalidErr) ||
errors.As(err, &hostnameErr) ||
errors.As(err, &unknownAuthErr) {
return true
}
// Check error message for certificate patterns
errStr := strings.ToLower(err.Error())
certPatterns := []string{
"certificate",
"x509",
"tls",
"ssl",
}
for _, pattern := range certPatterns {
if strings.Contains(errStr, pattern) {
return true
}
}
return false
}
-242
View File
@@ -1,242 +0,0 @@
package traefikoidc
import (
"testing"
"time"
)
// TestDefaultCircuitBreakerConfig tests the default configuration function
func TestDefaultCircuitBreakerConfig(t *testing.T) {
config := DefaultCircuitBreakerConfig()
// Test default values
if config.MaxFailures != 2 {
t.Errorf("Expected MaxFailures 2, got %d", config.MaxFailures)
}
if config.Timeout != 60*time.Second {
t.Errorf("Expected Timeout 60s, got %v", config.Timeout)
}
if config.ResetTimeout != 30*time.Second {
t.Errorf("Expected ResetTimeout 30s, got %v", config.ResetTimeout)
}
}
// TestBaseRecoveryMechanism_GetBaseMetrics tests getting base metrics
func TestBaseRecoveryMechanism_GetBaseMetrics(t *testing.T) {
logger := GetSingletonNoOpLogger()
base := NewBaseRecoveryMechanism("test-mechanism", logger)
metrics := base.GetBaseMetrics()
if metrics == nil {
t.Fatal("Expected non-nil metrics")
}
// Check expected metric fields
expectedFields := []string{
"total_requests",
"total_failures",
"total_successes",
"uptime_seconds",
"name",
}
for _, field := range expectedFields {
if _, exists := metrics[field]; !exists {
t.Errorf("Expected metric field %s to exist", field)
}
}
}
// TestBaseRecoveryMechanism_RecordRequest tests request recording
func TestBaseRecoveryMechanism_RecordRequest(t *testing.T) {
logger := GetSingletonNoOpLogger()
base := NewBaseRecoveryMechanism("test-mechanism", logger)
// Record some requests
base.RecordRequest()
base.RecordRequest()
base.RecordRequest()
// Get metrics to verify
metrics := base.GetBaseMetrics()
totalRequests := metrics["total_requests"].(int64)
if totalRequests != 3 {
t.Errorf("Expected 3 total requests, got %d", totalRequests)
}
}
// TestBaseRecoveryMechanism_RecordSuccess tests success recording
func TestBaseRecoveryMechanism_RecordSuccess(t *testing.T) {
logger := GetSingletonNoOpLogger()
base := NewBaseRecoveryMechanism("test-mechanism", logger)
// Record some successes
base.RecordSuccess()
base.RecordSuccess()
// Get metrics to verify
metrics := base.GetBaseMetrics()
totalSuccesses := metrics["total_successes"].(int64)
if totalSuccesses != 2 {
t.Errorf("Expected 2 successful requests, got %d", totalSuccesses)
}
}
// TestBaseRecoveryMechanism_RecordFailure tests failure recording
func TestBaseRecoveryMechanism_RecordFailure(t *testing.T) {
logger := GetSingletonNoOpLogger()
base := NewBaseRecoveryMechanism("test-mechanism", logger)
// Record some failures
base.RecordFailure()
base.RecordFailure()
base.RecordFailure()
// Get metrics to verify
metrics := base.GetBaseMetrics()
totalFailures := metrics["total_failures"].(int64)
if totalFailures != 3 {
t.Errorf("Expected 3 failed requests, got %d", totalFailures)
}
}
// TestBaseRecoveryMechanism_LogInfo tests info logging
func TestBaseRecoveryMechanism_LogInfo(t *testing.T) {
logger := GetSingletonNoOpLogger()
base := NewBaseRecoveryMechanism("test-mechanism", logger)
// Test logging doesn't panic
base.LogInfo("test message")
base.LogInfo("test message with args: %s %d", "arg1", 42)
// Test with nil logger
baseNoLogger := NewBaseRecoveryMechanism("test", nil)
baseNoLogger.LogInfo("test message") // Should not panic
}
// TestBaseRecoveryMechanism_LogError tests error logging
func TestBaseRecoveryMechanism_LogError(t *testing.T) {
logger := GetSingletonNoOpLogger()
base := NewBaseRecoveryMechanism("test-mechanism", logger)
// Test logging doesn't panic
base.LogError("error message")
base.LogError("error message with args: %s %d", "error", 500)
// Test with nil logger
baseNoLogger := NewBaseRecoveryMechanism("test", nil)
baseNoLogger.LogError("error message") // Should not panic
}
// TestBaseRecoveryMechanism_LogDebug tests debug logging
func TestBaseRecoveryMechanism_LogDebug(t *testing.T) {
logger := GetSingletonNoOpLogger()
base := NewBaseRecoveryMechanism("test-mechanism", logger)
// Test logging doesn't panic
base.LogDebug("debug message")
base.LogDebug("debug message with args: %s %d", "debug", 123)
// Test with nil logger
baseNoLogger := NewBaseRecoveryMechanism("test", nil)
baseNoLogger.LogDebug("debug message") // Should not panic
}
// TestCircuitBreaker_GetState tests getting circuit breaker state
func TestCircuitBreaker_GetState(t *testing.T) {
config := DefaultCircuitBreakerConfig()
logger := GetSingletonNoOpLogger()
cb := NewCircuitBreaker(config, logger)
// Initial state should be closed
state := cb.GetState()
if state != CircuitBreakerClosed {
t.Errorf("Expected initial state to be closed, got %d", state)
}
}
// TestCircuitBreaker_Reset tests resetting circuit breaker
func TestCircuitBreaker_Reset(t *testing.T) {
config := DefaultCircuitBreakerConfig()
logger := GetSingletonNoOpLogger()
cb := NewCircuitBreaker(config, logger)
// Reset should not panic
cb.Reset()
// State should be closed after reset
state := cb.GetState()
if state != CircuitBreakerClosed {
t.Errorf("Expected state to be closed after reset, got %d", state)
}
}
// TestCircuitBreaker_IsAvailable tests availability check
func TestCircuitBreaker_IsAvailable(t *testing.T) {
config := DefaultCircuitBreakerConfig()
logger := GetSingletonNoOpLogger()
cb := NewCircuitBreaker(config, logger)
// Initially should be available
available := cb.IsAvailable()
if !available {
t.Error("Expected circuit breaker to be available initially")
}
}
// TestCircuitBreaker_GetMetrics tests getting circuit breaker metrics
func TestCircuitBreaker_GetMetrics(t *testing.T) {
config := DefaultCircuitBreakerConfig()
logger := GetSingletonNoOpLogger()
cb := NewCircuitBreaker(config, logger)
metrics := cb.GetMetrics()
if metrics == nil {
t.Fatal("Expected non-nil metrics")
}
// Should include base metrics
if _, exists := metrics["total_requests"]; !exists {
t.Error("Expected total_requests in metrics")
}
// Should include circuit breaker specific metrics
if _, exists := metrics["state"]; !exists {
t.Error("Expected state in metrics")
}
}
// Retry mechanism tests removed due to complex dependencies
// Benchmark tests
func BenchmarkDefaultCircuitBreakerConfig(b *testing.B) {
for i := 0; i < b.N; i++ {
DefaultCircuitBreakerConfig()
}
}
func BenchmarkBaseRecoveryMechanism_GetBaseMetrics(b *testing.B) {
logger := GetSingletonNoOpLogger()
base := NewBaseRecoveryMechanism("test-mechanism", logger)
b.ResetTimer()
for i := 0; i < b.N; i++ {
base.GetBaseMetrics()
}
}
func BenchmarkBaseRecoveryMechanism_RecordRequest(b *testing.B) {
logger := GetSingletonNoOpLogger()
base := NewBaseRecoveryMechanism("test-mechanism", logger)
b.ResetTimer()
for i := 0; i < b.N; i++ {
base.RecordRequest()
}
}
+29
View File
@@ -0,0 +1,29 @@
package traefikoidc
import "testing"
func BenchmarkDefaultCircuitBreakerConfig(b *testing.B) {
for i := 0; i < b.N; i++ {
DefaultCircuitBreakerConfig()
}
}
func BenchmarkBaseRecoveryMechanism_GetBaseMetrics(b *testing.B) {
logger := GetSingletonNoOpLogger()
base := NewBaseRecoveryMechanism("test-mechanism", logger)
b.ResetTimer()
for i := 0; i < b.N; i++ {
base.GetBaseMetrics()
}
}
func BenchmarkBaseRecoveryMechanism_RecordRequest(b *testing.B) {
logger := GetSingletonNoOpLogger()
base := NewBaseRecoveryMechanism("test-mechanism", logger)
b.ResetTimer()
for i := 0; i < b.N; i++ {
base.RecordRequest()
}
}
File diff suppressed because it is too large Load Diff
+486
View File
@@ -0,0 +1,486 @@
# ============================================================================
# Complete Traefik Configuration Example with TraefikOIDC Plugin + Redis
# ============================================================================
#
# This example shows a complete, production-ready configuration for using
# the TraefikOIDC plugin with Redis caching in a multi-replica deployment.
#
# ============================================================================
# Part 1: Traefik Static Configuration (traefik.yml)
# ============================================================================
# This file configures Traefik itself and enables the plugin.
# Place this in /etc/traefik/traefik.yml or mount it in your container.
---
# Static Configuration
api:
dashboard: true
insecure: false # Set to true only for local development
entryPoints:
web:
address: ":80"
http:
redirections:
entryPoint:
to: websecure
scheme: https
websecure:
address: ":443"
http:
tls:
certResolver: letsencrypt
certificatesResolvers:
letsencrypt:
acme:
email: admin@example.com
storage: /letsencrypt/acme.json
httpChallenge:
entryPoint: web
providers:
file:
filename: /etc/traefik/dynamic.yml
watch: true
# Enable the TraefikOIDC plugin
experimental:
plugins:
traefikoidc:
moduleName: github.com/lukaszraczylo/traefikoidc
version: v0.8.0
log:
level: INFO
format: json
accessLog:
format: json
# ============================================================================
# Part 2: Traefik Dynamic Configuration (dynamic.yml)
# ============================================================================
# This file defines your routes, services, and middleware.
# Place this in /etc/traefik/dynamic.yml
---
http:
# -------------------------------------------------------------------------
# Middleware Definitions
# -------------------------------------------------------------------------
middlewares:
# Example 1: Minimal Redis Configuration
# Perfect for getting started quickly
oidc-minimal:
plugin:
traefikoidc:
# Required OIDC settings
clientID: "your-application-client-id"
clientSecret: "your-client-secret-from-provider"
providerURL: "https://auth.example.com"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "your-secure-64-character-encryption-key-must-be-kept-secret"
# Minimal Redis configuration
redis:
enabled: true
address: "redis:6379"
# Example 2: Production Redis Configuration
# Recommended for production deployments with multiple Traefik replicas
oidc-production:
plugin:
traefikoidc:
# OIDC Provider Configuration
clientID: "prod-client-id"
clientSecret: "prod-client-secret"
providerURL: "https://auth.example.com"
callbackURL: "/oauth2/callback"
# Session Configuration
sessionEncryptionKey: "prod-encryption-key-64-chars-long-keep-it-secret-and-safe"
sessionMaxAge: 28800 # 8 hours
# Security Settings
forceHTTPS: true
strictAudienceValidation: true
# Redis Configuration for Multi-Replica Deployment
redis:
enabled: true
address: "redis-master.redis-namespace.svc.cluster.local:6379"
password: "REPLACE_WITH_YOUR_REDIS_PASSWORD"
db: 0
keyPrefix: "traefikoidc:prod:"
# Cache Strategy
cacheMode: "hybrid" # Fast local cache + shared Redis
# Connection Pooling
poolSize: 20
connectTimeout: 5
readTimeout: 3
writeTimeout: 3
# Resilience Features
enableCircuitBreaker: true
circuitBreakerThreshold: 5
circuitBreakerTimeout: 60
enableHealthCheck: true
healthCheckInterval: 30
# Example 3: Redis with TLS (for production security)
oidc-secure:
plugin:
traefikoidc:
clientID: "secure-client-id"
clientSecret: "secure-client-secret"
providerURL: "https://auth.example.com"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "secure-64-character-encryption-key-for-production-use-only"
redis:
enabled: true
address: "redis.example.com:6380"
password: "REPLACE_WITH_YOUR_REDIS_PASSWORD"
enableTLS: true
tlsSkipVerify: false # Verify certificates in production
cacheMode: "redis"
# Example 4: Hybrid Mode (Best Performance + Consistency)
# Local cache for hot data, Redis for consistency across replicas
oidc-hybrid:
plugin:
traefikoidc:
clientID: "app-client-id"
clientSecret: "app-client-secret"
providerURL: "https://auth.example.com"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "hybrid-mode-encryption-key-64-characters-long-and-secure"
redis:
enabled: true
address: "redis:6379"
password: "REPLACE_WITH_YOUR_REDIS_PASSWORD"
cacheMode: "hybrid"
# Hybrid mode L1 cache settings
hybridL1Size: 1000 # Number of items in local cache
hybridL1MemoryMB: 20 # MB of memory for local cache
# -------------------------------------------------------------------------
# Router Definitions
# -------------------------------------------------------------------------
routers:
# Protected application using OIDC authentication
my-app:
rule: "Host(`app.example.com`)"
entryPoints:
- websecure
middlewares:
- oidc-production # Use the OIDC middleware
service: my-app-service
tls:
certResolver: letsencrypt
# Another app with minimal OIDC config
simple-app:
rule: "Host(`simple.example.com`)"
entryPoints:
- websecure
middlewares:
- oidc-minimal
service: simple-app-service
tls:
certResolver: letsencrypt
# -------------------------------------------------------------------------
# Service Definitions
# -------------------------------------------------------------------------
services:
my-app-service:
loadBalancer:
servers:
- url: "http://my-app:8080"
healthCheck:
path: /health
interval: 30s
timeout: 5s
simple-app-service:
loadBalancer:
servers:
- url: "http://simple-app:3000"
# ============================================================================
# Part 3: Docker Compose Example
# ============================================================================
---
# docker-compose.yml
version: '3.8'
services:
# Redis service for shared caching
redis:
image: redis:7-alpine
command: redis-server --requirepass yourredispassword --maxmemory 256mb --maxmemory-policy allkeys-lru
ports:
- "6379:6379"
volumes:
- redis-data:/data
healthcheck:
test: ["CMD", "redis-cli", "--raw", "incr", "ping"]
interval: 10s
timeout: 3s
retries: 5
networks:
- traefik-network
# Traefik with TraefikOIDC plugin
traefik:
image: traefik:v3.2
command:
- "--api.dashboard=true"
- "--providers.docker=true"
- "--providers.docker.exposedbydefault=false"
- "--providers.file.filename=/etc/traefik/dynamic.yml"
- "--entrypoints.web.address=:80"
- "--entrypoints.websecure.address=:443"
- "--experimental.plugins.traefikoidc.modulename=github.com/lukaszraczylo/traefikoidc"
- "--experimental.plugins.traefikoidc.version=v0.8.0"
ports:
- "80:80"
- "443:443"
- "8080:8080" # Dashboard
volumes:
- /var/run/docker.sock:/var/run/docker.sock:ro
- ./traefik-dynamic.yml:/etc/traefik/dynamic.yml:ro
- ./letsencrypt:/letsencrypt
depends_on:
- redis
networks:
- traefik-network
# Your application
my-app:
image: my-app:latest
labels:
- "traefik.enable=true"
- "traefik.http.routers.my-app.rule=Host(`app.example.com`)"
- "traefik.http.routers.my-app.entrypoints=websecure"
- "traefik.http.routers.my-app.tls.certresolver=letsencrypt"
# OIDC Middleware Configuration with Redis (using labels)
- "traefik.http.routers.my-app.middlewares=my-oidc@docker"
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.clientID=your-client-id"
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.clientSecret=your-client-secret"
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.providerURL=https://auth.example.com"
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.callbackURL=/oauth2/callback"
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.sessionEncryptionKey=your-64-character-encryption-key-here"
# Redis configuration
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.enabled=true"
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.address=redis:6379"
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.password=yourredispassword"
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.db=0"
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.keyPrefix=traefikoidc:"
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.cacheMode=hybrid"
networks:
- traefik-network
deploy:
replicas: 3 # Multiple replicas sharing Redis cache
volumes:
redis-data:
networks:
traefik-network:
driver: bridge
# ============================================================================
# Part 4: Kubernetes Example
# ============================================================================
---
# kubernetes-example.yaml
# Redis Deployment
apiVersion: apps/v1
kind: Deployment
metadata:
name: redis
namespace: traefik
spec:
replicas: 1
selector:
matchLabels:
app: redis
template:
metadata:
labels:
app: redis
spec:
containers:
- name: redis
image: redis:7-alpine
args:
- redis-server
- --requirepass
- $(REDIS_PASSWORD)
- --maxmemory
- 512mb
- --maxmemory-policy
- allkeys-lru
env:
- name: REDIS_PASSWORD
valueFrom:
secretKeyRef:
name: redis-secret
key: password
ports:
- containerPort: 6379
resources:
requests:
memory: "256Mi"
cpu: "100m"
limits:
memory: "512Mi"
cpu: "500m"
---
# Redis Service
apiVersion: v1
kind: Service
metadata:
name: redis
namespace: traefik
spec:
selector:
app: redis
ports:
- port: 6379
targetPort: 6379
---
# Redis Secret
apiVersion: v1
kind: Secret
metadata:
name: redis-secret
namespace: traefik
type: Opaque
stringData:
password: "REPLACE_WITH_YOUR_REDIS_PASSWORD"
---
# OIDC Middleware with Redis
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-auth
namespace: traefik
spec:
plugin:
traefikoidc:
# OIDC Configuration
clientID: "kubernetes-client-id"
clientSecret: "kubernetes-client-secret"
providerURL: "https://auth.example.com"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "kubernetes-64-character-session-encryption-key-keep-secret"
# Redis Configuration
redis:
enabled: true
address: "redis.traefik.svc.cluster.local:6379"
password: "REPLACE_WITH_YOUR_REDIS_PASSWORD"
db: 0
keyPrefix: "traefikoidc:k8s:"
cacheMode: "hybrid"
poolSize: 20
enableCircuitBreaker: true
enableHealthCheck: true
---
# IngressRoute using the middleware
apiVersion: traefik.io/v1alpha1
kind: IngressRoute
metadata:
name: my-app
namespace: default
spec:
entryPoints:
- websecure
routes:
- match: Host(`app.example.com`)
kind: Rule
middlewares:
- name: oidc-auth
namespace: traefik
services:
- name: my-app
port: 80
tls:
certResolver: letsencrypt
# ============================================================================
# Part 5: Environment Variables (Optional Fallback)
# ============================================================================
# If you prefer environment variables as fallback (not recommended for production),
# you can set these. NOTE: Plugin configuration takes precedence!
# Docker Compose env file (.env)
---
# OIDC Configuration
OIDC_CLIENT_ID=your-client-id
OIDC_CLIENT_SECRET=your-client-secret
OIDC_PROVIDER_URL=https://auth.example.com
# Redis Configuration (fallback)
REDIS_ENABLED=true
REDIS_ADDRESS=redis:6379
REDIS_PASSWORD=yourredispassword
REDIS_DB=0
REDIS_KEY_PREFIX=traefikoidc:
REDIS_CACHE_MODE=hybrid
REDIS_POOL_SIZE=20
REDIS_ENABLE_CIRCUIT_BREAKER=true
REDIS_ENABLE_HEALTH_CHECK=true
# ============================================================================
# Configuration Cheat Sheet
# ============================================================================
# Minimal Setup (Quick Start):
# redis:
# enabled: true
# address: "redis:6379"
# Production Setup (Recommended):
# redis:
# enabled: true
# address: "redis-master:6379"
# password: "strong-password"
# cacheMode: "hybrid"
# enableCircuitBreaker: true
# enableHealthCheck: true
# High Security Setup:
# redis:
# enabled: true
# address: "redis.example.com:6380"
# password: "strong-password"
# enableTLS: true
# tlsSkipVerify: false
# cacheMode: "redis"
# Cache Modes:
# - "memory": Local cache only (default, no Redis needed)
# - "redis": Redis only (consistent, shared across replicas)
# - "hybrid": Local L1 + Redis L2 (best performance + consistency)
+149
View File
@@ -0,0 +1,149 @@
# Example Traefik configuration for TraefikOIDC plugin with Redis caching
# This example shows how to configure Redis through Traefik's dynamic configuration
# Static configuration (traefik.yml)
experimental:
plugins:
traefikoidc:
moduleName: github.com/lukaszraczylo/traefikoidc
version: v0.8.0
# Dynamic configuration (dynamic.yml or labels)
http:
middlewares:
# Example 1: Basic Redis configuration
oidc-redis-basic:
plugin:
traefikoidc:
# Required OIDC settings
clientID: "your-client-id"
clientSecret: "your-client-secret"
providerURL: "https://auth.example.com"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "your-64-character-encryption-key-here-keep-it-secret"
# Redis configuration
redis:
enabled: true
address: "redis:6379"
# password: "your-redis-password" # Optional
db: 0
keyPrefix: "traefikoidc:"
# Example 2: Redis with resilience features
oidc-redis-resilient:
plugin:
traefikoidc:
# Required OIDC settings
clientID: "your-client-id"
clientSecret: "your-client-secret"
providerURL: "https://auth.example.com"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "your-64-character-encryption-key-here-keep-it-secret"
# Redis with full resilience configuration
redis:
enabled: true
address: "redis:6379"
password: "REPLACE_WITH_YOUR_REDIS_PASSWORD" # Example placeholder - use your actual password
db: 1
keyPrefix: "myapp:"
poolSize: 20
connectTimeout: 10
readTimeout: 5
writeTimeout: 5
cacheMode: "redis" # Options: "redis", "hybrid", "memory"
# Circuit breaker settings
enableCircuitBreaker: true
circuitBreakerThreshold: 5
circuitBreakerTimeout: 60
# Health check settings
enableHealthCheck: true
healthCheckInterval: 30
# Example 3: Redis with TLS
oidc-redis-tls:
plugin:
traefikoidc:
# Required OIDC settings
clientID: "your-client-id"
clientSecret: "your-client-secret"
providerURL: "https://auth.example.com"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "your-64-character-encryption-key-here-keep-it-secret"
# Redis with TLS configuration
redis:
enabled: true
address: "redis.example.com:6380"
password: "REPLACE_WITH_YOUR_REDIS_PASSWORD" # Example placeholder
enableTLS: true
tlsSkipVerify: false # Set to true only for testing
cacheMode: "redis"
routers:
my-app:
rule: "Host(`app.example.com`)"
middlewares:
- oidc-redis-basic
service: my-app-service
services:
my-app-service:
loadBalancer:
servers:
- url: "http://localhost:8080"
# Docker Compose labels example
# version: '3.8'
# services:
# traefik:
# image: traefik:v3.0
# # ... other config ...
#
# my-app:
# image: my-app:latest
# labels:
# - "traefik.enable=true"
# - "traefik.http.routers.my-app.rule=Host(`app.example.com`)"
# - "traefik.http.routers.my-app.middlewares=my-oidc"
# # OIDC middleware configuration with Redis
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.clientID=your-client-id"
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.clientSecret=your-secret"
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.providerURL=https://auth.example.com"
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.callbackURL=/oauth2/callback"
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.sessionEncryptionKey=your-64-char-key"
# # Redis configuration via labels
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.enabled=true"
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.address=redis:6379"
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.password=redis-password"
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.db=0"
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.keyPrefix=traefikoidc:"
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.cacheMode=redis"
#
# redis:
# image: redis:7-alpine
# command: redis-server --requirepass redis-password
# # ... other config ...
# Environment variable fallback (optional)
# If Redis configuration is not provided in Traefik config, these environment variables
# can be used as a fallback (but Traefik config takes precedence):
#
# REDIS_ENABLED=true
# REDIS_ADDRESS=redis:6379
# REDIS_PASSWORD=secret
# REDIS_DB=0
# REDIS_KEY_PREFIX=traefikoidc:
# REDIS_CACHE_MODE=redis
# REDIS_POOL_SIZE=10
# REDIS_CONNECT_TIMEOUT=5
# REDIS_READ_TIMEOUT=3
# REDIS_WRITE_TIMEOUT=3
# REDIS_ENABLE_TLS=false
# REDIS_TLS_SKIP_VERIFY=false
# REDIS_ENABLE_CIRCUIT_BREAKER=true
# REDIS_CIRCUIT_BREAKER_THRESHOLD=5
# REDIS_CIRCUIT_BREAKER_TIMEOUT=60
# REDIS_ENABLE_HEALTH_CHECK=true
# REDIS_HEALTH_CHECK_INTERVAL=30
-797
View File
@@ -1,797 +0,0 @@
package features
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"text/template"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// Mock types for testing
type TemplatedHeader struct {
Name string `json:"name"`
Value string `json:"value"`
}
type MockConfig struct {
ProviderURL string `json:"providerURL"`
ClientID string `json:"clientID"`
ClientSecret string `json:"clientSecret"`
CallbackURL string `json:"callbackURL"`
SessionEncryptionKey string `json:"sessionEncryptionKey"`
Headers []TemplatedHeader `json:"headers"`
}
// TestTemplateHeaderFeatures consolidates all template header-related tests
func TestTemplateHeaderFeatures(t *testing.T) {
t.Run("Issue55_TemplateExecutionWithWrongTypes", testIssue55TemplateExecutionWithWrongTypes)
t.Run("Template_Parsing_Validation", testTemplateParsingValidation)
t.Run("Middleware_Header_Templating", testMiddlewareHeaderTemplating)
t.Run("JSON_Config_Parsing", testJSONConfigParsing)
t.Run("Template_Double_Processing", testTemplateDoubleProcessing)
t.Run("Template_Execution_Context", testTemplateExecutionContext)
t.Run("Template_Integration_With_Plugin", testTemplateIntegrationWithPlugin)
t.Run("Template_Syntax_Validation", testTemplateSyntaxValidation)
t.Run("Missing_Field_Handling", testMissingFieldHandling)
t.Run("Complex_Template_Expressions", testComplexTemplateExpressions)
t.Run("Traefik_Configuration_Parsing", testTraefikConfigurationParsing)
}
// testIssue55TemplateExecutionWithWrongTypes tests what happens when templates
// receive wrong data types during execution - reproduces GitHub issue #55
func testIssue55TemplateExecutionWithWrongTypes(t *testing.T) {
testCases := []struct {
name string
templateText string
templateData interface{}
errorContains string
expectError bool
}{
{
name: "correct map data",
templateText: "Bearer {{.AccessToken}}",
templateData: map[string]interface{}{
"AccessToken": "valid-token",
},
expectError: false,
},
{
name: "boolean as root context - reproduces issue #55",
templateText: "Bearer {{.AccessToken}}",
templateData: true,
expectError: true,
errorContains: "can't evaluate field AccessToken in type bool",
},
{
name: "string as root context",
templateText: "Bearer {{.AccessToken}}",
templateData: "just a string",
expectError: true,
errorContains: "can't evaluate field AccessToken in type string",
},
{
name: "nested claims access with correct data",
templateText: "User: {{.Claims.email}}",
templateData: map[string]interface{}{
"Claims": map[string]interface{}{
"email": "user@example.com",
},
},
expectError: false,
},
{
name: "nested claims with wrong structure",
templateText: "User: {{.Claims.email}}",
templateData: map[string]interface{}{
"Claims": "not a map",
},
expectError: true,
errorContains: "can't evaluate field email in type",
},
{
name: "complex nested structure",
templateText: "{{.Claims.sub}} - {{.Claims.groups}} - {{.AccessToken}}",
templateData: map[string]interface{}{
"AccessToken": "token123",
"Claims": map[string]interface{}{
"sub": "user-id",
"groups": "admin,users",
},
},
expectError: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tmpl, err := template.New("test").Parse(tc.templateText)
require.NoError(t, err)
var buf bytes.Buffer
err = tmpl.Execute(&buf, tc.templateData)
if tc.expectError {
require.Error(t, err)
if tc.errorContains != "" {
assert.Contains(t, err.Error(), tc.errorContains)
}
} else {
require.NoError(t, err)
}
})
}
}
// testTemplateParsingValidation ensures templates are parsed correctly
func testTemplateParsingValidation(t *testing.T) {
testCases := []struct {
name string
headerTemplates []TemplatedHeader
shouldError bool
}{
{
name: "valid bearer token template",
headerTemplates: []TemplatedHeader{
{Name: "Authorization", Value: "Bearer {{.AccessToken}}"},
},
shouldError: false,
},
{
name: "multiple valid templates",
headerTemplates: []TemplatedHeader{
{Name: "Authorization", Value: "Bearer {{.AccessToken}}"},
{Name: "X-User-Email", Value: "{{.Claims.email}}"},
{Name: "X-User-ID", Value: "{{.Claims.sub}}"},
},
shouldError: false,
},
{
name: "template with conditional logic",
headerTemplates: []TemplatedHeader{
{Name: "X-Auth-Info", Value: "{{if .AccessToken}}Bearer {{.AccessToken}}{{else}}No Token{{end}}"},
},
shouldError: false,
},
{
name: "invalid template syntax",
headerTemplates: []TemplatedHeader{
{Name: "Bad-Template", Value: "{{.AccessToken"},
},
shouldError: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
for _, header := range tc.headerTemplates {
_, err := template.New(header.Name).Parse(header.Value)
if tc.shouldError {
require.Error(t, err)
} else {
require.NoError(t, err)
}
}
})
}
}
// testMiddlewareHeaderTemplating simulates the actual middleware flow
func testMiddlewareHeaderTemplating(t *testing.T) {
testCases := []struct {
name string
headers []TemplatedHeader
accessToken string
idToken string
claims map[string]interface{}
expectedValues map[string]string
}{
{
name: "authorization header with access token",
headers: []TemplatedHeader{
{Name: "Authorization", Value: "Bearer {{.AccessToken}}"},
},
accessToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9",
expectedValues: map[string]string{
"Authorization": "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9",
},
},
{
name: "multiple headers with claims",
headers: []TemplatedHeader{
{Name: "X-User-Email", Value: "{{.Claims.email}}"},
{Name: "X-User-Groups", Value: "{{.Claims.groups}}"},
{Name: "X-Auth-Token", Value: "{{.AccessToken}}"},
},
accessToken: "token123",
claims: map[string]interface{}{
"email": "user@example.com",
"groups": "admin,developers",
},
expectedValues: map[string]string{
"X-User-Email": "user@example.com",
"X-User-Groups": "admin,developers",
"X-Auth-Token": "token123",
},
},
{
name: "complex template expressions",
headers: []TemplatedHeader{
{Name: "X-User-Info", Value: "{{.Claims.sub}} ({{.Claims.email}})"},
{Name: "X-Auth-Header", Value: "Bearer {{.AccessToken}} | ID: {{.IDToken}}"},
},
accessToken: "access-token",
idToken: "id-token",
claims: map[string]interface{}{
"sub": "user-12345",
"email": "john@example.com",
},
expectedValues: map[string]string{
"X-User-Info": "user-12345 (john@example.com)",
"X-Auth-Header": "Bearer access-token | ID: id-token",
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Parse all templates
headerTemplates := make(map[string]*template.Template)
for _, header := range tc.headers {
tmpl, err := template.New(header.Name).Parse(header.Value)
require.NoError(t, err)
headerTemplates[header.Name] = tmpl
}
// Create template data
templateData := map[string]interface{}{
"AccessToken": tc.accessToken,
"IDToken": tc.idToken,
"Claims": tc.claims,
}
// Create a test request
req := httptest.NewRequest("GET", "/test", nil)
// Execute templates and set headers
for headerName, tmpl := range headerTemplates {
var buf bytes.Buffer
err := tmpl.Execute(&buf, templateData)
require.NoError(t, err)
req.Header.Set(headerName, buf.String())
}
// Verify all expected headers are set correctly
for headerName, expectedValue := range tc.expectedValues {
actualValue := req.Header.Get(headerName)
assert.Equal(t, expectedValue, actualValue)
}
})
}
}
// testJSONConfigParsing tests that JSON configuration is properly parsed
func testJSONConfigParsing(t *testing.T) {
testCases := []struct {
name string
jsonConfig string
expectedError bool
description string
}{
{
name: "valid JSON configuration",
jsonConfig: `{
"headers": [
{
"name": "Authorization",
"value": "Bearer {{.AccessToken}}"
}
]
}`,
expectedError: false,
description: "Properly formatted JSON with string values",
},
{
name: "JSON with boolean value",
jsonConfig: `{
"headers": [
{
"name": "Authorization",
"value": true
}
]
}`,
expectedError: true,
description: "Boolean value instead of string template",
},
{
name: "JSON with number value",
jsonConfig: `{
"headers": [
{
"name": "Authorization",
"value": 123
}
]
}`,
expectedError: true,
description: "Number value instead of string template",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var config struct {
Headers []TemplatedHeader `json:"headers"`
}
err := json.Unmarshal([]byte(tc.jsonConfig), &config)
if tc.expectedError {
require.Error(t, err, tc.description)
} else {
require.NoError(t, err, tc.description)
}
})
}
}
// testTemplateDoubleProcessing tests if template strings are being double-processed
func testTemplateDoubleProcessing(t *testing.T) {
// Simulate how Traefik passes config to the plugin
config := &MockConfig{
Headers: []TemplatedHeader{
{Name: "X-User-Email", Value: "{{.Claims.email}}"},
{Name: "X-User-Role", Value: "{{.Claims.internal_role}}"},
},
}
// Verify that template strings are still raw (not processed)
assert.Equal(t, "{{.Claims.email}}", config.Headers[0].Value)
assert.Equal(t, "{{.Claims.internal_role}}", config.Headers[1].Value)
// Simulate template parsing during initialization
headerTemplates := make(map[string]*template.Template)
funcMap := template.FuncMap{
"default": func(defaultVal interface{}, val interface{}) interface{} {
if val == nil || val == "" || val == "<no value>" {
return defaultVal
}
return val
},
"get": func(m interface{}, key string) interface{} {
if mapVal, ok := m.(map[string]interface{}); ok {
if val, exists := mapVal[key]; exists {
return val
}
}
return ""
},
}
for _, header := range config.Headers {
tmpl := template.New(header.Name).Funcs(funcMap).Option("missingkey=zero")
parsedTmpl, err := tmpl.Parse(header.Value)
require.NoError(t, err)
headerTemplates[header.Name] = parsedTmpl
}
// Test execution with actual claims
claims := map[string]interface{}{
"email": "user@example.com",
// Note: internal_role is missing
}
templateData := map[string]interface{}{
"Claims": claims,
}
// Execute templates
for headerName, tmpl := range headerTemplates {
var buf bytes.Buffer
err := tmpl.Execute(&buf, templateData)
require.NoError(t, err)
result := buf.String()
if headerName == "X-User-Email" {
assert.Equal(t, "user@example.com", result)
} else if headerName == "X-User-Role" {
// With missingkey=zero, missing fields return "<no value>"
assert.Equal(t, "<no value>", result)
}
}
}
// testTemplateExecutionContext tests the specific template data context
func testTemplateExecutionContext(t *testing.T) {
testCases := []struct {
name string
templateText string
data map[string]interface{}
expectedValue string
}{
{
name: "Access and ID token distinction",
templateText: "Access: {{.AccessToken}} ID: {{.IDToken}}",
data: map[string]interface{}{
"AccessToken": "access-token-value",
"IDToken": "id-token-value",
"Claims": map[string]interface{}{},
},
expectedValue: "Access: access-token-value ID: id-token-value",
},
{
name: "Combining tokens and claims",
templateText: "User: {{.Claims.sub}} Token: {{.AccessToken}}",
data: map[string]interface{}{
"AccessToken": "access-token",
"IDToken": "id-token",
"Claims": map[string]interface{}{
"sub": "user123",
},
},
expectedValue: "User: user123 Token: access-token",
},
{
name: "Custom non-standard claims",
templateText: "X-User-Role: {{.Claims.role}}, X-User-Permissions: {{.Claims.permissions}}",
data: map[string]interface{}{
"AccessToken": "access-token-value",
"Claims": map[string]interface{}{
"role": "admin",
"permissions": "read:all,write:own",
},
},
expectedValue: "X-User-Role: admin, X-User-Permissions: read:all,write:own",
},
{
name: "Deeply nested custom claims",
templateText: "X-Organization: {{.Claims.app_metadata.organization.name}}, X-Team: {{.Claims.app_metadata.team}}",
data: map[string]interface{}{
"Claims": map[string]interface{}{
"app_metadata": map[string]interface{}{
"organization": map[string]interface{}{
"name": "acme-corp",
},
"team": "platform",
},
},
},
expectedValue: "X-Organization: acme-corp, X-Team: platform",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tmpl, err := template.New("test").Parse(tc.templateText)
require.NoError(t, err)
var buf bytes.Buffer
err = tmpl.Execute(&buf, tc.data)
require.NoError(t, err)
assert.Equal(t, tc.expectedValue, buf.String())
})
}
}
// testTemplateIntegrationWithPlugin tests template processing in the actual plugin
func testTemplateIntegrationWithPlugin(t *testing.T) {
// Test template integration using mock plugin components
// Set up test OIDC server
var testServerURL string
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/.well-known/openid-configuration":
json.NewEncoder(w).Encode(map[string]interface{}{
"issuer": testServerURL,
"authorization_endpoint": testServerURL + "/auth",
"token_endpoint": testServerURL + "/token",
"jwks_uri": testServerURL + "/jwks",
"userinfo_endpoint": testServerURL + "/userinfo",
})
case "/jwks":
json.NewEncoder(w).Encode(map[string]interface{}{
"keys": []interface{}{},
})
default:
http.NotFound(w, r)
}
}))
defer testServer.Close()
testServerURL = testServer.URL
// Create config with templates that reference potentially missing fields
config := &MockConfig{
ProviderURL: testServer.URL,
ClientID: "test-client",
ClientSecret: "test-secret",
CallbackURL: "/callback",
SessionEncryptionKey: "test-encryption-key-32-characters",
Headers: []TemplatedHeader{
{Name: "X-User-Email", Value: "{{.Claims.email}}"},
{Name: "X-User-Role", Value: "{{.Claims.internal_role}}"},
},
}
// Initialize plugin would be done here
ctx := context.Background()
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
// Test would create plugin handler here
_ = ctx
_ = next
_ = config
}
// testTemplateSyntaxValidation tests that template syntax is properly validated
func testTemplateSyntaxValidation(t *testing.T) {
validTemplates := []string{
"{{.Claims.email}}",
"{{.Claims.internal_role}}",
"{{.AccessToken}}",
"{{.IdToken}}",
"{{.RefreshToken}}",
}
for _, tmplStr := range validTemplates {
err := validateTemplateSecure(tmplStr)
assert.NoError(t, err, "Template should be valid: %s", tmplStr)
}
// Test invalid templates
invalidTemplates := []struct {
template string
reason string
}{
{"{{call .SomeFunc}}", "function calls not allowed"},
{"{{range .Items}}{{.}}{{end}}", "range not allowed"},
{"{{with .Data}}{{.Field}}{{end}}", "with statements blocked"},
{"{{index .Array 0}}", "index access blocked"},
{"{{printf \"%s\" .Data}}", "printf blocked"},
}
for _, tc := range invalidTemplates {
err := validateTemplateSecure(tc.template)
assert.Error(t, err, "Template should be invalid: %s (%s)", tc.template, tc.reason)
assert.Contains(t, strings.ToLower(err.Error()), "dangerous")
}
// Test safe custom functions
safeTemplates := []string{
"{{get .Claims \"internal_role\"}}",
"{{default \"guest\" .Claims.role}}",
}
for _, tmplStr := range safeTemplates {
err := validateTemplateSecure(tmplStr)
assert.NoError(t, err, "Safe custom functions should be allowed: %s", tmplStr)
}
}
// Mock validation function for template security
func validateTemplateSecure(templateStr string) error {
// List of potentially dangerous template actions
dangerousFunctions := []string{
"call", "range", "with", "index", "printf", "println", "print",
"js", "html", "urlquery", "base64", "exec",
}
for _, dangerous := range dangerousFunctions {
if strings.Contains(templateStr, dangerous) {
return fmt.Errorf("dangerous template function detected: %s", dangerous)
}
}
// Define safe custom functions
funcMap := template.FuncMap{
"get": func(data map[string]interface{}, key string) interface{} {
return data[key]
},
"default": func(defaultVal interface{}, val interface{}) interface{} {
if val == nil || val == "" {
return defaultVal
}
return val
},
}
// Try to parse the template with custom functions to check for syntax errors
_, err := template.New("test").Funcs(funcMap).Parse(templateStr)
return err
}
// testMissingFieldHandling tests handling of missing fields in templates
func testMissingFieldHandling(t *testing.T) {
testCases := []struct {
name string
templateText string
data map[string]interface{}
expected string
}{
{
name: "missing claim field",
templateText: "{{.Claims.missing}}",
data: map[string]interface{}{
"Claims": map[string]interface{}{},
},
expected: "<no value>",
},
{
name: "missing nested field",
templateText: "{{.Claims.user.missing}}",
data: map[string]interface{}{
"Claims": map[string]interface{}{
"user": map[string]interface{}{},
},
},
expected: "<no value>",
},
{
name: "missing entire path",
templateText: "{{.Missing.Path.Field}}",
data: map[string]interface{}{},
expected: "<no value>",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tmpl, err := template.New("test").Parse(tc.templateText)
require.NoError(t, err)
var buf bytes.Buffer
err = tmpl.Execute(&buf, tc.data)
require.NoError(t, err)
assert.Equal(t, tc.expected, buf.String())
})
}
}
// testComplexTemplateExpressions tests complex template expressions
func testComplexTemplateExpressions(t *testing.T) {
testCases := []struct {
name string
templateText string
data map[string]interface{}
expected string
}{
{
name: "conditional template",
templateText: "{{if .Claims.admin}}Admin User{{else}}Regular User{{end}}",
data: map[string]interface{}{
"Claims": map[string]interface{}{
"admin": true,
},
},
expected: "Admin User",
},
{
name: "multiple claims concatenation",
templateText: "{{.Claims.firstName}} {{.Claims.lastName}} <{{.Claims.email}}>",
data: map[string]interface{}{
"Claims": map[string]interface{}{
"firstName": "John",
"lastName": "Doe",
"email": "john.doe@example.com",
},
},
expected: "John Doe <john.doe@example.com>",
},
{
name: "array access",
templateText: "{{index .Claims.roles 0}}",
data: map[string]interface{}{
"Claims": map[string]interface{}{
"roles": []string{"admin", "user"},
},
},
expected: "admin",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tmpl, err := template.New("test").Parse(tc.templateText)
require.NoError(t, err)
var buf bytes.Buffer
err = tmpl.Execute(&buf, tc.data)
require.NoError(t, err)
assert.Equal(t, tc.expected, buf.String())
})
}
}
// testTraefikConfigurationParsing tests various ways Traefik might pass configuration
func testTraefikConfigurationParsing(t *testing.T) {
testCases := []struct {
name string
config *MockConfig
expectError bool
description string
}{
{
name: "valid configuration with templated headers",
config: &MockConfig{
ProviderURL: "https://accounts.google.com",
ClientID: "test-client",
ClientSecret: "test-secret",
SessionEncryptionKey: "test-encryption-key-32-bytes-long",
CallbackURL: "/oauth2/callback",
Headers: []TemplatedHeader{
{Name: "Authorization", Value: "Bearer {{.AccessToken}}"},
},
},
expectError: false,
description: "Standard configuration should work",
},
{
name: "configuration with multiple headers",
config: &MockConfig{
ProviderURL: "https://accounts.google.com",
ClientID: "test-client",
ClientSecret: "test-secret",
SessionEncryptionKey: "test-encryption-key-32-bytes-long",
CallbackURL: "/oauth2/callback",
Headers: []TemplatedHeader{
{Name: "Authorization", Value: "Bearer {{.AccessToken}}"},
{Name: "X-User-Email", Value: "{{.Claims.email}}"},
{Name: "X-User-ID", Value: "{{.Claims.sub}}"},
},
},
expectError: false,
description: "Multiple headers should work",
},
{
name: "empty headers configuration",
config: &MockConfig{
ProviderURL: "https://accounts.google.com",
ClientID: "test-client",
ClientSecret: "test-secret",
SessionEncryptionKey: "test-encryption-key-32-bytes-long",
CallbackURL: "/oauth2/callback",
Headers: []TemplatedHeader{},
},
expectError: false,
description: "Empty headers should not cause issues",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Create a simple next handler
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
// Try to create the middleware would be done here
ctx := context.Background()
// Test would create middleware handler here
_ = ctx
_ = next
_ = tc.config
// For now, we just validate the configuration is well-formed
if !tc.expectError {
require.NotNil(t, tc.config, tc.description)
require.NotEmpty(t, tc.config.ClientID, tc.description)
}
})
}
}
+8 -2
View File
@@ -3,15 +3,21 @@ module github.com/lukaszraczylo/traefikoidc
go 1.24.0
require (
github.com/alicebob/miniredis/v2 v2.35.0
github.com/google/uuid v1.6.0
github.com/gorilla/sessions v1.3.0
github.com/redis/go-redis/v9 v9.17.2
github.com/stretchr/testify v1.10.0
golang.org/x/time v0.13.0
golang.org/x/time v0.14.0
gopkg.in/yaml.v3 v3.0.1
)
require (
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/gorilla/securecookie v1.1.2 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
github.com/stretchr/objx v0.5.2 // indirect
github.com/yuin/gopher-lua v1.1.1 // indirect
)
+18 -2
View File
@@ -1,5 +1,15 @@
github.com/alicebob/miniredis/v2 v2.35.0 h1:QwLphYqCEAo1eu1TqPRN2jgVMPBweeQcR21jeqDCONI=
github.com/alicebob/miniredis/v2 v2.35.0/go.mod h1:TcL7YfarKPGDAthEtl5NBeHZfeUQj6OXMm/+iu5cLMM=
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0=
github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
@@ -10,10 +20,16 @@ github.com/gorilla/sessions v1.3.0 h1:XYlkq7KcpOB2ZhHBPv5WpjMIxrQosiZanfoy1HLZFz
github.com/gorilla/sessions v1.3.0/go.mod h1:ePLdVu+jbEgHH+KWw8I1z2wqd0BAdAQh/8LRvBeoNcQ=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/redis/go-redis/v9 v9.17.2 h1:P2EGsA4qVIM3Pp+aPocCJ7DguDHhqrXNhVcEp4ViluI=
github.com/redis/go-redis/v9 v9.17.2/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
golang.org/x/time v0.13.0 h1:eUlYslOIt32DgYD6utsuUeHs4d7AsEYLuIAdg7FlYgI=
golang.org/x/time v0.13.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M=
github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw=
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
+7 -7
View File
@@ -10,16 +10,16 @@ import (
type GoroutineManager struct {
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
mu sync.RWMutex
goroutines map[string]*managedGoroutine
logger *Logger
wg sync.WaitGroup
mu sync.RWMutex
}
type managedGoroutine struct {
name string
cancel context.CancelFunc
startTime time.Time
cancel context.CancelFunc
name string
running bool
}
@@ -86,7 +86,7 @@ func (m *GoroutineManager) StartPeriodicTask(name string, interval time.Duration
for {
select {
case <-ctx.Done():
m.logger.Debugf("Periodic task %s cancelled", name)
m.logger.Debugf("Periodic task %s canceled", name)
return
case <-ticker.C:
task()
@@ -149,10 +149,10 @@ func (m *GoroutineManager) GetStatus() map[string]GoroutineStatus {
// GoroutineStatus represents the status of a managed goroutine
type GoroutineStatus struct {
Name string
Running bool
StartTime time.Time
Name string
Runtime time.Duration
Running bool
}
// ErrShutdownTimeout is returned when shutdown times out
+625
View File
@@ -0,0 +1,625 @@
package traefikoidc
import (
"context"
"sync/atomic"
"testing"
"time"
)
// Test GoroutineManager Creation
func TestNewGoroutineManager(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
if gm == nil {
t.Fatal("Expected non-nil goroutine manager")
}
if gm.ctx == nil {
t.Error("Expected context to be initialized")
}
if gm.cancel == nil {
t.Error("Expected cancel function to be initialized")
}
if gm.goroutines == nil {
t.Error("Expected goroutines map to be initialized")
}
if gm.logger != logger {
t.Error("Expected logger to be set")
}
}
// Test Starting Goroutines
func TestStartGoroutine(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
defer gm.Shutdown(time.Second)
executed := atomic.Bool{}
gm.StartGoroutine("test-goroutine", func(ctx context.Context) {
executed.Store(true)
})
// Give goroutine time to execute
time.Sleep(50 * time.Millisecond)
if !executed.Load() {
t.Error("Expected goroutine to execute")
}
status := gm.GetStatus()
if len(status) != 1 {
t.Errorf("Expected 1 goroutine in status, got %d", len(status))
}
if _, exists := status["test-goroutine"]; !exists {
t.Error("Expected goroutine 'test-goroutine' in status")
}
}
func TestStartGoroutineDuplicate(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
defer gm.Shutdown(time.Second)
counter := atomic.Int32{}
// Start a long-running goroutine
gm.StartGoroutine("duplicate-test", func(ctx context.Context) {
counter.Add(1)
<-ctx.Done()
})
// Give first goroutine time to start
time.Sleep(50 * time.Millisecond)
// Try to start another with same name (should be skipped)
gm.StartGoroutine("duplicate-test", func(ctx context.Context) {
counter.Add(1)
})
time.Sleep(50 * time.Millisecond)
// Should only have executed once
if counter.Load() != 1 {
t.Errorf("Expected counter to be 1 (duplicate should be skipped), got %d", counter.Load())
}
}
func TestStartGoroutineContextCancellation(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
started := atomic.Bool{}
canceled := atomic.Bool{}
gm.StartGoroutine("cancel-test", func(ctx context.Context) {
started.Store(true)
<-ctx.Done()
canceled.Store(true)
})
// Wait for goroutine to start
time.Sleep(50 * time.Millisecond)
if !started.Load() {
t.Error("Expected goroutine to start")
}
// Stop the goroutine
gm.StopGoroutine("cancel-test")
// Wait for cancellation
time.Sleep(50 * time.Millisecond)
if !canceled.Load() {
t.Error("Expected goroutine to be canceled")
}
}
func TestStartGoroutineWithPanic(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
defer gm.Shutdown(time.Second)
executed := atomic.Bool{}
gm.StartGoroutine("panic-test", func(ctx context.Context) {
executed.Store(true)
panic("test panic")
})
// Give goroutine time to panic and recover
time.Sleep(100 * time.Millisecond)
if !executed.Load() {
t.Error("Expected goroutine to execute before panic")
}
// Check that goroutine is marked as not running after panic
status := gm.GetStatus()
if goroutineStatus, exists := status["panic-test"]; exists {
if goroutineStatus.Running {
t.Error("Expected goroutine to be marked as not running after panic")
}
}
// Manager should still be functional
counter := atomic.Int32{}
gm.StartGoroutine("after-panic", func(ctx context.Context) {
counter.Add(1)
})
time.Sleep(50 * time.Millisecond)
if counter.Load() != 1 {
t.Error("Expected manager to still be functional after panic recovery")
}
}
// Test Periodic Tasks
func TestStartPeriodicTask(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
defer gm.Shutdown(time.Second)
counter := atomic.Int32{}
gm.StartPeriodicTask("periodic-test", 50*time.Millisecond, func() {
counter.Add(1)
})
// Wait for multiple executions
time.Sleep(160 * time.Millisecond)
// Should have executed at least 2-3 times
count := counter.Load()
if count < 2 {
t.Errorf("Expected periodic task to execute at least 2 times, got %d", count)
}
}
func TestStartPeriodicTaskCancellation(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
defer gm.Shutdown(time.Second)
counter := atomic.Int32{}
gm.StartPeriodicTask("cancel-periodic", 50*time.Millisecond, func() {
counter.Add(1)
})
// Wait for some executions
time.Sleep(120 * time.Millisecond)
// Stop the task
gm.StopGoroutine("cancel-periodic")
countBeforeStop := counter.Load()
// Wait and verify no more executions
time.Sleep(120 * time.Millisecond)
countAfterStop := counter.Load()
// Allow 1 additional execution (could be in progress when stopped)
if countAfterStop > countBeforeStop+1 {
t.Errorf("Expected periodic task to stop executing, before: %d, after: %d",
countBeforeStop, countAfterStop)
}
}
// Test Stopping Goroutines
func TestStopGoroutine(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
defer gm.Shutdown(time.Second)
stopped := atomic.Bool{}
gm.StartGoroutine("stop-test", func(ctx context.Context) {
<-ctx.Done()
stopped.Store(true)
})
// Wait for goroutine to start
time.Sleep(50 * time.Millisecond)
gm.StopGoroutine("stop-test")
// Wait for goroutine to stop
time.Sleep(50 * time.Millisecond)
if !stopped.Load() {
t.Error("Expected goroutine to be stopped")
}
status := gm.GetStatus()
if goroutineStatus, exists := status["stop-test"]; exists {
if goroutineStatus.Running {
t.Error("Expected goroutine to be marked as not running")
}
}
}
func TestStopGoroutineNonExistent(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
defer gm.Shutdown(time.Second)
// Should not panic or error when stopping non-existent goroutine
gm.StopGoroutine("non-existent")
}
func TestStopGoroutineAlreadyStopped(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
defer gm.Shutdown(time.Second)
gm.StartGoroutine("already-stopped", func(ctx context.Context) {
// Exit immediately
})
// Wait for goroutine to finish
time.Sleep(50 * time.Millisecond)
// Try to stop already-stopped goroutine (should be safe)
gm.StopGoroutine("already-stopped")
}
// Test Shutdown
func TestShutdownGraceful(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
counter := atomic.Int32{}
// Start multiple goroutines
for i := 0; i < 5; i++ {
name := "goroutine-" + string(rune('0'+i))
gm.StartGoroutine(name, func(ctx context.Context) {
counter.Add(1)
<-ctx.Done()
counter.Add(-1)
})
}
// Wait for all to start
time.Sleep(100 * time.Millisecond)
if counter.Load() != 5 {
t.Errorf("Expected 5 goroutines running, got %d", counter.Load())
}
// Shutdown with generous timeout
err := gm.Shutdown(time.Second)
if err != nil {
t.Errorf("Expected graceful shutdown, got error: %v", err)
}
if counter.Load() != 0 {
t.Errorf("Expected all goroutines to complete cleanup, got %d still running", counter.Load())
}
}
func TestShutdownWithTimeout(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
// Start a goroutine that ignores cancellation (bad behavior, but testing timeout)
gm.StartGoroutine("stubborn", func(ctx context.Context) {
// Simulate a goroutine that takes too long to stop
time.Sleep(500 * time.Millisecond)
})
time.Sleep(50 * time.Millisecond)
// Shutdown with very short timeout
err := gm.Shutdown(10 * time.Millisecond)
if err == nil {
t.Error("Expected timeout error")
}
if err != ErrShutdownTimeout {
t.Errorf("Expected ErrShutdownTimeout, got %v", err)
}
}
func TestShutdownEmpty(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
// Shutdown with no goroutines should succeed immediately
err := gm.Shutdown(time.Second)
if err != nil {
t.Errorf("Expected no error for empty shutdown, got: %v", err)
}
}
// Test Status
func TestGetStatus(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
defer gm.Shutdown(time.Second)
// Start multiple goroutines with different states
gm.StartGoroutine("running", func(ctx context.Context) {
<-ctx.Done()
})
gm.StartGoroutine("quick", func(ctx context.Context) {
// Exits immediately
})
time.Sleep(50 * time.Millisecond)
status := gm.GetStatus()
if len(status) != 2 {
t.Errorf("Expected 2 goroutines in status, got %d", len(status))
}
if runningStatus, exists := status["running"]; exists {
if !runningStatus.Running {
t.Error("Expected 'running' goroutine to be marked as running")
}
if runningStatus.Name != "running" {
t.Errorf("Expected name 'running', got %s", runningStatus.Name)
}
if runningStatus.StartTime.IsZero() {
t.Error("Expected non-zero start time")
}
if runningStatus.Runtime <= 0 {
t.Error("Expected positive runtime")
}
} else {
t.Error("Expected 'running' goroutine in status")
}
if quickStatus, exists := status["quick"]; exists {
if quickStatus.Running {
t.Error("Expected 'quick' goroutine to be marked as not running")
}
} else {
t.Error("Expected 'quick' goroutine in status")
}
}
func TestGetStatusEmpty(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
defer gm.Shutdown(time.Second)
status := gm.GetStatus()
if status == nil {
t.Fatal("Expected non-nil status map")
}
if len(status) != 0 {
t.Errorf("Expected empty status, got %d entries", len(status))
}
}
// Test Concurrent Operations
func TestConcurrentStartGoroutine(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
defer gm.Shutdown(2 * time.Second)
counter := atomic.Int32{}
const numGoroutines = 50
// Start many goroutines concurrently
for i := 0; i < numGoroutines; i++ {
go func(id int) {
name := "concurrent-" + string(rune('0'+id%10)) + string(rune('0'+id/10))
gm.StartGoroutine(name, func(ctx context.Context) {
counter.Add(1)
time.Sleep(50 * time.Millisecond)
counter.Add(-1)
})
}(i)
}
// Wait for all to start
time.Sleep(150 * time.Millisecond)
// Verify goroutines are tracked
status := gm.GetStatus()
if len(status) < numGoroutines/2 {
t.Errorf("Expected at least %d goroutines, got %d", numGoroutines/2, len(status))
}
}
func TestConcurrentStopGoroutine(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
defer gm.Shutdown(time.Second)
const numGoroutines = 20
// Start goroutines
for i := 0; i < numGoroutines; i++ {
name := "stop-concurrent-" + string(rune('0'+i%10))
gm.StartGoroutine(name, func(ctx context.Context) {
<-ctx.Done()
})
}
time.Sleep(50 * time.Millisecond)
// Stop all concurrently
for i := 0; i < numGoroutines; i++ {
go func(id int) {
name := "stop-concurrent-" + string(rune('0'+id%10))
gm.StopGoroutine(name)
}(i)
}
time.Sleep(100 * time.Millisecond)
// Verify all stopped
status := gm.GetStatus()
for _, s := range status {
if s.Running {
t.Errorf("Expected goroutine %s to be stopped", s.Name)
}
}
}
func TestConcurrentGetStatus(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
defer gm.Shutdown(time.Second)
// Start some goroutines
for i := 0; i < 10; i++ {
name := "status-test-" + string(rune('0'+i))
gm.StartGoroutine(name, func(ctx context.Context) {
<-ctx.Done()
})
}
// Concurrently read status many times (should not race)
done := make(chan struct{})
for i := 0; i < 20; i++ {
go func() {
for j := 0; j < 100; j++ {
_ = gm.GetStatus()
}
done <- struct{}{}
}()
}
// Wait for all concurrent reads
for i := 0; i < 20; i++ {
<-done
}
}
// Test Error Cases
func TestShutdownTimeoutError(t *testing.T) {
err := ErrShutdownTimeout
if err.Error() != "shutdown timeout: some goroutines did not stop in time" {
t.Errorf("Unexpected error message: %s", err.Error())
}
}
// Test Edge Cases
func TestStartGoroutineAfterShutdown(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
// Shutdown immediately
_ = gm.Shutdown(time.Second)
executed := atomic.Bool{}
// Try to start goroutine after shutdown
gm.StartGoroutine("after-shutdown", func(ctx context.Context) {
executed.Store(true)
<-ctx.Done()
})
time.Sleep(50 * time.Millisecond)
// Goroutine should have started but context already canceled
// It may or may not execute depending on timing, but shouldn't panic
status := gm.GetStatus()
if _, exists := status["after-shutdown"]; exists {
// If it's in status, it was tracked (acceptable)
t.Log("Goroutine was tracked even after shutdown")
}
}
func TestMultipleShutdowns(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
// First shutdown
err1 := gm.Shutdown(time.Second)
if err1 != nil {
t.Errorf("Expected first shutdown to succeed, got: %v", err1)
}
// Second shutdown (should not panic or error)
err2 := gm.Shutdown(time.Second)
if err2 != nil {
t.Errorf("Expected second shutdown to succeed, got: %v", err2)
}
}
func TestGoroutineWithImmediateReturn(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
defer gm.Shutdown(time.Second)
executed := atomic.Bool{}
gm.StartGoroutine("immediate", func(ctx context.Context) {
executed.Store(true)
// Return immediately
})
time.Sleep(50 * time.Millisecond)
if !executed.Load() {
t.Error("Expected goroutine to execute")
}
status := gm.GetStatus()
if goroutineStatus, exists := status["immediate"]; exists {
if goroutineStatus.Running {
t.Error("Expected immediately-returning goroutine to be marked as not running")
}
}
}
func TestPeriodicTaskPanicRecovery(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
defer gm.Shutdown(time.Second)
counter := atomic.Int32{}
gm.StartPeriodicTask("panic-periodic", 50*time.Millisecond, func() {
counter.Add(1)
if counter.Load() == 2 {
panic("periodic panic")
}
})
// Wait for panic to occur
time.Sleep(200 * time.Millisecond)
// After panic, the goroutine should have stopped
status := gm.GetStatus()
if goroutineStatus, exists := status["panic-periodic"]; exists {
if goroutineStatus.Running {
t.Error("Expected panicked periodic task to stop")
}
}
}
-764
View File
@@ -1,764 +0,0 @@
package handlers
import (
"errors"
"net/http"
"sync"
"testing"
"time"
)
// ============================================================================
// OAuth Handler Tests
// ============================================================================
func TestOAuthHandler(t *testing.T) {
t.Run("HandleAuthorizationRequest", func(t *testing.T) {
// Test authorization request handling logic
logger := &MockLogger{}
tests := []struct {
name string
requestURL string
expectedStatus int
checkLocation bool
}{
{
name: "Valid authorization request",
requestURL: "/auth/login",
expectedStatus: http.StatusFound,
checkLocation: true,
},
{
name: "With return URL",
requestURL: "/auth/login?return=/dashboard",
expectedStatus: http.StatusFound,
checkLocation: true,
},
}
// Test the test case structure
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
// Verify test case parameters
if test.requestURL == "" {
t.Error("Request URL should not be empty")
}
if test.expectedStatus == 0 {
t.Error("Expected status should be set")
}
// In a real implementation, this would test the actual handler
t.Logf("Testing %s with URL %s expecting status %d", test.name, test.requestURL, test.expectedStatus)
})
}
// Verify logger doesn't cause issues
logger.Debugf("Authorization request test completed")
})
t.Run("HandleCallbackRequest", func(t *testing.T) {
// Test callback request handling with existing mocks
sessionManager := NewMockSessionManager()
logger := &MockLogger{}
tests := []struct {
name string
queryParams string
expectedStatus int
expectError bool
}{
{
name: "Valid callback with code",
queryParams: "code=test-code&state=test-state",
expectedStatus: http.StatusFound,
expectError: false,
},
{
name: "Callback with error",
queryParams: "error=access_denied&error_description=User denied access",
expectedStatus: http.StatusBadRequest,
expectError: true,
},
{
name: "Missing code",
queryParams: "state=test-state",
expectedStatus: http.StatusBadRequest,
expectError: true,
},
{
name: "Missing state",
queryParams: "code=test-code",
expectedStatus: http.StatusBadRequest,
expectError: true,
},
}
// Test the callback scenarios
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
// Verify test case parameters
if test.queryParams == "" && !test.expectError {
t.Error("Query params should not be empty for successful cases")
}
if test.expectedStatus == 0 {
t.Error("Expected status should be set")
}
// Test session manager functionality
if sessionManager != nil {
t.Logf("Session manager available for test %s", test.name)
}
t.Logf("Testing %s with params %s expecting status %d", test.name, test.queryParams, test.expectedStatus)
})
}
// Verify logger doesn't cause issues
logger.Debugf("Callback request test completed")
})
t.Run("HandleLogout", func(t *testing.T) {
// Test logout functionality with mock implementations
sessionManager := NewMockSessionManager()
logger := &MockLogger{}
// Test session clearing
mockReq := &http.Request{}
session, err := sessionManager.GetSession(mockReq)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
// Set up authenticated session
err = session.SetAuthenticated(true)
if err != nil {
t.Fatalf("Failed to set authentication: %v", err)
}
session.SetIDToken("test-token")
// Verify session is authenticated
if !session.GetAuthenticated() {
t.Error("Session should be authenticated before logout")
}
// Test logout by clearing session
// session.Clear() // Method not implemented in SessionData
// Additional logout verification would go here
// Verify logger doesn't cause issues
logger.Debugf("Logout test completed")
t.Log("Logout test completed successfully")
})
}
// ============================================================================
// Auth Handler Tests
// ============================================================================
func TestAuthHandler(t *testing.T) {
t.Run("HandleAuthentication", func(t *testing.T) {
// Test authentication handling with mock types
// validator := &MockTokenValidator{valid: true} // Currently unused
/*
handler := &MockAuthHandler{
logger: &MockLogger{},
sessionManager: NewMockSessionManager(),
}
*/
tests := []struct {
name string
setupSession func(*MockSession)
expectedStatus int
expectNext bool
}{
{
name: "Authenticated user",
setupSession: func(s *MockSession) {
s.SetAuthenticated(true)
s.SetIDToken("valid-token")
},
expectedStatus: http.StatusOK,
expectNext: true,
},
{
name: "Unauthenticated user",
setupSession: func(s *MockSession) {
s.SetAuthenticated(false)
},
expectedStatus: http.StatusUnauthorized,
expectNext: false,
},
{
name: "Expired token",
setupSession: func(s *MockSession) {
s.SetAuthenticated(true)
s.SetIDToken("expired-token")
},
expectedStatus: http.StatusUnauthorized,
expectNext: false,
},
}
// Test the authentication test cases
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
// Test with mock session
mockSession := &MockSession{values: make(map[string]interface{})}
// Use mock session to avoid unused variable error
_ = mockSession
t.Logf("Testing %s", test.name)
})
}
})
t.Run("HandleRefreshToken", func(t *testing.T) {
// Test authentication handling with mock types
// validator := &MockTokenValidator{valid: true} // Currently unused
tests := []struct {
name string
refreshToken string
mockResponse *MockTokenResponse
mockError error
expectSuccess bool
}{
{
name: "Successful refresh",
refreshToken: "valid-refresh-token",
mockResponse: &MockTokenResponse{
AccessToken: "new-access-token",
IDToken: "new-id-token",
RefreshToken: "new-refresh-token",
},
expectSuccess: true,
},
{
name: "Failed refresh",
refreshToken: "invalid-refresh-token",
mockError: errors.New("invalid_grant"),
expectSuccess: false,
},
{
name: "Empty refresh token",
refreshToken: "",
expectSuccess: false,
},
}
// Test the authentication test cases
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
// Test with mock session
mockSession := &MockSession{values: make(map[string]interface{})}
// Use mock session to avoid unused variable error
_ = mockSession
t.Logf("Testing %s", test.name)
})
}
})
}
// ============================================================================
// Error Handler Tests
// ============================================================================
func TestErrorHandler(t *testing.T) {
t.Run("HandleHTTPErrors", func(t *testing.T) {
// Test with mock implementations
/*
handler := &MockErrorHandler{
logger: &MockLogger{},
}
*/
tests := []struct {
name string
errorCode int
errorMessage string
isAjax bool
expectedStatus int
expectedBody string
}{
{
name: "401 Unauthorized",
errorCode: http.StatusUnauthorized,
errorMessage: "Authentication required",
isAjax: false,
expectedStatus: http.StatusUnauthorized,
expectedBody: "Authentication required",
},
{
name: "403 Forbidden",
errorCode: http.StatusForbidden,
errorMessage: "Access denied",
isAjax: false,
expectedStatus: http.StatusForbidden,
expectedBody: "Access denied",
},
{
name: "500 Internal Server Error",
errorCode: http.StatusInternalServerError,
errorMessage: "Internal server error",
isAjax: false,
expectedStatus: http.StatusInternalServerError,
expectedBody: "Internal server error",
},
{
name: "Ajax 401",
errorCode: http.StatusUnauthorized,
errorMessage: "Token expired",
isAjax: true,
expectedStatus: http.StatusUnauthorized,
expectedBody: `{"error":"unauthorized","message":"Token expired"}`,
},
}
// Test the authentication test cases
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
// Test with mock session
mockSession := &MockSession{values: make(map[string]interface{})}
// Use mock session to avoid unused variable error
_ = mockSession
t.Logf("Testing %s", test.name)
})
}
})
t.Run("RecoverFromPanic", func(t *testing.T) {
// Test with mock implementations
/*
handler := &MockErrorHandler{
logger: &MockLogger{},
}
*/
tests := []struct {
name string
panicValue interface{}
expectError bool
}{
{
name: "String panic",
panicValue: "something went wrong",
expectError: true,
},
{
name: "Error panic",
panicValue: errors.New("critical error"),
expectError: true,
},
{
name: "Nil panic",
panicValue: nil,
expectError: false,
},
}
// Test the authentication test cases
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
// Test with mock session
mockSession := &MockSession{values: make(map[string]interface{})}
// Use mock session to avoid unused variable error
_ = mockSession
t.Logf("Testing %s", test.name)
})
}
})
}
// ============================================================================
// Azure OAuth Callback Tests
// ============================================================================
func TestAzureOAuthCallback(t *testing.T) {
t.Run("AzureSpecificClaims", func(t *testing.T) {
// Test with mock configuration
/*
handler := &OAuthHandler{
logger: &MockLogger{},
sessionManager: NewMockSessionManager(),
}
*/
azureClaims := map[string]interface{}{
"oid": "object-id",
"tid": "tenant-id",
"preferred_username": "user@example.com",
"name": "Test User",
"email": "user@example.com",
"groups": []string{"group1", "group2"},
}
// Test would go here when properly implemented
_ = azureClaims
})
t.Run("AzureTokenValidation", func(t *testing.T) {
// Test with mock validator types
/*
validator := &MockAzureTokenValidator{
tenantID: "test-tenant",
clientID: "test-client",
}
*/
tests := []struct {
name string
token string
claims map[string]interface{}
expectValid bool
}{
{
name: "Valid Azure token",
token: "valid-azure-token",
claims: map[string]interface{}{
"aud": "test-client",
"tid": "test-tenant",
"exp": float64(time.Now().Add(time.Hour).Unix()),
},
expectValid: true,
},
{
name: "Wrong tenant",
token: "wrong-tenant-token",
claims: map[string]interface{}{
"aud": "test-client",
"tid": "wrong-tenant",
"exp": float64(time.Now().Add(time.Hour).Unix()),
},
expectValid: false,
},
{
name: "Wrong audience",
token: "wrong-audience-token",
claims: map[string]interface{}{
"aud": "wrong-client",
"tid": "test-tenant",
"exp": float64(time.Now().Add(time.Hour).Unix()),
},
expectValid: false,
},
}
// Test the authentication test cases
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
// Test with mock session
mockSession := &MockSession{values: make(map[string]interface{})}
// Use mock session to avoid unused variable error
_ = mockSession
t.Logf("Testing %s", test.name)
})
}
})
}
// ============================================================================
// Concurrent Handler Tests
// ============================================================================
func TestConcurrentHandlers(t *testing.T) {
t.Run("ConcurrentCallbacks", func(t *testing.T) {
// Test with mock configuration
/*
handler := &OAuthHandler{
logger: &MockLogger{},
sessionManager: NewMockSessionManager(),
}
*/
var wg sync.WaitGroup
successCount := int32(0)
errorCount := int32(0)
// Test would go here when properly implemented
wg.Wait() // Proper usage instead of assignment
_ = successCount
_ = errorCount
})
t.Run("ConcurrentLogouts", func(t *testing.T) {
// Test with mock configuration
/*
handler := &OAuthHandler{
logger: &MockLogger{},
sessionManager: NewMockSessionManager(),
}
*/
var wg sync.WaitGroup
logoutCount := int32(0)
// Test would go here when properly implemented
wg.Wait() // Proper usage instead of assignment
_ = logoutCount
})
}
// ============================================================================
// Mock Implementations
// ============================================================================
type MockSessionManager struct {
sessions map[string]*MockSession
mu sync.RWMutex
}
func NewMockSessionManager() *MockSessionManager {
return &MockSessionManager{
sessions: make(map[string]*MockSession),
}
}
func (m *MockSessionManager) GetSession(r *http.Request) (SessionData, error) {
m.mu.Lock()
defer m.mu.Unlock()
sessionID := "test-session"
if session, exists := m.sessions[sessionID]; exists {
return session, nil
}
session := &MockSession{
values: make(map[string]interface{}),
}
m.sessions[sessionID] = session
return session, nil
}
type MockSession struct {
values map[string]interface{}
mu sync.RWMutex
}
func (s *MockSession) SetAuthenticated(auth bool) error {
s.mu.Lock()
defer s.mu.Unlock()
s.values["authenticated"] = auth
return nil
}
func (s *MockSession) GetAuthenticated() bool {
s.mu.RLock()
defer s.mu.RUnlock()
auth, ok := s.values["authenticated"].(bool)
return ok && auth
}
func (s *MockSession) SetIDToken(token string) {
s.mu.Lock()
defer s.mu.Unlock()
s.values["id_token"] = token
}
func (s *MockSession) GetIDToken() string {
s.mu.RLock()
defer s.mu.RUnlock()
token, _ := s.values["id_token"].(string)
return token
}
func (s *MockSession) SetAccessToken(token string) {
s.mu.Lock()
defer s.mu.Unlock()
s.values["access_token"] = token
}
func (s *MockSession) GetAccessToken() string {
s.mu.RLock()
defer s.mu.RUnlock()
token, _ := s.values["access_token"].(string)
return token
}
func (s *MockSession) SetRefreshToken(token string) {
s.mu.Lock()
defer s.mu.Unlock()
s.values["refresh_token"] = token
}
func (s *MockSession) GetRefreshToken() string {
s.mu.RLock()
defer s.mu.RUnlock()
token, _ := s.values["refresh_token"].(string)
return token
}
func (s *MockSession) SetState(state string) {
s.mu.Lock()
defer s.mu.Unlock()
s.values["state"] = state
}
func (s *MockSession) GetState() string {
s.mu.RLock()
defer s.mu.RUnlock()
state, _ := s.values["state"].(string)
return state
}
func (s *MockSession) SetClaims(claims map[string]interface{}) {
s.mu.Lock()
defer s.mu.Unlock()
s.values["claims"] = claims
}
func (s *MockSession) GetClaims() map[string]interface{} {
s.mu.RLock()
defer s.mu.RUnlock()
claims, _ := s.values["claims"].(map[string]interface{})
return claims
}
// Additional SessionData interface methods to match real interface
func (s *MockSession) GetCSRF() string {
s.mu.RLock()
defer s.mu.RUnlock()
csrf, _ := s.values["csrf"].(string)
return csrf
}
func (s *MockSession) GetNonce() string {
s.mu.RLock()
defer s.mu.RUnlock()
nonce, _ := s.values["nonce"].(string)
return nonce
}
func (s *MockSession) GetCodeVerifier() string {
s.mu.RLock()
defer s.mu.RUnlock()
verifier, _ := s.values["code_verifier"].(string)
return verifier
}
func (s *MockSession) GetIncomingPath() string {
s.mu.RLock()
defer s.mu.RUnlock()
path, _ := s.values["incoming_path"].(string)
return path
}
func (s *MockSession) SetEmail(email string) {
s.mu.Lock()
defer s.mu.Unlock()
s.values["email"] = email
}
func (s *MockSession) GetEmail() string {
s.mu.RLock()
defer s.mu.RUnlock()
email, _ := s.values["email"].(string)
return email
}
func (s *MockSession) SetCSRF(csrf string) {
s.mu.Lock()
defer s.mu.Unlock()
s.values["csrf"] = csrf
}
func (s *MockSession) SetNonce(nonce string) {
s.mu.Lock()
defer s.mu.Unlock()
s.values["nonce"] = nonce
}
func (s *MockSession) SetCodeVerifier(verifier string) {
s.mu.Lock()
defer s.mu.Unlock()
s.values["code_verifier"] = verifier
}
func (s *MockSession) SetIncomingPath(path string) {
s.mu.Lock()
defer s.mu.Unlock()
s.values["incoming_path"] = path
}
func (s *MockSession) ResetRedirectCount() {
s.mu.Lock()
defer s.mu.Unlock()
s.values["redirect_count"] = 0
}
func (s *MockSession) Save(r *http.Request, w http.ResponseWriter) error {
return nil
}
func (s *MockSession) Clear() {
s.mu.Lock()
defer s.mu.Unlock()
s.values = make(map[string]interface{})
}
func (s *MockSession) returnToPoolSafely() {
// No-op for mock
}
type MockTokenValidator struct {
valid bool
}
func (v *MockTokenValidator) Validate(token string) bool {
if token == "expired-token" {
return false
}
return v.valid
}
// ============================================================================
// Mock Handler Type Definitions (for testing)
// ============================================================================
// These mock handlers are simplified versions for testing purposes
// They don't match the actual handler implementations
type MockAuthHandler struct{}
type MockErrorHandler struct{}
type MockAzureTokenValidator struct {
tenantID string
clientID string
}
func (v *MockAzureTokenValidator) ValidateAzureToken(token string, claims map[string]interface{}) bool {
// Validate tenant ID
if tid, ok := claims["tid"].(string); !ok || tid != v.tenantID {
return false
}
// Validate audience
if aud, ok := claims["aud"].(string); !ok || aud != v.clientID {
return false
}
// Validate expiration
if exp, ok := claims["exp"].(float64); ok {
if time.Now().Unix() > int64(exp) {
return false
}
}
return true
}
// ============================================================================
// Helper Types and Mock Logger
// ============================================================================
type MockLogger struct{}
func (l *MockLogger) Debugf(format string, args ...interface{}) {}
func (l *MockLogger) Errorf(format string, args ...interface{}) {}
func (l *MockLogger) Error(msg string) {}
type MockTokenResponse struct {
AccessToken string `json:"access_token"`
IDToken string `json:"id_token"`
RefreshToken string `json:"refresh_token"`
}
-308
View File
@@ -1,308 +0,0 @@
// Package handlers provides HTTP request handlers for the OIDC middleware.
package handlers
import (
"context"
"fmt"
"net/http"
"strings"
)
// OAuthHandler handles OAuth callback requests
type OAuthHandler struct {
logger Logger
sessionManager SessionManager
tokenExchanger TokenExchanger
tokenVerifier TokenVerifier
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
isAllowedDomainFunc func(email string) bool
redirURLPath string
sendErrorResponseFunc func(rw http.ResponseWriter, req *http.Request, message string, code int)
}
// Logger interface for dependency injection
type Logger interface {
Debugf(format string, args ...interface{})
Errorf(format string, args ...interface{})
Error(msg string)
}
// SessionManager interface for session operations
type SessionManager interface {
GetSession(req *http.Request) (SessionData, error)
}
// SessionData interface for session data operations
type SessionData interface {
GetCSRF() string
GetNonce() string
GetCodeVerifier() string
GetIncomingPath() string
GetAuthenticated() bool
GetAccessToken() string
GetRefreshToken() string
GetIDToken() string
GetEmail() string
SetAuthenticated(bool) error
SetEmail(string)
SetIDToken(string)
SetAccessToken(string)
SetRefreshToken(string)
SetCSRF(string)
SetNonce(string)
SetCodeVerifier(string)
SetIncomingPath(string)
ResetRedirectCount()
Save(req *http.Request, rw http.ResponseWriter) error
returnToPoolSafely()
}
// TokenExchanger interface for token operations
type TokenExchanger interface {
ExchangeCodeForToken(ctx context.Context, grantType string, codeOrToken string, redirectURL string, codeVerifier string) (*TokenResponse, error)
}
// TokenVerifier interface for token verification
type TokenVerifier interface {
VerifyToken(token string) error
}
// TokenResponse represents the response from token exchange
type TokenResponse struct {
IDToken string
AccessToken string
RefreshToken string
}
// NewOAuthHandler creates a new OAuth handler
func NewOAuthHandler(logger Logger, sessionManager SessionManager, tokenExchanger TokenExchanger,
tokenVerifier TokenVerifier, extractClaimsFunc func(string) (map[string]interface{}, error),
isAllowedDomainFunc func(string) bool, redirURLPath string,
sendErrorResponseFunc func(http.ResponseWriter, *http.Request, string, int)) *OAuthHandler {
return &OAuthHandler{
logger: logger,
sessionManager: sessionManager,
tokenExchanger: tokenExchanger,
tokenVerifier: tokenVerifier,
extractClaimsFunc: extractClaimsFunc,
isAllowedDomainFunc: isAllowedDomainFunc,
redirURLPath: redirURLPath,
sendErrorResponseFunc: sendErrorResponseFunc,
}
}
// HandleCallback handles OAuth callback requests
func (h *OAuthHandler) HandleCallback(rw http.ResponseWriter, req *http.Request, redirectURL string) {
session, err := h.sessionManager.GetSession(req)
if err != nil {
h.logger.Errorf("Session error during callback: %v", err)
h.sendErrorResponseFunc(rw, req, "Session error during callback", http.StatusInternalServerError)
return
}
defer session.returnToPoolSafely()
h.logger.Debugf("Handling callback, URL: %s", req.URL.String())
// Debug logging for cookie configuration
h.logger.Debugf("Callback request headers - Host: %s, X-Forwarded-Host: %s, X-Forwarded-Proto: %s",
req.Host, req.Header.Get("X-Forwarded-Host"), req.Header.Get("X-Forwarded-Proto"))
// Log all cookies in the request for debugging
cookies := req.Cookies()
h.logger.Debugf("Total cookies in callback request: %d", len(cookies))
for _, cookie := range cookies {
if strings.HasPrefix(cookie.Name, "_oidc_") {
h.logger.Debugf("Cookie found - Name: %s, Domain: %s, Path: %s, SameSite: %v, Secure: %v, HttpOnly: %v, Value length: %d",
cookie.Name, cookie.Domain, cookie.Path, cookie.SameSite, cookie.Secure, cookie.HttpOnly, len(cookie.Value))
}
}
if req.URL.Query().Get("error") != "" {
errorDescription := req.URL.Query().Get("error_description")
if errorDescription == "" {
errorDescription = req.URL.Query().Get("error")
}
h.logger.Errorf("Authentication error from provider during callback: %s - %s", req.URL.Query().Get("error"), errorDescription)
h.sendErrorResponseFunc(rw, req, fmt.Sprintf("Authentication error from provider: %s", errorDescription), http.StatusBadRequest)
return
}
state := req.URL.Query().Get("state")
if state == "" {
h.logger.Error("No state in callback")
h.sendErrorResponseFunc(rw, req, "State parameter missing in callback", http.StatusBadRequest)
return
}
// Debug log the state parameter received
h.logger.Debugf("State parameter received in callback: %s (length: %d)", state, len(state))
csrfToken := session.GetCSRF()
if csrfToken == "" {
h.logger.Errorf("CSRF token missing in session during callback. Authenticated: %v, Request URL: %s",
session.GetAuthenticated(), req.URL.String())
// Enhanced debugging for missing CSRF token
cookie, err := req.Cookie("_oidc_raczylo_m")
if err != nil {
h.logger.Errorf("Main session cookie not found in request: %v", err)
h.logger.Debugf("Available cookies: %v", req.Header.Get("Cookie"))
} else {
h.logger.Errorf("Main session cookie exists but CSRF token is empty. Cookie value length: %d", len(cookie.Value))
h.logger.Debugf("Cookie details - Domain: %s, Path: %s, Secure: %v, HttpOnly: %v, SameSite: %v",
cookie.Domain, cookie.Path, cookie.Secure, cookie.HttpOnly, cookie.SameSite)
}
// Log session state for debugging
h.logger.Debugf("Session state during CSRF check - Authenticated: %v, Has AccessToken: %v",
session.GetAuthenticated(), session.GetAccessToken() != "")
h.sendErrorResponseFunc(rw, req, "CSRF token missing in session", http.StatusBadRequest)
return
}
// Debug log successful CSRF token retrieval
h.logger.Debugf("CSRF token retrieved from session: %s (length: %d)", csrfToken, len(csrfToken))
if state != csrfToken {
h.logger.Error("State parameter does not match CSRF token in session during callback")
h.sendErrorResponseFunc(rw, req, "Invalid state parameter (CSRF mismatch)", http.StatusBadRequest)
return
}
code := req.URL.Query().Get("code")
if code == "" {
h.logger.Error("No code in callback")
h.sendErrorResponseFunc(rw, req, "No authorization code received in callback", http.StatusBadRequest)
return
}
codeVerifier := session.GetCodeVerifier()
tokenResponse, err := h.tokenExchanger.ExchangeCodeForToken(req.Context(), "authorization_code", code, redirectURL, codeVerifier)
if err != nil {
h.logger.Errorf("Failed to exchange code for token during callback: %v", err)
h.sendErrorResponseFunc(rw, req, "Authentication failed: Could not exchange code for token", http.StatusInternalServerError)
return
}
if err = h.tokenVerifier.VerifyToken(tokenResponse.IDToken); err != nil {
h.logger.Errorf("Failed to verify id_token during callback: %v", err)
h.sendErrorResponseFunc(rw, req, "Authentication failed: Could not verify ID token", http.StatusInternalServerError)
return
}
claims, err := h.extractClaimsFunc(tokenResponse.IDToken)
if err != nil {
h.logger.Errorf("Failed to extract claims during callback: %v", err)
h.sendErrorResponseFunc(rw, req, "Authentication failed: Could not extract claims from token", http.StatusInternalServerError)
return
}
nonceClaim, ok := claims["nonce"].(string)
if !ok || nonceClaim == "" {
h.logger.Error("Nonce claim missing in id_token during callback")
h.sendErrorResponseFunc(rw, req, "Authentication failed: Nonce missing in token", http.StatusInternalServerError)
return
}
sessionNonce := session.GetNonce()
if sessionNonce == "" {
h.logger.Error("Nonce not found in session during callback")
h.sendErrorResponseFunc(rw, req, "Authentication failed: Nonce missing in session", http.StatusInternalServerError)
return
}
if nonceClaim != sessionNonce {
h.logger.Error("Nonce claim does not match session nonce during callback")
h.sendErrorResponseFunc(rw, req, "Authentication failed: Nonce mismatch", http.StatusInternalServerError)
return
}
email, _ := claims["email"].(string)
if email == "" {
h.logger.Errorf("Email claim missing or empty in token during callback")
h.sendErrorResponseFunc(rw, req, "Authentication failed: Email missing in token", http.StatusInternalServerError)
return
}
if !h.isAllowedDomainFunc(email) {
h.logger.Errorf("Disallowed email domain during callback: %s", email)
h.sendErrorResponseFunc(rw, req, "Authentication failed: Email domain not allowed", http.StatusForbidden)
return
}
if err := session.SetAuthenticated(true); err != nil {
h.logger.Errorf("Failed to set authenticated state and regenerate session ID: %v", err)
h.sendErrorResponseFunc(rw, req, "Failed to update session", http.StatusInternalServerError)
return
}
session.SetEmail(email)
session.SetIDToken(tokenResponse.IDToken)
session.SetAccessToken(tokenResponse.AccessToken)
session.SetRefreshToken(tokenResponse.RefreshToken)
session.SetCSRF("")
session.SetNonce("")
session.SetCodeVerifier("")
session.ResetRedirectCount()
redirectPath := "/"
if incomingPath := session.GetIncomingPath(); incomingPath != "" && incomingPath != h.redirURLPath {
redirectPath = incomingPath
}
session.SetIncomingPath("")
if err := session.Save(req, rw); err != nil {
h.logger.Errorf("Failed to save session after callback: %v", err)
h.sendErrorResponseFunc(rw, req, "Failed to save session after callback", http.StatusInternalServerError)
return
}
h.logger.Debugf("Callback successful, redirecting to %s", redirectPath)
http.Redirect(rw, req, redirectPath, http.StatusFound)
}
// URLHelper provides utility methods for URL operations
type URLHelper struct {
logger Logger
}
// NewURLHelper creates a new URL helper
func NewURLHelper(logger Logger) *URLHelper {
return &URLHelper{logger: logger}
}
// DetermineExcludedURL checks if a URL path should bypass OIDC authentication.
// It compares the request path against configured excluded URL prefixes.
func (h *URLHelper) DetermineExcludedURL(currentRequest string, excludedURLs map[string]struct{}) bool {
for excludedURL := range excludedURLs {
if strings.HasPrefix(currentRequest, excludedURL) {
h.logger.Debugf("URL is excluded - got %s / excluded hit: %s", currentRequest, excludedURL)
return true
}
}
return false
}
// DetermineScheme determines the URL scheme for building redirect URLs.
// It checks X-Forwarded-Proto header first, then TLS presence.
func (h *URLHelper) DetermineScheme(req *http.Request) string {
if scheme := req.Header.Get("X-Forwarded-Proto"); scheme != "" {
return scheme
}
if req.TLS != nil {
return "https"
}
return "http"
}
// DetermineHost determines the host for building redirect URLs.
// It checks X-Forwarded-Host header first, then falls back to req.Host.
func (h *URLHelper) DetermineHost(req *http.Request) string {
if host := req.Header.Get("X-Forwarded-Host"); host != "" {
return host
}
return req.Host
}
-899
View File
@@ -1,899 +0,0 @@
package handlers
import (
"context"
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
// Test mocks - implementing interfaces defined in oauth_handler.go
type mockLogger struct {
debugMessages []string
errorMessages []string
}
func (l *mockLogger) Debugf(format string, args ...interface{}) {
l.debugMessages = append(l.debugMessages, format)
}
func (l *mockLogger) Errorf(format string, args ...interface{}) {
l.errorMessages = append(l.errorMessages, format)
}
func (l *mockLogger) Error(msg string) {
l.errorMessages = append(l.errorMessages, msg)
}
type mockSessionManager struct {
sessionToReturn SessionData
errorToReturn error
}
func (m *mockSessionManager) GetSession(req *http.Request) (SessionData, error) {
return m.sessionToReturn, m.errorToReturn
}
type mockSessionData struct {
authenticated bool
email string
csrf string
nonce string
codeVerifier string
incomingPath string
accessToken string
refreshToken string
idToken string
saveError error
setAuthError error
}
func (s *mockSessionData) GetCSRF() string { return s.csrf }
func (s *mockSessionData) GetNonce() string { return s.nonce }
func (s *mockSessionData) GetCodeVerifier() string { return s.codeVerifier }
func (s *mockSessionData) GetIncomingPath() string { return s.incomingPath }
func (s *mockSessionData) GetAuthenticated() bool { return s.authenticated }
func (s *mockSessionData) GetAccessToken() string { return s.accessToken }
func (s *mockSessionData) GetRefreshToken() string { return s.refreshToken }
func (s *mockSessionData) GetIDToken() string { return s.idToken }
func (s *mockSessionData) GetEmail() string { return s.email }
func (s *mockSessionData) SetAuthenticated(auth bool) error {
s.authenticated = auth
return s.setAuthError
}
func (s *mockSessionData) SetEmail(email string) { s.email = email }
func (s *mockSessionData) SetIDToken(token string) { s.idToken = token }
func (s *mockSessionData) SetAccessToken(token string) { s.accessToken = token }
func (s *mockSessionData) SetRefreshToken(token string) { s.refreshToken = token }
func (s *mockSessionData) SetCSRF(csrf string) { s.csrf = csrf }
func (s *mockSessionData) SetNonce(nonce string) { s.nonce = nonce }
func (s *mockSessionData) SetCodeVerifier(verif string) { s.codeVerifier = verif }
func (s *mockSessionData) SetIncomingPath(path string) { s.incomingPath = path }
func (s *mockSessionData) ResetRedirectCount() {}
func (s *mockSessionData) returnToPoolSafely() {}
func (s *mockSessionData) Save(req *http.Request, rw http.ResponseWriter) error {
return s.saveError
}
type mockTokenExchanger struct {
response *TokenResponse
err error
}
func (e *mockTokenExchanger) ExchangeCodeForToken(ctx context.Context, grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error) {
return e.response, e.err
}
type mockTokenVerifier struct {
err error
}
func (v *mockTokenVerifier) VerifyToken(token string) error {
return v.err
}
// TestOAuthHandler_NewOAuthHandler tests the constructor
func TestOAuthHandler_NewOAuthHandler(t *testing.T) {
logger := &mockLogger{}
sessionManager := &mockSessionManager{}
tokenExchanger := &mockTokenExchanger{}
tokenVerifier := &mockTokenVerifier{}
extractClaims := func(token string) (map[string]interface{}, error) {
return map[string]interface{}{"email": "test@example.com", "nonce": "test-nonce"}, nil
}
isAllowed := func(email string) bool { return true }
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {}
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
extractClaims, isAllowed, "/callback", sendError)
if handler == nil {
t.Fatal("Expected handler to be created, got nil")
}
if handler.logger != logger {
t.Error("Logger not set correctly")
}
if handler.redirURLPath != "/callback" {
t.Errorf("Expected redirURLPath '/callback', got '%s'", handler.redirURLPath)
}
}
// TestOAuthHandler_HandleCallback_SessionError tests session retrieval errors
func TestOAuthHandler_HandleCallback_SessionError(t *testing.T) {
logger := &mockLogger{}
sessionManager := &mockSessionManager{errorToReturn: errors.New("session error")}
tokenExchanger := &mockTokenExchanger{}
tokenVerifier := &mockTokenVerifier{}
extractClaims := func(token string) (map[string]interface{}, error) {
return nil, nil
}
isAllowed := func(email string) bool { return true }
errorSent := false
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
errorSent = true
if code != http.StatusInternalServerError {
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
}
if !strings.Contains(msg, "Session error") {
t.Errorf("Expected error message to contain 'Session error', got '%s'", msg)
}
}
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
extractClaims, isAllowed, "/callback", sendError)
req := httptest.NewRequest("GET", "/callback?code=test&state=test", nil)
rw := httptest.NewRecorder()
handler.HandleCallback(rw, req, "http://example.com/callback")
if !errorSent {
t.Error("Expected error response to be sent")
}
if len(logger.errorMessages) == 0 {
t.Error("Expected error to be logged")
}
}
// TestOAuthHandler_HandleCallback_ProviderError tests OAuth provider errors
func TestOAuthHandler_HandleCallback_ProviderError(t *testing.T) {
logger := &mockLogger{}
session := &mockSessionData{csrf: "test-state", nonce: "test-nonce"}
sessionManager := &mockSessionManager{sessionToReturn: session}
tokenExchanger := &mockTokenExchanger{}
tokenVerifier := &mockTokenVerifier{}
extractClaims := func(token string) (map[string]interface{}, error) { return nil, nil }
isAllowed := func(email string) bool { return true }
errorSent := false
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
errorSent = true
if code != http.StatusBadRequest {
t.Errorf("Expected status %d, got %d", http.StatusBadRequest, code)
}
if !strings.Contains(msg, "Authentication error from provider") {
t.Errorf("Expected error message to contain 'Authentication error from provider', got '%s'", msg)
}
}
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
extractClaims, isAllowed, "/callback", sendError)
// Test with error parameter
req := httptest.NewRequest("GET", "/callback?error=access_denied&error_description=User%20denied%20access", nil)
rw := httptest.NewRecorder()
handler.HandleCallback(rw, req, "http://example.com/callback")
if !errorSent {
t.Error("Expected error response to be sent")
}
if len(logger.errorMessages) == 0 {
t.Error("Expected error to be logged")
}
}
// TestOAuthHandler_HandleCallback_MissingState tests missing state parameter
func TestOAuthHandler_HandleCallback_MissingState(t *testing.T) {
logger := &mockLogger{}
session := &mockSessionData{}
sessionManager := &mockSessionManager{sessionToReturn: session}
tokenExchanger := &mockTokenExchanger{}
tokenVerifier := &mockTokenVerifier{}
extractClaims := func(token string) (map[string]interface{}, error) { return nil, nil }
isAllowed := func(email string) bool { return true }
errorSent := false
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
errorSent = true
if code != http.StatusBadRequest {
t.Errorf("Expected status %d, got %d", http.StatusBadRequest, code)
}
if !strings.Contains(msg, "State parameter missing") {
t.Errorf("Expected error message to contain 'State parameter missing', got '%s'", msg)
}
}
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
extractClaims, isAllowed, "/callback", sendError)
req := httptest.NewRequest("GET", "/callback?code=test", nil)
rw := httptest.NewRecorder()
handler.HandleCallback(rw, req, "http://example.com/callback")
if !errorSent {
t.Error("Expected error response to be sent")
}
}
// TestOAuthHandler_HandleCallback_MissingCSRF tests missing CSRF token in session
func TestOAuthHandler_HandleCallback_MissingCSRF(t *testing.T) {
logger := &mockLogger{}
session := &mockSessionData{csrf: ""} // Empty CSRF
sessionManager := &mockSessionManager{sessionToReturn: session}
tokenExchanger := &mockTokenExchanger{}
tokenVerifier := &mockTokenVerifier{}
extractClaims := func(token string) (map[string]interface{}, error) { return nil, nil }
isAllowed := func(email string) bool { return true }
errorSent := false
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
errorSent = true
if code != http.StatusBadRequest {
t.Errorf("Expected status %d, got %d", http.StatusBadRequest, code)
}
if !strings.Contains(msg, "CSRF token missing") {
t.Errorf("Expected error message to contain 'CSRF token missing', got '%s'", msg)
}
}
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
extractClaims, isAllowed, "/callback", sendError)
req := httptest.NewRequest("GET", "/callback?code=test&state=test-state", nil)
rw := httptest.NewRecorder()
handler.HandleCallback(rw, req, "http://example.com/callback")
if !errorSent {
t.Error("Expected error response to be sent")
}
}
// TestOAuthHandler_HandleCallback_CSRFMismatch tests CSRF token mismatch
func TestOAuthHandler_HandleCallback_CSRFMismatch(t *testing.T) {
logger := &mockLogger{}
session := &mockSessionData{csrf: "different-token"}
sessionManager := &mockSessionManager{sessionToReturn: session}
tokenExchanger := &mockTokenExchanger{}
tokenVerifier := &mockTokenVerifier{}
extractClaims := func(token string) (map[string]interface{}, error) { return nil, nil }
isAllowed := func(email string) bool { return true }
errorSent := false
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
errorSent = true
if code != http.StatusBadRequest {
t.Errorf("Expected status %d, got %d", http.StatusBadRequest, code)
}
if !strings.Contains(msg, "CSRF mismatch") {
t.Errorf("Expected error message to contain 'CSRF mismatch', got '%s'", msg)
}
}
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
extractClaims, isAllowed, "/callback", sendError)
req := httptest.NewRequest("GET", "/callback?code=test&state=test-state", nil)
rw := httptest.NewRecorder()
handler.HandleCallback(rw, req, "http://example.com/callback")
if !errorSent {
t.Error("Expected error response to be sent")
}
}
// TestOAuthHandler_HandleCallback_MissingCode tests missing authorization code
func TestOAuthHandler_HandleCallback_MissingCode(t *testing.T) {
logger := &mockLogger{}
session := &mockSessionData{csrf: "test-state"}
sessionManager := &mockSessionManager{sessionToReturn: session}
tokenExchanger := &mockTokenExchanger{}
tokenVerifier := &mockTokenVerifier{}
extractClaims := func(token string) (map[string]interface{}, error) { return nil, nil }
isAllowed := func(email string) bool { return true }
errorSent := false
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
errorSent = true
if code != http.StatusBadRequest {
t.Errorf("Expected status %d, got %d", http.StatusBadRequest, code)
}
if !strings.Contains(msg, "No authorization code received") {
t.Errorf("Expected error message to contain 'No authorization code received', got '%s'", msg)
}
}
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
extractClaims, isAllowed, "/callback", sendError)
req := httptest.NewRequest("GET", "/callback?state=test-state", nil)
rw := httptest.NewRecorder()
handler.HandleCallback(rw, req, "http://example.com/callback")
if !errorSent {
t.Error("Expected error response to be sent")
}
}
// TestOAuthHandler_HandleCallback_TokenExchangeError tests token exchange failure
func TestOAuthHandler_HandleCallback_TokenExchangeError(t *testing.T) {
logger := &mockLogger{}
session := &mockSessionData{csrf: "test-state", nonce: "test-nonce", codeVerifier: "test-verifier"}
sessionManager := &mockSessionManager{sessionToReturn: session}
tokenExchanger := &mockTokenExchanger{err: errors.New("token exchange failed")}
tokenVerifier := &mockTokenVerifier{}
extractClaims := func(token string) (map[string]interface{}, error) { return nil, nil }
isAllowed := func(email string) bool { return true }
errorSent := false
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
errorSent = true
if code != http.StatusInternalServerError {
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
}
if !strings.Contains(msg, "Could not exchange code for token") {
t.Errorf("Expected error message to contain 'Could not exchange code for token', got '%s'", msg)
}
}
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
extractClaims, isAllowed, "/callback", sendError)
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
rw := httptest.NewRecorder()
handler.HandleCallback(rw, req, "http://example.com/callback")
if !errorSent {
t.Error("Expected error response to be sent")
}
}
// TestOAuthHandler_HandleCallback_TokenVerificationError tests token verification failure
func TestOAuthHandler_HandleCallback_TokenVerificationError(t *testing.T) {
logger := &mockLogger{}
session := &mockSessionData{csrf: "test-state", nonce: "test-nonce"}
sessionManager := &mockSessionManager{sessionToReturn: session}
tokenResponse := &TokenResponse{IDToken: "invalid-token", AccessToken: "access-token"}
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
tokenVerifier := &mockTokenVerifier{err: errors.New("token verification failed")}
extractClaims := func(token string) (map[string]interface{}, error) { return nil, nil }
isAllowed := func(email string) bool { return true }
errorSent := false
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
errorSent = true
if code != http.StatusInternalServerError {
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
}
if !strings.Contains(msg, "Could not verify ID token") {
t.Errorf("Expected error message to contain 'Could not verify ID token', got '%s'", msg)
}
}
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
extractClaims, isAllowed, "/callback", sendError)
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
rw := httptest.NewRecorder()
handler.HandleCallback(rw, req, "http://example.com/callback")
if !errorSent {
t.Error("Expected error response to be sent")
}
}
// TestOAuthHandler_HandleCallback_ClaimsExtractionError tests claims extraction failure
func TestOAuthHandler_HandleCallback_ClaimsExtractionError(t *testing.T) {
logger := &mockLogger{}
session := &mockSessionData{csrf: "test-state", nonce: "test-nonce"}
sessionManager := &mockSessionManager{sessionToReturn: session}
tokenResponse := &TokenResponse{IDToken: "valid-token", AccessToken: "access-token"}
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
tokenVerifier := &mockTokenVerifier{}
extractClaims := func(token string) (map[string]interface{}, error) {
return nil, errors.New("claims extraction failed")
}
isAllowed := func(email string) bool { return true }
errorSent := false
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
errorSent = true
if code != http.StatusInternalServerError {
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
}
if !strings.Contains(msg, "Could not extract claims") {
t.Errorf("Expected error message to contain 'Could not extract claims', got '%s'", msg)
}
}
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
extractClaims, isAllowed, "/callback", sendError)
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
rw := httptest.NewRecorder()
handler.HandleCallback(rw, req, "http://example.com/callback")
if !errorSent {
t.Error("Expected error response to be sent")
}
}
// TestOAuthHandler_HandleCallback_MissingNonceInToken tests missing nonce in token
func TestOAuthHandler_HandleCallback_MissingNonceInToken(t *testing.T) {
logger := &mockLogger{}
session := &mockSessionData{csrf: "test-state", nonce: "test-nonce"}
sessionManager := &mockSessionManager{sessionToReturn: session}
tokenResponse := &TokenResponse{IDToken: "valid-token", AccessToken: "access-token"}
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
tokenVerifier := &mockTokenVerifier{}
// Claims without nonce
extractClaims := func(token string) (map[string]interface{}, error) {
return map[string]interface{}{"email": "test@example.com"}, nil
}
isAllowed := func(email string) bool { return true }
errorSent := false
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
errorSent = true
if code != http.StatusInternalServerError {
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
}
if !strings.Contains(msg, "Nonce missing in token") {
t.Errorf("Expected error message to contain 'Nonce missing in token', got '%s'", msg)
}
}
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
extractClaims, isAllowed, "/callback", sendError)
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
rw := httptest.NewRecorder()
handler.HandleCallback(rw, req, "http://example.com/callback")
if !errorSent {
t.Error("Expected error response to be sent")
}
}
// TestOAuthHandler_HandleCallback_MissingNonceInSession tests missing nonce in session
func TestOAuthHandler_HandleCallback_MissingNonceInSession(t *testing.T) {
logger := &mockLogger{}
session := &mockSessionData{csrf: "test-state", nonce: ""} // Empty nonce
sessionManager := &mockSessionManager{sessionToReturn: session}
tokenResponse := &TokenResponse{IDToken: "valid-token", AccessToken: "access-token"}
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
tokenVerifier := &mockTokenVerifier{}
extractClaims := func(token string) (map[string]interface{}, error) {
return map[string]interface{}{"email": "test@example.com", "nonce": "test-nonce"}, nil
}
isAllowed := func(email string) bool { return true }
errorSent := false
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
errorSent = true
if code != http.StatusInternalServerError {
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
}
if !strings.Contains(msg, "Nonce missing in session") {
t.Errorf("Expected error message to contain 'Nonce missing in session', got '%s'", msg)
}
}
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
extractClaims, isAllowed, "/callback", sendError)
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
rw := httptest.NewRecorder()
handler.HandleCallback(rw, req, "http://example.com/callback")
if !errorSent {
t.Error("Expected error response to be sent")
}
}
// TestOAuthHandler_HandleCallback_NonceMismatch tests nonce mismatch
func TestOAuthHandler_HandleCallback_NonceMismatch(t *testing.T) {
logger := &mockLogger{}
session := &mockSessionData{csrf: "test-state", nonce: "session-nonce"}
sessionManager := &mockSessionManager{sessionToReturn: session}
tokenResponse := &TokenResponse{IDToken: "valid-token", AccessToken: "access-token"}
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
tokenVerifier := &mockTokenVerifier{}
extractClaims := func(token string) (map[string]interface{}, error) {
return map[string]interface{}{"email": "test@example.com", "nonce": "token-nonce"}, nil
}
isAllowed := func(email string) bool { return true }
errorSent := false
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
errorSent = true
if code != http.StatusInternalServerError {
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
}
if !strings.Contains(msg, "Nonce mismatch") {
t.Errorf("Expected error message to contain 'Nonce mismatch', got '%s'", msg)
}
}
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
extractClaims, isAllowed, "/callback", sendError)
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
rw := httptest.NewRecorder()
handler.HandleCallback(rw, req, "http://example.com/callback")
if !errorSent {
t.Error("Expected error response to be sent")
}
}
// TestOAuthHandler_HandleCallback_MissingEmail tests missing email in claims
func TestOAuthHandler_HandleCallback_MissingEmail(t *testing.T) {
logger := &mockLogger{}
session := &mockSessionData{csrf: "test-state", nonce: "test-nonce"}
sessionManager := &mockSessionManager{sessionToReturn: session}
tokenResponse := &TokenResponse{IDToken: "valid-token", AccessToken: "access-token"}
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
tokenVerifier := &mockTokenVerifier{}
extractClaims := func(token string) (map[string]interface{}, error) {
return map[string]interface{}{"nonce": "test-nonce"}, nil // No email
}
isAllowed := func(email string) bool { return true }
errorSent := false
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
errorSent = true
if code != http.StatusInternalServerError {
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
}
if !strings.Contains(msg, "Email missing in token") {
t.Errorf("Expected error message to contain 'Email missing in token', got '%s'", msg)
}
}
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
extractClaims, isAllowed, "/callback", sendError)
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
rw := httptest.NewRecorder()
handler.HandleCallback(rw, req, "http://example.com/callback")
if !errorSent {
t.Error("Expected error response to be sent")
}
}
// TestOAuthHandler_HandleCallback_DisallowedDomain tests disallowed email domain
func TestOAuthHandler_HandleCallback_DisallowedDomain(t *testing.T) {
logger := &mockLogger{}
session := &mockSessionData{csrf: "test-state", nonce: "test-nonce"}
sessionManager := &mockSessionManager{sessionToReturn: session}
tokenResponse := &TokenResponse{IDToken: "valid-token", AccessToken: "access-token"}
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
tokenVerifier := &mockTokenVerifier{}
extractClaims := func(token string) (map[string]interface{}, error) {
return map[string]interface{}{"email": "test@disallowed.com", "nonce": "test-nonce"}, nil
}
isAllowed := func(email string) bool { return false } // Disallow all domains
errorSent := false
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
errorSent = true
if code != http.StatusForbidden {
t.Errorf("Expected status %d, got %d", http.StatusForbidden, code)
}
if !strings.Contains(msg, "Email domain not allowed") {
t.Errorf("Expected error message to contain 'Email domain not allowed', got '%s'", msg)
}
}
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
extractClaims, isAllowed, "/callback", sendError)
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
rw := httptest.NewRecorder()
handler.HandleCallback(rw, req, "http://example.com/callback")
if !errorSent {
t.Error("Expected error response to be sent")
}
}
// TestOAuthHandler_HandleCallback_SessionSaveError tests session save failure
func TestOAuthHandler_HandleCallback_SessionSaveError(t *testing.T) {
logger := &mockLogger{}
session := &mockSessionData{
csrf: "test-state",
nonce: "test-nonce",
saveError: errors.New("save failed"),
}
sessionManager := &mockSessionManager{sessionToReturn: session}
tokenResponse := &TokenResponse{IDToken: "valid-token", AccessToken: "access-token", RefreshToken: "refresh-token"}
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
tokenVerifier := &mockTokenVerifier{}
extractClaims := func(token string) (map[string]interface{}, error) {
return map[string]interface{}{"email": "test@example.com", "nonce": "test-nonce"}, nil
}
isAllowed := func(email string) bool { return true }
errorSent := false
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
errorSent = true
if code != http.StatusInternalServerError {
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
}
if !strings.Contains(msg, "Failed to save session") {
t.Errorf("Expected error message to contain 'Failed to save session', got '%s'", msg)
}
}
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
extractClaims, isAllowed, "/callback", sendError)
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
rw := httptest.NewRecorder()
handler.HandleCallback(rw, req, "http://example.com/callback")
if !errorSent {
t.Error("Expected error response to be sent")
}
}
// TestOAuthHandler_HandleCallback_SetAuthenticatedError tests SetAuthenticated failure
func TestOAuthHandler_HandleCallback_SetAuthenticatedError(t *testing.T) {
logger := &mockLogger{}
session := &mockSessionData{
csrf: "test-state",
nonce: "test-nonce",
setAuthError: errors.New("set auth failed"),
}
sessionManager := &mockSessionManager{sessionToReturn: session}
tokenResponse := &TokenResponse{IDToken: "valid-token", AccessToken: "access-token"}
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
tokenVerifier := &mockTokenVerifier{}
extractClaims := func(token string) (map[string]interface{}, error) {
return map[string]interface{}{"email": "test@example.com", "nonce": "test-nonce"}, nil
}
isAllowed := func(email string) bool { return true }
errorSent := false
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
errorSent = true
if code != http.StatusInternalServerError {
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
}
if !strings.Contains(msg, "Failed to update session") {
t.Errorf("Expected error message to contain 'Failed to update session', got '%s'", msg)
}
}
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
extractClaims, isAllowed, "/callback", sendError)
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
rw := httptest.NewRecorder()
handler.HandleCallback(rw, req, "http://example.com/callback")
if !errorSent {
t.Error("Expected error response to be sent")
}
}
// TestOAuthHandler_HandleCallback_Success tests successful callback handling
func TestOAuthHandler_HandleCallback_Success(t *testing.T) {
logger := &mockLogger{}
session := &mockSessionData{
csrf: "test-state",
nonce: "test-nonce",
incomingPath: "/dashboard",
}
sessionManager := &mockSessionManager{sessionToReturn: session}
tokenResponse := &TokenResponse{
IDToken: "valid-id-token",
AccessToken: "valid-access-token",
RefreshToken: "valid-refresh-token",
}
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
tokenVerifier := &mockTokenVerifier{}
extractClaims := func(token string) (map[string]interface{}, error) {
return map[string]interface{}{"email": "test@example.com", "nonce": "test-nonce"}, nil
}
isAllowed := func(email string) bool { return true }
errorSent := false
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
errorSent = true
t.Errorf("Unexpected error sent: %s (code: %d)", msg, code)
}
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
extractClaims, isAllowed, "/callback", sendError)
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
rw := httptest.NewRecorder()
handler.HandleCallback(rw, req, "http://example.com/callback")
if errorSent {
t.Error("Unexpected error response sent")
}
// Check redirect
if rw.Code != http.StatusFound {
t.Errorf("Expected status %d, got %d", http.StatusFound, rw.Code)
}
location := rw.Header().Get("Location")
if location != "/dashboard" {
t.Errorf("Expected redirect to '/dashboard', got '%s'", location)
}
// Verify session data was set correctly
if session.email != "test@example.com" {
t.Errorf("Expected email 'test@example.com', got '%s'", session.email)
}
if session.idToken != "valid-id-token" {
t.Errorf("Expected ID token 'valid-id-token', got '%s'", session.idToken)
}
if session.accessToken != "valid-access-token" {
t.Errorf("Expected access token 'valid-access-token', got '%s'", session.accessToken)
}
if session.refreshToken != "valid-refresh-token" {
t.Errorf("Expected refresh token 'valid-refresh-token', got '%s'", session.refreshToken)
}
if !session.authenticated {
t.Error("Expected session to be authenticated")
}
// Check that temporary fields are cleared
if session.csrf != "" {
t.Errorf("Expected CSRF to be cleared, got '%s'", session.csrf)
}
if session.nonce != "" {
t.Errorf("Expected nonce to be cleared, got '%s'", session.nonce)
}
if session.codeVerifier != "" {
t.Errorf("Expected code verifier to be cleared, got '%s'", session.codeVerifier)
}
if session.incomingPath != "" {
t.Errorf("Expected incoming path to be cleared, got '%s'", session.incomingPath)
}
}
// TestOAuthHandler_HandleCallback_SuccessDefaultRedirect tests successful callback with default redirect
func TestOAuthHandler_HandleCallback_SuccessDefaultRedirect(t *testing.T) {
logger := &mockLogger{}
session := &mockSessionData{
csrf: "test-state",
nonce: "test-nonce",
incomingPath: "", // No incoming path, should default to "/"
}
sessionManager := &mockSessionManager{sessionToReturn: session}
tokenResponse := &TokenResponse{IDToken: "valid-token", AccessToken: "access-token"}
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
tokenVerifier := &mockTokenVerifier{}
extractClaims := func(token string) (map[string]interface{}, error) {
return map[string]interface{}{"email": "test@example.com", "nonce": "test-nonce"}, nil
}
isAllowed := func(email string) bool { return true }
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
t.Errorf("Unexpected error sent: %s (code: %d)", msg, code)
}
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
extractClaims, isAllowed, "/callback", sendError)
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
rw := httptest.NewRecorder()
handler.HandleCallback(rw, req, "http://example.com/callback")
// Check redirect to default path
if rw.Code != http.StatusFound {
t.Errorf("Expected status %d, got %d", http.StatusFound, rw.Code)
}
location := rw.Header().Get("Location")
if location != "/" {
t.Errorf("Expected redirect to '/', got '%s'", location)
}
}
// TestOAuthHandler_HandleCallback_RedirectURLPathExcluded tests incoming path same as redirect URL
func TestOAuthHandler_HandleCallback_RedirectURLPathExcluded(t *testing.T) {
logger := &mockLogger{}
session := &mockSessionData{
csrf: "test-state",
nonce: "test-nonce",
incomingPath: "/callback", // Same as redirect URL path
}
sessionManager := &mockSessionManager{sessionToReturn: session}
tokenResponse := &TokenResponse{IDToken: "valid-token", AccessToken: "access-token"}
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
tokenVerifier := &mockTokenVerifier{}
extractClaims := func(token string) (map[string]interface{}, error) {
return map[string]interface{}{"email": "test@example.com", "nonce": "test-nonce"}, nil
}
isAllowed := func(email string) bool { return true }
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
t.Errorf("Unexpected error sent: %s (code: %d)", msg, code)
}
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
extractClaims, isAllowed, "/callback", sendError)
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
rw := httptest.NewRecorder()
handler.HandleCallback(rw, req, "http://example.com/callback")
// Should redirect to default path when incoming path is same as callback path
location := rw.Header().Get("Location")
if location != "/" {
t.Errorf("Expected redirect to '/', got '%s'", location)
}
}
-454
View File
@@ -1,454 +0,0 @@
package handlers
import (
"crypto/tls"
"net/http"
"testing"
)
// TestURLHelper_NewURLHelper tests the URLHelper constructor
func TestURLHelper_NewURLHelper(t *testing.T) {
logger := &mockLogger{}
helper := NewURLHelper(logger)
if helper == nil {
t.Fatal("Expected URLHelper to be created, got nil")
}
if helper.logger != logger {
t.Error("Logger not set correctly")
}
}
// TestURLHelper_DetermineExcludedURL tests URL exclusion checking
func TestURLHelper_DetermineExcludedURL(t *testing.T) {
logger := &mockLogger{}
helper := NewURLHelper(logger)
tests := []struct {
name string
currentURL string
excludedURLs map[string]struct{}
expected bool
}{
{
name: "Exact match",
currentURL: "/health",
excludedURLs: map[string]struct{}{
"/health": {},
},
expected: true,
},
{
name: "Prefix match",
currentURL: "/health/status",
excludedURLs: map[string]struct{}{
"/health": {},
},
expected: true,
},
{
name: "No match",
currentURL: "/api/users",
excludedURLs: map[string]struct{}{
"/health": {},
},
expected: false,
},
{
name: "Multiple exclusions - first match",
currentURL: "/api/health",
excludedURLs: map[string]struct{}{
"/api": {},
"/health": {},
},
expected: true,
},
{
name: "Multiple exclusions - second match",
currentURL: "/health/check",
excludedURLs: map[string]struct{}{
"/api": {},
"/health": {},
},
expected: true,
},
{
name: "Empty excluded URLs",
currentURL: "/api/users",
excludedURLs: map[string]struct{}{},
expected: false,
},
{
name: "Root path exclusion",
currentURL: "/anything",
excludedURLs: map[string]struct{}{
"/": {},
},
expected: true,
},
{
name: "Case sensitive matching",
currentURL: "/API/users",
excludedURLs: map[string]struct{}{
"/api": {},
},
expected: false,
},
{
name: "Partial substring but not prefix",
currentURL: "/user/api/test",
excludedURLs: map[string]struct{}{
"/api": {},
},
expected: false,
},
{
name: "Empty current URL",
currentURL: "",
excludedURLs: map[string]struct{}{
"/health": {},
},
expected: false,
},
{
name: "URL with query parameters",
currentURL: "/health?status=ok",
excludedURLs: map[string]struct{}{
"/health": {},
},
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := helper.DetermineExcludedURL(tt.currentURL, tt.excludedURLs)
if result != tt.expected {
t.Errorf("DetermineExcludedURL() = %v, expected %v", result, tt.expected)
}
// Verify debug logging for excluded URLs
if result && len(logger.debugMessages) > 0 {
// Should have logged a debug message for excluded URL
found := false
for _, msg := range logger.debugMessages {
if msg == "URL is excluded - got %s / excluded hit: %s" {
found = true
break
}
}
if !found {
t.Error("Expected debug message for excluded URL")
}
}
// Reset logger messages for next test
logger.debugMessages = nil
})
}
}
// TestURLHelper_DetermineScheme tests scheme determination
func TestURLHelper_DetermineScheme(t *testing.T) {
logger := &mockLogger{}
helper := NewURLHelper(logger)
tests := []struct {
name string
setupRequest func() *http.Request
expectedScheme string
}{
{
name: "X-Forwarded-Proto header present - https",
setupRequest: func() *http.Request {
req, _ := http.NewRequest("GET", "http://example.com", nil)
req.Header.Set("X-Forwarded-Proto", "https")
return req
},
expectedScheme: "https",
},
{
name: "X-Forwarded-Proto header present - http",
setupRequest: func() *http.Request {
req, _ := http.NewRequest("GET", "http://example.com", nil)
req.Header.Set("X-Forwarded-Proto", "http")
return req
},
expectedScheme: "http",
},
{
name: "TLS connection without X-Forwarded-Proto",
setupRequest: func() *http.Request {
req, _ := http.NewRequest("GET", "https://example.com", nil)
req.TLS = &tls.ConnectionState{} // Simulate TLS connection
return req
},
expectedScheme: "https",
},
{
name: "No TLS and no X-Forwarded-Proto",
setupRequest: func() *http.Request {
req, _ := http.NewRequest("GET", "http://example.com", nil)
return req
},
expectedScheme: "http",
},
{
name: "X-Forwarded-Proto takes precedence over TLS",
setupRequest: func() *http.Request {
req, _ := http.NewRequest("GET", "https://example.com", nil)
req.TLS = &tls.ConnectionState{} // Simulate TLS connection
req.Header.Set("X-Forwarded-Proto", "http")
return req
},
expectedScheme: "http",
},
{
name: "Empty X-Forwarded-Proto falls back to TLS",
setupRequest: func() *http.Request {
req, _ := http.NewRequest("GET", "https://example.com", nil)
req.TLS = &tls.ConnectionState{} // Simulate TLS connection
req.Header.Set("X-Forwarded-Proto", "")
return req
},
expectedScheme: "https",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := tt.setupRequest()
result := helper.DetermineScheme(req)
if result != tt.expectedScheme {
t.Errorf("DetermineScheme() = %v, expected %v", result, tt.expectedScheme)
}
})
}
}
// TestURLHelper_DetermineHost tests host determination
func TestURLHelper_DetermineHost(t *testing.T) {
logger := &mockLogger{}
helper := NewURLHelper(logger)
tests := []struct {
name string
setupRequest func() *http.Request
expectedHost string
}{
{
name: "X-Forwarded-Host header present",
setupRequest: func() *http.Request {
req, _ := http.NewRequest("GET", "http://example.com", nil)
req.Host = "internal.example.com"
req.Header.Set("X-Forwarded-Host", "public.example.com")
return req
},
expectedHost: "public.example.com",
},
{
name: "No X-Forwarded-Host, use req.Host",
setupRequest: func() *http.Request {
req, _ := http.NewRequest("GET", "http://example.com", nil)
req.Host = "direct.example.com"
return req
},
expectedHost: "direct.example.com",
},
{
name: "Empty X-Forwarded-Host falls back to req.Host",
setupRequest: func() *http.Request {
req, _ := http.NewRequest("GET", "http://example.com", nil)
req.Host = "fallback.example.com"
req.Header.Set("X-Forwarded-Host", "")
return req
},
expectedHost: "fallback.example.com",
},
{
name: "X-Forwarded-Host with port",
setupRequest: func() *http.Request {
req, _ := http.NewRequest("GET", "http://example.com", nil)
req.Host = "internal.example.com:8080"
req.Header.Set("X-Forwarded-Host", "public.example.com:443")
return req
},
expectedHost: "public.example.com:443",
},
{
name: "req.Host with port",
setupRequest: func() *http.Request {
req, _ := http.NewRequest("GET", "http://example.com:8080", nil)
req.Host = "example.com:8080"
return req
},
expectedHost: "example.com:8080",
},
{
name: "Multiple X-Forwarded-Host values (first one used)",
setupRequest: func() *http.Request {
req, _ := http.NewRequest("GET", "http://example.com", nil)
req.Host = "internal.example.com"
req.Header.Set("X-Forwarded-Host", "first.example.com, second.example.com")
return req
},
expectedHost: "first.example.com, second.example.com", // Header value as-is
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := tt.setupRequest()
result := helper.DetermineHost(req)
if result != tt.expectedHost {
t.Errorf("DetermineHost() = %v, expected %v", result, tt.expectedHost)
}
})
}
}
// TestURLHelper_DetermineSchemeAndHost_Integration tests scheme and host working together
func TestURLHelper_DetermineSchemeAndHost_Integration(t *testing.T) {
logger := &mockLogger{}
helper := NewURLHelper(logger)
tests := []struct {
name string
setupRequest func() *http.Request
expectedScheme string
expectedHost string
}{
{
name: "Both headers present",
setupRequest: func() *http.Request {
req, _ := http.NewRequest("GET", "http://internal.example.com", nil)
req.Host = "internal.example.com"
req.Header.Set("X-Forwarded-Proto", "https")
req.Header.Set("X-Forwarded-Host", "public.example.com")
return req
},
expectedScheme: "https",
expectedHost: "public.example.com",
},
{
name: "Neither header present, TLS connection",
setupRequest: func() *http.Request {
req, _ := http.NewRequest("GET", "https://secure.example.com", nil)
req.Host = "secure.example.com"
req.TLS = &tls.ConnectionState{} // Simulate TLS connection
return req
},
expectedScheme: "https",
expectedHost: "secure.example.com",
},
{
name: "Neither header present, no TLS",
setupRequest: func() *http.Request {
req, _ := http.NewRequest("GET", "http://plain.example.com", nil)
req.Host = "plain.example.com"
return req
},
expectedScheme: "http",
expectedHost: "plain.example.com",
},
{
name: "Mixed - only scheme header",
setupRequest: func() *http.Request {
req, _ := http.NewRequest("GET", "http://mixed.example.com", nil)
req.Host = "mixed.example.com"
req.Header.Set("X-Forwarded-Proto", "https")
return req
},
expectedScheme: "https",
expectedHost: "mixed.example.com",
},
{
name: "Mixed - only host header",
setupRequest: func() *http.Request {
req, _ := http.NewRequest("GET", "http://mixed.example.com", nil)
req.Host = "internal.example.com"
req.Header.Set("X-Forwarded-Host", "external.example.com")
return req
},
expectedScheme: "http",
expectedHost: "external.example.com",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := tt.setupRequest()
scheme := helper.DetermineScheme(req)
host := helper.DetermineHost(req)
if scheme != tt.expectedScheme {
t.Errorf("DetermineScheme() = %v, expected %v", scheme, tt.expectedScheme)
}
if host != tt.expectedHost {
t.Errorf("DetermineHost() = %v, expected %v", host, tt.expectedHost)
}
// Test that we can build a complete URL
fullURL := scheme + "://" + host + "/callback"
expectedURL := tt.expectedScheme + "://" + tt.expectedHost + "/callback"
if fullURL != expectedURL {
t.Errorf("Combined URL = %v, expected %v", fullURL, expectedURL)
}
})
}
}
// Benchmark tests to ensure the helper methods are performant
func BenchmarkURLHelper_DetermineExcludedURL(b *testing.B) {
logger := &mockLogger{}
helper := NewURLHelper(logger)
excludedURLs := map[string]struct{}{
"/health": {},
"/metrics": {},
"/status": {},
"/api/v1": {},
"/api/v2": {},
"/static": {},
"/assets": {},
"/favicon": {},
"/robots": {},
"/sitemap": {},
}
testURL := "/api/users"
b.ResetTimer()
for i := 0; i < b.N; i++ {
helper.DetermineExcludedURL(testURL, excludedURLs)
}
}
func BenchmarkURLHelper_DetermineScheme(b *testing.B) {
logger := &mockLogger{}
helper := NewURLHelper(logger)
req, _ := http.NewRequest("GET", "http://example.com", nil)
req.Header.Set("X-Forwarded-Proto", "https")
b.ResetTimer()
for i := 0; i < b.N; i++ {
helper.DetermineScheme(req)
}
}
func BenchmarkURLHelper_DetermineHost(b *testing.B) {
logger := &mockLogger{}
helper := NewURLHelper(logger)
req, _ := http.NewRequest("GET", "http://example.com", nil)
req.Host = "internal.example.com"
req.Header.Set("X-Forwarded-Host", "external.example.com")
b.ResetTimer()
for i := 0; i < b.N; i++ {
helper.DetermineHost(req)
}
}
+22 -10
View File
@@ -13,6 +13,8 @@ import (
"net/url"
"strings"
"time"
"github.com/lukaszraczylo/traefikoidc/internal/utils"
)
// generateNonce creates a cryptographically secure random nonce for OIDC flows.
@@ -109,7 +111,7 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType string, code
client := t.tokenHTTPClient
if client == nil {
// Use shared transport pool to prevent memory leaks
jar, _ := cookiejar.New(nil)
jar, _ := cookiejar.New(nil) // Safe to ignore: cookiejar creation with nil options rarely fails
pooledClient := CreateTokenHTTPClient()
client = &http.Client{
Transport: pooledClient.Transport,
@@ -124,7 +126,12 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType string, code
}
}
req, err := http.NewRequestWithContext(ctx, "POST", t.tokenURL, strings.NewReader(data.Encode()))
// Read tokenURL with RLock
t.metadataMu.RLock()
tokenURL := t.tokenURL
t.metadataMu.RUnlock()
req, err := http.NewRequestWithContext(ctx, "POST", tokenURL, strings.NewReader(data.Encode()))
if err != nil {
return nil, fmt.Errorf("failed to create token request: %w", err)
}
@@ -135,13 +142,13 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType string, code
return nil, fmt.Errorf("failed to exchange tokens: %w", err)
}
defer func() {
io.Copy(io.Discard, resp.Body)
resp.Body.Close()
_, _ = io.Copy(io.Discard, resp.Body) // Safe to ignore: draining response body on defer
_ = resp.Body.Close() // Safe to ignore: closing body on defer
}()
if resp.StatusCode != http.StatusOK {
limitReader := io.LimitReader(resp.Body, 1024*10)
bodyBytes, _ := io.ReadAll(limitReader)
bodyBytes, _ := io.ReadAll(limitReader) // Safe to ignore: reading error body for diagnostics
return nil, fmt.Errorf("token endpoint returned status %d: %s", resp.StatusCode, string(bodyBytes))
}
@@ -232,7 +239,7 @@ func NewTokenCache() *TokenCache {
// - expiration: The duration for which the cache entry should be valid
func (tc *TokenCache) Set(token string, claims map[string]interface{}, expiration time.Duration) {
token = "t-" + token
tc.cache.Set(token, claims, expiration)
_ = tc.cache.Set(token, claims, expiration) // Safe to ignore: cache failures are non-critical
}
// Get retrieves cached claims for a token.
@@ -344,8 +351,8 @@ func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
return
}
host := t.determineHost(req)
scheme := t.determineScheme(req)
host := utils.DetermineHost(req)
scheme := utils.DetermineScheme(req, t.forceHTTPS)
baseURL := fmt.Sprintf("%s://%s", scheme, host)
postLogoutRedirectURI := t.postLogoutRedirectURI
@@ -355,8 +362,13 @@ func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
postLogoutRedirectURI = fmt.Sprintf("%s%s", baseURL, postLogoutRedirectURI)
}
if t.endSessionURL != "" && idToken != "" {
logoutURL, err := BuildLogoutURL(t.endSessionURL, idToken, postLogoutRedirectURI)
// Read endSessionURL with RLock
t.metadataMu.RLock()
endSessionURL := t.endSessionURL
t.metadataMu.RUnlock()
if endSessionURL != "" && idToken != "" {
logoutURL, err := BuildLogoutURL(endSessionURL, idToken, postLogoutRedirectURI)
if err != nil {
t.logger.Errorf("Failed to build logout URL: %v", err)
http.Error(rw, "Logout error", http.StatusInternalServerError)
+13 -20
View File
@@ -12,30 +12,23 @@ import (
// HTTPClientConfig provides configuration for creating HTTP clients
type HTTPClientConfig struct {
// Timeout for the entire request
Timeout time.Duration
// MaxRedirects allowed (0 means follow Go's default of 10)
MaxRedirects int
// UseCookieJar enables cookie jar for the client
UseCookieJar bool
// Connection settings
IdleConnTimeout time.Duration
MaxIdleConns int
ReadBufferSize int
DialTimeout time.Duration
KeepAlive time.Duration
TLSHandshakeTimeout time.Duration
ResponseHeaderTimeout time.Duration
ExpectContinueTimeout time.Duration
IdleConnTimeout time.Duration
// Connection pool settings
MaxIdleConns int
MaxIdleConnsPerHost int
MaxConnsPerHost int
// Buffer settings
WriteBufferSize int
ReadBufferSize int
// Feature flags
ForceHTTP2 bool
DisableKeepAlives bool
DisableCompression bool
MaxRedirects int
MaxIdleConnsPerHost int
Timeout time.Duration
MaxConnsPerHost int
WriteBufferSize int
UseCookieJar bool
ForceHTTP2 bool
DisableKeepAlives bool
DisableCompression bool
}
// DefaultHTTPClientConfig returns the default configuration for general use
@@ -245,7 +238,7 @@ func (f *HTTPClientFactory) CreateHTTPClient(config HTTPClientConfig) *http.Clie
// Add cookie jar if requested
if config.UseCookieJar {
jar, _ := cookiejar.New(nil)
jar, _ := cookiejar.New(nil) // Safe to ignore: cookiejar creation with nil options rarely fails
client.Jar = jar
}
+210
View File
@@ -0,0 +1,210 @@
package traefikoidc
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestOIDCProviderHTTPClientConfigUnit tests OIDCProviderHTTPClientConfig function
func TestOIDCProviderHTTPClientConfigUnit(t *testing.T) {
config := OIDCProviderHTTPClientConfig()
// Verify OIDC-specific settings
assert.Equal(t, 15*time.Second, config.Timeout, "OIDC provider should have 15s timeout")
assert.Equal(t, 100, config.MaxIdleConns, "OIDC provider should have 100 max idle conns")
assert.Equal(t, 25, config.MaxIdleConnsPerHost, "OIDC provider should have 25 max idle conns per host")
assert.Equal(t, 50, config.MaxConnsPerHost, "OIDC provider should have 50 max conns per host")
assert.Equal(t, 90*time.Second, config.IdleConnTimeout, "OIDC provider should have 90s idle conn timeout")
assert.True(t, config.UseCookieJar, "OIDC provider should have cookie jar enabled")
}
// TestCreateDefaultClientUnit tests CreateDefaultClient function
func TestCreateDefaultClientUnit(t *testing.T) {
factory := NewHTTPClientFactory()
client := factory.CreateDefaultClient()
require.NotNil(t, client)
assert.NotNil(t, client.Transport, "client should have transport")
assert.Equal(t, 10*time.Second, client.Timeout, "default client should have 10s timeout")
}
// TestCreateTokenClientUnit tests CreateTokenClient function
func TestCreateTokenClientUnit(t *testing.T) {
factory := NewHTTPClientFactory()
client := factory.CreateTokenClient()
require.NotNil(t, client)
assert.NotNil(t, client.Transport, "client should have transport")
assert.NotNil(t, client.Jar, "token client should have cookie jar")
assert.Equal(t, 10*time.Second, client.Timeout, "token client should have 10s timeout")
}
// TestCreateHTTPClientWithConfigUnit tests CreateHTTPClientWithConfig function
func TestCreateHTTPClientWithConfigUnit(t *testing.T) {
config := HTTPClientConfig{
Timeout: 5 * time.Second,
MaxIdleConns: 20,
MaxIdleConnsPerHost: 5,
UseCookieJar: true,
}
client := CreateHTTPClientWithConfig(config)
require.NotNil(t, client)
assert.Equal(t, 5*time.Second, client.Timeout)
assert.NotNil(t, client.Jar, "client should have cookie jar when configured")
}
// TestHTTPClientFactoryCreateHTTPClientValidation tests validation in CreateHTTPClient
func TestHTTPClientFactoryCreateHTTPClientValidation(t *testing.T) {
factory := NewHTTPClientFactory()
t.Run("zero values get defaults", func(t *testing.T) {
config := HTTPClientConfig{
// All zero values
}
client := factory.CreateHTTPClient(config)
require.NotNil(t, client)
// Verify defaults were applied
assert.Equal(t, 30*time.Second, client.Timeout)
})
t.Run("custom values preserved", func(t *testing.T) {
config := HTTPClientConfig{
Timeout: 15 * time.Second,
MaxIdleConns: 50,
MaxRedirects: 3,
UseCookieJar: true,
ForceHTTP2: true,
DisableKeepAlives: true,
}
client := factory.CreateHTTPClient(config)
require.NotNil(t, client)
assert.Equal(t, 15*time.Second, client.Timeout)
assert.NotNil(t, client.Jar)
})
t.Run("invalid timeout gets default", func(t *testing.T) {
config := HTTPClientConfig{
Timeout: -1 * time.Second, // Invalid
}
client := factory.CreateHTTPClient(config)
require.NotNil(t, client)
// Should get default due to validation failure
assert.Equal(t, 30*time.Second, client.Timeout)
})
}
// TestHTTPClientFactoryValidateHTTPClientConfig tests ValidateHTTPClientConfig
func TestHTTPClientFactoryValidateHTTPClientConfig(t *testing.T) {
factory := NewHTTPClientFactory()
tests := []struct {
name string
errorMsg string
config HTTPClientConfig
wantError bool
}{
{
name: "valid config",
config: HTTPClientConfig{
Timeout: 10 * time.Second,
DialTimeout: 5 * time.Second,
TLSHandshakeTimeout: 2 * time.Second,
MaxIdleConns: 50,
MaxIdleConnsPerHost: 10,
MaxConnsPerHost: 20,
},
wantError: false,
},
{
name: "negative MaxIdleConns",
config: HTTPClientConfig{
Timeout: 10 * time.Second,
DialTimeout: 5 * time.Second,
TLSHandshakeTimeout: 2 * time.Second,
MaxIdleConns: -1,
},
wantError: true,
errorMsg: "MaxIdleConns cannot be negative",
},
{
name: "MaxIdleConns too high",
config: HTTPClientConfig{
Timeout: 10 * time.Second,
DialTimeout: 5 * time.Second,
TLSHandshakeTimeout: 2 * time.Second,
MaxIdleConns: 1500,
},
wantError: true,
errorMsg: "MaxIdleConns too high",
},
{
name: "negative MaxIdleConnsPerHost",
config: HTTPClientConfig{
Timeout: 10 * time.Second,
DialTimeout: 5 * time.Second,
TLSHandshakeTimeout: 2 * time.Second,
MaxIdleConnsPerHost: -1,
},
wantError: true,
errorMsg: "MaxIdleConnsPerHost cannot be negative",
},
{
name: "timeout too high",
config: HTTPClientConfig{
Timeout: 10 * time.Minute,
DialTimeout: 5 * time.Second,
TLSHandshakeTimeout: 2 * time.Second,
},
wantError: true,
errorMsg: "timeout too high",
},
{
name: "negative timeout",
config: HTTPClientConfig{
Timeout: -1 * time.Second,
DialTimeout: 5 * time.Second,
TLSHandshakeTimeout: 2 * time.Second,
},
wantError: true,
errorMsg: "timeout must be positive",
},
{
name: "MaxIdleConnsPerHost exceeds MaxConnsPerHost",
config: HTTPClientConfig{
Timeout: 10 * time.Second,
DialTimeout: 5 * time.Second,
TLSHandshakeTimeout: 2 * time.Second,
MaxIdleConnsPerHost: 50,
MaxConnsPerHost: 10,
},
wantError: true,
errorMsg: "MaxIdleConnsPerHost (50) cannot exceed MaxConnsPerHost (10)",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := factory.ValidateHTTPClientConfig(&tt.config)
if tt.wantError {
assert.Error(t, err)
if tt.errorMsg != "" {
assert.Contains(t, err.Error(), tt.errorMsg)
}
} else {
assert.NoError(t, err)
}
})
}
}
+48 -16
View File
@@ -12,19 +12,19 @@ import (
// SharedTransportPool manages a pool of shared HTTP transports to prevent connection exhaustion
type SharedTransportPool struct {
mu sync.RWMutex
transports map[string]*sharedTransport
maxConns int
ctx context.Context
transports map[string]*sharedTransport
cancel context.CancelFunc
clientCount int32 // SECURITY FIX: Track total HTTP clients
maxClients int32 // SECURITY FIX: Limit total clients to 5
maxConns int
mu sync.RWMutex
clientCount int32
maxClients int32
}
type sharedTransport struct {
lastUsed time.Time
transport *http.Transport
refCount int
lastUsed time.Time
}
var (
@@ -146,6 +146,9 @@ func (p *SharedTransportPool) ReleaseTransport(transport *http.Transport) {
}
// cleanupIdleTransports periodically cleans up unused transports
// Uses two-phase cleanup to minimize lock contention:
// 1. Find candidates while holding read lock
// 2. Remove and close transports with minimal lock duration
func (p *SharedTransportPool) cleanupIdleTransports(ctx context.Context) {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
@@ -155,17 +158,46 @@ func (p *SharedTransportPool) cleanupIdleTransports(ctx context.Context) {
case <-ctx.Done():
return
case <-ticker.C:
p.mu.Lock()
now := time.Now()
for transportKey, shared := range p.transports {
// Clean up transports not used for 2 minutes with no references
if shared.refCount <= 0 && now.Sub(shared.lastUsed) > 2*time.Minute {
shared.transport.CloseIdleConnections()
delete(p.transports, transportKey)
// SECURITY FIX: Decrement client count when removing transport
atomic.AddInt32(&p.clientCount, -1)
}
p.performCleanup()
}
}
}
// performCleanup does the actual cleanup with optimized locking
func (p *SharedTransportPool) performCleanup() {
now := time.Now()
// Phase 1: Find candidates while holding read lock (fast)
p.mu.RLock()
candidates := make([]string, 0)
for transportKey, shared := range p.transports {
// Clean up transports not used for 2 minutes with no references
if shared.refCount <= 0 && now.Sub(shared.lastUsed) > 2*time.Minute {
candidates = append(candidates, transportKey)
}
}
p.mu.RUnlock()
if len(candidates) == 0 {
return
}
// Phase 2: Remove and close each candidate individually
// This minimizes lock contention and allows concurrent access
for _, key := range candidates {
p.mu.Lock()
shared, exists := p.transports[key]
if exists && shared.refCount <= 0 && now.Sub(shared.lastUsed) > 2*time.Minute {
// Remove from map first (releases memory)
delete(p.transports, key)
atomic.AddInt32(&p.clientCount, -1)
p.mu.Unlock()
// Close idle connections outside the lock (can be slow)
if shared.transport != nil {
shared.transport.CloseIdleConnections()
}
} else {
p.mu.Unlock()
}
}
+691
View File
@@ -0,0 +1,691 @@
package traefikoidc
import (
"context"
"net/http"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestSharedTransportPoolGetOrCreateTransport tests transport creation and reuse
func TestSharedTransportPoolGetOrCreateTransport(t *testing.T) {
t.Run("create new transport", func(t *testing.T) {
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
clientCount: 0,
maxClients: 5,
}
config := DefaultHTTPClientConfig()
transport := pool.GetOrCreateTransport(config)
require.NotNil(t, transport)
assert.Equal(t, int32(1), atomic.LoadInt32(&pool.clientCount))
assert.Len(t, pool.transports, 1)
})
t.Run("reuse existing transport", func(t *testing.T) {
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
clientCount: 0,
maxClients: 5,
}
config := DefaultHTTPClientConfig()
transport1 := pool.GetOrCreateTransport(config)
transport2 := pool.GetOrCreateTransport(config)
assert.Equal(t, transport1, transport2, "should reuse same transport")
assert.Equal(t, int32(1), atomic.LoadInt32(&pool.clientCount), "client count should not increase")
// Check ref count
pool.mu.RLock()
key := pool.configKey(config)
shared := pool.transports[key]
pool.mu.RUnlock()
assert.Equal(t, 2, shared.refCount, "ref count should be 2")
})
t.Run("client limit enforcement", func(t *testing.T) {
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
clientCount: 5, // Already at max
maxClients: 5,
}
config := DefaultHTTPClientConfig()
transport := pool.GetOrCreateTransport(config)
assert.Nil(t, transport, "should return nil when at client limit")
})
t.Run("client limit with existing transport", func(t *testing.T) {
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
clientCount: 0,
maxClients: 5,
}
// Create first transport
config1 := DefaultHTTPClientConfig()
transport1 := pool.GetOrCreateTransport(config1)
require.NotNil(t, transport1)
// Set client count to max
atomic.StoreInt32(&pool.clientCount, 5)
// Try to create with different config
config2 := DefaultHTTPClientConfig()
config2.MaxConnsPerHost = 15 // Different config
transport2 := pool.GetOrCreateTransport(config2)
// Should return existing transport since at limit
assert.NotNil(t, transport2)
assert.Equal(t, transport1, transport2)
})
t.Run("updates last used time", func(t *testing.T) {
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
clientCount: 0,
maxClients: 5,
}
config := DefaultHTTPClientConfig()
transport := pool.GetOrCreateTransport(config)
require.NotNil(t, transport)
pool.mu.RLock()
key := pool.configKey(config)
firstTime := pool.transports[key].lastUsed
pool.mu.RUnlock()
time.Sleep(10 * time.Millisecond)
// Get again
transport2 := pool.GetOrCreateTransport(config)
require.NotNil(t, transport2)
pool.mu.RLock()
secondTime := pool.transports[key].lastUsed
pool.mu.RUnlock()
assert.True(t, secondTime.After(firstTime), "lastUsed should be updated")
})
}
// TestSharedTransportPoolReleaseTransport tests transport release
func TestSharedTransportPoolReleaseTransport(t *testing.T) {
t.Run("decrement ref count", func(t *testing.T) {
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
clientCount: 0,
maxClients: 5,
}
config := DefaultHTTPClientConfig()
transport := pool.GetOrCreateTransport(config)
require.NotNil(t, transport)
// Get again to increase ref count
pool.GetOrCreateTransport(config)
pool.mu.RLock()
key := pool.configKey(config)
refCount := pool.transports[key].refCount
pool.mu.RUnlock()
assert.Equal(t, 2, refCount)
// Release
pool.ReleaseTransport(transport)
pool.mu.RLock()
newRefCount := pool.transports[key].refCount
pool.mu.RUnlock()
assert.Equal(t, 1, newRefCount)
})
t.Run("ref count reaches zero", func(t *testing.T) {
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
clientCount: 0,
maxClients: 5,
}
config := DefaultHTTPClientConfig()
transport := pool.GetOrCreateTransport(config)
require.NotNil(t, transport)
pool.mu.RLock()
key := pool.configKey(config)
pool.mu.RUnlock()
// Release to zero
pool.ReleaseTransport(transport)
pool.mu.RLock()
shared := pool.transports[key]
pool.mu.RUnlock()
assert.Equal(t, 0, shared.refCount)
assert.NotZero(t, shared.lastUsed, "lastUsed should be set")
})
t.Run("release non-existent transport", func(t *testing.T) {
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
clientCount: 0,
maxClients: 5,
}
// Create a transport not in the pool
fakeTransport := &http.Transport{}
// Should not panic
assert.NotPanics(t, func() {
pool.ReleaseTransport(fakeTransport)
})
})
t.Run("release updates last used", func(t *testing.T) {
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
clientCount: 0,
maxClients: 5,
}
config := DefaultHTTPClientConfig()
transport := pool.GetOrCreateTransport(config)
require.NotNil(t, transport)
time.Sleep(10 * time.Millisecond)
beforeRelease := time.Now()
pool.ReleaseTransport(transport)
pool.mu.RLock()
key := pool.configKey(config)
lastUsed := pool.transports[key].lastUsed
pool.mu.RUnlock()
assert.True(t, lastUsed.After(beforeRelease) || lastUsed.Equal(beforeRelease))
})
}
// TestSharedTransportPoolCleanup tests cleanup functionality
func TestSharedTransportPoolCleanup(t *testing.T) {
t.Run("cleanup all transports", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
clientCount: 0,
maxClients: 5,
ctx: ctx,
cancel: cancel,
}
// Create multiple transports
config1 := DefaultHTTPClientConfig()
pool.GetOrCreateTransport(config1)
config2 := DefaultHTTPClientConfig()
config2.MaxConnsPerHost = 15
pool.GetOrCreateTransport(config2)
assert.Greater(t, len(pool.transports), 0)
// Cleanup
pool.Cleanup()
assert.Len(t, pool.transports, 0, "all transports should be removed")
})
t.Run("cleanup cancels context", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
clientCount: 0,
maxClients: 5,
ctx: ctx,
cancel: cancel,
}
pool.Cleanup()
select {
case <-pool.ctx.Done():
// Context was canceled
case <-time.After(100 * time.Millisecond):
t.Error("context should be canceled")
}
})
t.Run("cleanup with no transports", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
clientCount: 0,
maxClients: 5,
ctx: ctx,
cancel: cancel,
}
assert.NotPanics(t, func() {
pool.Cleanup()
})
})
t.Run("cleanup closes idle connections", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
clientCount: 0,
maxClients: 5,
ctx: ctx,
cancel: cancel,
}
config := DefaultHTTPClientConfig()
transport := pool.GetOrCreateTransport(config)
require.NotNil(t, transport)
// Cleanup should call CloseIdleConnections on each transport
pool.Cleanup()
// Verify transports map is cleared
assert.Empty(t, pool.transports)
})
}
// TestSharedTransportPoolCleanupIdleTransports tests periodic cleanup
func TestSharedTransportPoolCleanupIdleTransports(t *testing.T) {
if testing.Short() {
t.Skip("Skipping cleanup goroutine test in short mode")
}
t.Run("cleanup removes idle transports", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
clientCount: 0,
maxClients: 5,
ctx: ctx,
cancel: cancel,
}
// Create transport and release it
config := DefaultHTTPClientConfig()
transport := pool.GetOrCreateTransport(config)
require.NotNil(t, transport)
pool.ReleaseTransport(transport)
// Set lastUsed to old time
pool.mu.Lock()
key := pool.configKey(config)
pool.transports[key].lastUsed = time.Now().Add(-3 * time.Minute)
pool.mu.Unlock()
// Start cleanup in background (simulating what would happen)
// Note: We're testing the cleanup logic manually here
pool.mu.Lock()
now := time.Now()
for transportKey, shared := range pool.transports {
if shared.refCount <= 0 && now.Sub(shared.lastUsed) > 2*time.Minute {
shared.transport.CloseIdleConnections()
delete(pool.transports, transportKey)
atomic.AddInt32(&pool.clientCount, -1)
}
}
pool.mu.Unlock()
// Transport should be removed
pool.mu.RLock()
_, exists := pool.transports[key]
pool.mu.RUnlock()
assert.False(t, exists, "old idle transport should be removed")
})
t.Run("cleanup preserves active transports", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
clientCount: 0,
maxClients: 5,
ctx: ctx,
cancel: cancel,
}
// Create transport with refs
config := DefaultHTTPClientConfig()
transport := pool.GetOrCreateTransport(config)
require.NotNil(t, transport)
// Keep ref count > 0, but set old lastUsed
pool.mu.Lock()
key := pool.configKey(config)
pool.transports[key].lastUsed = time.Now().Add(-3 * time.Minute)
pool.mu.Unlock()
// Run cleanup logic
pool.mu.Lock()
now := time.Now()
for transportKey, shared := range pool.transports {
if shared.refCount <= 0 && now.Sub(shared.lastUsed) > 2*time.Minute {
shared.transport.CloseIdleConnections()
delete(pool.transports, transportKey)
}
}
pool.mu.Unlock()
// Transport should still exist (has ref count)
pool.mu.RLock()
_, exists := pool.transports[key]
pool.mu.RUnlock()
assert.True(t, exists, "transport with references should be preserved")
})
t.Run("cleanup respects context cancellation", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
clientCount: 0,
maxClients: 5,
ctx: ctx,
cancel: cancel,
}
// Start cleanup goroutine
done := make(chan bool)
go func() {
pool.cleanupIdleTransports(ctx)
done <- true
}()
// Cancel context
cancel()
// Should exit quickly
select {
case <-done:
// Success
case <-time.After(2 * time.Second):
t.Error("cleanup goroutine should exit on context cancellation")
}
})
}
// TestCreatePooledHTTPClient tests pooled client creation
func TestCreatePooledHTTPClient(t *testing.T) {
t.Run("create client with default config", func(t *testing.T) {
config := DefaultHTTPClientConfig()
client := CreatePooledHTTPClient(config)
require.NotNil(t, client)
assert.NotNil(t, client.Transport)
assert.Equal(t, config.Timeout, client.Timeout)
})
t.Run("create multiple clients reuse transport", func(t *testing.T) {
// Reset global pool for clean test
globalTransportPoolOnce = sync.Once{}
globalTransportPool = nil
config := DefaultHTTPClientConfig()
client1 := CreatePooledHTTPClient(config)
client2 := CreatePooledHTTPClient(config)
require.NotNil(t, client1)
require.NotNil(t, client2)
// Should use same transport
assert.Equal(t, client1.Transport, client2.Transport)
})
t.Run("redirect policy is set", func(t *testing.T) {
config := DefaultHTTPClientConfig()
config.MaxRedirects = 3
client := CreatePooledHTTPClient(config)
require.NotNil(t, client)
assert.NotNil(t, client.CheckRedirect)
// Test redirect limit
var redirects []*http.Request
for i := 0; i < 3; i++ {
redirects = append(redirects, &http.Request{})
}
err := client.CheckRedirect(nil, redirects)
assert.Error(t, err, "should error after max redirects")
})
t.Run("default redirect limit", func(t *testing.T) {
config := DefaultHTTPClientConfig()
config.MaxRedirects = 0 // Should default to 10
client := CreatePooledHTTPClient(config)
require.NotNil(t, client)
// Test default redirect limit (10)
var redirects []*http.Request
for i := 0; i < 10; i++ {
redirects = append(redirects, &http.Request{})
}
err := client.CheckRedirect(nil, redirects)
assert.Error(t, err, "should error after 10 redirects")
})
}
// TestGetGlobalTransportPool tests singleton pattern
func TestGetGlobalTransportPool(t *testing.T) {
t.Run("returns same instance", func(t *testing.T) {
pool1 := GetGlobalTransportPool()
pool2 := GetGlobalTransportPool()
assert.Equal(t, pool1, pool2, "should return same singleton instance")
})
t.Run("pool is initialized", func(t *testing.T) {
pool := GetGlobalTransportPool()
require.NotNil(t, pool)
assert.NotNil(t, pool.transports)
assert.Equal(t, 20, pool.maxConns)
assert.Equal(t, int32(5), pool.maxClients)
assert.NotNil(t, pool.ctx)
assert.NotNil(t, pool.cancel)
})
}
// TestSharedTransportPoolConcurrency tests thread safety
func TestSharedTransportPoolConcurrency(t *testing.T) {
if testing.Short() {
t.Skip("Skipping concurrency test in short mode")
}
t.Run("concurrent GetOrCreateTransport", func(t *testing.T) {
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
clientCount: 0,
maxClients: 10, // Allow more for concurrency test
}
config := DefaultHTTPClientConfig()
const numGoroutines = 20
var wg sync.WaitGroup
transports := make([]*http.Transport, numGoroutines)
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
transports[idx] = pool.GetOrCreateTransport(config)
}(i)
}
wg.Wait()
// All should get same transport
firstTransport := transports[0]
for i := 1; i < numGoroutines; i++ {
if transports[i] != nil {
assert.Equal(t, firstTransport, transports[i])
}
}
})
t.Run("concurrent ReleaseTransport", func(t *testing.T) {
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
clientCount: 0,
maxClients: 10,
}
config := DefaultHTTPClientConfig()
transport := pool.GetOrCreateTransport(config)
// Increase ref count
for i := 0; i < 20; i++ {
pool.GetOrCreateTransport(config)
}
const numReleases = 20
var wg sync.WaitGroup
for i := 0; i < numReleases; i++ {
wg.Add(1)
go func() {
defer wg.Done()
pool.ReleaseTransport(transport)
}()
}
wg.Wait()
// Should not panic and ref count should be decremented
pool.mu.RLock()
key := pool.configKey(config)
refCount := pool.transports[key].refCount
pool.mu.RUnlock()
assert.Equal(t, 1, refCount, "ref count should be 1 after 20 releases from initial 21")
})
}
// TestSharedTransportPoolEdgeCases tests edge cases
func TestSharedTransportPoolEdgeCases(t *testing.T) {
t.Run("config key generation", func(t *testing.T) {
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
}
config1 := DefaultHTTPClientConfig()
config1.MaxConnsPerHost = 10
config1.MaxIdleConnsPerHost = 5
config2 := DefaultHTTPClientConfig()
config2.MaxConnsPerHost = 10
config2.MaxIdleConnsPerHost = 5
key1 := pool.configKey(config1)
key2 := pool.configKey(config2)
assert.Equal(t, key1, key2, "same config should produce same key")
})
t.Run("different configs produce different keys", func(t *testing.T) {
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
}
config1 := DefaultHTTPClientConfig()
config1.MaxConnsPerHost = 10
config2 := DefaultHTTPClientConfig()
config2.MaxConnsPerHost = 20
key1 := pool.configKey(config1)
key2 := pool.configKey(config2)
assert.NotEqual(t, key1, key2, "different configs should produce different keys")
})
t.Run("client count decrements on cleanup", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
clientCount: 0,
maxClients: 5,
ctx: ctx,
cancel: cancel,
}
config := DefaultHTTPClientConfig()
transport := pool.GetOrCreateTransport(config)
require.NotNil(t, transport)
initialCount := atomic.LoadInt32(&pool.clientCount)
assert.Equal(t, int32(1), initialCount)
// Release and mark as old
pool.ReleaseTransport(transport)
pool.mu.Lock()
key := pool.configKey(config)
pool.transports[key].lastUsed = time.Now().Add(-3 * time.Minute)
pool.mu.Unlock()
// Run cleanup
pool.mu.Lock()
now := time.Now()
for transportKey, shared := range pool.transports {
if shared.refCount <= 0 && now.Sub(shared.lastUsed) > 2*time.Minute {
shared.transport.CloseIdleConnections()
delete(pool.transports, transportKey)
atomic.AddInt32(&pool.clientCount, -1)
}
}
pool.mu.Unlock()
finalCount := atomic.LoadInt32(&pool.clientCount)
assert.Equal(t, int32(0), finalCount, "client count should decrement on cleanup")
})
}
+52 -47
View File
@@ -15,20 +15,21 @@ import (
// XSS, path traversal, and other injection attacks. It validates and sanitizes
// various input types used in OIDC authentication flows.
type InputValidator struct {
usernameRegex *regexp.Regexp
tokenRegex *regexp.Regexp
logger *Logger
urlRegex *regexp.Regexp
emailRegex *regexp.Regexp
sqlInjectionPatterns []string
pathTraversalPatterns []string
xssPatterns []string
maxUsernameLength int
maxURLLength int
maxTokenLength int
maxEmailLength int
maxClaimLength int
maxHeaderLength int
usernameRegex *regexp.Regexp
tokenRegex *regexp.Regexp
logger *Logger
urlRegex *regexp.Regexp
emailRegex *regexp.Regexp
sqlInjectionPatterns []string
pathTraversalPatterns []string
xssPatterns []string
maxUsernameLength int
maxURLLength int
maxTokenLength int
maxEmailLength int
maxClaimLength int
maxHeaderLength int
allowPrivateIPAddresses bool // Allow private IP addresses in URL validation
}
// ValidationResult encapsulates the outcome of input validation.
@@ -46,13 +47,14 @@ type ValidationResult struct {
// It specifies maximum lengths for various input types and controls whether
// strict validation mode is enabled.
type InputValidationConfig struct {
MaxTokenLength int `json:"max_token_length"`
MaxURLLength int `json:"max_url_length"`
MaxHeaderLength int `json:"max_header_length"`
MaxClaimLength int `json:"max_claim_length"`
MaxEmailLength int `json:"max_email_length"`
MaxUsernameLength int `json:"max_username_length"`
StrictMode bool `json:"strict_mode"`
MaxTokenLength int `json:"max_token_length"`
MaxURLLength int `json:"max_url_length"`
MaxHeaderLength int `json:"max_header_length"`
MaxClaimLength int `json:"max_claim_length"`
MaxEmailLength int `json:"max_email_length"`
MaxUsernameLength int `json:"max_username_length"`
StrictMode bool `json:"strict_mode"`
AllowPrivateIPAddresses bool `json:"allow_private_ip_addresses"` // Allow private IP addresses in URL validation
}
// DefaultInputValidationConfig returns a secure default configuration
@@ -103,16 +105,17 @@ func NewInputValidator(config InputValidationConfig, logger *Logger) (*InputVali
}
return &InputValidator{
maxTokenLength: config.MaxTokenLength,
maxURLLength: config.MaxURLLength,
maxHeaderLength: config.MaxHeaderLength,
maxClaimLength: config.MaxClaimLength,
maxEmailLength: config.MaxEmailLength,
maxUsernameLength: config.MaxUsernameLength,
emailRegex: emailRegex,
urlRegex: urlRegex,
tokenRegex: tokenRegex,
usernameRegex: usernameRegex,
maxTokenLength: config.MaxTokenLength,
maxURLLength: config.MaxURLLength,
maxHeaderLength: config.MaxHeaderLength,
maxClaimLength: config.MaxClaimLength,
maxEmailLength: config.MaxEmailLength,
maxUsernameLength: config.MaxUsernameLength,
allowPrivateIPAddresses: config.AllowPrivateIPAddresses,
emailRegex: emailRegex,
urlRegex: urlRegex,
tokenRegex: tokenRegex,
usernameRegex: usernameRegex,
sqlInjectionPatterns: []string{
"'", "\"", ";", "--", "/*", "*/", "xp_", "sp_",
"union", "select", "insert", "update", "delete", "drop",
@@ -335,24 +338,26 @@ func (iv *InputValidator) ValidateURL(urlStr string) ValidationResult {
}
}
// Check for private IP ranges (RFC 1918)
if strings.HasPrefix(hostname, "10.") ||
strings.HasPrefix(hostname, "192.168.") ||
strings.HasPrefix(hostname, "172.") {
// For 172.x check if it's in the 172.16.0.0/12 range
if strings.HasPrefix(hostname, "172.") {
parts := strings.Split(hostname, ".")
if len(parts) >= 2 {
if second, err := strconv.Atoi(parts[1]); err == nil && second >= 16 && second <= 31 {
result.IsValid = false
result.Errors = append(result.Errors, "private IP URLs are not allowed for security")
return result
// Check for private IP ranges (RFC 1918) - skip if allowPrivateIPAddresses is enabled
if !iv.allowPrivateIPAddresses {
if strings.HasPrefix(hostname, "10.") ||
strings.HasPrefix(hostname, "192.168.") ||
strings.HasPrefix(hostname, "172.") {
// For 172.x check if it's in the 172.16.0.0/12 range
if strings.HasPrefix(hostname, "172.") {
parts := strings.Split(hostname, ".")
if len(parts) >= 2 {
if second, err := strconv.Atoi(parts[1]); err == nil && second >= 16 && second <= 31 {
result.IsValid = false
result.Errors = append(result.Errors, "private IP URLs are not allowed for security")
return result
}
}
} else {
result.IsValid = false
result.Errors = append(result.Errors, "private IP URLs are not allowed for security")
return result
}
} else {
result.IsValid = false
result.Errors = append(result.Errors, "private IP URLs are not allowed for security")
return result
}
}
+9 -9
View File
@@ -14,7 +14,7 @@ func TestInputValidator(t *testing.T) {
}
t.Run("Valid token validation", func(t *testing.T) {
validToken := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.EkN-DOsnsuRjRO6BxXemmJDm3HbxrbRzXglbN2S4sOkopdU4IsDxTI8jO19W_A4K8ZPJijNLis4EZsHeY559a4DFOd50_OqgHs3UjpMC6M6FNqI2J-I2NxrragtnDxGxdJUvDERDQVHzeNlVQiuqWDEeO_O-0KptafbfyuGqfQxH_6dp2_MeFpAc"
validToken := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.EkN-DOsnsuRjRO6BxXemmJDm3HbxrbRzXglbN2S4sOkopdU4IsDxTI8jO19W_A4K8ZPJijNLis4EZsHeY559a4DFOd50_OqgHs3UjpMC6M6FNqI2J-I2NxrragtnDxGxdJUvDERDQVHzeNlVQiuqWDEeO_O-0KptafbfyuGqfQxH_6dp2_MeFpAc" // trufflehog:ignore
result := validator.ValidateToken(validToken)
if !result.IsValid {
@@ -428,12 +428,12 @@ func TestInputValidatorValidateToken(t *testing.T) {
tests := []struct {
name string
token string
expectValid bool
description string
expectValid bool
}{
{
name: "ValidJWTToken",
token: "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiZXhwIjoxNTE2MjM5MDIyLCJpYXQiOjE1MTYyMzkwMjJ9.signature",
token: "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiZXhwIjoxNTE2MjM5MDIyLCJpYXQiOjE1MTYyMzkwMjJ9.signature", // trufflehog:ignore
expectValid: true,
description: "Valid JWT token should pass validation",
},
@@ -475,7 +475,7 @@ func TestInputValidatorValidateToken(t *testing.T) {
},
{
name: "MaliciousJWTWithExtraData",
token: "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.sig.malicious_extra",
token: "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.sig.malicious_extra", // trufflehog:ignore
expectValid: false,
description: "JWT with extra malicious data should fail validation",
},
@@ -500,8 +500,8 @@ func TestInputValidatorValidateEmail(t *testing.T) {
tests := []struct {
name string
email string
expectValid bool
description string
expectValid bool
}{
{
name: "ValidEmail",
@@ -578,8 +578,8 @@ func TestInputValidatorValidateURL(t *testing.T) {
tests := []struct {
name string
url string
expectValid bool
description string
expectValid bool
}{
{
name: "ValidHTTPSURL",
@@ -669,8 +669,8 @@ func TestInputValidatorValidateClaim(t *testing.T) {
name string
claimName string
claimValue string
expectValid bool
description string
expectValid bool
}{
{
name: "ValidStringClaim",
@@ -750,8 +750,8 @@ func TestInputValidatorValidateHeader(t *testing.T) {
name string
headerName string
headerValue string
expectValid bool
description string
expectValid bool
}{
{
name: "ValidHeader",
@@ -830,8 +830,8 @@ func TestInputValidatorValidateUsername(t *testing.T) {
tests := []struct {
name string
username string
expectValid bool
description string
expectValid bool
}{
{
name: "ValidUsername",
+4 -4
View File
@@ -726,20 +726,20 @@ type MockConfig struct {
}
type MockSession struct {
id string
userID string
created time.Time
lastUsed time.Time
data map[string]interface{}
id string
userID string
}
type TestResult struct {
UserID int
StartTime time.Time
EndTime time.Time
Error error
UserID int
Duration time.Duration
Success bool
Error error
}
// ============================================================================
+79
View File
@@ -0,0 +1,79 @@
package backends
import "time"
// BackendType represents the type of cache backend
type BackendType string
const (
BackendTypeMemory BackendType = "memory"
BackendTypeRedis BackendType = "redis"
BackendTypeHybrid BackendType = "hybrid"
// Aliases for backward compatibility
TypeMemory BackendType = "memory"
TypeRedis BackendType = "redis"
TypeHybrid BackendType = "hybrid"
)
// Config provides common configuration for cache backends
type Config struct {
L2Config *Config
L1Config *Config
RedisPrefix string
Type BackendType
RedisAddr string
RedisPassword string
PoolSize int
RedisDB int
CleanupInterval time.Duration
MaxMemoryBytes int64
MaxSize int
HealthCheckInterval time.Duration
AsyncWrites bool
EnableCircuitBreaker bool
EnableHealthCheck bool
EnableMetrics bool
}
// DefaultConfig returns a default configuration for in-memory caching
func DefaultConfig() *Config {
return &Config{
Type: BackendTypeMemory,
MaxSize: 1000,
MaxMemoryBytes: 50 * 1024 * 1024, // 50MB
CleanupInterval: 5 * time.Minute,
EnableMetrics: true,
}
}
// DefaultRedisConfig returns a default configuration for Redis caching
func DefaultRedisConfig(addr string) *Config {
return &Config{
Type: BackendTypeRedis,
RedisAddr: addr,
RedisDB: 0,
RedisPrefix: "traefikoidc:",
PoolSize: 10,
EnableCircuitBreaker: true,
EnableHealthCheck: true,
HealthCheckInterval: 30 * time.Second,
EnableMetrics: true,
}
}
// DefaultHybridConfig returns a default configuration for hybrid caching
func DefaultHybridConfig(redisAddr string) *Config {
return &Config{
Type: BackendTypeHybrid,
L1Config: &Config{
Type: BackendTypeMemory,
MaxSize: 500,
MaxMemoryBytes: 10 * 1024 * 1024, // 10MB for L1
CleanupInterval: 1 * time.Minute,
},
L2Config: DefaultRedisConfig(redisAddr),
AsyncWrites: true,
EnableMetrics: true,
}
}
+59
View File
@@ -0,0 +1,59 @@
//go:build !yaegi
package backends
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestDefaultHybridConfig verifies the default hybrid configuration
func TestDefaultHybridConfig(t *testing.T) {
redisAddr := "localhost:6379"
config := DefaultHybridConfig(redisAddr)
require.NotNil(t, config)
// Verify top-level config
assert.Equal(t, BackendTypeHybrid, config.Type)
assert.True(t, config.AsyncWrites)
assert.True(t, config.EnableMetrics)
// Verify L1 (memory) config
require.NotNil(t, config.L1Config)
assert.Equal(t, BackendTypeMemory, config.L1Config.Type)
assert.Equal(t, 500, config.L1Config.MaxSize)
assert.Equal(t, int64(10*1024*1024), config.L1Config.MaxMemoryBytes) // 10MB
assert.Equal(t, 1*time.Minute, config.L1Config.CleanupInterval)
// Verify L2 (Redis) config exists
require.NotNil(t, config.L2Config)
assert.Equal(t, BackendTypeRedis, config.L2Config.Type)
}
func TestDefaultHybridConfig_DifferentRedisAddr(t *testing.T) {
tests := []struct {
name string
redisAddr string
}{
{"localhost", "localhost:6379"},
{"remote host", "redis.example.com:6379"},
{"IP address", "192.168.1.100:6379"},
{"custom port", "localhost:6380"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
config := DefaultHybridConfig(tt.redisAddr)
require.NotNil(t, config)
assert.Equal(t, BackendTypeHybrid, config.Type)
assert.NotNil(t, config.L1Config)
assert.NotNil(t, config.L2Config)
})
}
}
+38
View File
@@ -0,0 +1,38 @@
package backends
import "errors"
var (
// ErrBackendClosed is returned when operating on a closed backend
ErrBackendClosed = errors.New("cache backend is closed")
// ErrKeyNotFound is returned when a key doesn't exist
ErrKeyNotFound = errors.New("key not found")
// ErrCacheMiss indicates the requested key was not found in the cache
ErrCacheMiss = errors.New("cache miss")
// ErrBackendUnavailable indicates the cache backend is not available
ErrBackendUnavailable = errors.New("cache backend unavailable")
// ErrInvalidValue indicates the cached value is invalid or corrupted
ErrInvalidValue = errors.New("invalid cached value")
// ErrInvalidTTL is returned when TTL is invalid
ErrInvalidTTL = errors.New("invalid TTL")
// ErrConnectionFailed is returned when connection fails
ErrConnectionFailed = errors.New("connection failed")
// ErrCircuitOpen is returned when circuit breaker is open
ErrCircuitOpen = errors.New("circuit breaker is open")
// ErrTimeout is returned when operation times out
ErrTimeout = errors.New("operation timeout")
// ErrSerializationFailed is returned when serialization fails
ErrSerializationFailed = errors.New("serialization failed")
// ErrDeserializationFailed is returned when deserialization fails
ErrDeserializationFailed = errors.New("deserialization failed")
)
+685
View File
@@ -0,0 +1,685 @@
// Package backend provides cache backend implementations for the Traefik OIDC plugin.
package backends
import (
"context"
"fmt"
"log"
"sync"
"sync/atomic"
"time"
)
// HybridBackend implements a two-tier cache with L1 (memory) and L2 (Redis) backends
// It provides automatic failover, async writes for non-critical data, and optimized read paths
type HybridBackend struct {
lastL2Error atomic.Value
secondary CacheBackend
primary CacheBackend
logger Logger
ctx context.Context
syncWriteCacheTypes map[string]bool
asyncWriteBuffer chan *asyncWriteItem
cancel context.CancelFunc
wg sync.WaitGroup
l1Hits atomic.Int64
errors atomic.Int64
l2Writes atomic.Int64
l1Writes atomic.Int64
misses atomic.Int64
l2Hits atomic.Int64
fallbackMode atomic.Bool
}
// asyncWriteItem represents an async write operation
type asyncWriteItem struct {
ctx context.Context
key string
value []byte
ttl time.Duration
}
// Logger interface for structured logging
type Logger interface {
Debugf(format string, args ...interface{})
Infof(format string, args ...interface{})
Warnf(format string, args ...interface{})
Errorf(format string, args ...interface{})
}
// defaultLogger provides a basic logger implementation
type defaultLogger struct {
*log.Logger
}
func (l *defaultLogger) Debugf(format string, args ...interface{}) {
l.Printf("[DEBUG] "+format, args...)
}
func (l *defaultLogger) Infof(format string, args ...interface{}) {
l.Printf("[INFO] "+format, args...)
}
func (l *defaultLogger) Warnf(format string, args ...interface{}) {
l.Printf("[WARN] "+format, args...)
}
func (l *defaultLogger) Errorf(format string, args ...interface{}) {
l.Printf("[ERROR] "+format, args...)
}
// HybridConfig provides configuration for the hybrid backend
type HybridConfig struct {
Primary CacheBackend
Secondary CacheBackend
Logger Logger
SyncWriteCacheTypes map[string]bool
AsyncBufferSize int
}
// NewHybridBackend creates a new hybrid cache backend with L1 (memory) and L2 (Redis) tiers
func NewHybridBackend(config *HybridConfig) (*HybridBackend, error) {
if config == nil {
return nil, fmt.Errorf("config is required")
}
if config.Primary == nil {
return nil, fmt.Errorf("primary (L1) backend is required")
}
if config.Secondary == nil {
return nil, fmt.Errorf("secondary (L2) backend is required")
}
if config.Logger == nil {
config.Logger = &defaultLogger{Logger: log.New(log.Writer(), "[HybridCache] ", log.LstdFlags)}
}
if config.AsyncBufferSize <= 0 {
config.AsyncBufferSize = 1000
}
// Default critical cache types that require synchronous writes
if config.SyncWriteCacheTypes == nil {
config.SyncWriteCacheTypes = map[string]bool{
"blacklist": true, // Token blacklist must be immediately consistent
"token": true, // Token validation is critical
}
}
ctx, cancel := context.WithCancel(context.Background())
h := &HybridBackend{
primary: config.Primary,
secondary: config.Secondary,
syncWriteCacheTypes: config.SyncWriteCacheTypes,
asyncWriteBuffer: make(chan *asyncWriteItem, config.AsyncBufferSize),
ctx: ctx,
cancel: cancel,
logger: config.Logger,
}
// Start async write worker
h.wg.Add(1)
go h.asyncWriteWorker()
// Start health monitoring
h.wg.Add(1)
go h.healthMonitor()
h.logger.Infof("HybridBackend initialized with L1 (memory) and L2 (Redis) tiers")
h.logger.Infof("Sync write cache types: %v", config.SyncWriteCacheTypes)
h.logger.Infof("Async write buffer size: %d", config.AsyncBufferSize)
return h, nil
}
// Set stores a value in both L1 and L2 caches
func (h *HybridBackend) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
// Always write to L1 first (synchronous)
if err := h.primary.Set(ctx, key, value, ttl); err != nil {
h.errors.Add(1)
h.logger.Warnf("Failed to write to L1 cache: %v", err)
// Continue to try L2 even if L1 fails
} else {
h.l1Writes.Add(1)
}
// Check if we're in fallback mode
if h.fallbackMode.Load() {
h.logger.Debugf("Operating in fallback mode, skipping L2 write for key: %s", key)
return nil // Don't fail the operation if L2 is down
}
// Determine if this should be a sync or async write based on cache type
cacheType := h.extractCacheType(key)
requiresSync := h.syncWriteCacheTypes[cacheType]
if requiresSync {
// Synchronous write for critical cache types
if err := h.secondary.Set(ctx, key, value, ttl); err != nil {
h.errors.Add(1)
h.logger.Warnf("Failed to write to L2 cache (sync) for key %s: %v", key, err)
h.recordL2Error()
// Don't fail the operation - L1 write succeeded
return nil
}
h.l2Writes.Add(1)
h.logger.Debugf("Synchronous write to L2 completed for critical key: %s", key)
} else {
// Asynchronous write for non-critical cache types
select {
case h.asyncWriteBuffer <- &asyncWriteItem{
key: key,
value: value,
ttl: ttl,
ctx: ctx,
}:
h.logger.Debugf("Queued async write to L2 for key: %s", key)
default:
// Buffer is full, log and continue
h.logger.Warnf("Async write buffer full, dropping L2 write for key: %s", key)
h.errors.Add(1)
}
}
return nil
}
// Get retrieves a value from cache, checking L1 first, then L2
func (h *HybridBackend) Get(ctx context.Context, key string) ([]byte, time.Duration, bool, error) {
// Try L1 first
value, ttl, exists, err := h.primary.Get(ctx, key)
if err != nil {
h.errors.Add(1)
h.logger.Debugf("L1 get error for key %s: %v", key, err)
}
if exists {
h.l1Hits.Add(1)
return value, ttl, true, nil
}
// Check if we're in fallback mode
if h.fallbackMode.Load() {
h.misses.Add(1)
return nil, 0, false, nil
}
// Try L2
value, ttl, exists, err = h.secondary.Get(ctx, key)
if err != nil {
h.errors.Add(1)
h.logger.Debugf("L2 get error for key %s: %v", key, err)
h.recordL2Error()
h.misses.Add(1)
return nil, 0, false, nil // Don't propagate L2 errors
}
if !exists {
h.misses.Add(1)
return nil, 0, false, nil
}
h.l2Hits.Add(1)
// Populate L1 cache with value from L2 (write-through on read)
// Use goroutine to avoid blocking the read path
go func() {
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
if err := h.primary.Set(writeCtx, key, value, ttl); err != nil {
h.logger.Debugf("Failed to populate L1 cache from L2 for key %s: %v", key, err)
} else {
h.logger.Debugf("Populated L1 cache from L2 for key: %s", key)
}
}()
return value, ttl, true, nil
}
// Delete removes a key from both L1 and L2 caches
func (h *HybridBackend) Delete(ctx context.Context, key string) (bool, error) {
var deleted bool
// Delete from L1
if d, err := h.primary.Delete(ctx, key); err != nil {
h.logger.Debugf("Failed to delete from L1 cache: %v", err)
} else if d {
deleted = true
}
// Delete from L2 if not in fallback mode
if !h.fallbackMode.Load() {
if d, err := h.secondary.Delete(ctx, key); err != nil {
h.logger.Debugf("Failed to delete from L2 cache: %v", err)
h.recordL2Error()
} else if d {
deleted = true
}
}
return deleted, nil
}
// Exists checks if a key exists in either cache
func (h *HybridBackend) Exists(ctx context.Context, key string) (bool, error) {
// Check L1 first
if exists, err := h.primary.Exists(ctx, key); err == nil && exists {
return true, nil
}
// Check L2 if not in fallback mode
if !h.fallbackMode.Load() {
if exists, err := h.secondary.Exists(ctx, key); err == nil && exists {
return true, nil
}
}
return false, nil
}
// Clear removes all keys from both caches
func (h *HybridBackend) Clear(ctx context.Context) error {
var lastErr error
// Clear L1
if err := h.primary.Clear(ctx); err != nil {
h.logger.Errorf("Failed to clear L1 cache: %v", err)
lastErr = err
}
// Clear L2 if not in fallback mode
if !h.fallbackMode.Load() {
if err := h.secondary.Clear(ctx); err != nil {
h.logger.Errorf("Failed to clear L2 cache: %v", err)
h.recordL2Error()
lastErr = err
}
}
return lastErr
}
// GetStats returns statistics for the hybrid cache
func (h *HybridBackend) GetStats() map[string]interface{} {
l1Hits := h.l1Hits.Load()
l2Hits := h.l2Hits.Load()
misses := h.misses.Load()
total := l1Hits + l2Hits + misses
stats := map[string]interface{}{
"type": TypeHybrid,
"l1_hits": l1Hits,
"l2_hits": l2Hits,
"misses": misses,
"total": total,
"l1_writes": h.l1Writes.Load(),
"l2_writes": h.l2Writes.Load(),
"errors": h.errors.Load(),
"fallback_mode": h.fallbackMode.Load(),
}
if total > 0 {
stats["l1_hit_rate"] = float64(l1Hits) / float64(total)
stats["l2_hit_rate"] = float64(l2Hits) / float64(total)
stats["overall_hit_rate"] = float64(l1Hits+l2Hits) / float64(total)
}
// Add sub-backend stats
stats["l1_stats"] = h.primary.GetStats()
stats["l2_stats"] = h.secondary.GetStats()
// Add last L2 error time if available
if lastErr := h.lastL2Error.Load(); lastErr != nil {
if t, ok := lastErr.(time.Time); ok {
stats["last_l2_error"] = t.Format(time.RFC3339)
stats["seconds_since_l2_error"] = time.Since(t).Seconds()
}
}
return stats
}
// Ping checks if both backends are healthy
func (h *HybridBackend) Ping(ctx context.Context) error {
// Check L1
if err := h.primary.Ping(ctx); err != nil {
return fmt.Errorf("L1 ping failed: %w", err)
}
// Check L2 (but don't fail if it's down)
if err := h.secondary.Ping(ctx); err != nil {
h.logger.Warnf("L2 ping failed: %v", err)
h.recordL2Error()
// Don't return error - we can operate with L1 only
} else {
// L2 is healthy, clear fallback mode if it was set
if h.fallbackMode.CompareAndSwap(true, false) {
h.logger.Infof("L2 backend recovered, exiting fallback mode")
}
}
return nil
}
// Close shuts down the hybrid backend
func (h *HybridBackend) Close() error {
// Cancel context to stop workers
h.cancel()
// Close async write channel
close(h.asyncWriteBuffer)
// Wait for workers to finish with timeout
done := make(chan struct{})
go func() {
h.wg.Wait()
close(done)
}()
select {
case <-done:
// Workers finished
case <-time.After(5 * time.Second):
h.logger.Warnf("Timeout waiting for workers to finish")
}
var lastErr error
// Close backends
if err := h.primary.Close(); err != nil {
h.logger.Errorf("Failed to close L1 backend: %v", err)
lastErr = err
}
if err := h.secondary.Close(); err != nil {
h.logger.Errorf("Failed to close L2 backend: %v", err)
lastErr = err
}
h.logger.Infof("HybridBackend closed")
return lastErr
}
// GetMany retrieves multiple values efficiently
func (h *HybridBackend) GetMany(ctx context.Context, keys []string) (map[string][]byte, error) {
if len(keys) == 0 {
return make(map[string][]byte), nil
}
results := make(map[string][]byte, len(keys))
missingKeys := make([]string, 0)
// Try L1 first for all keys
for _, key := range keys {
if value, _, exists, _ := h.primary.Get(ctx, key); exists {
results[key] = value
h.l1Hits.Add(1)
} else {
missingKeys = append(missingKeys, key)
}
}
// If all found in L1 or in fallback mode, return
if len(missingKeys) == 0 || h.fallbackMode.Load() {
return results, nil
}
// Try L2 for missing keys using batch operation if available
if batcher, ok := h.secondary.(interface {
GetMany(context.Context, []string) (map[string][]byte, error)
}); ok {
l2Results, err := batcher.GetMany(ctx, missingKeys)
if err != nil {
h.logger.Debugf("L2 batch get error: %v", err)
h.recordL2Error()
} else {
for key, value := range l2Results {
results[key] = value
h.l2Hits.Add(1)
// Asynchronously populate L1
go func(k string, v []byte) {
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
_ = h.primary.Set(writeCtx, k, v, 0) // Use default TTL
}(key, value)
}
}
} else {
// Fallback to individual gets
for _, key := range missingKeys {
if value, ttl, exists, err := h.secondary.Get(ctx, key); err == nil && exists {
results[key] = value
h.l2Hits.Add(1)
// Asynchronously populate L1
go func(k string, v []byte, t time.Duration) {
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
_ = h.primary.Set(writeCtx, k, v, t)
}(key, value, ttl)
}
}
}
// Count misses for keys not found anywhere
for _, key := range keys {
if _, found := results[key]; !found {
h.misses.Add(1)
}
}
return results, nil
}
// SetMany stores multiple key-value pairs efficiently
func (h *HybridBackend) SetMany(ctx context.Context, items map[string][]byte, ttl time.Duration) error {
if len(items) == 0 {
return nil
}
// Write to L1 first
for key, value := range items {
if err := h.primary.Set(ctx, key, value, ttl); err != nil {
h.logger.Debugf("Failed to write to L1 in batch: %v", err)
} else {
h.l1Writes.Add(1)
}
}
// Skip L2 if in fallback mode
if h.fallbackMode.Load() {
return nil
}
// Check if L2 supports batch operations
if batcher, ok := h.secondary.(interface {
SetMany(context.Context, map[string][]byte, time.Duration) error
}); ok {
if err := batcher.SetMany(ctx, items, ttl); err != nil {
h.logger.Warnf("Failed to batch write to L2: %v", err)
h.recordL2Error()
} else {
h.l2Writes.Add(int64(len(items)))
}
} else {
// Fallback to individual sets
for key, value := range items {
cacheType := h.extractCacheType(key)
if h.syncWriteCacheTypes[cacheType] {
// Sync write for critical types
if err := h.secondary.Set(ctx, key, value, ttl); err != nil {
h.logger.Debugf("Failed to write to L2: %v", err)
h.recordL2Error()
} else {
h.l2Writes.Add(1)
}
} else {
// Async write for non-critical types
select {
case h.asyncWriteBuffer <- &asyncWriteItem{
key: key,
value: value,
ttl: ttl,
ctx: ctx,
}:
// Queued
default:
h.logger.Warnf("Async buffer full for batch write")
}
}
}
}
return nil
}
// asyncWriteWorker processes asynchronous writes to L2
func (h *HybridBackend) asyncWriteWorker() {
defer h.wg.Done()
for {
select {
case <-h.ctx.Done():
// Drain remaining items with best effort
for len(h.asyncWriteBuffer) > 0 {
select {
case item := <-h.asyncWriteBuffer:
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
_ = h.secondary.Set(ctx, item.key, item.value, item.ttl)
cancel()
default:
return
}
}
return
case item, ok := <-h.asyncWriteBuffer:
if !ok {
return
}
// Skip if in fallback mode
if h.fallbackMode.Load() {
continue
}
// Perform the write with a timeout
writeCtx, cancel := context.WithTimeout(item.ctx, 500*time.Millisecond)
if err := h.secondary.Set(writeCtx, item.key, item.value, item.ttl); err != nil {
h.errors.Add(1)
h.logger.Debugf("Async write to L2 failed for key %s: %v", item.key, err)
h.recordL2Error()
} else {
h.l2Writes.Add(1)
h.logger.Debugf("Async write to L2 completed for key: %s", item.key)
}
cancel()
}
}
}
// healthMonitor periodically checks L2 health and manages fallback mode
func (h *HybridBackend) healthMonitor() {
defer h.wg.Done()
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-h.ctx.Done():
return
case <-ticker.C:
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
if err := h.secondary.Ping(ctx); err != nil {
if !h.fallbackMode.Load() {
h.fallbackMode.Store(true)
h.logger.Warnf("L2 backend unhealthy, entering fallback mode: %v", err)
}
} else {
if h.fallbackMode.CompareAndSwap(true, false) {
h.logger.Infof("L2 backend healthy, exiting fallback mode")
}
}
cancel()
}
}
}
// recordL2Error records the timestamp of an L2 error
func (h *HybridBackend) recordL2Error() {
h.lastL2Error.Store(time.Now())
// Check if we should enter fallback mode based on recent errors
if !h.fallbackMode.Load() {
// Simple heuristic: if we've had an error in the last second, consider L2 unhealthy
if lastErr := h.lastL2Error.Load(); lastErr != nil {
if t, ok := lastErr.(time.Time); ok && time.Since(t) < time.Second {
h.fallbackMode.Store(true)
h.logger.Warnf("Multiple L2 errors detected, entering fallback mode")
}
}
}
}
// extractCacheType attempts to determine the cache type from the key
func (h *HybridBackend) extractCacheType(key string) string {
// Simple heuristic based on key prefixes
// This should match the actual cache type strategy in the main application
if len(key) > 10 {
prefix := key[:10]
switch {
case contains(prefix, "blacklist"):
return "blacklist"
case contains(prefix, "token"):
return "token"
case contains(prefix, "metadata"):
return "metadata"
case contains(prefix, "jwk"):
return "jwk"
case contains(prefix, "session"):
return "session"
case contains(prefix, "introspect"):
return "introspection"
}
}
return "general"
}
// contains checks if a string contains a substring (case-insensitive)
func contains(s, substr string) bool {
if len(substr) > len(s) {
return false
}
for i := 0; i <= len(s)-len(substr); i++ {
match := true
for j := 0; j < len(substr); j++ {
if toLower(s[i+j]) != toLower(substr[j]) {
match = false
break
}
}
if match {
return true
}
}
return false
}
// toLower converts a byte to lowercase
func toLower(b byte) byte {
if b >= 'A' && b <= 'Z' {
return b + 32
}
return b
}
File diff suppressed because it is too large Load Diff
+102
View File
@@ -0,0 +1,102 @@
// Package backend provides cache backend implementations for the Traefik OIDC plugin.
package backends
import (
"context"
"time"
)
// CacheBackend defines the interface for all cache backend implementations
// Implementations include: MemoryBackend, RedisBackend, and HybridBackend
type CacheBackend interface {
// Set stores a value in the cache with the specified TTL
// Returns an error if the operation fails
Set(ctx context.Context, key string, value []byte, ttl time.Duration) error
// Get retrieves a value from the cache
// Returns: value, remaining TTL, exists flag, and error
// If the key doesn't exist, exists will be false
Get(ctx context.Context, key string) (value []byte, ttl time.Duration, exists bool, err error)
// Delete removes a key from the cache
// Returns true if the key was deleted, false if it didn't exist
Delete(ctx context.Context, key string) (bool, error)
// Exists checks if a key exists in the cache
Exists(ctx context.Context, key string) (bool, error)
// Clear removes all keys from the cache
Clear(ctx context.Context) error
// GetStats returns cache statistics
// Stats include: hits, misses, size, memory usage, etc.
GetStats() map[string]interface{}
// Close shuts down the cache backend and releases resources
Close() error
// Ping checks if the backend is healthy and responsive
Ping(ctx context.Context) error
}
// BackendStats represents statistics for a cache backend
type BackendStats struct {
StartTime time.Time
LastErrorTime time.Time
Type BackendType
LastError string
Deletes int64
Errors int64
Evictions int64
CurrentSize int64
MaxSize int64
MemoryUsage int64
AverageGetLatency time.Duration
AverageSetLatency time.Duration
Sets int64
Misses int64
Uptime time.Duration
Hits int64
}
// BackendCapabilities describes the capabilities of a cache backend
type BackendCapabilities struct {
// Distributed indicates if the backend is distributed across multiple instances
Distributed bool
// Persistent indicates if the backend persists data across restarts
Persistent bool
// Eviction indicates if the backend supports automatic eviction
Eviction bool
// TTL indicates if the backend supports TTL (time-to-live)
TTL bool
// MaxKeySize is the maximum size of a key in bytes (0 = unlimited)
MaxKeySize int64
// MaxValueSize is the maximum size of a value in bytes (0 = unlimited)
MaxValueSize int64
// MaxKeys is the maximum number of keys (0 = unlimited)
MaxKeys int64
// SupportsExpire indicates if the backend supports expiration
SupportsExpire bool
// SupportsMultiGet indicates if the backend supports batch get operations
SupportsMultiGet bool
// SupportsTransaction indicates if the backend supports transactions
SupportsTransaction bool
// SupportsCompression indicates if the backend supports compression
SupportsCompression bool
// RequiresSerialize indicates if values must be serialized
RequiresSerialize bool
// AtomicOperations indicates if the backend supports atomic operations
AtomicOperations bool
}
+421
View File
@@ -0,0 +1,421 @@
package backends
import (
"context"
"fmt"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestCacheBackendContract defines a set of tests that all CacheBackend implementations must pass
// This ensures that Memory, Redis, and Hybrid backends all behave consistently
func TestCacheBackendContract(t *testing.T) {
// Test suite will be run against each backend type
t.Run("MemoryBackend", func(t *testing.T) {
backend := setupMemoryBackend(t)
runContractTests(t, backend)
})
t.Run("RedisBackend", func(t *testing.T) {
backend := setupRedisBackend(t)
runContractTests(t, backend)
})
t.Run("HybridBackend", func(t *testing.T) {
backend := setupHybridBackend(t)
runContractTests(t, backend)
})
}
// runContractTests executes all contract tests against a backend
func runContractTests(t *testing.T, backend CacheBackend) {
t.Helper()
ctx := context.Background()
t.Run("BasicSetGet", func(t *testing.T) {
testBasicSetGet(t, ctx, backend)
})
t.Run("GetNonExistent", func(t *testing.T) {
testGetNonExistent(t, ctx, backend)
})
t.Run("UpdateExisting", func(t *testing.T) {
testUpdateExisting(t, ctx, backend)
})
t.Run("Delete", func(t *testing.T) {
testDelete(t, ctx, backend)
})
t.Run("DeleteNonExistent", func(t *testing.T) {
testDeleteNonExistent(t, ctx, backend)
})
t.Run("Exists", func(t *testing.T) {
testExists(t, ctx, backend)
})
t.Run("TTLExpiration", func(t *testing.T) {
testTTLExpiration(t, ctx, backend)
})
t.Run("Clear", func(t *testing.T) {
testClear(t, ctx, backend)
})
t.Run("Ping", func(t *testing.T) {
testPing(t, ctx, backend)
})
t.Run("Stats", func(t *testing.T) {
testStats(t, ctx, backend)
})
t.Run("ConcurrentAccess", func(t *testing.T) {
testConcurrentAccess(t, ctx, backend)
})
t.Run("LargeValues", func(t *testing.T) {
testLargeValues(t, ctx, backend)
})
t.Run("EmptyValues", func(t *testing.T) {
testEmptyValues(t, ctx, backend)
})
t.Run("SpecialCharactersInKeys", func(t *testing.T) {
testSpecialCharactersInKeys(t, ctx, backend)
})
}
// testBasicSetGet verifies basic set and get operations
func testBasicSetGet(t *testing.T, ctx context.Context, backend CacheBackend) {
t.Helper()
key := "test-key-1"
value := []byte("test-value-1")
ttl := 1 * time.Minute
// Set value
err := backend.Set(ctx, key, value, ttl)
require.NoError(t, err, "Set should not return error")
// Get value
retrieved, remainingTTL, exists, err := backend.Get(ctx, key)
require.NoError(t, err, "Get should not return error")
assert.True(t, exists, "Key should exist")
assert.Equal(t, value, retrieved, "Retrieved value should match")
assert.Greater(t, remainingTTL, 50*time.Second, "TTL should be close to original")
assert.LessOrEqual(t, remainingTTL, ttl, "TTL should not exceed original")
}
// testGetNonExistent verifies behavior when getting non-existent keys
func testGetNonExistent(t *testing.T, ctx context.Context, backend CacheBackend) {
t.Helper()
key := "non-existent-key"
retrieved, ttl, exists, err := backend.Get(ctx, key)
require.NoError(t, err, "Get should not return error for non-existent key")
assert.False(t, exists, "Key should not exist")
assert.Nil(t, retrieved, "Value should be nil")
assert.Equal(t, time.Duration(0), ttl, "TTL should be zero")
}
// testUpdateExisting verifies updating an existing key
func testUpdateExisting(t *testing.T, ctx context.Context, backend CacheBackend) {
t.Helper()
key := "update-key"
value1 := []byte("original-value")
value2 := []byte("updated-value")
ttl := 1 * time.Minute
// Set initial value
err := backend.Set(ctx, key, value1, ttl)
require.NoError(t, err)
// Update value
err = backend.Set(ctx, key, value2, ttl)
require.NoError(t, err)
// Verify updated value
retrieved, _, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, value2, retrieved, "Value should be updated")
}
// testDelete verifies delete operation
func testDelete(t *testing.T, ctx context.Context, backend CacheBackend) {
t.Helper()
key := "delete-key"
value := []byte("delete-value")
// Set value
err := backend.Set(ctx, key, value, 1*time.Minute)
require.NoError(t, err)
// Verify exists
exists, err := backend.Exists(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
// Delete
deleted, err := backend.Delete(ctx, key)
require.NoError(t, err)
assert.True(t, deleted, "Delete should return true for existing key")
// Verify deleted
exists, err = backend.Exists(ctx, key)
require.NoError(t, err)
assert.False(t, exists, "Key should not exist after delete")
}
// testDeleteNonExistent verifies deleting non-existent keys
func testDeleteNonExistent(t *testing.T, ctx context.Context, backend CacheBackend) {
t.Helper()
key := "non-existent-delete-key"
deleted, err := backend.Delete(ctx, key)
require.NoError(t, err)
assert.False(t, deleted, "Delete should return false for non-existent key")
}
// testExists verifies the Exists operation
func testExists(t *testing.T, ctx context.Context, backend CacheBackend) {
t.Helper()
key := "exists-key"
value := []byte("exists-value")
// Check non-existent key
exists, err := backend.Exists(ctx, key)
require.NoError(t, err)
assert.False(t, exists, "Key should not exist initially")
// Set value
err = backend.Set(ctx, key, value, 1*time.Minute)
require.NoError(t, err)
// Check existing key
exists, err = backend.Exists(ctx, key)
require.NoError(t, err)
assert.True(t, exists, "Key should exist after Set")
}
// testTTLExpiration verifies TTL expiration behavior
func testTTLExpiration(t *testing.T, ctx context.Context, backend CacheBackend) {
t.Helper()
key := "ttl-key"
value := []byte("ttl-value")
shortTTL := 100 * time.Millisecond
// Set with short TTL
err := backend.Set(ctx, key, value, shortTTL)
require.NoError(t, err)
// Verify exists immediately
exists, err := backend.Exists(ctx, key)
require.NoError(t, err)
assert.True(t, exists, "Key should exist immediately after Set")
// Wait for expiration
time.Sleep(200 * time.Millisecond)
// Verify expired
exists, err = backend.Exists(ctx, key)
require.NoError(t, err)
assert.False(t, exists, "Key should not exist after TTL expiration")
}
// testClear verifies Clear operation
func testClear(t *testing.T, ctx context.Context, backend CacheBackend) {
t.Helper()
// Set multiple keys
for i := 0; i < 5; i++ {
key := fmt.Sprintf("clear-key-%d", i)
value := []byte(fmt.Sprintf("clear-value-%d", i))
err := backend.Set(ctx, key, value, 1*time.Minute)
require.NoError(t, err)
}
// Give async writes time to complete before clearing
// This prevents race conditions with async write workers
time.Sleep(50 * time.Millisecond)
// Clear all
err := backend.Clear(ctx)
require.NoError(t, err)
// Verify all keys are gone
for i := 0; i < 5; i++ {
key := fmt.Sprintf("clear-key-%d", i)
exists, err := backend.Exists(ctx, key)
require.NoError(t, err)
assert.False(t, exists, "Key should not exist after Clear")
}
}
// testPing verifies Ping operation
func testPing(t *testing.T, ctx context.Context, backend CacheBackend) {
t.Helper()
err := backend.Ping(ctx)
assert.NoError(t, err, "Ping should succeed on healthy backend")
}
// testStats verifies GetStats operation
func testStats(t *testing.T, ctx context.Context, backend CacheBackend) {
t.Helper()
stats := backend.GetStats()
assert.NotNil(t, stats, "Stats should not be nil")
// Stats should contain basic metrics
_, hasHits := stats["hits"]
_, hasMisses := stats["misses"]
assert.True(t, hasHits || hasMisses, "Stats should contain hits or misses")
}
// testConcurrentAccess verifies thread safety
func testConcurrentAccess(t *testing.T, ctx context.Context, backend CacheBackend) {
t.Helper()
var wg sync.WaitGroup
goroutines := 10
iterations := 20
// Concurrent writes
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
for j := 0; j < iterations; j++ {
key := fmt.Sprintf("concurrent-key-%d-%d", id, j)
value := []byte(fmt.Sprintf("concurrent-value-%d-%d", id, j))
err := backend.Set(ctx, key, value, 1*time.Minute)
assert.NoError(t, err)
// Read back
retrieved, _, exists, err := backend.Get(ctx, key)
assert.NoError(t, err)
if exists {
assert.Equal(t, value, retrieved)
}
}
}(i)
}
wg.Wait()
}
// testLargeValues verifies handling of large values
func testLargeValues(t *testing.T, ctx context.Context, backend CacheBackend) {
t.Helper()
key := "large-value-key"
value := GenerateLargeValue(1024 * 1024) // 1MB
err := backend.Set(ctx, key, value, 1*time.Minute)
require.NoError(t, err, "Should handle large values")
retrieved, _, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, len(value), len(retrieved), "Large value should be retrieved intact")
}
// testEmptyValues verifies handling of empty values
func testEmptyValues(t *testing.T, ctx context.Context, backend CacheBackend) {
t.Helper()
key := "empty-value-key"
value := []byte{}
err := backend.Set(ctx, key, value, 1*time.Minute)
require.NoError(t, err, "Should handle empty values")
retrieved, _, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists, "Empty value should exist")
assert.Equal(t, 0, len(retrieved), "Retrieved value should be empty")
}
// testSpecialCharactersInKeys verifies handling of special characters in keys
func testSpecialCharactersInKeys(t *testing.T, ctx context.Context, backend CacheBackend) {
t.Helper()
specialKeys := []string{
"key:with:colons",
"key/with/slashes",
"key-with-dashes",
"key_with_underscores",
"key.with.dots",
"key|with|pipes",
}
for _, key := range specialKeys {
value := []byte(fmt.Sprintf("value-for-%s", key))
err := backend.Set(ctx, key, value, 1*time.Minute)
require.NoError(t, err, "Should handle special character in key: %s", key)
retrieved, _, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists, "Key with special characters should exist: %s", key)
assert.Equal(t, value, retrieved)
}
}
// Helper functions to setup different backend types
// These will be implemented in respective test files
func setupMemoryBackend(t *testing.T) CacheBackend {
t.Helper()
// This will be implemented in memory_test.go
// For now, return nil to allow compilation
t.Skip("MemoryBackend implementation pending")
return nil
}
func setupRedisBackend(t *testing.T) CacheBackend {
t.Helper()
// This will be implemented in redis_test.go
// For now, return nil to allow compilation
t.Skip("RedisBackend implementation pending")
return nil
}
func setupHybridBackend(t *testing.T) CacheBackend {
t.Helper()
primary := newMockBackend()
secondary := newMockBackend()
config := &HybridConfig{
Primary: primary,
Secondary: secondary,
AsyncBufferSize: 100,
Logger: NewTestLogger(t),
}
hybrid, err := NewHybridBackend(config)
require.NoError(t, err)
t.Cleanup(func() {
hybrid.Close()
})
return hybrid
}
+535
View File
@@ -0,0 +1,535 @@
// Package backend provides cache backend implementations for the Traefik OIDC plugin.
package backends
import (
"context"
"sync"
"sync/atomic"
"time"
)
// Default configuration values
const (
defaultShardCount = 256
defaultMaxSize = int64(10000)
defaultMaxMemory = int64(100 * 1024 * 1024) // 100MB
defaultCleanupInterval = 5 * time.Minute
)
// memoryCacheItem represents an item in the memory cache
type memoryCacheItem struct {
expiresAt time.Time
createdAt time.Time
accessedAt time.Time
value interface{}
element interface{} // *list.Element, using interface{} to avoid import cycle
key string
accessCount int64
size int64
}
// isExpired checks if the item is expired
func (item *memoryCacheItem) isExpired() bool {
if item.expiresAt.IsZero() {
return false
}
return time.Now().After(item.expiresAt)
}
// MemoryCacheBackend implements the CacheBackend interface using sharded in-memory storage
// The sharded design reduces lock contention by partitioning keys across multiple shards,
// each with its own lock.
type MemoryCacheBackend struct {
shards []*cacheShard
startTime time.Time
lastErrorTime time.Time
cleanupDone chan struct{}
cleanupTicker *time.Ticker
lastError string
shardCount uint32
shardMask uint32
maxSize int64
maxMemory int64
cleanupInterval time.Duration
// Global stats (aggregated from shards)
hits atomic.Int64
misses atomic.Int64
sets atomic.Int64
deletes atomic.Int64
evictions atomic.Int64
errors atomic.Int64
// Latency tracking
totalGetTime atomic.Int64
totalSetTime atomic.Int64
getCount atomic.Int64
setCount atomic.Int64
// State
closed atomic.Bool
mu sync.RWMutex // For global operations like stats and error tracking
}
// NewMemoryCacheBackend creates a new sharded memory cache backend
func NewMemoryCacheBackend(maxSize int64, maxMemory int64, cleanupInterval time.Duration) *MemoryCacheBackend {
if maxSize <= 0 {
maxSize = defaultMaxSize
}
if maxMemory <= 0 {
maxMemory = defaultMaxMemory
}
if cleanupInterval <= 0 {
cleanupInterval = defaultCleanupInterval
}
shardCount := uint32(defaultShardCount)
// For very small caches, reduce shard count to maintain sensible per-shard limits
// Ensure each shard can hold at least 2 items for proper LRU behavior
for shardCount > 1 && maxSize/int64(shardCount) < 2 {
shardCount /= 2
}
if shardCount < 1 {
shardCount = 1
}
// Per-shard limits are soft hints; global limits are enforced
// Give shards 2x the average to allow for uneven distribution
shardMaxSize := (maxSize * 2) / int64(shardCount)
if shardMaxSize < 4 {
shardMaxSize = 4
}
shardMaxMemory := (maxMemory * 2) / int64(shardCount)
if shardMaxMemory < 4096 {
shardMaxMemory = 4096 // Minimum 4KB per shard
}
m := &MemoryCacheBackend{
shards: make([]*cacheShard, shardCount),
shardCount: shardCount,
shardMask: shardCount - 1, // For fast modulo with power-of-2
maxSize: maxSize,
maxMemory: maxMemory,
startTime: time.Now(),
cleanupInterval: cleanupInterval,
cleanupDone: make(chan struct{}),
}
// Initialize shards
for i := uint32(0); i < shardCount; i++ {
m.shards[i] = newCacheShard(shardMaxSize, shardMaxMemory)
}
// Start cleanup goroutine
m.cleanupTicker = time.NewTicker(cleanupInterval)
go m.cleanupLoop()
return m
}
// getShard returns the shard for a given key
func (m *MemoryCacheBackend) getShard(key string) *cacheShard {
hash := fnv32(key)
return m.shards[hash&m.shardMask]
}
// cleanupLoop runs periodic cleanup of expired items
func (m *MemoryCacheBackend) cleanupLoop() {
for {
select {
case <-m.cleanupTicker.C:
m.cleanupExpired()
case <-m.cleanupDone:
return
}
}
}
// cleanupExpired removes all expired items from all shards
func (m *MemoryCacheBackend) cleanupExpired() {
if m.closed.Load() {
return
}
totalRemoved := 0
for _, shard := range m.shards {
totalRemoved += shard.cleanup()
}
if totalRemoved > 0 {
m.evictions.Add(int64(totalRemoved))
}
}
// Get retrieves a value from the cache
func (m *MemoryCacheBackend) Get(ctx context.Context, key string) (interface{}, error) {
if m.closed.Load() {
return nil, ErrBackendUnavailable
}
start := time.Now()
defer func() {
duration := time.Since(start).Nanoseconds()
m.totalGetTime.Add(duration)
m.getCount.Add(1)
}()
shard := m.getShard(key)
value, exists, expired := shard.get(key)
if expired {
// Clean up expired item
shard.delete(key)
m.misses.Add(1)
return nil, ErrCacheMiss
}
if !exists {
m.misses.Add(1)
return nil, ErrCacheMiss
}
m.hits.Add(1)
return value, nil
}
// Set stores a value in the cache with optional TTL
func (m *MemoryCacheBackend) Set(ctx context.Context, key string, value interface{}, ttl time.Duration) error {
if m.closed.Load() {
return ErrBackendUnavailable
}
start := time.Now()
defer func() {
duration := time.Since(start).Nanoseconds()
m.totalSetTime.Add(duration)
m.setCount.Add(1)
}()
// Calculate item size
itemSize := int64(len(key)) + estimateValueSize(value)
// Enforce global limits before adding new item
m.enforceGlobalLimits(itemSize)
var expiresAt time.Time
if ttl > 0 {
expiresAt = time.Now().Add(ttl)
}
shard := m.getShard(key)
shard.set(key, value, expiresAt, itemSize)
m.sets.Add(1)
return nil
}
// enforceGlobalLimits ensures global size and memory limits are respected
// by evicting from shards when necessary
func (m *MemoryCacheBackend) enforceGlobalLimits(newItemSize int64) {
// Check and enforce size limit
for {
totalSize, totalMemory := m.getGlobalStats()
needsSizeEviction := m.maxSize > 0 && totalSize >= m.maxSize
needsMemoryEviction := m.maxMemory > 0 && totalMemory+newItemSize > m.maxMemory
if !needsSizeEviction && !needsMemoryEviction {
break
}
// Find the shard with the most items and evict from it
evicted := m.evictFromLargestShard()
if !evicted {
break // No more items to evict
}
m.evictions.Add(1)
}
}
// getGlobalStats returns the total size and memory usage across all shards
func (m *MemoryCacheBackend) getGlobalStats() (totalSize, totalMemory int64) {
for _, shard := range m.shards {
size, memory := shard.stats()
totalSize += size
totalMemory += memory
}
return
}
// evictFromLargestShard evicts the globally oldest item across all shards
// This provides true LRU behavior even with sharding
func (m *MemoryCacheBackend) evictFromLargestShard() bool {
var oldestShard *cacheShard
var oldestTime time.Time
for _, shard := range m.shards {
accessTime := shard.getOldestAccessTime()
// Skip empty shards
if accessTime.IsZero() {
continue
}
// Find the shard with the oldest (earliest) access time
if oldestShard == nil || accessTime.Before(oldestTime) {
oldestTime = accessTime
oldestShard = shard
}
}
if oldestShard == nil {
return false
}
return oldestShard.evictOne()
}
// Delete removes a key from the cache
func (m *MemoryCacheBackend) Delete(ctx context.Context, key string) error {
if m.closed.Load() {
return ErrBackendUnavailable
}
shard := m.getShard(key)
if shard.delete(key) {
m.deletes.Add(1)
}
return nil
}
// Exists checks if a key exists in the cache
func (m *MemoryCacheBackend) Exists(ctx context.Context, key string) (bool, error) {
if m.closed.Load() {
return false, ErrBackendUnavailable
}
shard := m.getShard(key)
return shard.exists(key), nil
}
// Clear removes all items from the cache
func (m *MemoryCacheBackend) Clear(ctx context.Context) error {
if m.closed.Load() {
return ErrBackendUnavailable
}
for _, shard := range m.shards {
shard.clear()
}
return nil
}
// Keys returns all keys matching the pattern (use "*" for all keys)
func (m *MemoryCacheBackend) Keys(ctx context.Context, pattern string) ([]string, error) {
if m.closed.Load() {
return nil, ErrBackendUnavailable
}
var allKeys []string
for _, shard := range m.shards {
keys := shard.keys(pattern)
allKeys = append(allKeys, keys...)
}
return allKeys, nil
}
// Size returns the total number of items in the cache
func (m *MemoryCacheBackend) Size(ctx context.Context) (int64, error) {
if m.closed.Load() {
return 0, ErrBackendUnavailable
}
var total int64
for _, shard := range m.shards {
size, _ := shard.stats()
total += size
}
return total, nil
}
// TTL returns the remaining time-to-live for a key
func (m *MemoryCacheBackend) TTL(ctx context.Context, key string) (time.Duration, error) {
if m.closed.Load() {
return 0, ErrBackendUnavailable
}
shard := m.getShard(key)
ttl, exists := shard.ttl(key)
if !exists {
return 0, ErrCacheMiss
}
return ttl, nil
}
// Expire updates the TTL for an existing key
func (m *MemoryCacheBackend) Expire(ctx context.Context, key string, ttl time.Duration) error {
if m.closed.Load() {
return ErrBackendUnavailable
}
shard := m.getShard(key)
if !shard.expire(key, ttl) {
return ErrCacheMiss
}
return nil
}
// GetStats returns statistics about the cache backend
func (m *MemoryCacheBackend) GetStats(ctx context.Context) (*BackendStats, error) {
if m.closed.Load() {
return nil, ErrBackendUnavailable
}
// Aggregate stats from all shards
var totalSize, totalMemory int64
for _, shard := range m.shards {
size, memory := shard.stats()
totalSize += size
totalMemory += memory
}
m.mu.RLock()
lastError := m.lastError
lastErrorTime := m.lastErrorTime
m.mu.RUnlock()
avgGetLatency := time.Duration(0)
if getCount := m.getCount.Load(); getCount > 0 {
avgGetLatency = time.Duration(m.totalGetTime.Load() / getCount)
}
avgSetLatency := time.Duration(0)
if setCount := m.setCount.Load(); setCount > 0 {
avgSetLatency = time.Duration(m.totalSetTime.Load() / setCount)
}
return &BackendStats{
Type: TypeMemory,
Hits: m.hits.Load(),
Misses: m.misses.Load(),
Sets: m.sets.Load(),
Deletes: m.deletes.Load(),
Errors: m.errors.Load(),
Evictions: m.evictions.Load(),
CurrentSize: totalSize,
MaxSize: m.maxSize,
MemoryUsage: totalMemory,
AverageGetLatency: avgGetLatency,
AverageSetLatency: avgSetLatency,
LastError: lastError,
LastErrorTime: lastErrorTime,
Uptime: time.Since(m.startTime),
StartTime: m.startTime,
}, nil
}
// Ping checks if the backend is healthy
func (m *MemoryCacheBackend) Ping(ctx context.Context) error {
if m.closed.Load() {
return ErrBackendUnavailable
}
return nil
}
// Close closes the backend and releases resources
func (m *MemoryCacheBackend) Close() error {
if m.closed.Swap(true) {
return nil // Already closed
}
m.cleanupTicker.Stop()
close(m.cleanupDone)
// Clear all shards
for _, shard := range m.shards {
shard.clear()
}
return nil
}
// IsHealthy returns true if the backend is healthy
func (m *MemoryCacheBackend) IsHealthy() bool {
return !m.closed.Load()
}
// Type returns the backend type
func (m *MemoryCacheBackend) Type() BackendType {
return TypeMemory
}
// Capabilities returns the backend capabilities
func (m *MemoryCacheBackend) Capabilities() *BackendCapabilities {
return &BackendCapabilities{
Distributed: false,
Persistent: false,
Eviction: true,
TTL: true,
MaxKeySize: 1024, // 1KB
MaxValueSize: 10485760, // 10MB
MaxKeys: m.maxSize,
SupportsExpire: true,
SupportsMultiGet: true,
SupportsTransaction: false,
SupportsCompression: false,
RequiresSerialize: false,
}
}
// GetShardCount returns the number of shards (for testing/monitoring)
func (m *MemoryCacheBackend) GetShardCount() uint32 {
return m.shardCount
}
// GetShardStats returns per-shard statistics (for monitoring)
func (m *MemoryCacheBackend) GetShardStats() []map[string]int64 {
stats := make([]map[string]int64, m.shardCount)
for i, shard := range m.shards {
size, memory := shard.stats()
stats[i] = map[string]int64{
"size": size,
"memory": memory,
}
}
return stats
}
// Helper functions
// estimateValueSize estimates the size of a value in bytes
func estimateValueSize(value interface{}) int64 {
switch v := value.(type) {
case string:
return int64(len(v))
case []byte:
return int64(len(v))
case int, int32, int64, uint, uint32, uint64:
return 8
case float32, float64:
return 8
case bool:
return 1
default:
// For complex types, use a default estimate
return 256
}
}
// matchPattern checks if a key matches a pattern (simplified glob matching)
func matchPattern(pattern, key string) bool {
if pattern == "*" {
return true
}
// Simplified pattern matching
if len(pattern) > 0 && pattern[0] == '*' {
suffix := pattern[1:]
return len(key) >= len(suffix) && key[len(key)-len(suffix):] == suffix
}
return key == pattern
}
+182
View File
@@ -0,0 +1,182 @@
package backends
import (
"context"
"testing"
"time"
"github.com/alicebob/miniredis/v2"
)
// setupBenchmarkRedis creates a miniredis instance for benchmarking
func setupBenchmarkRedis(b *testing.B) string {
b.Helper()
mr, err := miniredis.Run()
if err != nil {
b.Fatal(err)
}
b.Cleanup(func() {
mr.Close()
})
return mr.Addr()
}
// BenchmarkRedisOperations_WithPooling benchmarks memory allocations with object pooling
func BenchmarkRedisOperations_WithPooling(b *testing.B) {
addr := setupBenchmarkRedis(b)
config := &PoolConfig{
Address: addr,
MaxConnections: 10,
ConnectTimeout: 5 * time.Second,
}
pool, err := NewConnectionPool(config)
if err != nil {
b.Fatal(err)
}
defer pool.Close()
ctx := context.Background()
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
conn, err := pool.Get(ctx)
if err != nil {
b.Fatal(err)
}
// Perform various operations
_, _ = conn.Do("SET", "bench-key", "bench-value")
_, _ = conn.Do("GET", "bench-key")
_, _ = conn.Do("EXISTS", "bench-key")
_, _ = conn.Do("DEL", "bench-key")
pool.Put(conn)
}
}
// BenchmarkRedisBackend_SetGet benchmarks the full backend with pooling
func BenchmarkRedisBackend_SetGet(b *testing.B) {
addr := setupBenchmarkRedis(b)
backend, err := NewRedisBackend(&Config{
RedisAddr: addr,
PoolSize: 10,
})
if err != nil {
b.Fatal(err)
}
defer backend.Close()
ctx := context.Background()
testData := []byte("benchmark test data with some content")
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
// Set operation
err := backend.Set(ctx, "bench-key", testData, 0)
if err != nil {
b.Fatal(err)
}
// Get operation
_, _, _, err = backend.Get(ctx, "bench-key")
if err != nil {
b.Fatal(err)
}
}
}
// BenchmarkRedisBackend_ConcurrentAccess benchmarks concurrent operations with pooling
func BenchmarkRedisBackend_ConcurrentAccess(b *testing.B) {
addr := setupBenchmarkRedis(b)
backend, err := NewRedisBackend(&Config{
RedisAddr: addr,
PoolSize: 10,
})
if err != nil {
b.Fatal(err)
}
defer backend.Close()
ctx := context.Background()
testData := []byte("concurrent benchmark data")
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_ = backend.Set(ctx, "concurrent-key", testData, 0)
_, _, _, _ = backend.Get(ctx, "concurrent-key")
}
})
}
// BenchmarkRESPProtocol_WriteRead benchmarks RESP protocol encoding/decoding
func BenchmarkRESPProtocol_WriteRead(b *testing.B) {
addr := setupBenchmarkRedis(b)
config := &PoolConfig{
Address: addr,
MaxConnections: 10,
ConnectTimeout: 5 * time.Second,
}
pool, err := NewConnectionPool(config)
if err != nil {
b.Fatal(err)
}
defer pool.Close()
ctx := context.Background()
conn, err := pool.Get(ctx)
if err != nil {
b.Fatal(err)
}
defer pool.Put(conn)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
// This tests the pooling of RESPReader/RESPWriter
_, _ = conn.Do("PING")
}
}
// BenchmarkConnectionPool_GetPut benchmarks connection pool operations
func BenchmarkConnectionPool_GetPut(b *testing.B) {
addr := setupBenchmarkRedis(b)
config := &PoolConfig{
Address: addr,
MaxConnections: 10,
ConnectTimeout: 5 * time.Second,
}
pool, err := NewConnectionPool(config)
if err != nil {
b.Fatal(err)
}
defer pool.Close()
ctx := context.Background()
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
conn, err := pool.Get(ctx)
if err != nil {
b.Fatal(err)
}
pool.Put(conn)
}
}
+290
View File
@@ -0,0 +1,290 @@
package backends
import (
"container/list"
"sync"
"time"
)
// cacheShard represents a single shard of the sharded cache
// Each shard has its own lock for reduced contention
type cacheShard struct {
items map[string]*memoryCacheItem
lruList *list.List
mu sync.RWMutex
maxSize int64
maxMemory int64
size int64
memoryUsed int64
}
// newCacheShard creates a new cache shard
func newCacheShard(maxSize, maxMemory int64) *cacheShard {
return &cacheShard{
items: make(map[string]*memoryCacheItem),
lruList: list.New(),
maxSize: maxSize,
maxMemory: maxMemory,
}
}
// get retrieves a value from this shard
// Returns: value, exists, expired
func (s *cacheShard) get(key string) (interface{}, bool, bool) {
s.mu.RLock()
item, exists := s.items[key]
s.mu.RUnlock()
if !exists {
return nil, false, false
}
if item.isExpired() {
return nil, true, true // exists but expired
}
// Update access time and LRU position under write lock
s.mu.Lock()
// Re-check item exists (could have been deleted)
item, exists = s.items[key]
if exists && !item.isExpired() {
item.accessedAt = time.Now()
item.accessCount++
if elem, ok := item.element.(*list.Element); ok && elem != nil {
s.lruList.MoveToFront(elem)
}
}
s.mu.Unlock()
if !exists || item.isExpired() {
return nil, false, false
}
return item.value, true, false
}
// set stores a value in this shard
func (s *cacheShard) set(key string, value interface{}, expiresAt time.Time, size int64) {
s.mu.Lock()
defer s.mu.Unlock()
// Check if we need to evict items
if s.maxSize > 0 && s.size >= s.maxSize {
s.evictLRULocked()
}
if s.maxMemory > 0 && s.memoryUsed+size > s.maxMemory {
s.evictLRULocked()
}
// Remove old item if exists
if oldItem, exists := s.items[key]; exists {
s.memoryUsed -= oldItem.size
if elem, ok := oldItem.element.(*list.Element); ok && elem != nil {
s.lruList.Remove(elem)
}
s.size--
}
now := time.Now()
item := &memoryCacheItem{
key: key,
value: value,
expiresAt: expiresAt,
createdAt: now,
accessedAt: now,
accessCount: 0,
size: size,
}
item.element = s.lruList.PushFront(item)
s.items[key] = item
s.size++
s.memoryUsed += size
}
// delete removes a key from this shard
// Returns true if the key was deleted
func (s *cacheShard) delete(key string) bool {
s.mu.Lock()
defer s.mu.Unlock()
item, exists := s.items[key]
if !exists {
return false
}
s.deleteItemLocked(item)
return true
}
// exists checks if a key exists (and is not expired)
func (s *cacheShard) exists(key string) bool {
s.mu.RLock()
item, exists := s.items[key]
s.mu.RUnlock()
if !exists {
return false
}
return !item.isExpired()
}
// ttl returns the remaining TTL for a key
func (s *cacheShard) ttl(key string) (time.Duration, bool) {
s.mu.RLock()
item, exists := s.items[key]
s.mu.RUnlock()
if !exists || item.isExpired() {
return 0, false
}
if item.expiresAt.IsZero() {
return 0, true // No expiration
}
remaining := time.Until(item.expiresAt)
if remaining < 0 {
return 0, false
}
return remaining, true
}
// expire updates the TTL for an existing key
func (s *cacheShard) expire(key string, ttl time.Duration) bool {
s.mu.Lock()
defer s.mu.Unlock()
item, exists := s.items[key]
if !exists || item.isExpired() {
return false
}
if ttl > 0 {
item.expiresAt = time.Now().Add(ttl)
} else {
item.expiresAt = time.Time{} // Remove expiration
}
return true
}
// keys returns all non-expired keys matching the pattern
func (s *cacheShard) keys(pattern string) []string {
s.mu.RLock()
defer s.mu.RUnlock()
var keys []string
for key, item := range s.items {
if !item.isExpired() && matchPattern(pattern, key) {
keys = append(keys, key)
}
}
return keys
}
// clear removes all items from this shard
func (s *cacheShard) clear() {
s.mu.Lock()
defer s.mu.Unlock()
s.items = make(map[string]*memoryCacheItem)
s.lruList.Init()
s.size = 0
s.memoryUsed = 0
}
// cleanup removes expired items
// Returns the number of items removed
func (s *cacheShard) cleanup() int {
s.mu.Lock()
defer s.mu.Unlock()
var toRemove []*memoryCacheItem
for _, item := range s.items {
if item.isExpired() {
toRemove = append(toRemove, item)
}
}
for _, item := range toRemove {
s.deleteItemLocked(item)
}
return len(toRemove)
}
// stats returns statistics for this shard
func (s *cacheShard) stats() (size, memory int64) {
s.mu.RLock()
defer s.mu.RUnlock()
return s.size, s.memoryUsed
}
// deleteItemLocked removes an item (must be called with lock held)
func (s *cacheShard) deleteItemLocked(item *memoryCacheItem) {
if elem, ok := item.element.(*list.Element); ok && elem != nil {
s.lruList.Remove(elem)
}
delete(s.items, item.key)
s.size--
s.memoryUsed -= item.size
}
// evictLRULocked evicts the least recently used item (must be called with lock held)
func (s *cacheShard) evictLRULocked() bool {
if s.lruList.Len() == 0 {
return false
}
element := s.lruList.Back()
if element != nil {
item := element.Value.(*memoryCacheItem)
s.deleteItemLocked(item)
return true
}
return false
}
// evictOne evicts one item from this shard (for global limit enforcement)
func (s *cacheShard) evictOne() bool {
s.mu.Lock()
defer s.mu.Unlock()
return s.evictLRULocked()
}
// getOldestAccessTime returns the access time of the LRU item (oldest) in this shard
// Returns zero time if shard is empty
func (s *cacheShard) getOldestAccessTime() time.Time {
s.mu.RLock()
defer s.mu.RUnlock()
if s.lruList.Len() == 0 {
return time.Time{}
}
element := s.lruList.Back()
if element != nil {
item := element.Value.(*memoryCacheItem)
return item.accessedAt
}
return time.Time{}
}
// fnv32 computes FNV-1a hash of a string
// This is a fast, well-distributed hash function
func fnv32(key string) uint32 {
const (
offset32 = uint32(2166136261)
prime32 = uint32(16777619)
)
hash := offset32
for i := 0; i < len(key); i++ {
hash ^= uint32(key[i])
hash *= prime32
}
return hash
}
+283
View File
@@ -0,0 +1,283 @@
package backends
import (
"context"
"fmt"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestShardedCache_ShardDistribution tests that keys are distributed across shards
func TestShardedCache_ShardDistribution(t *testing.T) {
t.Parallel()
// Create a cache with large enough size to have multiple shards
config := DefaultConfig()
config.MaxSize = 10000
config.MaxMemoryBytes = 100 * 1024 * 1024 // 100MB
backend, err := NewMemoryBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
// Add many items to see distribution
numItems := 1000
for i := 0; i < numItems; i++ {
key := fmt.Sprintf("dist-key-%d", i)
value := []byte(fmt.Sprintf("dist-value-%d", i))
err := backend.Set(ctx, key, value, time.Minute)
require.NoError(t, err)
}
// Check that items are distributed across multiple shards
shardStats := backend.MemoryCacheBackend.GetShardStats()
nonEmptyShards := 0
for _, stat := range shardStats {
if stat["size"] > 0 {
nonEmptyShards++
}
}
// With good hash distribution, we should have items in multiple shards
assert.Greater(t, nonEmptyShards, 1, "Items should be distributed across multiple shards")
}
// TestShardedCache_ShardCount tests that shard count adapts to cache size
func TestShardedCache_ShardCount(t *testing.T) {
t.Parallel()
tests := []struct {
maxSize int
expectLowShards bool
}{
{5, true}, // Very small cache should have fewer shards
{100, true}, // Small cache should have fewer shards
{10000, false}, // Large cache should have default shards
}
for _, tt := range tests {
t.Run(fmt.Sprintf("MaxSize_%d", tt.maxSize), func(t *testing.T) {
config := DefaultConfig()
config.MaxSize = tt.maxSize
backend, err := NewMemoryBackend(config)
require.NoError(t, err)
defer backend.Close()
shardCount := backend.MemoryCacheBackend.GetShardCount()
if tt.expectLowShards {
assert.Less(t, shardCount, uint32(256), "Small cache should have fewer shards")
} else {
assert.Equal(t, uint32(256), shardCount, "Large cache should have default shard count")
}
})
}
}
// TestShardedCache_ConcurrentSameKey tests concurrent access to the same key
func TestShardedCache_ConcurrentSameKey(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
key := "concurrent-same-key"
initialValue := []byte("initial-value")
err = backend.Set(ctx, key, initialValue, time.Minute)
require.NoError(t, err)
var wg sync.WaitGroup
goroutines := 50
iterations := 100
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
for j := 0; j < iterations; j++ {
// Mix of reads and writes
if j%3 == 0 {
newValue := []byte(fmt.Sprintf("value-%d-%d", id, j))
err := backend.Set(ctx, key, newValue, time.Minute)
assert.NoError(t, err)
} else {
_, _, _, err := backend.Get(ctx, key)
assert.NoError(t, err)
}
}
}(i)
}
wg.Wait()
// Key should still exist
exists, err := backend.Exists(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
}
// TestShardedCache_GlobalLRUEviction tests that global LRU is maintained
func TestShardedCache_GlobalLRUEviction(t *testing.T) {
t.Parallel()
// Create a small cache to force eviction
config := DefaultConfig()
config.MaxSize = 10
backend, err := NewMemoryBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
// Add items
for i := 0; i < 10; i++ {
key := fmt.Sprintf("global-lru-%d", i)
value := []byte(fmt.Sprintf("value-%d", i))
err := backend.Set(ctx, key, value, time.Minute)
require.NoError(t, err)
// Small delay to ensure different access times
time.Sleep(time.Millisecond)
}
// Access some items to make them recently used
for i := 5; i < 10; i++ {
key := fmt.Sprintf("global-lru-%d", i)
_, _, _, err := backend.Get(ctx, key)
require.NoError(t, err)
}
// Add more items to trigger eviction
for i := 10; i < 15; i++ {
key := fmt.Sprintf("global-lru-%d", i)
value := []byte(fmt.Sprintf("value-%d", i))
err := backend.Set(ctx, key, value, time.Minute)
require.NoError(t, err)
}
// Recently accessed items (5-9) should still exist
for i := 5; i < 10; i++ {
key := fmt.Sprintf("global-lru-%d", i)
exists, err := backend.Exists(ctx, key)
require.NoError(t, err)
assert.True(t, exists, "Recently accessed item %d should exist", i)
}
// Check eviction stats
stats := backend.GetStats()
evictions := stats["evictions"].(int64)
assert.Greater(t, evictions, int64(0), "Should have evictions")
}
// TestShardedCache_StatsAggregation tests that stats are aggregated correctly
func TestShardedCache_StatsAggregation(t *testing.T) {
t.Parallel()
config := DefaultConfig()
config.MaxSize = 10000
backend, err := NewMemoryBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
// Add items to multiple shards
numItems := 100
for i := 0; i < numItems; i++ {
key := fmt.Sprintf("stats-key-%d", i)
value := []byte(fmt.Sprintf("stats-value-%d", i))
err := backend.Set(ctx, key, value, time.Minute)
require.NoError(t, err)
}
// Read some items
for i := 0; i < numItems/2; i++ {
key := fmt.Sprintf("stats-key-%d", i)
backend.Get(ctx, key)
}
// Read non-existent items
for i := 0; i < 10; i++ {
backend.Get(ctx, fmt.Sprintf("nonexistent-%d", i))
}
stats := backend.GetStats()
// Verify stats
assert.Equal(t, int64(numItems), stats["sets"].(int64), "Sets should match")
assert.Equal(t, int64(numItems/2), stats["hits"].(int64), "Hits should match")
assert.Equal(t, int64(10), stats["misses"].(int64), "Misses should match")
assert.Equal(t, int64(numItems), stats["size"].(int64), "Size should match")
// Verify hit rate
hitRate := stats["hit_rate"].(float64)
expectedHitRate := float64(numItems/2) / float64(numItems/2+10)
assert.InDelta(t, expectedHitRate, hitRate, 0.01, "Hit rate should match")
}
// BenchmarkShardedCache_Parallel benchmarks parallel access
func BenchmarkShardedCache_Parallel(b *testing.B) {
config := DefaultConfig()
config.MaxSize = 100000
config.MaxMemoryBytes = 100 * 1024 * 1024
backend, _ := NewMemoryBackend(config)
defer backend.Close()
ctx := context.Background()
// Pre-populate cache
for i := 0; i < 10000; i++ {
key := fmt.Sprintf("bench-key-%d", i)
value := []byte(fmt.Sprintf("bench-value-%d", i))
backend.Set(ctx, key, value, time.Hour)
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
key := fmt.Sprintf("bench-key-%d", i%10000)
backend.Get(ctx, key)
i++
}
})
}
// BenchmarkShardedCache_MixedOps benchmarks mixed operations
func BenchmarkShardedCache_MixedOps(b *testing.B) {
config := DefaultConfig()
config.MaxSize = 100000
config.MaxMemoryBytes = 100 * 1024 * 1024
backend, _ := NewMemoryBackend(config)
defer backend.Close()
ctx := context.Background()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
key := fmt.Sprintf("mixed-key-%d", i%1000)
if i%3 == 0 {
value := []byte(fmt.Sprintf("mixed-value-%d", i))
backend.Set(ctx, key, value, time.Hour)
} else {
backend.Get(ctx, key)
}
i++
}
})
}
+783
View File
@@ -0,0 +1,783 @@
package backends
import (
"context"
"fmt"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestMemoryBackend_BasicOperations tests basic CRUD operations
func TestMemoryBackend_BasicOperations(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
t.Run("SetAndGet", func(t *testing.T) {
key := "test-key"
value := []byte("test-value")
ttl := 1 * time.Minute
err := backend.Set(ctx, key, value, ttl)
require.NoError(t, err)
retrieved, remainingTTL, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, value, retrieved)
assert.Greater(t, remainingTTL, 50*time.Second)
assert.LessOrEqual(t, remainingTTL, ttl)
})
t.Run("GetNonExistent", func(t *testing.T) {
_, _, exists, err := backend.Get(ctx, "non-existent")
require.NoError(t, err)
assert.False(t, exists)
})
t.Run("Delete", func(t *testing.T) {
key := "delete-key"
value := []byte("delete-value")
err := backend.Set(ctx, key, value, 1*time.Minute)
require.NoError(t, err)
deleted, err := backend.Delete(ctx, key)
require.NoError(t, err)
assert.True(t, deleted)
exists, err := backend.Exists(ctx, key)
require.NoError(t, err)
assert.False(t, exists)
})
t.Run("DeleteNonExistent", func(t *testing.T) {
deleted, err := backend.Delete(ctx, "non-existent-delete")
require.NoError(t, err)
assert.False(t, deleted)
})
t.Run("Exists", func(t *testing.T) {
key := "exists-key"
value := []byte("exists-value")
exists, err := backend.Exists(ctx, key)
require.NoError(t, err)
assert.False(t, exists)
err = backend.Set(ctx, key, value, 1*time.Minute)
require.NoError(t, err)
exists, err = backend.Exists(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
})
t.Run("Clear", func(t *testing.T) {
// Add multiple items
for i := 0; i < 10; i++ {
key := fmt.Sprintf("clear-key-%d", i)
value := []byte(fmt.Sprintf("clear-value-%d", i))
err := backend.Set(ctx, key, value, 1*time.Minute)
require.NoError(t, err)
}
err := backend.Clear(ctx)
require.NoError(t, err)
stats := backend.GetStats()
size := stats["size"].(int64)
assert.Equal(t, int64(0), size)
})
}
// TestMemoryBackend_TTLExpiration tests TTL and expiration
func TestMemoryBackend_TTLExpiration(t *testing.T) {
t.Parallel()
config := DefaultConfig()
config.CleanupInterval = 50 * time.Millisecond
backend, err := NewMemoryBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
t.Run("ShortTTL", func(t *testing.T) {
key := "short-ttl-key"
value := []byte("short-ttl-value")
shortTTL := 100 * time.Millisecond
err := backend.Set(ctx, key, value, shortTTL)
require.NoError(t, err)
// Verify exists immediately
exists, err := backend.Exists(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
// Wait for expiration
time.Sleep(150 * time.Millisecond)
// Should be expired
_, _, exists, err = backend.Get(ctx, key)
require.NoError(t, err)
assert.False(t, exists)
})
t.Run("TTLDecrement", func(t *testing.T) {
key := "ttl-decrement-key"
value := []byte("ttl-decrement-value")
ttl := 2 * time.Second
err := backend.Set(ctx, key, value, ttl)
require.NoError(t, err)
// Check TTL immediately
_, ttl1, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
// Wait a bit
time.Sleep(500 * time.Millisecond)
// Check TTL again - should be less
_, ttl2, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Less(t, ttl2, ttl1, "TTL should decrease over time")
})
t.Run("CleanupExpiredItems", func(t *testing.T) {
// Set multiple items with short TTL
for i := 0; i < 5; i++ {
key := fmt.Sprintf("cleanup-key-%d", i)
value := []byte(fmt.Sprintf("cleanup-value-%d", i))
err := backend.Set(ctx, key, value, 50*time.Millisecond)
require.NoError(t, err)
}
// Wait for cleanup to run
time.Sleep(200 * time.Millisecond)
// All items should be cleaned up
for i := 0; i < 5; i++ {
key := fmt.Sprintf("cleanup-key-%d", i)
exists, err := backend.Exists(ctx, key)
require.NoError(t, err)
assert.False(t, exists, "Expired items should be cleaned up")
}
})
}
// TestMemoryBackend_LRUEviction tests LRU eviction
func TestMemoryBackend_LRUEviction(t *testing.T) {
t.Parallel()
config := DefaultConfig()
config.MaxSize = 5
backend, err := NewMemoryBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
// Fill cache to max size
for i := 0; i < 5; i++ {
key := fmt.Sprintf("lru-key-%d", i)
value := []byte(fmt.Sprintf("lru-value-%d", i))
err := backend.Set(ctx, key, value, 1*time.Minute)
require.NoError(t, err)
}
// Access first item to make it most recently used
_, _, exists, err := backend.Get(ctx, "lru-key-0")
require.NoError(t, err)
assert.True(t, exists)
// Add a new item - should evict lru-key-1 (least recently used)
err = backend.Set(ctx, "lru-key-new", []byte("new-value"), 1*time.Minute)
require.NoError(t, err)
// lru-key-0 should still exist (was accessed recently)
exists, err = backend.Exists(ctx, "lru-key-0")
require.NoError(t, err)
assert.True(t, exists, "Recently accessed item should not be evicted")
// lru-key-1 should be evicted
exists, err = backend.Exists(ctx, "lru-key-1")
require.NoError(t, err)
assert.False(t, exists, "Least recently used item should be evicted")
// Check eviction count
stats := backend.GetStats()
evictions := stats["evictions"].(int64)
assert.Greater(t, evictions, int64(0), "Should have evictions")
}
// TestMemoryBackend_MemoryLimit tests memory-based eviction
func TestMemoryBackend_MemoryLimit(t *testing.T) {
t.Parallel()
config := DefaultConfig()
config.MaxSize = 100
config.MaxMemoryBytes = 1024 // 1KB limit
backend, err := NewMemoryBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
// Add items until memory limit is reached
largeValue := make([]byte, 512) // 512 bytes each
for i := 0; i < 5; i++ {
key := fmt.Sprintf("mem-key-%d", i)
err := backend.Set(ctx, key, largeValue, 1*time.Minute)
require.NoError(t, err)
}
stats := backend.GetStats()
memory := stats["memory"].(int64)
assert.LessOrEqual(t, memory, config.MaxMemoryBytes, "Memory should not exceed limit")
evictions := stats["evictions"].(int64)
assert.Greater(t, evictions, int64(0), "Should have memory-based evictions")
}
// TestMemoryBackend_ConcurrentAccess tests thread safety
func TestMemoryBackend_ConcurrentAccess(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
var wg sync.WaitGroup
goroutines := 20
iterations := 50
// Concurrent writes
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
for j := 0; j < iterations; j++ {
key := fmt.Sprintf("concurrent-key-%d-%d", id, j)
value := []byte(fmt.Sprintf("concurrent-value-%d-%d", id, j))
err := backend.Set(ctx, key, value, 1*time.Minute)
assert.NoError(t, err)
// Read back
retrieved, _, exists, err := backend.Get(ctx, key)
assert.NoError(t, err)
if exists {
assert.Equal(t, value, retrieved)
}
// Random deletes
if j%5 == 0 {
backend.Delete(ctx, key)
}
}
}(i)
}
wg.Wait()
// Verify stats are consistent
stats := backend.GetStats()
hits := stats["hits"].(int64)
misses := stats["misses"].(int64)
assert.Greater(t, hits+misses, int64(0), "Should have cache operations")
}
// TestMemoryBackend_UpdateExisting tests updating existing keys
func TestMemoryBackend_UpdateExisting(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
key := "update-key"
value1 := []byte("original-value")
value2 := []byte("updated-value")
// Set original
err = backend.Set(ctx, key, value1, 1*time.Minute)
require.NoError(t, err)
// Update
err = backend.Set(ctx, key, value2, 2*time.Minute)
require.NoError(t, err)
// Verify updated
retrieved, ttl, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, value2, retrieved)
assert.Greater(t, ttl, 1*time.Minute, "TTL should be updated")
// Size should not increase (same key)
stats := backend.GetStats()
size := stats["size"].(int64)
assert.Equal(t, int64(1), size, "Size should be 1 for one key")
}
// TestMemoryBackend_Stats tests statistics tracking
func TestMemoryBackend_Stats(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
// Initial stats
stats := backend.GetStats()
assert.Equal(t, int64(0), stats["hits"].(int64))
assert.Equal(t, int64(0), stats["misses"].(int64))
// Add items and track hits/misses
backend.Set(ctx, "key1", []byte("value1"), 1*time.Minute)
backend.Set(ctx, "key2", []byte("value2"), 1*time.Minute)
// Hit
backend.Get(ctx, "key1")
// Miss
backend.Get(ctx, "non-existent")
stats = backend.GetStats()
assert.Equal(t, int64(1), stats["hits"].(int64))
assert.Equal(t, int64(1), stats["misses"].(int64))
hitRate := stats["hit_rate"].(float64)
assert.InDelta(t, 0.5, hitRate, 0.01)
}
// TestMemoryBackend_EmptyValues tests handling of empty values
func TestMemoryBackend_EmptyValues(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
key := "empty-key"
emptyValue := []byte{}
err = backend.Set(ctx, key, emptyValue, 1*time.Minute)
require.NoError(t, err)
retrieved, _, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, 0, len(retrieved))
}
// TestMemoryBackend_LargeValues tests handling of large values
func TestMemoryBackend_LargeValues(t *testing.T) {
t.Parallel()
config := DefaultConfig()
config.MaxMemoryBytes = 10 * 1024 * 1024 // 10MB
backend, err := NewMemoryBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
key := "large-key"
largeValue := make([]byte, 1024*1024) // 1MB
err = backend.Set(ctx, key, largeValue, 1*time.Minute)
require.NoError(t, err)
retrieved, _, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, len(largeValue), len(retrieved))
}
// TestMemoryBackend_Close tests proper cleanup on close
func TestMemoryBackend_Close(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
ctx := context.Background()
// Add some items
for i := 0; i < 10; i++ {
key := fmt.Sprintf("close-key-%d", i)
value := []byte(fmt.Sprintf("close-value-%d", i))
backend.Set(ctx, key, value, 1*time.Minute)
}
// Close
err = backend.Close()
require.NoError(t, err)
// Operations after close should fail
err = backend.Set(ctx, "after-close", []byte("value"), 1*time.Minute)
assert.Error(t, err)
assert.Equal(t, ErrBackendClosed, err)
_, _, _, err = backend.Get(ctx, "close-key-0")
assert.Error(t, err)
assert.Equal(t, ErrBackendClosed, err)
// Closing again should be safe
err = backend.Close()
assert.NoError(t, err)
}
// TestMemoryBackend_Ping tests ping operation
func TestMemoryBackend_Ping(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
err = backend.Ping(ctx)
assert.NoError(t, err)
// Close and ping should fail
backend.Close()
err = backend.Ping(ctx)
assert.Error(t, err)
}
// TestMemoryBackend_ValueIsolation tests that returned values are isolated
func TestMemoryBackend_ValueIsolation(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
key := "isolation-key"
originalValue := []byte("original-value")
err = backend.Set(ctx, key, originalValue, 1*time.Minute)
require.NoError(t, err)
// Get value and modify it
retrieved, _, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
// Modify retrieved value
if len(retrieved) > 0 {
retrieved[0] = 'X'
}
// Get again - should be unchanged
retrieved2, _, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, originalValue, retrieved2, "Original value should not be modified")
}
// TestMemoryBackend_Keys tests the Keys method with pattern matching
func TestMemoryBackend_Keys(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
// Add test data
testKeys := []string{"user:1", "user:2", "session:abc", "session:def", "token:xyz"}
for _, key := range testKeys {
err := backend.Set(ctx, key, []byte("value"), 1*time.Minute)
require.NoError(t, err)
}
t.Run("AllKeys", func(t *testing.T) {
keys, err := backend.Keys(ctx, "*")
require.NoError(t, err)
assert.Len(t, keys, 5)
})
t.Run("SpecificPattern", func(t *testing.T) {
// Simple exact match
keys, err := backend.Keys(ctx, "user:1")
require.NoError(t, err)
assert.Len(t, keys, 1)
assert.Contains(t, keys, "user:1")
})
t.Run("ExcludesExpired", func(t *testing.T) {
// Add an expired key
expiredKey := "expired:key"
err := backend.Set(ctx, expiredKey, []byte("value"), 1*time.Millisecond)
require.NoError(t, err)
// Wait for expiration
time.Sleep(10 * time.Millisecond)
keys, err := backend.Keys(ctx, "*")
require.NoError(t, err)
assert.NotContains(t, keys, expiredKey, "Expired keys should not be returned")
})
t.Run("AfterClose", func(t *testing.T) {
closedBackend, _ := NewMemoryBackend(DefaultConfig())
closedBackend.Close()
_, err := closedBackend.Keys(ctx, "*")
assert.Error(t, err)
assert.Equal(t, ErrBackendUnavailable, err)
})
}
// TestMemoryBackend_Size tests the Size method
func TestMemoryBackend_Size(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
// Initially empty
size, err := backend.Size(ctx)
require.NoError(t, err)
assert.Equal(t, int64(0), size)
// Add items
for i := 0; i < 5; i++ {
key := fmt.Sprintf("key-%d", i)
err := backend.Set(ctx, key, []byte("value"), 1*time.Minute)
require.NoError(t, err)
}
size, err = backend.Size(ctx)
require.NoError(t, err)
assert.Equal(t, int64(5), size)
// Delete one
backend.Delete(ctx, "key-0")
size, err = backend.Size(ctx)
require.NoError(t, err)
assert.Equal(t, int64(4), size)
// After close
backend.Close()
_, err = backend.Size(ctx)
assert.Error(t, err)
assert.Equal(t, ErrBackendUnavailable, err)
}
// TestMemoryBackend_TTL tests the TTL method
func TestMemoryBackend_TTL(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
t.Run("ExistingKey", func(t *testing.T) {
key := "ttl-key"
ttl := 1 * time.Minute
err := backend.Set(ctx, key, []byte("value"), ttl)
require.NoError(t, err)
remaining, err := backend.TTL(ctx, key)
require.NoError(t, err)
assert.Greater(t, remaining, 50*time.Second)
assert.LessOrEqual(t, remaining, ttl)
})
t.Run("NonExistentKey", func(t *testing.T) {
_, err := backend.TTL(ctx, "non-existent")
assert.Error(t, err)
assert.Equal(t, ErrCacheMiss, err)
})
t.Run("NoExpiration", func(t *testing.T) {
key := "no-expiry"
// TTL of 0 typically means no expiration
err := backend.Set(ctx, key, []byte("value"), 0)
require.NoError(t, err)
remaining, err := backend.TTL(ctx, key)
require.NoError(t, err)
// No expiration returns 0
assert.Equal(t, time.Duration(0), remaining)
})
t.Run("AfterClose", func(t *testing.T) {
closedBackend, _ := NewMemoryBackend(DefaultConfig())
closedBackend.Close()
_, err := closedBackend.TTL(ctx, "key")
assert.Error(t, err)
assert.Equal(t, ErrBackendUnavailable, err)
})
}
// TestMemoryBackend_Expire tests the Expire method
func TestMemoryBackend_Expire(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
t.Run("UpdateTTL", func(t *testing.T) {
key := "expire-key"
err := backend.Set(ctx, key, []byte("value"), 1*time.Minute)
require.NoError(t, err)
// Update to shorter TTL
err = backend.Expire(ctx, key, 5*time.Second)
require.NoError(t, err)
// Check new TTL
remaining, err := backend.TTL(ctx, key)
require.NoError(t, err)
assert.LessOrEqual(t, remaining, 5*time.Second)
})
t.Run("NonExistentKey", func(t *testing.T) {
err := backend.Expire(ctx, "non-existent", 1*time.Minute)
assert.Error(t, err)
assert.Equal(t, ErrCacheMiss, err)
})
t.Run("RemoveExpiration", func(t *testing.T) {
key := "no-expire-key"
err := backend.Set(ctx, key, []byte("value"), 1*time.Minute)
require.NoError(t, err)
// Set TTL to 0 to remove expiration
err = backend.Expire(ctx, key, 0)
require.NoError(t, err)
// TTL should now be 0
remaining, err := backend.TTL(ctx, key)
require.NoError(t, err)
assert.Equal(t, time.Duration(0), remaining)
})
t.Run("AfterClose", func(t *testing.T) {
closedBackend, _ := NewMemoryBackend(DefaultConfig())
closedBackend.Close()
err := closedBackend.Expire(ctx, "key", 1*time.Minute)
assert.Error(t, err)
assert.Equal(t, ErrBackendUnavailable, err)
})
}
// TestMemoryBackend_IsHealthy tests the IsHealthy method
func TestMemoryBackend_IsHealthy(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
// Should be healthy when open
assert.True(t, backend.IsHealthy())
// Should be unhealthy after close
backend.Close()
assert.False(t, backend.IsHealthy())
}
// TestMemoryBackend_Type tests the Type method
func TestMemoryBackend_Type(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
backendType := backend.Type()
assert.Equal(t, TypeMemory, backendType)
}
// TestMemoryBackend_Capabilities tests the Capabilities method
func TestMemoryBackend_Capabilities(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
caps := backend.Capabilities()
require.NotNil(t, caps)
// Memory backend should not be distributed or persistent
assert.False(t, caps.Distributed)
assert.False(t, caps.Persistent)
// Should support eviction and TTL
assert.True(t, caps.Eviction)
assert.True(t, caps.TTL)
assert.True(t, caps.SupportsExpire)
assert.True(t, caps.SupportsMultiGet)
// Check limits
assert.Greater(t, caps.MaxKeySize, int64(0))
assert.Greater(t, caps.MaxValueSize, int64(0))
}
// TestMatchPattern tests the matchPattern helper function
func TestMatchPattern(t *testing.T) {
t.Parallel()
tests := []struct {
pattern string
key string
matches bool
}{
{"*", "any-key", true},
{"*", "another", true},
{"user:1", "user:1", true},
{"user:1", "user:2", false},
{"*:suffix", "prefix:suffix", true},
{"*suffix", "prefix-suffix", true},
{"*abc", "xyzabc", true},
{"*abc", "xyz", false},
}
for _, tt := range tests {
t.Run(fmt.Sprintf("%s-%s", tt.pattern, tt.key), func(t *testing.T) {
result := matchPattern(tt.pattern, tt.key)
assert.Equal(t, tt.matches, result)
})
}
}
+143
View File
@@ -0,0 +1,143 @@
package backends
import (
"context"
"time"
)
// MemoryBackend wraps MemoryCacheBackend to implement the CacheBackend interface
type MemoryBackend struct {
*MemoryCacheBackend
}
// NewMemoryBackend creates a new memory backend from a config
func NewMemoryBackend(config *Config) (*MemoryBackend, error) {
maxSize := int64(config.MaxSize)
if maxSize <= 0 {
maxSize = 1000
}
cacheBackend := NewMemoryCacheBackend(maxSize, config.MaxMemoryBytes, config.CleanupInterval)
return &MemoryBackend{
MemoryCacheBackend: cacheBackend,
}, nil
}
// Set stores a value in the cache with the specified TTL
func (m *MemoryBackend) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
err := m.MemoryCacheBackend.Set(ctx, key, value, ttl)
if err == ErrBackendUnavailable {
return ErrBackendClosed
}
return err
}
// Get retrieves a value from the cache
func (m *MemoryBackend) Get(ctx context.Context, key string) ([]byte, time.Duration, bool, error) {
val, err := m.MemoryCacheBackend.Get(ctx, key)
if err != nil {
if err == ErrCacheMiss {
return nil, 0, false, nil
}
if err == ErrBackendUnavailable {
return nil, 0, false, ErrBackendClosed
}
return nil, 0, false, err
}
// Get TTL using the TTL method
ttl, ttlErr := m.MemoryCacheBackend.TTL(ctx, key)
if ttlErr != nil {
// If we can't get TTL, still return the value with 0 TTL
ttl = 0
}
// Convert interface{} to []byte
var valueBytes []byte
if val != nil {
if bytes, ok := val.([]byte); ok {
valueBytes = bytes
} else {
// If it's not already []byte, return an error
return nil, 0, false, ErrInvalidValue
}
}
return valueBytes, ttl, true, nil
}
// Delete removes a key from the cache
func (m *MemoryBackend) Delete(ctx context.Context, key string) (bool, error) {
// Check if key exists first
exists, err := m.MemoryCacheBackend.Exists(ctx, key)
if err != nil {
return false, err
}
if !exists {
return false, nil
}
err = m.MemoryCacheBackend.Delete(ctx, key)
if err != nil {
return false, err
}
return true, nil
}
// Exists checks if a key exists in the cache
func (m *MemoryBackend) Exists(ctx context.Context, key string) (bool, error) {
return m.MemoryCacheBackend.Exists(ctx, key)
}
// Clear removes all keys from the cache
func (m *MemoryBackend) Clear(ctx context.Context) error {
return m.MemoryCacheBackend.Clear(ctx)
}
// GetStats returns cache statistics
func (m *MemoryBackend) GetStats() map[string]interface{} {
stats, err := m.MemoryCacheBackend.GetStats(context.Background())
if err != nil {
return map[string]interface{}{
"error": err.Error(),
}
}
// Convert BackendStats to map
hitRate := float64(0)
total := stats.Hits + stats.Misses
if total > 0 {
hitRate = float64(stats.Hits) / float64(total)
}
return map[string]interface{}{
"type": stats.Type,
"hits": stats.Hits,
"misses": stats.Misses,
"sets": stats.Sets,
"deletes": stats.Deletes,
"errors": stats.Errors,
"evictions": stats.Evictions,
"size": stats.CurrentSize,
"max_size": stats.MaxSize,
"memory": stats.MemoryUsage,
"hit_rate": hitRate,
"uptime": stats.Uptime,
"start_time": stats.StartTime,
"shard_count": m.MemoryCacheBackend.GetShardCount(),
}
}
// Close shuts down the cache backend and releases resources
func (m *MemoryBackend) Close() error {
return m.MemoryCacheBackend.Close()
}
// Ping checks if the backend is healthy and responsive
func (m *MemoryBackend) Ping(ctx context.Context) error {
return m.MemoryCacheBackend.Ping(ctx)
}
// Ensure MemoryBackend implements CacheBackend
var _ CacheBackend = (*MemoryBackend)(nil)
+566
View File
@@ -0,0 +1,566 @@
package backends
import (
"context"
"errors"
"fmt"
"sync"
"sync/atomic"
"time"
)
// Pure-Go Redis client implementation
// Compatible with Yaegi interpreter (no unsafe package)
// Implements RESP protocol for basic Redis operations
var (
ErrPoolExhausted = errors.New("connection pool exhausted")
)
// RedisBackend implements a Redis-based cache backend using pure Go
type RedisBackend struct {
config *Config
pool *ConnectionPool
healthMonitor *HealthMonitor
// Metrics
hits atomic.Int64
misses atomic.Int64
// Lifecycle
closed atomic.Bool
mu sync.Mutex
}
// NewRedisBackend creates a new Redis cache backend with pure-Go implementation
func NewRedisBackend(config *Config) (*RedisBackend, error) {
if config == nil {
return nil, fmt.Errorf("config is required")
}
if config.RedisAddr == "" {
return nil, fmt.Errorf("redis address is required")
}
// Create connection pool with health checks enabled
// Timeouts are kept short to prevent request pileup when Redis is slow/stalled.
// The UniversalCache uses 200ms context timeout, so socket timeouts should be
// shorter to allow proper context cancellation handling.
poolConfig := &PoolConfig{
Address: config.RedisAddr,
Password: config.RedisPassword,
DB: config.RedisDB,
MaxConnections: config.PoolSize,
ConnectTimeout: 2 * time.Second,
ReadTimeout: 500 * time.Millisecond,
WriteTimeout: 500 * time.Millisecond,
EnableHealthCheck: true,
MaxRetries: 3,
RetryDelay: 100 * time.Millisecond,
}
pool, err := NewConnectionPool(poolConfig)
if err != nil {
return nil, fmt.Errorf("failed to create connection pool: %w", err)
}
// Create health monitor
healthConfig := DefaultHealthMonitorConfig()
healthMonitor := NewHealthMonitor(pool, healthConfig)
backend := &RedisBackend{
config: config,
pool: pool,
healthMonitor: healthMonitor,
}
// Test connectivity
if err := backend.Ping(context.Background()); err != nil {
_ = pool.Close()
return nil, fmt.Errorf("failed to ping Redis: %w", err)
}
// Start health monitoring
healthMonitor.Start()
return backend, nil
}
// Set stores a value in Redis with TTL
func (r *RedisBackend) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
if r.closed.Load() {
return ErrBackendClosed
}
prefixedKey := r.prefixKey(key)
// Execute with retry logic
return r.executeWithRetry(ctx, func(conn *RedisConn) error {
var err error
// Use PSETEX for millisecond precision, SETEX for second precision
if ttl > 0 {
ttlMillis := ttl.Milliseconds()
if ttlMillis < 1000 {
// Use PSETEX for sub-second TTLs (millisecond precision)
_, err = conn.Do("PSETEX", prefixedKey, fmt.Sprintf("%d", ttlMillis), string(value))
} else {
// Use SETEX for larger TTLs (second precision)
ttlSeconds := int(ttl.Seconds())
_, err = conn.Do("SETEX", prefixedKey, fmt.Sprintf("%d", ttlSeconds), string(value))
}
} else {
_, err = conn.Do("SET", prefixedKey, string(value))
}
return err
})
}
// Get retrieves a value from Redis
func (r *RedisBackend) Get(ctx context.Context, key string) ([]byte, time.Duration, bool, error) {
if r.closed.Load() {
return nil, 0, false, ErrBackendClosed
}
prefixedKey := r.prefixKey(key)
var resultValue []byte
var resultTTL time.Duration
var resultExists bool
// Execute with retry logic
err := r.executeWithRetry(ctx, func(conn *RedisConn) error {
// Get value
resp, err := conn.Do("GET", prefixedKey)
if err != nil {
if errors.Is(err, ErrNilResponse) {
r.misses.Add(1)
resultExists = false
return nil // Not an error, key just doesn't exist
}
return err
}
value, err := RESPString(resp)
if err != nil {
return err
}
// Get TTL
ttlResp, err := conn.Do("TTL", prefixedKey)
if err != nil {
// If TTL fails, still return the value
r.hits.Add(1)
resultValue = []byte(value)
resultTTL = 0
resultExists = true
return nil
}
ttlSeconds, _ := RESPInt(ttlResp)
var ttl time.Duration
if ttlSeconds > 0 {
ttl = time.Duration(ttlSeconds) * time.Second
}
r.hits.Add(1)
resultValue = []byte(value)
resultTTL = ttl
resultExists = true
return nil
})
return resultValue, resultTTL, resultExists, err
}
// Delete removes a key from Redis
func (r *RedisBackend) Delete(ctx context.Context, key string) (bool, error) {
if r.closed.Load() {
return false, ErrBackendClosed
}
conn, err := r.pool.Get(ctx)
if err != nil {
return false, err
}
defer r.pool.Put(conn)
prefixedKey := r.prefixKey(key)
resp, err := conn.Do("DEL", prefixedKey)
if err != nil {
return false, err
}
count, err := RESPInt(resp)
if err != nil {
return false, err
}
return count > 0, nil
}
// Exists checks if a key exists in Redis
func (r *RedisBackend) Exists(ctx context.Context, key string) (bool, error) {
if r.closed.Load() {
return false, ErrBackendClosed
}
conn, err := r.pool.Get(ctx)
if err != nil {
return false, err
}
defer r.pool.Put(conn)
prefixedKey := r.prefixKey(key)
resp, err := conn.Do("EXISTS", prefixedKey)
if err != nil {
return false, err
}
count, err := RESPInt(resp)
if err != nil {
return false, err
}
return count > 0, nil
}
// Clear removes all keys with the configured prefix
func (r *RedisBackend) Clear(ctx context.Context) error {
if r.closed.Load() {
return ErrBackendClosed
}
conn, err := r.pool.Get(ctx)
if err != nil {
return err
}
defer r.pool.Put(conn)
// Use FLUSHDB if no prefix (clear entire DB)
if r.config.RedisPrefix == "" {
_, err := conn.Do("FLUSHDB")
return err
}
// With prefix, we need to scan and delete keys
// For simplicity in this implementation, we'll use KEYS pattern (not recommended for production at scale)
pattern := r.config.RedisPrefix + "*"
resp, err := conn.Do("KEYS", pattern)
if err != nil {
return err
}
// Extract keys from array response
keys, ok := resp.([]interface{})
if !ok || len(keys) == 0 {
return nil
}
// Delete each key
for _, keyInterface := range keys {
key, err := RESPString(keyInterface)
if err != nil {
continue
}
_, _ = conn.Do("DEL", key) // Best effort, ignore errors
}
return nil
}
// GetStats returns backend statistics
func (r *RedisBackend) GetStats() map[string]interface{} {
hits := r.hits.Load()
misses := r.misses.Load()
total := hits + misses
hitRate := float64(0)
if total > 0 {
hitRate = float64(hits) / float64(total)
}
stats := map[string]interface{}{
"backend": "redis-pure-go",
"address": r.config.RedisAddr,
"hits": hits,
"misses": misses,
"hit_rate": hitRate,
"pool": r.pool.Stats(),
}
// Add health monitor stats if available
if r.healthMonitor != nil {
stats["health"] = r.healthMonitor.GetStats()
}
return stats
}
// Ping checks Redis connectivity
func (r *RedisBackend) Ping(ctx context.Context) error {
if r.closed.Load() {
return ErrBackendClosed
}
conn, err := r.pool.Get(ctx)
if err != nil {
return err
}
defer r.pool.Put(conn)
_, err = conn.Do("PING")
return err
}
// Close closes the Redis backend and all connections
func (r *RedisBackend) Close() error {
if r.closed.Swap(true) {
return nil // Already closed
}
r.mu.Lock()
defer r.mu.Unlock()
// Stop health monitor
if r.healthMonitor != nil {
r.healthMonitor.Stop()
}
// Close connection pool
if r.pool != nil {
return r.pool.Close()
}
return nil
}
// prefixKey adds the configured prefix to a key
func (r *RedisBackend) prefixKey(key string) string {
if r.config.RedisPrefix == "" {
return key
}
return r.config.RedisPrefix + key
}
// executeWithRetry executes a Redis operation with exponential backoff retry logic.
// It checks context cancellation at multiple points to ensure fast abort when the
// caller's context is cancelled (e.g., due to request timeout).
func (r *RedisBackend) executeWithRetry(ctx context.Context, operation func(*RedisConn) error) error {
maxRetries := 3
baseDelay := 50 * time.Millisecond // Reduced from 100ms to fail faster
for attempt := 0; attempt < maxRetries; attempt++ {
// Check context before each attempt to fail fast
if ctx.Err() != nil {
return ctx.Err()
}
conn, err := r.pool.Get(ctx)
if err != nil {
// If we can't get a connection and this is the last attempt, fail
if attempt == maxRetries-1 {
return fmt.Errorf("failed to get connection after %d attempts: %w", maxRetries, err)
}
// Wait with exponential backoff before retrying
delay := baseDelay * time.Duration(1<<uint(attempt))
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(delay):
continue
}
}
// Execute the operation
err = operation(conn)
r.pool.Put(conn)
// Check context after operation - if cancelled, don't bother retrying
if ctx.Err() != nil {
return ctx.Err()
}
// If successful, return
if err == nil {
return nil
}
// If error is not retryable or last attempt, fail
if attempt == maxRetries-1 || !isRetryableError(err) {
return err
}
// Wait with exponential backoff before retrying
delay := baseDelay * time.Duration(1<<uint(attempt))
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(delay):
continue
}
}
return fmt.Errorf("operation failed after %d attempts", maxRetries)
}
// isRetryableError determines if an error is worth retrying
func isRetryableError(err error) bool {
if err == nil {
return false
}
// Retry on connection errors, timeouts, etc.
// Don't retry on application-level errors like wrong type
errMsg := err.Error()
retryablePatterns := []string{
"connection",
"timeout",
"EOF",
"broken pipe",
"reset by peer",
}
for _, pattern := range retryablePatterns {
if contains(errMsg, pattern) {
return true
}
}
return false
}
// SetMany stores multiple values in Redis using pipelining for efficiency
// This reduces N round-trips to a single round-trip
func (r *RedisBackend) SetMany(ctx context.Context, items map[string][]byte, ttl time.Duration) error {
if r.closed.Load() {
return ErrBackendClosed
}
if len(items) == 0 {
return nil
}
// For single items, use regular Set
if len(items) == 1 {
for key, value := range items {
return r.Set(ctx, key, value, ttl)
}
}
conn, err := r.pool.Get(ctx)
if err != nil {
return err
}
defer r.pool.Put(conn)
pipeline := conn.NewPipeline()
// Queue all SET commands
ttlSeconds := int(ttl.Seconds())
ttlMillis := ttl.Milliseconds()
for key, value := range items {
prefixedKey := r.prefixKey(key)
if ttl > 0 {
if ttlMillis < 1000 {
// Use PSETEX for sub-second TTLs
pipeline.Queue("PSETEX", prefixedKey, fmt.Sprintf("%d", ttlMillis), string(value))
} else {
// Use SETEX for larger TTLs
pipeline.Queue("SETEX", prefixedKey, fmt.Sprintf("%d", ttlSeconds), string(value))
}
} else {
pipeline.Queue("SET", prefixedKey, string(value))
}
}
// Execute pipeline
responses, err := pipeline.Execute()
if err != nil {
return fmt.Errorf("pipeline SetMany failed: %w", err)
}
// Check responses for errors (each should be "OK")
for i, resp := range responses {
if resp == nil {
continue
}
if str, ok := resp.(string); ok && str == "OK" {
continue
}
return fmt.Errorf("SetMany: unexpected response at index %d: %v", i, resp)
}
return nil
}
// GetMany retrieves multiple values from Redis using pipelining for efficiency
// This reduces N round-trips to a single round-trip
func (r *RedisBackend) GetMany(ctx context.Context, keys []string) (map[string][]byte, error) {
if r.closed.Load() {
return nil, ErrBackendClosed
}
if len(keys) == 0 {
return make(map[string][]byte), nil
}
// For single key, use regular Get
if len(keys) == 1 {
result := make(map[string][]byte)
value, _, exists, err := r.Get(ctx, keys[0])
if err != nil {
return nil, err
}
if exists {
result[keys[0]] = value
}
return result, nil
}
conn, err := r.pool.Get(ctx)
if err != nil {
return nil, err
}
defer r.pool.Put(conn)
pipeline := conn.NewPipeline()
// Queue all GET commands
prefixedKeys := make([]string, len(keys))
for i, key := range keys {
prefixedKeys[i] = r.prefixKey(key)
pipeline.Queue("GET", prefixedKeys[i])
}
// Execute pipeline
responses, err := pipeline.Execute()
if err != nil {
return nil, fmt.Errorf("pipeline GetMany failed: %w", err)
}
// Process responses
result := make(map[string][]byte)
for i, resp := range responses {
if resp == nil {
// Key doesn't exist
r.misses.Add(1)
continue
}
value, err := RESPString(resp)
if err != nil {
// Invalid response, skip this key
r.misses.Add(1)
continue
}
r.hits.Add(1)
result[keys[i]] = []byte(value)
}
return result, nil
}
+170
View File
@@ -0,0 +1,170 @@
package backends
import (
"context"
"sync"
"sync/atomic"
"time"
)
// HealthMonitor continuously monitors Redis connection health and triggers reconnections
type HealthMonitor struct {
pool *ConnectionPool
config *HealthMonitorConfig
stopChan chan struct{}
wg sync.WaitGroup
lastCheckTime atomic.Int64
consecutiveFailures atomic.Int64
totalChecks atomic.Int64
totalFailures atomic.Int64
healthy atomic.Bool
running atomic.Bool
}
// HealthMonitorConfig configures the health monitor
type HealthMonitorConfig struct {
OnHealthChange func(healthy bool)
CheckInterval time.Duration
Timeout time.Duration
UnhealthyThreshold int
}
// DefaultHealthMonitorConfig returns default health monitor configuration
func DefaultHealthMonitorConfig() *HealthMonitorConfig {
return &HealthMonitorConfig{
CheckInterval: 5 * time.Second,
Timeout: 3 * time.Second,
UnhealthyThreshold: 3,
}
}
// NewHealthMonitor creates a new health monitor
func NewHealthMonitor(pool *ConnectionPool, config *HealthMonitorConfig) *HealthMonitor {
if config == nil {
config = DefaultHealthMonitorConfig()
}
hm := &HealthMonitor{
pool: pool,
config: config,
stopChan: make(chan struct{}),
}
hm.healthy.Store(true) // Assume healthy initially
return hm
}
// Start begins health monitoring
func (hm *HealthMonitor) Start() {
if hm.running.Swap(true) {
return // Already running
}
hm.wg.Add(1)
go hm.monitorLoop()
}
// Stop stops health monitoring
func (hm *HealthMonitor) Stop() {
if !hm.running.Swap(false) {
return // Not running
}
close(hm.stopChan)
hm.wg.Wait()
}
// IsHealthy returns the current health status
func (hm *HealthMonitor) IsHealthy() bool {
return hm.healthy.Load()
}
// GetStats returns health monitor statistics
func (hm *HealthMonitor) GetStats() map[string]interface{} {
lastCheck := time.Unix(hm.lastCheckTime.Load(), 0)
return map[string]interface{}{
"healthy": hm.healthy.Load(),
"consecutive_failures": hm.consecutiveFailures.Load(),
"total_checks": hm.totalChecks.Load(),
"total_failures": hm.totalFailures.Load(),
"last_check": lastCheck,
}
}
// monitorLoop runs the health check loop
func (hm *HealthMonitor) monitorLoop() {
defer hm.wg.Done()
ticker := time.NewTicker(hm.config.CheckInterval)
defer ticker.Stop()
// Perform initial check immediately
hm.performHealthCheck()
for {
select {
case <-hm.stopChan:
return
case <-ticker.C:
hm.performHealthCheck()
}
}
}
// performHealthCheck executes a health check
func (hm *HealthMonitor) performHealthCheck() {
hm.totalChecks.Add(1)
hm.lastCheckTime.Store(time.Now().Unix())
ctx, cancel := context.WithTimeout(context.Background(), hm.config.Timeout)
defer cancel()
// Try to get a connection and ping Redis
conn, err := hm.pool.Get(ctx)
if err != nil {
hm.recordFailure()
return
}
defer hm.pool.Put(conn)
// Ping Redis
_, err = conn.Do("PING")
if err != nil {
hm.recordFailure()
return
}
// Success!
hm.recordSuccess()
}
// recordSuccess records a successful health check
func (hm *HealthMonitor) recordSuccess() {
wasHealthy := hm.healthy.Load()
hm.consecutiveFailures.Store(0)
hm.healthy.Store(true)
// Trigger callback if health changed
if !wasHealthy && hm.config.OnHealthChange != nil {
hm.config.OnHealthChange(true)
}
}
// recordFailure records a failed health check
func (hm *HealthMonitor) recordFailure() {
hm.totalFailures.Add(1)
failures := hm.consecutiveFailures.Add(1)
wasHealthy := hm.healthy.Load()
// Mark unhealthy if threshold exceeded
if failures >= int64(hm.config.UnhealthyThreshold) {
hm.healthy.Store(false)
// Trigger callback if health changed
if wasHealthy && hm.config.OnHealthChange != nil {
hm.config.OnHealthChange(false)
}
}
}
+421
View File
@@ -0,0 +1,421 @@
package backends
import (
"context"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestHealthMonitor_BasicOperation tests basic health monitoring
func TestHealthMonitor_BasicOperation(t *testing.T) {
mr := NewMiniredisServer(t)
config := &PoolConfig{
Address: mr.GetAddr(),
MaxConnections: 5,
ConnectTimeout: 5 * time.Second,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
// Create health monitor with fast check interval for testing
hmConfig := &HealthMonitorConfig{
CheckInterval: 100 * time.Millisecond,
Timeout: 1 * time.Second,
UnhealthyThreshold: 2,
}
hm := NewHealthMonitor(pool, hmConfig)
require.NotNil(t, hm)
// Initially should be healthy
assert.True(t, hm.IsHealthy())
// Start monitoring
hm.Start()
defer hm.Stop()
// Wait for a few checks
time.Sleep(500 * time.Millisecond)
// Should still be healthy
assert.True(t, hm.IsHealthy())
// Check stats
stats := hm.GetStats()
require.NotNil(t, stats)
assert.True(t, stats["healthy"].(bool))
assert.Greater(t, stats["total_checks"].(int64), int64(0))
assert.Equal(t, int64(0), stats["consecutive_failures"].(int64))
}
// TestHealthMonitor_HealthyToUnhealthy tests transition to unhealthy state
func TestHealthMonitor_HealthyToUnhealthy(t *testing.T) {
mr := NewMiniredisServer(t)
config := &PoolConfig{
Address: mr.GetAddr(),
MaxConnections: 5,
ConnectTimeout: 100 * time.Millisecond,
ReadTimeout: 100 * time.Millisecond,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
var healthChangedCalled atomic.Bool
hmConfig := &HealthMonitorConfig{
CheckInterval: 50 * time.Millisecond,
Timeout: 100 * time.Millisecond,
UnhealthyThreshold: 2,
OnHealthChange: func(healthy bool) {
if !healthy {
healthChangedCalled.Store(true)
}
},
}
hm := NewHealthMonitor(pool, hmConfig)
hm.Start()
defer hm.Stop()
// Initially healthy
assert.True(t, hm.IsHealthy())
// Simulate Redis errors
mr.SetError("ERR server is down")
// Wait for health checks to detect failure (2 failures * 50ms + buffer)
time.Sleep(350 * time.Millisecond)
// Should now be unhealthy
assert.False(t, hm.IsHealthy(), "Health monitor should detect server failure")
assert.True(t, healthChangedCalled.Load(), "OnHealthChange callback should be called")
// Check stats
stats := hm.GetStats()
assert.False(t, stats["healthy"].(bool))
assert.GreaterOrEqual(t, stats["consecutive_failures"].(int64), int64(2))
assert.Greater(t, stats["total_failures"].(int64), int64(0))
}
// TestHealthMonitor_UnhealthyToHealthy tests recovery to healthy state
func TestHealthMonitor_UnhealthyToHealthy(t *testing.T) {
mr := NewMiniredisServer(t)
config := &PoolConfig{
Address: mr.GetAddr(),
MaxConnections: 5,
ConnectTimeout: 100 * time.Millisecond,
ReadTimeout: 100 * time.Millisecond,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
var recoveryDetected atomic.Bool
hmConfig := &HealthMonitorConfig{
CheckInterval: 50 * time.Millisecond,
Timeout: 100 * time.Millisecond,
UnhealthyThreshold: 2,
OnHealthChange: func(healthy bool) {
if healthy {
recoveryDetected.Store(true)
}
},
}
hm := NewHealthMonitor(pool, hmConfig)
hm.Start()
defer hm.Stop()
// Initially healthy
assert.True(t, hm.IsHealthy())
// Simulate Redis errors
mr.SetError("ERR server is down")
// Wait for health checks to detect failure
time.Sleep(350 * time.Millisecond)
// Should now be unhealthy
assert.False(t, hm.IsHealthy(), "Should detect server failure")
// Clear error to simulate recovery
mr.ClearError()
// Wait for recovery
time.Sleep(350 * time.Millisecond)
// Should be healthy again
assert.True(t, hm.IsHealthy(), "Should recover after server restart")
assert.True(t, recoveryDetected.Load(), "Recovery callback should be called")
// Consecutive failures should be reset
stats := hm.GetStats()
assert.True(t, stats["healthy"].(bool))
assert.Equal(t, int64(0), stats["consecutive_failures"].(int64))
}
// TestHealthMonitor_StartStop tests start/stop behavior
func TestHealthMonitor_StartStop(t *testing.T) {
mr := NewMiniredisServer(t)
config := &PoolConfig{
Address: mr.GetAddr(),
MaxConnections: 5,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
hm := NewHealthMonitor(pool, DefaultHealthMonitorConfig())
// Start monitoring
hm.Start()
assert.True(t, hm.running.Load())
// Starting again should be no-op
hm.Start()
assert.True(t, hm.running.Load())
// Stop monitoring
hm.Stop()
assert.False(t, hm.running.Load())
// Stopping again should be no-op
hm.Stop()
assert.False(t, hm.running.Load())
}
// TestHealthMonitor_MultipleMonitors tests multiple health monitors
func TestHealthMonitor_MultipleMonitors(t *testing.T) {
mr := NewMiniredisServer(t)
config := &PoolConfig{
Address: mr.GetAddr(),
MaxConnections: 10,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
// Create multiple monitors
hm1 := NewHealthMonitor(pool, &HealthMonitorConfig{
CheckInterval: 100 * time.Millisecond,
Timeout: 1 * time.Second,
UnhealthyThreshold: 2,
})
hm2 := NewHealthMonitor(pool, &HealthMonitorConfig{
CheckInterval: 150 * time.Millisecond,
Timeout: 1 * time.Second,
UnhealthyThreshold: 3,
})
// Start both
hm1.Start()
hm2.Start()
// Both should be healthy
time.Sleep(200 * time.Millisecond)
assert.True(t, hm1.IsHealthy())
assert.True(t, hm2.IsHealthy())
// Stop both
hm1.Stop()
hm2.Stop()
// Verify they stopped
assert.False(t, hm1.running.Load())
assert.False(t, hm2.running.Load())
}
// TestHealthMonitor_StatsAccuracy tests stats tracking
func TestHealthMonitor_StatsAccuracy(t *testing.T) {
mr := NewMiniredisServer(t)
config := &PoolConfig{
Address: mr.GetAddr(),
MaxConnections: 5,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
hm := NewHealthMonitor(pool, &HealthMonitorConfig{
CheckInterval: 100 * time.Millisecond,
Timeout: 1 * time.Second,
UnhealthyThreshold: 2,
})
hm.Start()
defer hm.Stop()
// Wait for some checks
time.Sleep(550 * time.Millisecond)
stats := hm.GetStats()
// Should have performed multiple checks
totalChecks := stats["total_checks"].(int64)
assert.GreaterOrEqual(t, totalChecks, int64(4))
// All checks should succeed
assert.Equal(t, int64(0), stats["total_failures"].(int64))
assert.Equal(t, int64(0), stats["consecutive_failures"].(int64))
// Last check time should be recent (within check interval + buffer)
// Use 2s tolerance to account for CI runner load and timing variance
lastCheck := stats["last_check"].(time.Time)
assert.WithinDuration(t, time.Now(), lastCheck, 2*time.Second)
}
// TestHealthMonitor_DefaultConfig tests default configuration
func TestHealthMonitor_DefaultConfig(t *testing.T) {
config := DefaultHealthMonitorConfig()
assert.Equal(t, 5*time.Second, config.CheckInterval)
assert.Equal(t, 3*time.Second, config.Timeout)
assert.Equal(t, 3, config.UnhealthyThreshold)
assert.Nil(t, config.OnHealthChange)
}
// TestHealthMonitor_PoolExhaustion tests behavior when pool is exhausted
func TestHealthMonitor_PoolExhaustion(t *testing.T) {
mr := NewMiniredisServer(t)
config := &PoolConfig{
Address: mr.GetAddr(),
MaxConnections: 1, // Very small pool
ConnectTimeout: 100 * time.Millisecond,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
hm := NewHealthMonitor(pool, &HealthMonitorConfig{
CheckInterval: 100 * time.Millisecond,
Timeout: 50 * time.Millisecond, // Short timeout
UnhealthyThreshold: 2,
})
hm.Start()
defer hm.Stop()
// Get the only connection, blocking health checks
ctx := context.Background()
conn, err := pool.Get(ctx)
require.NoError(t, err)
// Wait for health check attempts
time.Sleep(350 * time.Millisecond)
// Health monitor might mark as unhealthy due to timeouts
stats := hm.GetStats()
t.Logf("Stats with blocked pool: %+v", stats)
// Return connection
pool.Put(conn)
// Wait for recovery
time.Sleep(300 * time.Millisecond)
// Should recover
assert.True(t, hm.IsHealthy())
}
// TestConnectionPool_WithHealthChecks tests pool with health checks enabled
func TestConnectionPool_WithHealthChecks(t *testing.T) {
mr := NewMiniredisServer(t)
config := &PoolConfig{
Address: mr.GetAddr(),
MaxConnections: 5,
ConnectTimeout: 5 * time.Second,
EnableHealthCheck: true,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
ctx := context.Background()
// Get a connection
conn, err := pool.Get(ctx)
require.NoError(t, err)
require.NotNil(t, conn)
// Connection should be healthy
assert.True(t, pool.isConnectionHealthy(conn))
// Use connection
resp, err := conn.Do("PING")
require.NoError(t, err)
assert.Equal(t, "PONG", resp)
// Return to pool
pool.Put(conn)
// Get again - should reuse and validate
conn2, err := pool.Get(ctx)
require.NoError(t, err)
require.NotNil(t, conn2)
pool.Put(conn2)
}
// TestConnectionPool_StaleConnectionRemoval tests stale connection handling
func TestConnectionPool_StaleConnectionRemoval(t *testing.T) {
mr := NewMiniredisServer(t)
config := &PoolConfig{
Address: mr.GetAddr(),
MaxConnections: 3,
ConnectTimeout: 5 * time.Second,
EnableHealthCheck: true,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
ctx := context.Background()
// Get and return a connection
conn, err := pool.Get(ctx)
require.NoError(t, err)
pool.Put(conn)
initialTotal := pool.totalConns.Load()
// Close the connection manually to make it stale
conn.Close()
// Get another connection - should detect stale and create new
conn2, err := pool.Get(ctx)
require.NoError(t, err)
require.NotNil(t, conn2)
// Connection should be healthy
assert.True(t, pool.isConnectionHealthy(conn2))
pool.Put(conn2)
// Total connections might be same or less (stale removed)
finalTotal := pool.totalConns.Load()
assert.LessOrEqual(t, finalTotal, initialTotal+1)
}
+461
View File
@@ -0,0 +1,461 @@
package backends
import (
"context"
"fmt"
"testing"
"time"
"github.com/alicebob/miniredis/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// setupTestRedis creates a miniredis instance for testing
func setupTestRedis(t *testing.T) (*miniredis.Miniredis, *RedisBackend) {
t.Helper()
mr, err := miniredis.Run()
require.NoError(t, err)
t.Cleanup(func() {
mr.Close()
})
backend, err := NewRedisBackend(&Config{
RedisAddr: mr.Addr(),
RedisPrefix: "test:",
PoolSize: 5,
})
require.NoError(t, err)
t.Cleanup(func() {
backend.Close()
})
return mr, backend
}
// TestPipeline_Basic tests basic pipeline functionality
func TestPipeline_Basic(t *testing.T) {
t.Parallel()
mr, err := miniredis.Run()
require.NoError(t, err)
defer mr.Close()
config := &PoolConfig{
Address: mr.Addr(),
MaxConnections: 5,
ConnectTimeout: 5 * time.Second,
ReadTimeout: 1 * time.Second,
WriteTimeout: 1 * time.Second,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
ctx := context.Background()
conn, err := pool.Get(ctx)
require.NoError(t, err)
defer pool.Put(conn)
t.Run("SingleCommand", func(t *testing.T) {
pipeline := conn.NewPipeline()
pipeline.Queue("SET", "single-key", "single-value")
responses, err := pipeline.Execute()
require.NoError(t, err)
require.Len(t, responses, 1)
assert.Equal(t, "OK", responses[0])
})
t.Run("MultipleCommands", func(t *testing.T) {
pipeline := conn.NewPipeline()
pipeline.Queue("SET", "key1", "value1")
pipeline.Queue("SET", "key2", "value2")
pipeline.Queue("SET", "key3", "value3")
pipeline.Queue("GET", "key1")
pipeline.Queue("GET", "key2")
pipeline.Queue("GET", "key3")
responses, err := pipeline.Execute()
require.NoError(t, err)
require.Len(t, responses, 6)
// First 3 are SET responses
assert.Equal(t, "OK", responses[0])
assert.Equal(t, "OK", responses[1])
assert.Equal(t, "OK", responses[2])
// Last 3 are GET responses
assert.Equal(t, "value1", responses[3])
assert.Equal(t, "value2", responses[4])
assert.Equal(t, "value3", responses[5])
})
t.Run("EmptyPipeline", func(t *testing.T) {
pipeline := conn.NewPipeline()
responses, err := pipeline.Execute()
require.NoError(t, err)
assert.Nil(t, responses)
})
t.Run("NilResponses", func(t *testing.T) {
pipeline := conn.NewPipeline()
pipeline.Queue("GET", "nonexistent-key")
responses, err := pipeline.Execute()
require.NoError(t, err)
require.Len(t, responses, 1)
assert.Nil(t, responses[0])
})
}
// TestPipeline_SetMany tests pipelined SetMany
func TestPipeline_SetMany(t *testing.T) {
t.Parallel()
_, backend := setupTestRedis(t)
ctx := context.Background()
t.Run("SetManyItems", func(t *testing.T) {
items := make(map[string][]byte)
for i := 0; i < 10; i++ {
items[fmt.Sprintf("setmany-key-%d", i)] = []byte(fmt.Sprintf("value-%d", i))
}
err := backend.SetMany(ctx, items, time.Minute)
require.NoError(t, err)
// Verify all items were set
for key, expectedValue := range items {
value, _, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists, "Key %s should exist", key)
assert.Equal(t, expectedValue, value)
}
})
t.Run("SetManyEmpty", func(t *testing.T) {
err := backend.SetMany(ctx, map[string][]byte{}, time.Minute)
require.NoError(t, err)
})
t.Run("SetManySingleItem", func(t *testing.T) {
items := map[string][]byte{
"single-setmany": []byte("single-value"),
}
err := backend.SetMany(ctx, items, time.Minute)
require.NoError(t, err)
value, _, exists, err := backend.Get(ctx, "single-setmany")
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, []byte("single-value"), value)
})
t.Run("SetManyNoTTL", func(t *testing.T) {
items := map[string][]byte{
"nottl-key1": []byte("value1"),
"nottl-key2": []byte("value2"),
}
err := backend.SetMany(ctx, items, 0)
require.NoError(t, err)
// Keys should exist
for key := range items {
exists, err := backend.Exists(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
}
})
}
// TestPipeline_GetMany tests pipelined GetMany
func TestPipeline_GetMany(t *testing.T) {
t.Parallel()
_, backend := setupTestRedis(t)
ctx := context.Background()
// Pre-populate cache
for i := 0; i < 10; i++ {
key := fmt.Sprintf("getmany-key-%d", i)
value := []byte(fmt.Sprintf("value-%d", i))
err := backend.Set(ctx, key, value, time.Minute)
require.NoError(t, err)
}
t.Run("GetManyExisting", func(t *testing.T) {
keys := make([]string, 10)
for i := 0; i < 10; i++ {
keys[i] = fmt.Sprintf("getmany-key-%d", i)
}
results, err := backend.GetMany(ctx, keys)
require.NoError(t, err)
assert.Len(t, results, 10)
for i, key := range keys {
assert.Equal(t, []byte(fmt.Sprintf("value-%d", i)), results[key])
}
})
t.Run("GetManyMixed", func(t *testing.T) {
keys := []string{
"getmany-key-0", // exists
"nonexistent-key-1", // doesn't exist
"getmany-key-2", // exists
"nonexistent-key-2", // doesn't exist
}
results, err := backend.GetMany(ctx, keys)
require.NoError(t, err)
assert.Len(t, results, 2) // Only existing keys
assert.Equal(t, []byte("value-0"), results["getmany-key-0"])
assert.Equal(t, []byte("value-2"), results["getmany-key-2"])
assert.NotContains(t, results, "nonexistent-key-1")
assert.NotContains(t, results, "nonexistent-key-2")
})
t.Run("GetManyEmpty", func(t *testing.T) {
results, err := backend.GetMany(ctx, []string{})
require.NoError(t, err)
assert.NotNil(t, results)
assert.Len(t, results, 0)
})
t.Run("GetManySingleKey", func(t *testing.T) {
results, err := backend.GetMany(ctx, []string{"getmany-key-5"})
require.NoError(t, err)
assert.Len(t, results, 1)
assert.Equal(t, []byte("value-5"), results["getmany-key-5"])
})
t.Run("GetManyAllNonexistent", func(t *testing.T) {
keys := []string{
"nonexistent-1",
"nonexistent-2",
"nonexistent-3",
}
results, err := backend.GetMany(ctx, keys)
require.NoError(t, err)
assert.Len(t, results, 0)
})
}
// TestPipeline_LargeBatch tests pipelining with large batches
func TestPipeline_LargeBatch(t *testing.T) {
t.Parallel()
_, backend := setupTestRedis(t)
ctx := context.Background()
t.Run("SetMany100Items", func(t *testing.T) {
items := make(map[string][]byte)
for i := 0; i < 100; i++ {
items[fmt.Sprintf("large-batch-%d", i)] = []byte(fmt.Sprintf("value-%d", i))
}
err := backend.SetMany(ctx, items, time.Minute)
require.NoError(t, err)
// Verify random samples
for _, i := range []int{0, 25, 50, 75, 99} {
key := fmt.Sprintf("large-batch-%d", i)
value, _, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, []byte(fmt.Sprintf("value-%d", i)), value)
}
})
t.Run("GetMany100Items", func(t *testing.T) {
keys := make([]string, 100)
for i := 0; i < 100; i++ {
keys[i] = fmt.Sprintf("large-batch-%d", i)
}
results, err := backend.GetMany(ctx, keys)
require.NoError(t, err)
assert.Len(t, results, 100)
})
}
// TestPipeline_Stats tests that stats are tracked correctly with pipelining
func TestPipeline_Stats(t *testing.T) {
t.Parallel()
_, backend := setupTestRedis(t)
ctx := context.Background()
// Set some items
items := map[string][]byte{
"stats-key-1": []byte("value1"),
"stats-key-2": []byte("value2"),
}
err := backend.SetMany(ctx, items, time.Minute)
require.NoError(t, err)
// Get items (some exist, some don't)
keys := []string{
"stats-key-1",
"stats-key-2",
"stats-key-nonexistent",
}
results, err := backend.GetMany(ctx, keys)
require.NoError(t, err)
assert.Len(t, results, 2)
// Check stats
stats := backend.GetStats()
hits := stats["hits"].(int64)
misses := stats["misses"].(int64)
assert.Equal(t, int64(2), hits, "Should have 2 hits")
assert.Equal(t, int64(1), misses, "Should have 1 miss")
}
// BenchmarkPipeline_SetMany benchmarks SetMany with pipelining
func BenchmarkPipeline_SetMany(b *testing.B) {
mr, err := miniredis.Run()
if err != nil {
b.Fatal(err)
}
defer mr.Close()
backend, err := NewRedisBackend(&Config{
RedisAddr: mr.Addr(),
RedisPrefix: "bench:",
PoolSize: 10,
})
if err != nil {
b.Fatal(err)
}
defer backend.Close()
ctx := context.Background()
// Prepare items
items := make(map[string][]byte)
for i := 0; i < 100; i++ {
items[fmt.Sprintf("bench-key-%d", i)] = []byte(fmt.Sprintf("bench-value-%d", i))
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = backend.SetMany(ctx, items, time.Minute)
}
}
// BenchmarkPipeline_GetMany benchmarks GetMany with pipelining
func BenchmarkPipeline_GetMany(b *testing.B) {
mr, err := miniredis.Run()
if err != nil {
b.Fatal(err)
}
defer mr.Close()
backend, err := NewRedisBackend(&Config{
RedisAddr: mr.Addr(),
RedisPrefix: "bench:",
PoolSize: 10,
})
if err != nil {
b.Fatal(err)
}
defer backend.Close()
ctx := context.Background()
// Pre-populate cache
for i := 0; i < 100; i++ {
key := fmt.Sprintf("bench-key-%d", i)
value := []byte(fmt.Sprintf("bench-value-%d", i))
backend.Set(ctx, key, value, time.Hour)
}
// Prepare keys
keys := make([]string, 100)
for i := 0; i < 100; i++ {
keys[i] = fmt.Sprintf("bench-key-%d", i)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = backend.GetMany(ctx, keys)
}
}
// BenchmarkPipeline_VsSequential benchmarks pipeline vs sequential operations
func BenchmarkPipeline_VsSequential(b *testing.B) {
mr, err := miniredis.Run()
if err != nil {
b.Fatal(err)
}
defer mr.Close()
backend, err := NewRedisBackend(&Config{
RedisAddr: mr.Addr(),
RedisPrefix: "bench:",
PoolSize: 10,
})
if err != nil {
b.Fatal(err)
}
defer backend.Close()
ctx := context.Background()
// Prepare items
items := make(map[string][]byte)
keys := make([]string, 50)
for i := 0; i < 50; i++ {
key := fmt.Sprintf("compare-key-%d", i)
keys[i] = key
items[key] = []byte(fmt.Sprintf("compare-value-%d", i))
}
b.Run("Pipelined-Set", func(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = backend.SetMany(ctx, items, time.Minute)
}
})
b.Run("Sequential-Set", func(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
for key, value := range items {
_ = backend.Set(ctx, key, value, time.Minute)
}
}
})
// Pre-populate for get benchmarks
_ = backend.SetMany(ctx, items, time.Hour)
b.Run("Pipelined-Get", func(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = backend.GetMany(ctx, keys)
}
})
b.Run("Sequential-Get", func(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
for _, key := range keys {
_, _, _, _ = backend.Get(ctx, key)
}
}
})
}
+455
View File
@@ -0,0 +1,455 @@
package backends
import (
"context"
"errors"
"fmt"
"net"
"sync"
"sync/atomic"
"time"
)
// ConnectionPool manages a pool of Redis connections
// Pure-Go implementation compatible with Yaegi
type ConnectionPool struct {
config *PoolConfig
connections chan *RedisConn
mu sync.Mutex
closed atomic.Bool
// Metrics
activeConns atomic.Int32
totalConns atomic.Int32
gets atomic.Int64
puts atomic.Int64
timeouts atomic.Int64
}
// PoolConfig holds connection pool configuration
type PoolConfig struct {
Address string
Password string
DB int
MaxConnections int
ConnectTimeout time.Duration
ReadTimeout time.Duration
WriteTimeout time.Duration
EnableHealthCheck bool // Enable connection health validation
MaxRetries int // Max retries for failed operations
RetryDelay time.Duration // Initial delay between retries
}
// NewConnectionPool creates a new connection pool
func NewConnectionPool(config *PoolConfig) (*ConnectionPool, error) {
if config == nil {
return nil, errors.New("config is required")
}
if config.MaxConnections <= 0 {
config.MaxConnections = 10
}
if config.ConnectTimeout == 0 {
config.ConnectTimeout = 5 * time.Second
}
pool := &ConnectionPool{
config: config,
connections: make(chan *RedisConn, config.MaxConnections),
}
return pool, nil
}
// Get retrieves a connection from the pool or creates a new one
func (p *ConnectionPool) Get(ctx context.Context) (*RedisConn, error) {
if p.closed.Load() {
return nil, ErrBackendClosed
}
p.gets.Add(1)
// Try to get a connection with validation
maxAttempts := 3
for attempt := 0; attempt < maxAttempts; attempt++ {
var conn *RedisConn
var err error
select {
case conn = <-p.connections:
// Reuse existing connection - validate if health check enabled
if p.config.EnableHealthCheck && !p.isConnectionHealthy(conn) {
// Connection is stale, close it and try again
_ = conn.Close()
p.totalConns.Add(-1)
continue
}
p.activeConns.Add(1)
return conn, nil
case <-ctx.Done():
return nil, ctx.Err()
default:
// No available connection, create new one if under limit
// #nosec G115 -- MaxConnections is a small config value that fits in int32
if p.totalConns.Load() < int32(p.config.MaxConnections) {
conn, err = p.createConnection()
if err != nil {
// If this is the last attempt, return error
if attempt == maxAttempts-1 {
return nil, err
}
// Wait before retry with exponential backoff
time.Sleep(time.Duration(attempt+1) * 100 * time.Millisecond)
continue
}
p.activeConns.Add(1)
p.totalConns.Add(1)
return conn, nil
}
// Pool exhausted, wait for a connection with timeout
select {
case conn = <-p.connections:
// Validate connection if health check enabled
if p.config.EnableHealthCheck && !p.isConnectionHealthy(conn) {
_ = conn.Close()
p.totalConns.Add(-1)
continue
}
p.activeConns.Add(1)
return conn, nil
case <-ctx.Done():
p.timeouts.Add(1)
return nil, ctx.Err()
case <-time.After(p.config.ConnectTimeout):
p.timeouts.Add(1)
return nil, ErrPoolExhausted
}
}
}
return nil, errors.New("failed to get healthy connection after retries")
}
// Put returns a connection to the pool
func (p *ConnectionPool) Put(conn *RedisConn) {
if conn == nil {
return
}
p.puts.Add(1)
p.activeConns.Add(-1)
if p.closed.Load() || conn.closed.Load() {
_ = conn.Close()
p.totalConns.Add(-1)
return
}
// Return to pool (non-blocking)
select {
case p.connections <- conn:
// Successfully returned to pool
default:
// Pool full, close connection
_ = conn.Close()
p.totalConns.Add(-1)
}
}
// Close closes all connections in the pool
func (p *ConnectionPool) Close() error {
if p.closed.Swap(true) {
return nil
}
p.mu.Lock()
defer p.mu.Unlock()
close(p.connections)
// Close all pooled connections
for conn := range p.connections {
_ = conn.Close()
}
return nil
}
// Stats returns pool statistics
func (p *ConnectionPool) Stats() map[string]interface{} {
return map[string]interface{}{
"active_connections": p.activeConns.Load(),
"total_connections": p.totalConns.Load(),
"max_connections": p.config.MaxConnections,
"gets": p.gets.Load(),
"puts": p.puts.Load(),
"timeouts": p.timeouts.Load(),
}
}
// createConnection creates a new Redis connection
func (p *ConnectionPool) createConnection() (*RedisConn, error) {
// Connect with timeout
dialer := &net.Dialer{
Timeout: p.config.ConnectTimeout,
}
conn, err := dialer.Dial("tcp", p.config.Address)
if err != nil {
return nil, fmt.Errorf("failed to connect to Redis: %w", err)
}
redisConn := &RedisConn{
conn: conn,
readTimeout: p.config.ReadTimeout,
writeTimeout: p.config.WriteTimeout,
}
// Authenticate if password is provided
if p.config.Password != "" {
if _, err := redisConn.Do("AUTH", p.config.Password); err != nil {
_ = redisConn.Close()
return nil, fmt.Errorf("authentication failed: %w", err)
}
}
// Select database
if p.config.DB != 0 {
if _, err := redisConn.Do("SELECT", fmt.Sprintf("%d", p.config.DB)); err != nil {
_ = redisConn.Close()
return nil, fmt.Errorf("failed to select database: %w", err)
}
}
return redisConn, nil
}
// RedisConn represents a single Redis connection
type RedisConn struct {
conn net.Conn
readTimeout time.Duration
writeTimeout time.Duration
closed atomic.Bool
mu sync.Mutex
}
// Do executes a Redis command and returns the response
func (c *RedisConn) Do(command string, args ...string) (interface{}, error) {
if c.closed.Load() {
return nil, ErrBackendClosed
}
c.mu.Lock()
defer c.mu.Unlock()
// Validate argument count to prevent integer overflow in slice operations
// maxSafeArgs is set to (1<<20)-1 = 1,048,575 which is more than any reasonable Redis command
const maxSafeArgs = (1 << 20) - 1
if len(args) > maxSafeArgs {
return nil, errors.New("too many arguments: exceeds maximum safe count")
}
// Build command arguments
// Validate total argument size to prevent memory exhaustion
const maxTotalArgBytes = 64 << 20 // 64 MiB max total size
totalBytes := len(command)
for _, s := range args {
// Protect against possible overflow
if len(s) > maxTotalArgBytes-totalBytes {
return nil, errors.New("arguments too large (would overflow maximum allowed total size)")
}
totalBytes += len(s)
if totalBytes > maxTotalArgBytes {
return nil, errors.New("total argument size exceeds maximum allowed")
}
}
// Build command slice: prepend command to args
// Using append avoids arithmetic on potentially large len(args)
cmdArgs := append([]string{command}, args...)
// Set write timeout
if c.writeTimeout > 0 {
_ = c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
}
// Write command (using pooled writer for memory efficiency)
writer := NewRESPWriter(c.conn)
err := writer.WriteCommand(cmdArgs...)
writer.Release() // Return to pool immediately after use
if err != nil {
c.closed.Store(true)
return nil, err
}
// Set read timeout
if c.readTimeout > 0 {
_ = c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
}
// Read response (using pooled reader for memory efficiency)
reader := NewRESPReader(c.conn)
resp, err := reader.ReadResponse()
reader.Release() // Return to pool immediately after use
if err != nil {
if !errors.Is(err, ErrNilResponse) {
c.closed.Store(true)
}
return nil, err
}
return resp, nil
}
// Close closes the connection
func (c *RedisConn) Close() error {
if c.closed.Swap(true) {
return nil
}
c.mu.Lock()
defer c.mu.Unlock()
if c.conn != nil {
return c.conn.Close()
}
return nil
}
// isConnectionHealthy validates a connection is still working
func (p *ConnectionPool) isConnectionHealthy(conn *RedisConn) bool {
if conn == nil || conn.closed.Load() {
return false
}
// Set a read deadline for the ping
if conn.conn != nil {
_ = conn.conn.SetReadDeadline(time.Now().Add(1 * time.Second))
defer func() { _ = conn.conn.SetReadDeadline(time.Time{}) }() // Clear deadline
}
_, err := conn.Do("PING")
return err == nil
}
// Pipeline represents a Redis pipeline for batch operations
// It queues multiple commands and executes them in a single round-trip
type Pipeline struct {
conn *RedisConn
commands []pipelineCommand
mu sync.Mutex
}
// pipelineCommand represents a single command in the pipeline
type pipelineCommand struct {
command string
args []string
}
// NewPipeline creates a new pipeline for the connection
func (c *RedisConn) NewPipeline() *Pipeline {
return &Pipeline{
conn: c,
commands: make([]pipelineCommand, 0, 16), // Pre-allocate for typical batch size
}
}
// Queue adds a command to the pipeline
func (p *Pipeline) Queue(command string, args ...string) {
p.mu.Lock()
defer p.mu.Unlock()
p.commands = append(p.commands, pipelineCommand{
command: command,
args: args,
})
}
// Execute sends all queued commands and returns all responses
// Returns a slice of responses in the same order as commands were queued
func (p *Pipeline) Execute() ([]interface{}, error) {
p.mu.Lock()
defer p.mu.Unlock()
if len(p.commands) == 0 {
return nil, nil
}
if p.conn.closed.Load() {
return nil, ErrBackendClosed
}
p.conn.mu.Lock()
defer p.conn.mu.Unlock()
// Set write timeout for all commands
if p.conn.writeTimeout > 0 {
// Use longer timeout for batch operations
timeout := p.conn.writeTimeout * time.Duration(len(p.commands))
if timeout > 30*time.Second {
timeout = 30 * time.Second // Cap at 30 seconds
}
_ = p.conn.conn.SetWriteDeadline(time.Now().Add(timeout))
}
// Write all commands (pipelining - send all before reading any responses)
writer := NewRESPWriter(p.conn.conn)
for _, cmd := range p.commands {
cmdArgs := append([]string{cmd.command}, cmd.args...)
if err := writer.WriteCommand(cmdArgs...); err != nil {
writer.Release()
p.conn.closed.Store(true)
return nil, fmt.Errorf("pipeline write error: %w", err)
}
}
writer.Release()
// Set read timeout for all responses
if p.conn.readTimeout > 0 {
timeout := p.conn.readTimeout * time.Duration(len(p.commands))
if timeout > 30*time.Second {
timeout = 30 * time.Second
}
_ = p.conn.conn.SetReadDeadline(time.Now().Add(timeout))
}
// Read all responses
responses := make([]interface{}, len(p.commands))
reader := NewRESPReader(p.conn.conn)
defer reader.Release()
for i := range p.commands {
resp, err := reader.ReadResponse()
if err != nil {
// For nil responses, store nil instead of erroring
if errors.Is(err, ErrNilResponse) {
responses[i] = nil
continue
}
p.conn.closed.Store(true)
return responses[:i], fmt.Errorf("pipeline read error at command %d: %w", i, err)
}
responses[i] = resp
}
return responses, nil
}
// Clear resets the pipeline for reuse
func (p *Pipeline) Clear() {
p.mu.Lock()
defer p.mu.Unlock()
p.commands = p.commands[:0]
}
// Len returns the number of queued commands
func (p *Pipeline) Len() int {
p.mu.Lock()
defer p.mu.Unlock()
return len(p.commands)
}
+620
View File
@@ -0,0 +1,620 @@
package backends
import (
"context"
"errors"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestConnectionPool_BasicOperations tests basic pool operations
func TestConnectionPool_BasicOperations(t *testing.T) {
mr := NewMiniredisServer(t)
config := &PoolConfig{
Address: mr.GetAddr(),
MaxConnections: 5,
ConnectTimeout: 5 * time.Second,
ReadTimeout: 3 * time.Second,
WriteTimeout: 3 * time.Second,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
t.Run("GetAndPutConnection", func(t *testing.T) {
ctx := context.Background()
// Get a connection
conn, err := pool.Get(ctx)
require.NoError(t, err)
require.NotNil(t, conn)
// Verify connection works
resp, err := conn.Do("PING")
require.NoError(t, err)
assert.Equal(t, "PONG", resp)
// Return to pool
pool.Put(conn)
// Get again - should reuse same connection
conn2, err := pool.Get(ctx)
require.NoError(t, err)
require.NotNil(t, conn2)
pool.Put(conn2)
})
t.Run("Stats", func(t *testing.T) {
stats := pool.Stats()
require.NotNil(t, stats)
assert.Contains(t, stats, "active_connections")
assert.Contains(t, stats, "total_connections")
assert.Contains(t, stats, "max_connections")
assert.Equal(t, 5, stats["max_connections"])
})
}
// TestConnectionPool_MaxConnections tests pool size limits
func TestConnectionPool_MaxConnections(t *testing.T) {
mr := NewMiniredisServer(t)
maxConns := 3
config := &PoolConfig{
Address: mr.GetAddr(),
MaxConnections: maxConns,
ConnectTimeout: 1 * time.Second,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
ctx := context.Background()
// Get max connections
conns := make([]*RedisConn, maxConns)
for i := 0; i < maxConns; i++ {
conn, err := pool.Get(ctx)
require.NoError(t, err)
conns[i] = conn
}
// Verify stats
stats := pool.Stats()
assert.Equal(t, int32(maxConns), stats["total_connections"])
assert.Equal(t, int32(maxConns), stats["active_connections"])
// Try to get one more - should block/timeout
ctx2, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
conn, err := pool.Get(ctx2)
require.Error(t, err)
require.Nil(t, conn)
// Return one connection
pool.Put(conns[0])
// Now we should be able to get a connection
conn, err = pool.Get(context.Background())
require.NoError(t, err)
require.NotNil(t, conn)
// Cleanup
pool.Put(conn)
for i := 1; i < maxConns; i++ {
pool.Put(conns[i])
}
}
// TestConnectionPool_ConcurrentAccess tests concurrent pool usage
func TestConnectionPool_ConcurrentAccess(t *testing.T) {
mr := NewMiniredisServer(t)
config := &PoolConfig{
Address: mr.GetAddr(),
MaxConnections: 10,
ConnectTimeout: 5 * time.Second,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
ctx := context.Background()
numGoroutines := 50
numOperations := 20
var wg sync.WaitGroup
errors := make(chan error, numGoroutines*numOperations)
// Spawn goroutines
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
for j := 0; j < numOperations; j++ {
conn, err := pool.Get(ctx)
if err != nil {
errors <- err
continue
}
// Do some work
_, err = conn.Do("PING")
if err != nil {
errors <- err
}
// Return to pool
pool.Put(conn)
// Small delay
time.Sleep(time.Millisecond)
}
}(i)
}
wg.Wait()
close(errors)
// Check for errors
errorCount := 0
for err := range errors {
t.Logf("Error: %v", err)
errorCount++
}
assert.Equal(t, 0, errorCount, "Expected no errors in concurrent access")
// Verify stats
stats := pool.Stats()
t.Logf("Final stats: %+v", stats)
assert.LessOrEqual(t, stats["total_connections"].(int32), int32(10))
assert.Equal(t, int32(0), stats["active_connections"])
}
// TestConnectionPool_ContextCancellation tests context cancellation
func TestConnectionPool_ContextCancellation(t *testing.T) {
mr := NewMiniredisServer(t)
config := &PoolConfig{
Address: mr.GetAddr(),
MaxConnections: 1,
ConnectTimeout: 5 * time.Second,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
// Get the only connection
conn, err := pool.Get(context.Background())
require.NoError(t, err)
// Try to get another with cancelled context
ctx, cancel := context.WithCancel(context.Background())
cancel() // Cancel immediately
conn2, err := pool.Get(ctx)
require.Error(t, err)
require.Nil(t, conn2)
assert.Contains(t, err.Error(), "context canceled")
// Cleanup
pool.Put(conn)
}
// TestConnectionPool_Authentication tests auth support
func TestConnectionPool_Authentication(t *testing.T) {
mr := NewMiniredisServer(t)
// Set password on miniredis
mr.server.RequireAuth("secret-password")
t.Run("CorrectPassword", func(t *testing.T) {
config := &PoolConfig{
Address: mr.GetAddr(),
Password: "secret-password",
MaxConnections: 2,
ConnectTimeout: 5 * time.Second,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
conn, err := pool.Get(context.Background())
require.NoError(t, err)
resp, err := conn.Do("PING")
require.NoError(t, err)
assert.Equal(t, "PONG", resp)
pool.Put(conn)
})
t.Run("WrongPassword", func(t *testing.T) {
t.Skip("Miniredis doesn't fully simulate AUTH errors like real Redis")
config := &PoolConfig{
Address: mr.GetAddr(),
Password: "wrong-password",
MaxConnections: 2,
ConnectTimeout: 5 * time.Second,
}
_, err := NewConnectionPool(config)
require.Error(t, err)
assert.Contains(t, err.Error(), "authentication failed")
})
}
// TestConnectionPool_DatabaseSelection tests DB selection
func TestConnectionPool_DatabaseSelection(t *testing.T) {
mr := NewMiniredisServer(t)
config := &PoolConfig{
Address: mr.GetAddr(),
DB: 5,
MaxConnections: 2,
ConnectTimeout: 5 * time.Second,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
conn, err := pool.Get(context.Background())
require.NoError(t, err)
// Connection should be on DB 5
resp, err := conn.Do("PING")
require.NoError(t, err)
assert.Equal(t, "PONG", resp)
pool.Put(conn)
}
// TestConnectionPool_ClosedConnection tests handling closed connections
func TestConnectionPool_ClosedConnection(t *testing.T) {
mr := NewMiniredisServer(t)
config := &PoolConfig{
Address: mr.GetAddr(),
MaxConnections: 2,
ConnectTimeout: 5 * time.Second,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
// Get connection
conn, err := pool.Get(context.Background())
require.NoError(t, err)
// Close it manually
conn.Close()
// Try to use it
_, err = conn.Do("PING")
require.Error(t, err)
assert.True(t, errors.Is(err, ErrBackendClosed))
// Return to pool (should be discarded)
pool.Put(conn)
// Get new connection - should create a new one
conn2, err := pool.Get(context.Background())
require.NoError(t, err)
require.NotNil(t, conn2)
resp, err := conn2.Do("PING")
require.NoError(t, err)
assert.Equal(t, "PONG", resp)
pool.Put(conn2)
}
// TestConnectionPool_Close tests pool closure
func TestConnectionPool_Close(t *testing.T) {
mr := NewMiniredisServer(t)
config := &PoolConfig{
Address: mr.GetAddr(),
MaxConnections: 5,
ConnectTimeout: 5 * time.Second,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
// Get some connections
conns := make([]*RedisConn, 3)
for i := 0; i < 3; i++ {
conn, err := pool.Get(context.Background())
require.NoError(t, err)
conns[i] = conn
}
// Return them
for _, conn := range conns {
pool.Put(conn)
}
// Close pool
err = pool.Close()
require.NoError(t, err)
// Try to get connection from closed pool
_, err = pool.Get(context.Background())
require.Error(t, err)
assert.True(t, errors.Is(err, ErrBackendClosed))
// Close again should be no-op
err = pool.Close()
require.NoError(t, err)
}
// TestConnectionPool_Timeouts tests various timeout scenarios
func TestConnectionPool_Timeouts(t *testing.T) {
mr := NewMiniredisServer(t)
config := &PoolConfig{
Address: mr.GetAddr(),
MaxConnections: 2,
ConnectTimeout: 100 * time.Millisecond,
ReadTimeout: 100 * time.Millisecond,
WriteTimeout: 100 * time.Millisecond,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
conn, err := pool.Get(context.Background())
require.NoError(t, err)
// Normal operation should work
resp, err := conn.Do("PING")
require.NoError(t, err)
assert.Equal(t, "PONG", resp)
pool.Put(conn)
}
// TestRedisConn_DoCommand tests the Do method
func TestRedisConn_DoCommand(t *testing.T) {
mr := NewMiniredisServer(t)
config := &PoolConfig{
Address: mr.GetAddr(),
MaxConnections: 2,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
conn, err := pool.Get(context.Background())
require.NoError(t, err)
defer pool.Put(conn)
t.Run("SET and GET", func(t *testing.T) {
// SET
resp, err := conn.Do("SET", "testkey", "testvalue")
require.NoError(t, err)
assert.Equal(t, "OK", resp)
// GET
resp, err = conn.Do("GET", "testkey")
require.NoError(t, err)
assert.Equal(t, "testvalue", resp)
})
t.Run("DEL", func(t *testing.T) {
// SET key first
_, err := conn.Do("SET", "delkey", "delvalue")
require.NoError(t, err)
// DEL
resp, err := conn.Do("DEL", "delkey")
require.NoError(t, err)
count, err := RESPInt(resp)
require.NoError(t, err)
assert.Equal(t, int64(1), count)
})
t.Run("EXISTS", func(t *testing.T) {
// SET key first
_, err := conn.Do("SET", "existskey", "value")
require.NoError(t, err)
// EXISTS - key exists
resp, err := conn.Do("EXISTS", "existskey")
require.NoError(t, err)
count, err := RESPInt(resp)
require.NoError(t, err)
assert.Equal(t, int64(1), count)
// EXISTS - key doesn't exist
resp, err = conn.Do("EXISTS", "nonexistent")
require.NoError(t, err)
count, err = RESPInt(resp)
require.NoError(t, err)
assert.Equal(t, int64(0), count)
})
t.Run("TTL commands", func(t *testing.T) {
// SETEX
resp, err := conn.Do("SETEX", "ttlkey", "60", "ttlvalue")
require.NoError(t, err)
assert.Equal(t, "OK", resp)
// TTL
resp, err = conn.Do("TTL", "ttlkey")
require.NoError(t, err)
ttl, err := RESPInt(resp)
require.NoError(t, err)
assert.Greater(t, ttl, int64(0))
assert.LessOrEqual(t, ttl, int64(60))
})
}
// TestPoolConfig_Defaults tests default configuration values
func TestPoolConfig_Defaults(t *testing.T) {
mr := NewMiniredisServer(t)
config := &PoolConfig{
Address: mr.GetAddr(),
// Leave other fields at zero values
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
// Should use defaults
assert.Equal(t, 10, pool.config.MaxConnections)
assert.Equal(t, 5*time.Second, pool.config.ConnectTimeout)
// Verify it works
conn, err := pool.Get(context.Background())
require.NoError(t, err)
pool.Put(conn)
}
// TestConnectionPool_NilConnection tests handling nil connections
func TestConnectionPool_NilConnection(t *testing.T) {
mr := NewMiniredisServer(t)
config := &PoolConfig{
Address: mr.GetAddr(),
MaxConnections: 2,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
// Putting nil should be safe
pool.Put(nil)
// Pool should still work
conn, err := pool.Get(context.Background())
require.NoError(t, err)
require.NotNil(t, conn)
pool.Put(conn)
}
// TestConnectionPool_StatsTracking tests metrics tracking
func TestConnectionPool_StatsTracking(t *testing.T) {
mr := NewMiniredisServer(t)
config := &PoolConfig{
Address: mr.GetAddr(),
MaxConnections: 5,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
ctx := context.Background()
// Initial stats
stats := pool.Stats()
initialGets := stats["gets"].(int64)
initialPuts := stats["puts"].(int64)
// Perform operations
numOps := 10
for i := 0; i < numOps; i++ {
conn, err := pool.Get(ctx)
require.NoError(t, err)
pool.Put(conn)
}
// Check updated stats
stats = pool.Stats()
assert.Equal(t, initialGets+int64(numOps), stats["gets"].(int64))
assert.Equal(t, initialPuts+int64(numOps), stats["puts"].(int64))
assert.Equal(t, int32(0), stats["active_connections"].(int32))
}
// TestRedisConn_TooManyArguments tests protection against allocation overflow
func TestRedisConn_TooManyArguments(t *testing.T) {
mr := NewMiniredisServer(t)
config := &PoolConfig{
Address: mr.GetAddr(),
MaxConnections: 1,
ConnectTimeout: 5 * time.Second,
ReadTimeout: 3 * time.Second,
WriteTimeout: 3 * time.Second,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
ctx := context.Background()
conn, err := pool.Get(ctx)
require.NoError(t, err)
defer pool.Put(conn)
t.Run("AcceptableArgumentCount", func(t *testing.T) {
// Should work with reasonable number of args
args := make([]string, 100)
for i := range args {
args[i] = "value"
}
_, err := conn.Do("MSET", args...)
// May fail due to Redis constraints, but shouldn't panic or error on overflow
// Just verify it doesn't trigger our overflow protection
if err != nil {
assert.NotContains(t, err.Error(), "too many arguments")
}
})
t.Run("RejectExcessiveArguments", func(t *testing.T) {
// Create an absurdly large number of arguments that would cause overflow
// Use 1M + 1 to exceed maxSafeArgs = (1<<20)-1 = 1048575
args := make([]string, 1<<20) // 1,048,576 args
for i := range args {
args[i] = "x"
}
_, err := conn.Do("MSET", args...)
require.Error(t, err)
assert.Contains(t, err.Error(), "too many arguments")
})
t.Run("BoundaryCase", func(t *testing.T) {
// Test exactly at the boundary (maxSafeArgs)
args := make([]string, (1<<20)-1) // Exactly 1,048,575 args (max allowed)
for i := range args {
args[i] = "x"
}
_, err := conn.Do("ECHO", args...)
// Should not error due to overflow protection
if err != nil {
assert.NotContains(t, err.Error(), "too many arguments")
}
})
}
+545
View File
@@ -0,0 +1,545 @@
package backends
import (
"context"
"fmt"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestRedisBackend_BasicOperations tests basic Redis operations
func TestRedisBackend_BasicOperations(t *testing.T) {
t.Parallel()
mr := NewMiniredisServer(t)
config := DefaultRedisConfig(mr.GetAddr())
backend, err := NewRedisBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
t.Run("SetAndGet", func(t *testing.T) {
key := "redis-test-key"
value := []byte("redis-test-value")
ttl := 1 * time.Minute
err := backend.Set(ctx, key, value, ttl)
require.NoError(t, err)
retrieved, remainingTTL, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, value, retrieved)
assert.Greater(t, remainingTTL, 50*time.Second)
})
t.Run("GetNonExistent", func(t *testing.T) {
_, _, exists, err := backend.Get(ctx, "non-existent-redis-key")
require.NoError(t, err)
assert.False(t, exists)
})
t.Run("Delete", func(t *testing.T) {
key := "redis-delete-key"
value := []byte("redis-delete-value")
err := backend.Set(ctx, key, value, 1*time.Minute)
require.NoError(t, err)
deleted, err := backend.Delete(ctx, key)
require.NoError(t, err)
assert.True(t, deleted)
exists, err := backend.Exists(ctx, key)
require.NoError(t, err)
assert.False(t, exists)
})
t.Run("Exists", func(t *testing.T) {
key := "redis-exists-key"
value := []byte("redis-exists-value")
exists, err := backend.Exists(ctx, key)
require.NoError(t, err)
assert.False(t, exists)
err = backend.Set(ctx, key, value, 1*time.Minute)
require.NoError(t, err)
exists, err = backend.Exists(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
})
}
// TestRedisBackend_KeyPrefixing tests key namespace prefixing
func TestRedisBackend_KeyPrefixing(t *testing.T) {
t.Parallel()
mr := NewMiniredisServer(t)
config := DefaultRedisConfig(mr.GetAddr())
config.RedisPrefix = "test:prefix:"
backend, err := NewRedisBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
key := "my-key"
value := []byte("my-value")
err = backend.Set(ctx, key, value, 1*time.Minute)
require.NoError(t, err)
// Check that key is stored with prefix
keys := mr.CheckKeys()
require.Len(t, keys, 1)
assert.Equal(t, "test:prefix:my-key", keys[0])
// Get should work without prefix
retrieved, _, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, value, retrieved)
}
// TestRedisBackend_TTLExpiration tests TTL handling
func TestRedisBackend_TTLExpiration(t *testing.T) {
t.Parallel()
mr := NewMiniredisServer(t)
config := DefaultRedisConfig(mr.GetAddr())
backend, err := NewRedisBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
t.Run("ShortTTL", func(t *testing.T) {
key := "ttl-key"
value := []byte("ttl-value")
shortTTL := 100 * time.Millisecond
err := backend.Set(ctx, key, value, shortTTL)
require.NoError(t, err)
// Exists immediately
exists, err := backend.Exists(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
// Fast forward time in miniredis
mr.FastForward(150 * time.Millisecond)
// Should be expired
exists, err = backend.Exists(ctx, key)
require.NoError(t, err)
assert.False(t, exists)
})
t.Run("TTLRemaining", func(t *testing.T) {
key := "ttl-remaining-key"
value := []byte("ttl-remaining-value")
ttl := 10 * time.Second
err := backend.Set(ctx, key, value, ttl)
require.NoError(t, err)
// Get immediately
_, ttl1, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
// Fast forward 2 seconds
mr.FastForward(2 * time.Second)
// Check TTL is less
_, ttl2, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Less(t, ttl2, ttl1)
})
}
// TestRedisBackend_Clear tests clearing all keys
func TestRedisBackend_Clear(t *testing.T) {
t.Parallel()
mr := NewMiniredisServer(t)
config := DefaultRedisConfig(mr.GetAddr())
config.RedisPrefix = "clear-test:"
backend, err := NewRedisBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
// Add multiple keys
for i := 0; i < 10; i++ {
key := fmt.Sprintf("clear-key-%d", i)
value := []byte(fmt.Sprintf("clear-value-%d", i))
err := backend.Set(ctx, key, value, 1*time.Minute)
require.NoError(t, err)
}
// Verify keys exist
keys := mr.CheckKeys()
assert.Len(t, keys, 10)
// Clear all
err = backend.Clear(ctx)
require.NoError(t, err)
// Verify all keys are gone
keys = mr.CheckKeys()
assert.Len(t, keys, 0)
}
// TestRedisBackend_ConnectionFailure tests behavior on connection failure
func TestRedisBackend_ConnectionFailure(t *testing.T) {
t.Parallel()
// Try to connect to non-existent Redis
config := DefaultRedisConfig("localhost:9999")
_, err := NewRedisBackend(config)
assert.Error(t, err, "Should fail to connect to non-existent Redis")
}
// TestRedisBackend_RedisErrors tests handling of Redis errors
func TestRedisBackend_RedisErrors(t *testing.T) {
t.Parallel()
mr := NewMiniredisServer(t)
config := DefaultRedisConfig(mr.GetAddr())
backend, err := NewRedisBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
// Simulate Redis error
mr.SetError("simulated error")
// Operations should fail
err = backend.Set(ctx, "error-key", []byte("error-value"), 1*time.Minute)
assert.Error(t, err)
// Clear error
mr.ClearError()
// Operations should work again
err = backend.Set(ctx, "success-key", []byte("success-value"), 1*time.Minute)
assert.NoError(t, err)
}
// TestRedisBackend_ConcurrentAccess tests thread safety
func TestRedisBackend_ConcurrentAccess(t *testing.T) {
t.Parallel()
mr := NewMiniredisServer(t)
config := DefaultRedisConfig(mr.GetAddr())
backend, err := NewRedisBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
var wg sync.WaitGroup
goroutines := 20
iterations := 50
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
for j := 0; j < iterations; j++ {
key := fmt.Sprintf("concurrent-key-%d-%d", id, j)
value := []byte(fmt.Sprintf("concurrent-value-%d-%d", id, j))
err := backend.Set(ctx, key, value, 1*time.Minute)
assert.NoError(t, err)
retrieved, _, exists, err := backend.Get(ctx, key)
assert.NoError(t, err)
if exists {
assert.Equal(t, value, retrieved)
}
if j%5 == 0 {
backend.Delete(ctx, key)
}
}
}(i)
}
wg.Wait()
stats := backend.GetStats()
hits := stats["hits"].(int64)
misses := stats["misses"].(int64)
assert.Greater(t, hits+misses, int64(0))
}
// TestRedisBackend_Stats tests statistics tracking
func TestRedisBackend_Stats(t *testing.T) {
t.Parallel()
mr := NewMiniredisServer(t)
config := DefaultRedisConfig(mr.GetAddr())
backend, err := NewRedisBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
// Initial stats
stats := backend.GetStats()
assert.Equal(t, int64(0), stats["hits"].(int64))
assert.Equal(t, int64(0), stats["misses"].(int64))
// Add and access items
backend.Set(ctx, "key1", []byte("value1"), 1*time.Minute)
backend.Get(ctx, "key1") // Hit
backend.Get(ctx, "non-existent") // Miss
stats = backend.GetStats()
assert.Equal(t, int64(1), stats["hits"].(int64))
assert.Equal(t, int64(1), stats["misses"].(int64))
hitRate := stats["hit_rate"].(float64)
assert.InDelta(t, 0.5, hitRate, 0.01)
}
// TestRedisBackend_Ping tests health check
func TestRedisBackend_Ping(t *testing.T) {
t.Parallel()
mr := NewMiniredisServer(t)
config := DefaultRedisConfig(mr.GetAddr())
backend, err := NewRedisBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
err = backend.Ping(ctx)
assert.NoError(t, err)
// Close and ping should fail
backend.Close()
err = backend.Ping(ctx)
assert.Error(t, err)
}
// TestRedisBackend_Close tests proper cleanup
func TestRedisBackend_Close(t *testing.T) {
t.Parallel()
mr := NewMiniredisServer(t)
config := DefaultRedisConfig(mr.GetAddr())
backend, err := NewRedisBackend(config)
require.NoError(t, err)
ctx := context.Background()
// Add items
for i := 0; i < 10; i++ {
key := fmt.Sprintf("close-key-%d", i)
value := []byte(fmt.Sprintf("close-value-%d", i))
backend.Set(ctx, key, value, 1*time.Minute)
}
// Close
err = backend.Close()
require.NoError(t, err)
// Operations should fail
err = backend.Set(ctx, "after-close", []byte("value"), 1*time.Minute)
assert.Error(t, err)
assert.Equal(t, ErrBackendClosed, err)
// Double close should be safe
err = backend.Close()
assert.NoError(t, err)
}
// TestRedisBackend_UpdateExisting tests updating existing keys
func TestRedisBackend_UpdateExisting(t *testing.T) {
t.Parallel()
mr := NewMiniredisServer(t)
config := DefaultRedisConfig(mr.GetAddr())
backend, err := NewRedisBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
key := "update-key"
value1 := []byte("original-value")
value2 := []byte("updated-value")
// Set original
err = backend.Set(ctx, key, value1, 1*time.Minute)
require.NoError(t, err)
// Update
err = backend.Set(ctx, key, value2, 2*time.Minute)
require.NoError(t, err)
// Verify updated
retrieved, ttl, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, value2, retrieved)
assert.Greater(t, ttl, 1*time.Minute)
}
// TestRedisBackend_LargeValues tests handling of large values
func TestRedisBackend_LargeValues(t *testing.T) {
t.Parallel()
mr := NewMiniredisServer(t)
config := DefaultRedisConfig(mr.GetAddr())
backend, err := NewRedisBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
key := "large-key"
largeValue := make([]byte, 1024*1024) // 1MB
err = backend.Set(ctx, key, largeValue, 1*time.Minute)
require.NoError(t, err)
retrieved, _, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, len(largeValue), len(retrieved))
}
// TestRedisBackend_EmptyValues tests handling of empty values
func TestRedisBackend_EmptyValues(t *testing.T) {
t.Parallel()
mr := NewMiniredisServer(t)
config := DefaultRedisConfig(mr.GetAddr())
backend, err := NewRedisBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
key := "empty-key"
emptyValue := []byte{}
err = backend.Set(ctx, key, emptyValue, 1*time.Minute)
require.NoError(t, err)
retrieved, _, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, 0, len(retrieved))
}
// TestRedisBackend_PipelineOperations tests batch operations
func TestRedisBackend_PipelineOperations(t *testing.T) {
t.Parallel()
mr := NewMiniredisServer(t)
config := DefaultRedisConfig(mr.GetAddr())
backend, err := NewRedisBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
t.Run("SetMany", func(t *testing.T) {
items := make(map[string][]byte)
for i := 0; i < 10; i++ {
key := fmt.Sprintf("batch-key-%d", i)
value := []byte(fmt.Sprintf("batch-value-%d", i))
items[key] = value
}
err := backend.SetMany(ctx, items, 1*time.Minute)
require.NoError(t, err)
// Verify all items were set
for key, expectedValue := range items {
retrieved, _, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, expectedValue, retrieved)
}
})
t.Run("GetMany", func(t *testing.T) {
// Set test data
testData := GenerateTestData(5)
for key, value := range testData {
backend.Set(ctx, key, value, 1*time.Minute)
}
// Get all keys
keys := make([]string, 0, len(testData))
for key := range testData {
keys = append(keys, key)
}
results, err := backend.GetMany(ctx, keys)
require.NoError(t, err)
assert.Len(t, results, len(testData))
for key, expectedValue := range testData {
retrievedValue, exists := results[key]
assert.True(t, exists)
assert.Equal(t, expectedValue, retrievedValue)
}
})
t.Run("GetManyWithNonExistent", func(t *testing.T) {
keys := []string{"exists-1", "non-existent", "exists-2"}
backend.Set(ctx, "exists-1", []byte("value-1"), 1*time.Minute)
backend.Set(ctx, "exists-2", []byte("value-2"), 1*time.Minute)
results, err := backend.GetMany(ctx, keys)
require.NoError(t, err)
assert.Len(t, results, 2) // Only existing keys
assert.Equal(t, []byte("value-1"), results["exists-1"])
assert.Equal(t, []byte("value-2"), results["exists-2"])
_, exists := results["non-existent"]
assert.False(t, exists)
})
}
// TestRedisBackend_NoPrefix tests operation without prefix
func TestRedisBackend_NoPrefix(t *testing.T) {
t.Parallel()
mr := NewMiniredisServer(t)
config := DefaultRedisConfig(mr.GetAddr())
config.RedisPrefix = "" // No prefix
backend, err := NewRedisBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
key := "no-prefix-key"
value := []byte("no-prefix-value")
err = backend.Set(ctx, key, value, 1*time.Minute)
require.NoError(t, err)
// Check key is stored without prefix
keys := mr.CheckKeys()
require.Len(t, keys, 1)
assert.Equal(t, key, keys[0])
}
+251
View File
@@ -0,0 +1,251 @@
package backends
import (
"bufio"
"errors"
"fmt"
"io"
"strconv"
"strings"
"sync"
)
// RESP (REdis Serialization Protocol) implementation
// Pure Go implementation compatible with Yaegi interpreter (no unsafe package)
var (
ErrInvalidRESP = errors.New("invalid RESP response")
ErrNilResponse = errors.New("nil response")
)
// Object pools for memory optimization - reduces allocations by 50-70%
var (
readerPool = sync.Pool{
New: func() interface{} {
return &RESPReader{
r: bufio.NewReaderSize(nil, 4096),
}
},
}
writerPool = sync.Pool{
New: func() interface{} {
return &RESPWriter{
w: nil,
}
},
}
)
// RESPWriter writes RESP protocol messages
type RESPWriter struct {
w io.Writer
}
// NewRESPWriter creates a new RESP writer from the pool (memory optimized)
func NewRESPWriter(w io.Writer) *RESPWriter {
writer := writerPool.Get().(*RESPWriter)
writer.w = w
return writer
}
// Release returns the writer to the pool for reuse
func (w *RESPWriter) Release() {
w.w = nil
writerPool.Put(w)
}
// WriteCommand writes a Redis command in RESP array format
// Example: SET key value EX 3600 -> *5\r\n$3\r\nSET\r\n$3\r\nkey\r\n$5\r\nvalue\r\n$2\r\nEX\r\n$4\r\n3600\r\n
func (w *RESPWriter) WriteCommand(args ...string) error {
// Write array header
if _, err := fmt.Fprintf(w.w, "*%d\r\n", len(args)); err != nil {
return err
}
// Write each argument as bulk string
for _, arg := range args {
if _, err := fmt.Fprintf(w.w, "$%d\r\n%s\r\n", len(arg), arg); err != nil {
return err
}
}
return nil
}
// RESPReader reads RESP protocol messages
type RESPReader struct {
r *bufio.Reader
}
// NewRESPReader creates a new RESP reader from the pool (memory optimized)
func NewRESPReader(r io.Reader) *RESPReader {
reader := readerPool.Get().(*RESPReader)
reader.r.Reset(r)
return reader
}
// Release returns the reader to the pool for reuse
func (r *RESPReader) Release() {
r.r.Reset(nil)
readerPool.Put(r)
}
// ReadResponse reads a RESP response and returns the parsed value
func (r *RESPReader) ReadResponse() (interface{}, error) {
typeByte, err := r.r.ReadByte()
if err != nil {
return nil, err
}
switch typeByte {
case '+': // Simple string
return r.readSimpleString()
case '-': // Error
return nil, r.readError()
case ':': // Integer
return r.readInteger()
case '$': // Bulk string
return r.readBulkString()
case '*': // Array
return r.readArray()
default:
return nil, fmt.Errorf("%w: unknown type byte '%c'", ErrInvalidRESP, typeByte)
}
}
// readSimpleString reads a simple string (+OK\r\n)
func (r *RESPReader) readSimpleString() (string, error) {
line, err := r.readLine()
if err != nil {
return "", err
}
return line, nil
}
// readError reads an error message (-Error message\r\n)
func (r *RESPReader) readError() error {
line, err := r.readLine()
if err != nil {
return err
}
return errors.New(line)
}
// readInteger reads an integer (:1000\r\n)
func (r *RESPReader) readInteger() (int64, error) {
line, err := r.readLine()
if err != nil {
return 0, err
}
return strconv.ParseInt(line, 10, 64)
}
// readBulkString reads a bulk string ($6\r\nfoobar\r\n or $-1\r\n for nil)
func (r *RESPReader) readBulkString() (interface{}, error) {
line, err := r.readLine()
if err != nil {
return nil, err
}
length, err := strconv.Atoi(line)
if err != nil {
return nil, fmt.Errorf("%w: invalid bulk string length", ErrInvalidRESP)
}
// -1 indicates nil bulk string
if length == -1 {
return nil, ErrNilResponse
}
// Read exactly 'length' bytes plus \r\n
buf := make([]byte, length+2)
if _, err := io.ReadFull(r.r, buf); err != nil {
return nil, err
}
// Verify \r\n terminator
if buf[length] != '\r' || buf[length+1] != '\n' {
return nil, fmt.Errorf("%w: missing CRLF after bulk string", ErrInvalidRESP)
}
return string(buf[:length]), nil
}
// readArray reads an array (*2\r\n...\r\n or *-1\r\n for nil)
func (r *RESPReader) readArray() (interface{}, error) {
line, err := r.readLine()
if err != nil {
return nil, err
}
length, err := strconv.Atoi(line)
if err != nil {
return nil, fmt.Errorf("%w: invalid array length", ErrInvalidRESP)
}
// -1 indicates nil array
if length == -1 {
return nil, ErrNilResponse
}
// Read each element
result := make([]interface{}, length)
for i := 0; i < length; i++ {
elem, err := r.ReadResponse()
if err != nil {
return nil, err
}
result[i] = elem
}
return result, nil
}
// readLine reads a line terminated by \r\n
func (r *RESPReader) readLine() (string, error) {
line, err := r.r.ReadString('\n')
if err != nil {
return "", err
}
// Remove \r\n
line = strings.TrimSuffix(line, "\r\n")
if !strings.HasSuffix(line+"\r\n", "\r\n") {
return "", fmt.Errorf("%w: missing CRLF", ErrInvalidRESP)
}
return line, nil
}
// RESPString extracts a string from RESP response
func RESPString(resp interface{}) (string, error) {
if resp == nil {
return "", ErrNilResponse
}
switch v := resp.(type) {
case string:
return v, nil
case []byte:
return string(v), nil
default:
return "", fmt.Errorf("expected string, got %T", resp)
}
}
// RESPInt extracts an integer from RESP response
func RESPInt(resp interface{}) (int64, error) {
if resp == nil {
return 0, ErrNilResponse
}
switch v := resp.(type) {
case int64:
return v, nil
case int:
return int64(v), nil
default:
return 0, fmt.Errorf("expected integer, got %T", resp)
}
}
+495
View File
@@ -0,0 +1,495 @@
package backends
import (
"bytes"
"errors"
"io"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestRESPWriter_WriteCommand tests RESP command writing
func TestRESPWriter_WriteCommand(t *testing.T) {
tests := []struct {
name string
expected string
args []string
}{
{
name: "Simple command",
args: []string{"PING"},
expected: "*1\r\n$4\r\nPING\r\n",
},
{
name: "SET command",
args: []string{"SET", "key", "value"},
expected: "*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$5\r\nvalue\r\n",
},
{
name: "SETEX command",
args: []string{"SETEX", "mykey", "60", "myvalue"},
expected: "*4\r\n$5\r\nSETEX\r\n$5\r\nmykey\r\n$2\r\n60\r\n$7\r\nmyvalue\r\n",
},
{
name: "DEL with multiple keys",
args: []string{"DEL", "key1", "key2", "key3"},
expected: "*4\r\n$3\r\nDEL\r\n$4\r\nkey1\r\n$4\r\nkey2\r\n$4\r\nkey3\r\n",
},
{
name: "Command with empty string",
args: []string{"SET", "key", ""},
expected: "*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$0\r\n\r\n",
},
{
name: "Command with special characters",
args: []string{"SET", "key", "val\r\nue"},
expected: "*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$7\r\nval\r\nue\r\n",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
buf := &bytes.Buffer{}
writer := NewRESPWriter(buf)
err := writer.WriteCommand(tt.args...)
require.NoError(t, err)
assert.Equal(t, tt.expected, buf.String())
})
}
}
// TestRESPReader_ReadSimpleString tests reading simple strings
func TestRESPReader_ReadSimpleString(t *testing.T) {
tests := []struct {
name string
input string
expected string
wantErr bool
}{
{
name: "OK response",
input: "+OK\r\n",
expected: "OK",
wantErr: false,
},
{
name: "PONG response",
input: "+PONG\r\n",
expected: "PONG",
wantErr: false,
},
{
name: "Empty string",
input: "+\r\n",
expected: "",
wantErr: false,
},
{
name: "String with spaces",
input: "+Hello World\r\n",
expected: "Hello World",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
reader := NewRESPReader(strings.NewReader(tt.input))
result, err := reader.ReadResponse()
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, tt.expected, result)
})
}
}
// TestRESPReader_ReadError tests reading error messages
func TestRESPReader_ReadError(t *testing.T) {
tests := []struct {
name string
input string
expectedError string
}{
{
name: "ERR error",
input: "-ERR unknown command\r\n",
expectedError: "ERR unknown command",
},
{
name: "WRONGTYPE error",
input: "-WRONGTYPE Operation against a key holding the wrong kind of value\r\n",
expectedError: "WRONGTYPE Operation against a key holding the wrong kind of value",
},
{
name: "Simple error",
input: "-Error\r\n",
expectedError: "Error",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
reader := NewRESPReader(strings.NewReader(tt.input))
_, err := reader.ReadResponse()
require.Error(t, err)
assert.Equal(t, tt.expectedError, err.Error())
})
}
}
// TestRESPReader_ReadInteger tests reading integers
func TestRESPReader_ReadInteger(t *testing.T) {
tests := []struct {
name string
input string
expected int64
wantErr bool
}{
{
name: "Zero",
input: ":0\r\n",
expected: 0,
wantErr: false,
},
{
name: "Positive integer",
input: ":1000\r\n",
expected: 1000,
wantErr: false,
},
{
name: "Negative integer",
input: ":-1\r\n",
expected: -1,
wantErr: false,
},
{
name: "Large integer",
input: ":9223372036854775807\r\n",
expected: 9223372036854775807,
wantErr: false,
},
{
name: "Invalid integer",
input: ":abc\r\n",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
reader := NewRESPReader(strings.NewReader(tt.input))
result, err := reader.ReadResponse()
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, tt.expected, result)
})
}
}
// TestRESPReader_ReadBulkString tests reading bulk strings
func TestRESPReader_ReadBulkString(t *testing.T) {
tests := []struct {
expected interface{}
name string
input string
wantErr bool
isNil bool
}{
{
name: "Simple bulk string",
input: "$6\r\nfoobar\r\n",
expected: "foobar",
wantErr: false,
},
{
name: "Empty bulk string",
input: "$0\r\n\r\n",
expected: "",
wantErr: false,
},
{
name: "Nil bulk string",
input: "$-1\r\n",
expected: nil,
wantErr: true,
isNil: true,
},
{
name: "Binary safe bulk string",
input: "$5\r\n\x00\x01\x02\x03\x04\r\n",
expected: "\x00\x01\x02\x03\x04",
wantErr: false,
},
{
name: "Invalid length",
input: "$abc\r\ntest\r\n",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
reader := NewRESPReader(strings.NewReader(tt.input))
result, err := reader.ReadResponse()
if tt.isNil {
require.Error(t, err)
assert.True(t, errors.Is(err, ErrNilResponse))
return
}
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, tt.expected, result)
})
}
}
// TestRESPReader_ReadArray tests reading arrays
func TestRESPReader_ReadArray(t *testing.T) {
tests := []struct {
name string
input string
expected []interface{}
wantErr bool
isNil bool
}{
{
name: "Empty array",
input: "*0\r\n",
expected: []interface{}{},
wantErr: false,
},
{
name: "Array of bulk strings",
input: "*2\r\n$3\r\nfoo\r\n$3\r\nbar\r\n",
expected: []interface{}{
"foo",
"bar",
},
wantErr: false,
},
{
name: "Array of integers",
input: "*3\r\n:1\r\n:2\r\n:3\r\n",
expected: []interface{}{
int64(1),
int64(2),
int64(3),
},
wantErr: false,
},
{
name: "Mixed array",
input: "*5\r\n:1\r\n:2\r\n:3\r\n:4\r\n$6\r\nfoobar\r\n",
expected: []interface{}{
int64(1),
int64(2),
int64(3),
int64(4),
"foobar",
},
wantErr: false,
},
{
name: "Nil array",
input: "*-1\r\n",
expected: nil,
wantErr: true,
isNil: true,
},
{
name: "Nested arrays",
input: "*2\r\n*2\r\n$3\r\nfoo\r\n$3\r\nbar\r\n*1\r\n$3\r\nbaz\r\n",
expected: []interface{}{
[]interface{}{"foo", "bar"},
[]interface{}{"baz"},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
reader := NewRESPReader(strings.NewReader(tt.input))
result, err := reader.ReadResponse()
if tt.isNil {
require.Error(t, err)
assert.True(t, errors.Is(err, ErrNilResponse))
return
}
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, tt.expected, result)
})
}
}
// TestRESPReader_InvalidInput tests error handling for invalid input
func TestRESPReader_InvalidInput(t *testing.T) {
tests := []struct {
name string
input string
}{
{
name: "Unknown type byte",
input: "?invalid\r\n",
},
{
name: "Incomplete response",
input: "+OK",
},
{
name: "Missing CRLF in bulk string",
input: "$5\r\nhello",
},
{
name: "Truncated array",
input: "*3\r\n:1\r\n:2\r\n",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
reader := NewRESPReader(strings.NewReader(tt.input))
_, err := reader.ReadResponse()
require.Error(t, err)
})
}
}
// TestRESPReader_EOF tests handling of EOF
func TestRESPReader_EOF(t *testing.T) {
reader := NewRESPReader(strings.NewReader(""))
_, err := reader.ReadResponse()
require.Error(t, err)
assert.True(t, errors.Is(err, io.EOF))
}
// TestRESPHelpers tests helper functions
func TestRESPHelpers(t *testing.T) {
t.Run("RESPString", func(t *testing.T) {
// Valid string
result, err := RESPString("hello")
require.NoError(t, err)
assert.Equal(t, "hello", result)
// Byte slice
result, err = RESPString([]byte("world"))
require.NoError(t, err)
assert.Equal(t, "world", result)
// Nil
_, err = RESPString(nil)
require.Error(t, err)
assert.True(t, errors.Is(err, ErrNilResponse))
// Invalid type
_, err = RESPString(123)
require.Error(t, err)
})
t.Run("RESPInt", func(t *testing.T) {
// Valid int64
result, err := RESPInt(int64(42))
require.NoError(t, err)
assert.Equal(t, int64(42), result)
// Valid int
result, err = RESPInt(42)
require.NoError(t, err)
assert.Equal(t, int64(42), result)
// Nil
_, err = RESPInt(nil)
require.Error(t, err)
assert.True(t, errors.Is(err, ErrNilResponse))
// Invalid type
_, err = RESPInt("string")
require.Error(t, err)
})
}
// TestRESPRoundTrip tests full round-trip encoding/decoding
func TestRESPRoundTrip(t *testing.T) {
tests := []struct {
expected interface{}
name string
response string
command []string
}{
{
name: "PING command",
command: []string{"PING"},
response: "+PONG\r\n",
expected: "PONG",
},
{
name: "GET command with result",
command: []string{"GET", "mykey"},
response: "$7\r\nmyvalue\r\n",
expected: "myvalue",
},
{
name: "GET command with nil",
command: []string{"GET", "nonexistent"},
response: "$-1\r\n",
expected: nil,
},
{
name: "DEL command",
command: []string{"DEL", "key1", "key2"},
response: ":2\r\n",
expected: int64(2),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Write command
writeBuf := &bytes.Buffer{}
writer := NewRESPWriter(writeBuf)
err := writer.WriteCommand(tt.command...)
require.NoError(t, err)
// Read response
reader := NewRESPReader(strings.NewReader(tt.response))
result, err := reader.ReadResponse()
if tt.expected == nil {
require.Error(t, err)
assert.True(t, errors.Is(err, ErrNilResponse))
} else {
require.NoError(t, err)
assert.Equal(t, tt.expected, result)
}
})
}
}
+183
View File
@@ -0,0 +1,183 @@
package backends
import (
"context"
"sync"
"sync/atomic"
"time"
)
// SingleflightCache wraps a CacheBackend with singleflight deduplication
// to prevent thundering herd problems when multiple concurrent requests
// try to fetch the same uncached key.
type SingleflightCache struct {
backend CacheBackend
mu sync.Mutex
calls map[string]*singleflightCall
// Metrics
deduplicatedCalls atomic.Int64
totalCalls atomic.Int64
}
// singleflightCall represents an in-flight or completed fetch call
type singleflightCall struct {
wg sync.WaitGroup
val []byte
ttl time.Duration
err error
done bool
}
// NewSingleflightCache creates a new singleflight-wrapped cache backend
func NewSingleflightCache(backend CacheBackend) *SingleflightCache {
return &SingleflightCache{
backend: backend,
calls: make(map[string]*singleflightCall),
}
}
// Fetcher is a function type that fetches data when cache misses
type Fetcher func(ctx context.Context) (value []byte, ttl time.Duration, err error)
// GetOrFetch retrieves a value from cache or calls the fetcher exactly once
// per key when there's a cache miss. Concurrent calls for the same key will
// wait for the first call to complete and share its result.
func (s *SingleflightCache) GetOrFetch(ctx context.Context, key string, fetcher Fetcher) ([]byte, error) {
s.totalCalls.Add(1)
// Try cache first
value, _, exists, err := s.backend.Get(ctx, key)
if err != nil {
return nil, err
}
if exists {
return value, nil
}
// Cache miss - use singleflight
s.mu.Lock()
// Check if there's already an in-flight call for this key
if call, ok := s.calls[key]; ok {
s.mu.Unlock()
s.deduplicatedCalls.Add(1)
// Wait for the in-flight call to complete
call.wg.Wait()
// Check context cancellation
if ctx.Err() != nil {
return nil, ctx.Err()
}
return call.val, call.err
}
// Create new call
call := &singleflightCall{}
call.wg.Add(1)
s.calls[key] = call
s.mu.Unlock()
// Execute the fetcher
call.val, call.ttl, call.err = fetcher(ctx)
call.done = true
// If successful, store in cache
if call.err == nil && call.val != nil {
// Use a background context for cache storage to ensure it completes
// even if the original context is cancelled
storeCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
_ = s.backend.Set(storeCtx, key, call.val, call.ttl)
cancel()
}
// Signal waiting goroutines
call.wg.Done()
// Clean up the call from the map after a short delay
// This allows late arrivals to still benefit from the result
go func() {
time.Sleep(100 * time.Millisecond)
s.mu.Lock()
if c, ok := s.calls[key]; ok && c == call {
delete(s.calls, key)
}
s.mu.Unlock()
}()
return call.val, call.err
}
// Get retrieves a value from the underlying cache backend
func (s *SingleflightCache) Get(ctx context.Context, key string) ([]byte, time.Duration, bool, error) {
return s.backend.Get(ctx, key)
}
// Set stores a value in the underlying cache backend
func (s *SingleflightCache) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
return s.backend.Set(ctx, key, value, ttl)
}
// Delete removes a key from the underlying cache backend
func (s *SingleflightCache) Delete(ctx context.Context, key string) (bool, error) {
return s.backend.Delete(ctx, key)
}
// Exists checks if a key exists in the underlying cache backend
func (s *SingleflightCache) Exists(ctx context.Context, key string) (bool, error) {
return s.backend.Exists(ctx, key)
}
// Clear removes all keys from the underlying cache backend
func (s *SingleflightCache) Clear(ctx context.Context) error {
return s.backend.Clear(ctx)
}
// GetStats returns cache statistics including singleflight metrics
func (s *SingleflightCache) GetStats() map[string]interface{} {
stats := s.backend.GetStats()
// Add singleflight-specific stats
totalCalls := s.totalCalls.Load()
deduped := s.deduplicatedCalls.Load()
stats["singleflight_total_calls"] = totalCalls
stats["singleflight_deduplicated"] = deduped
if totalCalls > 0 {
stats["singleflight_dedup_rate"] = float64(deduped) / float64(totalCalls)
} else {
stats["singleflight_dedup_rate"] = float64(0)
}
s.mu.Lock()
stats["singleflight_inflight"] = len(s.calls)
s.mu.Unlock()
return stats
}
// Close shuts down the cache backend
func (s *SingleflightCache) Close() error {
return s.backend.Close()
}
// Ping checks if the backend is healthy
func (s *SingleflightCache) Ping(ctx context.Context) error {
return s.backend.Ping(ctx)
}
// GetBackend returns the underlying cache backend
func (s *SingleflightCache) GetBackend() CacheBackend {
return s.backend
}
// ResetStats resets the singleflight statistics
func (s *SingleflightCache) ResetStats() {
s.totalCalls.Store(0)
s.deduplicatedCalls.Store(0)
}
// Ensure SingleflightCache implements CacheBackend
var _ CacheBackend = (*SingleflightCache)(nil)
+510
View File
@@ -0,0 +1,510 @@
package backends
import (
"context"
"errors"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestSingleflightCache_BasicGetOrFetch tests basic GetOrFetch functionality
func TestSingleflightCache_BasicGetOrFetch(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
cache := NewSingleflightCache(backend)
ctx := context.Background()
t.Run("CacheHit", func(t *testing.T) {
key := "existing-key"
value := []byte("existing-value")
// Pre-populate cache
err := cache.Set(ctx, key, value, time.Minute)
require.NoError(t, err)
var fetchCalled bool
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
fetchCalled = true
return []byte("fetched-value"), time.Minute, nil
}
result, err := cache.GetOrFetch(ctx, key, fetcher)
require.NoError(t, err)
assert.Equal(t, value, result)
assert.False(t, fetchCalled, "Fetcher should not be called on cache hit")
})
t.Run("CacheMiss", func(t *testing.T) {
key := "missing-key"
expectedValue := []byte("fetched-value")
var fetchCalled bool
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
fetchCalled = true
return expectedValue, time.Minute, nil
}
result, err := cache.GetOrFetch(ctx, key, fetcher)
require.NoError(t, err)
assert.Equal(t, expectedValue, result)
assert.True(t, fetchCalled, "Fetcher should be called on cache miss")
// Verify value was stored in cache
cached, _, exists, err := cache.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, expectedValue, cached)
})
t.Run("FetcherError", func(t *testing.T) {
key := "error-key"
expectedErr := errors.New("fetch failed")
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
return nil, 0, expectedErr
}
result, err := cache.GetOrFetch(ctx, key, fetcher)
assert.Error(t, err)
assert.Equal(t, expectedErr, err)
assert.Nil(t, result)
// Verify nothing was stored in cache
_, _, exists, err := cache.Get(ctx, key)
require.NoError(t, err)
assert.False(t, exists)
})
}
// TestSingleflightCache_Deduplication tests that concurrent calls are deduplicated
func TestSingleflightCache_Deduplication(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
cache := NewSingleflightCache(backend)
ctx := context.Background()
key := "dedup-key"
expectedValue := []byte("dedup-value")
var fetchCount atomic.Int32
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
fetchCount.Add(1)
// Simulate slow fetch
time.Sleep(100 * time.Millisecond)
return expectedValue, time.Minute, nil
}
// Launch multiple concurrent requests
concurrency := 10
var wg sync.WaitGroup
results := make([][]byte, concurrency)
errs := make([]error, concurrency)
for i := 0; i < concurrency; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
results[idx], errs[idx] = cache.GetOrFetch(ctx, key, fetcher)
}(i)
}
wg.Wait()
// Verify all requests got the same result
for i := 0; i < concurrency; i++ {
assert.NoError(t, errs[i])
assert.Equal(t, expectedValue, results[i])
}
// Verify fetcher was only called once
assert.Equal(t, int32(1), fetchCount.Load(), "Fetcher should only be called once")
// Verify deduplication stats
stats := cache.GetStats()
deduped := stats["singleflight_deduplicated"].(int64)
assert.Equal(t, int64(concurrency-1), deduped, "Should have deduplicated N-1 calls")
}
// TestSingleflightCache_DifferentKeys tests that different keys can fetch in parallel
func TestSingleflightCache_DifferentKeys(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
cache := NewSingleflightCache(backend)
ctx := context.Background()
var fetchCount atomic.Int32
fetchStarted := make(chan struct{}, 3)
fetchComplete := make(chan struct{})
fetcher := func(key string) Fetcher {
return func(ctx context.Context) ([]byte, time.Duration, error) {
fetchCount.Add(1)
fetchStarted <- struct{}{}
<-fetchComplete // Wait for signal
return []byte("value-" + key), time.Minute, nil
}
}
// Launch concurrent requests for different keys
var wg sync.WaitGroup
for i := 0; i < 3; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
key := fmt.Sprintf("key-%d", idx)
_, _ = cache.GetOrFetch(ctx, key, fetcher(key))
}(i)
}
// Wait for all fetches to start
for i := 0; i < 3; i++ {
<-fetchStarted
}
// All 3 fetches should be running in parallel
assert.Equal(t, int32(3), fetchCount.Load(), "All three fetches should run in parallel")
// Release all fetches
close(fetchComplete)
wg.Wait()
}
// TestSingleflightCache_ContextCancellation tests context cancellation
func TestSingleflightCache_ContextCancellation(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
cache := NewSingleflightCache(backend)
key := "cancel-key"
fetchStarted := make(chan struct{})
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
close(fetchStarted)
// Simulate slow fetch
time.Sleep(500 * time.Millisecond)
return []byte("value"), time.Minute, nil
}
// Start first request with long timeout
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
ctx := context.Background()
_, _ = cache.GetOrFetch(ctx, key, fetcher)
}()
// Wait for fetch to start
<-fetchStarted
// Start second request with short timeout
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()
_, err = cache.GetOrFetch(ctx, key, fetcher)
assert.Error(t, err)
assert.Equal(t, context.DeadlineExceeded, err)
wg.Wait()
}
// TestSingleflightCache_ErrorPropagation tests that errors are properly propagated
func TestSingleflightCache_ErrorPropagation(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
cache := NewSingleflightCache(backend)
ctx := context.Background()
key := "error-prop-key"
expectedErr := errors.New("intentional error")
var fetchCount atomic.Int32
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
fetchCount.Add(1)
time.Sleep(50 * time.Millisecond)
return nil, 0, expectedErr
}
// Launch multiple concurrent requests
concurrency := 5
var wg sync.WaitGroup
errs := make([]error, concurrency)
for i := 0; i < concurrency; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
_, errs[idx] = cache.GetOrFetch(ctx, key, fetcher)
}(i)
}
wg.Wait()
// Verify all requests got the same error
for i := 0; i < concurrency; i++ {
assert.Error(t, errs[i])
assert.Equal(t, expectedErr, errs[i])
}
// Verify fetcher was only called once
assert.Equal(t, int32(1), fetchCount.Load())
}
// TestSingleflightCache_PassthroughMethods tests that passthrough methods work
func TestSingleflightCache_PassthroughMethods(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
cache := NewSingleflightCache(backend)
ctx := context.Background()
t.Run("Set", func(t *testing.T) {
err := cache.Set(ctx, "set-key", []byte("set-value"), time.Minute)
require.NoError(t, err)
val, _, exists, err := cache.Get(ctx, "set-key")
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, []byte("set-value"), val)
})
t.Run("Get", func(t *testing.T) {
err := cache.Set(ctx, "get-key", []byte("get-value"), time.Minute)
require.NoError(t, err)
val, ttl, exists, err := cache.Get(ctx, "get-key")
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, []byte("get-value"), val)
assert.Greater(t, ttl, time.Duration(0))
})
t.Run("Delete", func(t *testing.T) {
err := cache.Set(ctx, "delete-key", []byte("delete-value"), time.Minute)
require.NoError(t, err)
deleted, err := cache.Delete(ctx, "delete-key")
require.NoError(t, err)
assert.True(t, deleted)
exists, err := cache.Exists(ctx, "delete-key")
require.NoError(t, err)
assert.False(t, exists)
})
t.Run("Exists", func(t *testing.T) {
exists, err := cache.Exists(ctx, "nonexistent")
require.NoError(t, err)
assert.False(t, exists)
err = cache.Set(ctx, "exists-key", []byte("value"), time.Minute)
require.NoError(t, err)
exists, err = cache.Exists(ctx, "exists-key")
require.NoError(t, err)
assert.True(t, exists)
})
t.Run("Clear", func(t *testing.T) {
err := cache.Set(ctx, "clear-key", []byte("value"), time.Minute)
require.NoError(t, err)
err = cache.Clear(ctx)
require.NoError(t, err)
exists, err := cache.Exists(ctx, "clear-key")
require.NoError(t, err)
assert.False(t, exists)
})
t.Run("Ping", func(t *testing.T) {
err := cache.Ping(ctx)
require.NoError(t, err)
})
}
// TestSingleflightCache_Stats tests statistics tracking
func TestSingleflightCache_Stats(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
cache := NewSingleflightCache(backend)
ctx := context.Background()
// Make some calls
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
time.Sleep(50 * time.Millisecond)
return []byte("value"), time.Minute, nil
}
var wg sync.WaitGroup
for i := 0; i < 5; i++ {
wg.Add(1)
go func() {
defer wg.Done()
_, _ = cache.GetOrFetch(ctx, "stats-key", fetcher)
}()
}
wg.Wait()
stats := cache.GetStats()
// Check singleflight stats exist
assert.Contains(t, stats, "singleflight_total_calls")
assert.Contains(t, stats, "singleflight_deduplicated")
assert.Contains(t, stats, "singleflight_dedup_rate")
assert.Contains(t, stats, "singleflight_inflight")
// Verify values
assert.Equal(t, int64(5), stats["singleflight_total_calls"])
assert.Equal(t, int64(4), stats["singleflight_deduplicated"])
// Also check underlying backend stats are included
assert.Contains(t, stats, "hits")
assert.Contains(t, stats, "misses")
}
// TestSingleflightCache_ResetStats tests stats reset
func TestSingleflightCache_ResetStats(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
cache := NewSingleflightCache(backend)
ctx := context.Background()
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
return []byte("value"), time.Minute, nil
}
// Make some calls
_, _ = cache.GetOrFetch(ctx, "key1", fetcher)
_, _ = cache.GetOrFetch(ctx, "key2", fetcher)
stats := cache.GetStats()
assert.Greater(t, stats["singleflight_total_calls"].(int64), int64(0))
// Reset stats
cache.ResetStats()
stats = cache.GetStats()
assert.Equal(t, int64(0), stats["singleflight_total_calls"])
assert.Equal(t, int64(0), stats["singleflight_deduplicated"])
}
// TestSingleflightCache_GetBackend tests GetBackend method
func TestSingleflightCache_GetBackend(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
cache := NewSingleflightCache(backend)
assert.Equal(t, backend, cache.GetBackend())
}
// BenchmarkSingleflightCache_Sequential benchmarks sequential access
func BenchmarkSingleflightCache_Sequential(b *testing.B) {
backend, _ := NewMemoryBackend(DefaultConfig())
defer backend.Close()
cache := NewSingleflightCache(backend)
ctx := context.Background()
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
return []byte("benchmark-value"), time.Minute, nil
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
key := fmt.Sprintf("key-%d", i%100)
_, _ = cache.GetOrFetch(ctx, key, fetcher)
}
}
// BenchmarkSingleflightCache_Concurrent benchmarks concurrent access
func BenchmarkSingleflightCache_Concurrent(b *testing.B) {
backend, _ := NewMemoryBackend(DefaultConfig())
defer backend.Close()
cache := NewSingleflightCache(backend)
ctx := context.Background()
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
time.Sleep(time.Millisecond) // Simulate slow fetch
return []byte("benchmark-value"), time.Minute, nil
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
key := fmt.Sprintf("key-%d", i%10) // Only 10 unique keys to force deduplication
_, _ = cache.GetOrFetch(ctx, key, fetcher)
i++
}
})
}
// BenchmarkSingleflightCache_HighContention benchmarks high contention scenario
func BenchmarkSingleflightCache_HighContention(b *testing.B) {
backend, _ := NewMemoryBackend(DefaultConfig())
defer backend.Close()
cache := NewSingleflightCache(backend)
ctx := context.Background()
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
time.Sleep(10 * time.Millisecond) // Slow fetch to force queuing
return []byte("benchmark-value"), time.Minute, nil
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
// All goroutines hit the same key
_, _ = cache.GetOrFetch(ctx, "hot-key", fetcher)
}
})
}

Some files were not shown because too many files have changed in this diff Show More