Compare commits

..

20 Commits

Author SHA1 Message Date
lukaszraczylo 6efb78b7a8 Smarter approach to the cookies (#103)
* Smarter approach to the cookies

  - Single maxCookieSize = 1400 constant with clear documentation
  - Combined cookie storage for ~40-45% size reduction
  - Backward compatible migration from legacy cookies

* Tuneup the code.
2025-12-12 18:35:06 +00:00
lukaszraczylo d0b920c4f0 multiple realms fix (#102)
* Allow to use multiple realms

This change is a ressurection of PR #88 which can't be merged due to significant refactor of the codebase.

* Fix the autocleanup routine to handle multiple realms correctly, update tests.

* Metadata rediscovery when provider is unavailable for any reason during the start.

This one prevents the permanent 503 from the plugin when OIDC provider was for some reason unavailable during the start.
2025-12-10 13:07:22 +00:00
lukaszraczylo c474bbafd6 Cleanup [dec2025] (#101)
* Cleanup excessive comments.

* Remove leftovers hanging around from previous refactor

* Improve test coverage
2025-12-09 01:38:02 +00:00
lukaszraczylo 9126c74723 December 2025 Improvements - Azure AD, Internal Networks, Startup Race Condition (#100)
* Allow internal IPs for OIDC configuration via extra flag.

Addresses issue #97

* Allow for internal IPs in OIDC configuration.

Addresses issue #97.

* feat: Add allowPrivateIPAddresses config option for internal networks

Adds a new configuration option `allowPrivateIPAddresses` that allows
OIDC provider URLs to use private IP addresses (10.x.x.x, 172.16-31.x.x,
192.168.x.x). This is useful for internal deployments where Keycloak or
other OIDC providers run on private networks without DNS resolution.

Security considerations:
- Loopback addresses (127.0.0.1, localhost, ::1) remain blocked
- Link-local addresses (169.254.x.x) remain blocked
- Default is false (secure by default)

Fixes #97

* feat: Support non-email user identifiers for Azure AD

Add userIdentifierClaim configuration option to support Azure AD users
without email addresses. This allows using alternative JWT claims like
"sub", "oid", "upn", or "preferred_username" for user identification.

- Default behavior uses "email" claim (backward compatible)
- Falls back to "sub" claim if configured claim is missing
- allowedUsers matches against the configured claim value
- allowedUserDomains only applies when using email-based identification

Fixes #95

* Race condition on traefik pod startup

When the plugin initializes and calls GetMetadataWithRecovery():

1. Checks cache first (if metadata is cached, returns immediately)
2. Creates a retry executor with startup-optimized settings (10 attempts, 1s delays)
3. Attempts to fetch metadata from the OIDC provider
4. If the fetch fails with a retryable error (connection refused, EOF, TLS/certificate errors, Traefik default cert), it waits and retries
5. After 10 attempts or on a non-retryable error, returns the error

This allows the plugin to handle the race condition where:
- Traefik initializes the plugin before routes are established
- Traefik serves its default certificate before loading real ones
- The OIDC provider pod isn't fully ready yet

Fixes issue #90

* Race condition on traefik pod startup

When the plugin initializes and calls GetMetadataWithRecovery():

1. Checks cache first (if metadata is cached, returns immediately)
2. Creates a retry executor with startup-optimized settings (10 attempts, 1s delays)
3. Attempts to fetch metadata from the OIDC provider
4. If the fetch fails with a retryable error (connection refused, EOF, TLS/certificate errors, Traefik default cert), it waits and retries
5. After 10 attempts or on a non-retryable error, returns the error

This allows the plugin to handle the race condition where:
- Traefik initializes the plugin before routes are established
- Traefik serves its default certificate before loading real ones
- The OIDC provider pod isn't fully ready yet

Fixes issue #90

* Headers too big and 431 responses

Added new option `minimalHeaders` to reduce the size of forwarded headers from the auth middleware to backend services.

  - When minimalHeaders: false (default): All headers are forwarded as before
    - X-Forwarded-User (always set)
    - X-Auth-Request-Redirect
    - X-Auth-Request-User
    - X-Auth-Request-Token (the large ID token)
    - X-User-Groups, X-User-Roles (if configured)
  - When minimalHeaders: true: Reduces header overhead
    - X-Forwarded-User (always set)
    - X-User-Groups, X-User-Roles (still forwarded if configured)
    - Custom templated headers (still processed)
    - Skipped: X-Auth-Request-Token, X-Auth-Request-User, X-Auth-Request-Redirect

Fixes issues #64 and #86
2025-12-08 14:21:17 +00:00
lukaszraczylo a750c4f5b9 Size computation for allocation may overflow (#99)
* Size computation for allocation may overflow

Performing calculations involving the size of potentially large strings or slices can result in an overflow (for signed integer types) or a wraparound (for unsigned types). An overflow causes the result of the calculation to become negative, while a wraparound results in a small (positive) number.
2025-12-08 11:22:28 +00:00
lukaszraczylo 56051779ee Hotfix: goreleaser archive format. 2025-12-08 02:39:40 +00:00
lukaszraczylo 3f126d50f3 Force the v in the release tags and name. 2025-12-08 02:34:10 +00:00
lukaszraczylo 91f0fc9ab8 Switch to go releaser 2025-12-08 02:32:46 +00:00
lukaszraczylo 66b9ed0861 Reauthentication + redis fix
When introspection explicitly returns that a token is inactive/revoked/expired, the plugin now properly triggers re-authentication or refresh instead of falling back to ID token validation. This fixes the functional issue where users
weren't being redirected to re-authenticate.
Redis change ensures that when the caller's context is cancelled (e.g., the 200ms timeout in UniversalCache.Get()), the operation aborts quickly instead of continuing with retries.
2025-12-01 13:47:28 +00:00
lukaszraczylo e64fc7f730 Add redis support for distributed caching (#83)
* Add redis support for distributed caching

* Move towards the self-provided Redis connection pool and RESP protocol implementation.
Official redis client library won't work with yaegi.

* fixup! Move towards the self-provided Redis connection pool and RESP protocol implementation. Official redis client library won't work with yaegi.

* fixup! fixup! Move towards the self-provided Redis connection pool and RESP protocol implementation. Official redis client library won't work with yaegi.

* fixup! fixup! fixup! Move towards the self-provided Redis connection pool and RESP protocol implementation. Official redis client library won't work with yaegi.

* fixup! fixup! fixup! fixup! Move towards the self-provided Redis connection pool and RESP protocol implementation. Official redis client library won't work with yaegi.

* fixup! fixup! fixup! fixup! fixup! Move towards the self-provided Redis connection pool and RESP protocol implementation. Official redis client library won't work with yaegi.

* ... and another all nighter.

* fixup! ... and another all nighter.

* fixup! fixup! ... and another all nighter.

* fixup! fixup! fixup! ... and another all nighter.

* Resolve issue #85 by adding ability to set custom claims in JWT tokens

* Remove redundant validation in auth middleware ( issue #89 )

* Add ability to set cookie prefix for session cookies ( #87 )

* fixup! Add ability to set cookie prefix for session cookies ( #87 )

* Add ability to set cookie max age - issue #91

* Potential fix for code scanning alert no. 10: Size computation for allocation may overflow

Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>

* fixup! Merge main into 0.8.0-redis: resolve conflicts

---------

Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
2025-11-30 02:18:46 +00:00
lukaszraczylo 5fcbd54955 Add sharded cache and prevention of CPU spikes / locks (#96)
* Add sharded cache and prevention of CPU spikes / locks

* Add dynamic client registration with oidc provider

* Fix race condition introduced during the sharded cache implementation.

* Add page for traefikoidc.
2025-11-30 01:41:12 +00:00
lukaszraczylo e70cd1907c Create CNAME 2025-11-30 01:28:07 +00:00
lukaszraczylo e45b06c86d Fix markdown issues. 2025-10-17 14:40:50 +01:00
lukaszraczylo ae59a5e88a 0.7.10 (#80)
* Add ability to disable replay protection. - This is useful for runs with multiple traefik replicas to avoid false positives and tokens re-creation.
* Enhance the CI/CD pipelines
* Increase test coverage.
* Update vendored dependencies.
* Update behaviour on forceHTTPS as per issue #82
2025-10-16 10:56:28 +01:00
lukaszraczylo 79e9b164f9 release 0.7.9 (#78)
* Speed improvements.

After introduction of introspection the plugin became significantly slower.
This commit introduces several optimizations to bring the speed back up.

* Add relevant documentation and tests.
2025-10-13 10:43:35 +01:00
lukaszraczylo 93888e56d1 fixup! Multiple issues addressed (#76) 2025-10-09 00:56:53 +01:00
lukaszraczylo eff9bd7bd2 Multiple issues addressed (#76)
- Issue #74
- Issue #14
2025-10-09 00:44:03 +01:00
lukaszraczylo bde1db1c3b traefik plugin 0.7.7 (#73)
* Automatic discovery of the scopes.

Issue #61 raised very valid concerns about users configuring scopes that are not supported by the provider.
This change introduces automatic discovery of supported scopes by fetching the provider's discovery document and filtering out unsupported scopes.

Before:
User configures: scopes: ["openid", "profile", "email", "offline_access"]
Self-hosted GitLab: "The requested scope is invalid, unknown, or malformed"
Authentication:  FAILS

After:
User configures: scopes: ["openid", "profile", "email", "offline_access"]
Middleware checks discovery doc → offline_access not supported
Automatically filters to: ["openid", "profile", "email"]
Authentication:  SUCCEEDS

* Resolves issue #74 by enabling user to specify expected audience in the configuration.

* Fix flaky tests.
2025-10-08 11:44:00 +01:00
lukaszraczylo 79d34ea4c9 Fix recursion in token resilience logic (#72) 2025-10-07 10:34:15 +01:00
lukaszraczylo c3f23cb99b Release 0.7.5 (#70)
* Resolve issue with opaque tokens not being parsed correctly

* Increase test coverage

* Further improvements to test coverage and code quality

* Add new providers.

* fixup! Add new providers.

* Cleanup.

* fixup! Cleanup.

* fixup! fixup! Cleanup.

* fixup! fixup! fixup! Cleanup.

* fixup! fixup! fixup! fixup! Cleanup.

* Memory management optimisation

24 bytes per Put < 256-4096 bytes per buffer allocation avoided (10-170x difference)

* Pooling cleanup.
2025-10-01 12:13:10 +01:00
501 changed files with 153494 additions and 26711 deletions
+38
View File
@@ -0,0 +1,38 @@
# Code Owners for traefik-oidc
# These owners will be automatically requested for review when someone opens a PR
# Default owner for everything in the repo
* @lukaszraczylo
# Core authentication and middleware
/middleware/ @lukaszraczylo
/auth/ @lukaszraczylo
/handlers/ @lukaszraczylo
# OIDC providers
/internal/providers/ @lukaszraczylo
# Session management and security
/session/ @lukaszraczylo
/internal/security/ @lukaszraczylo
/security/ @lukaszraczylo
# Token management
/internal/token/ @lukaszraczylo
# Configuration
/config/ @lukaszraczylo
/.traefik.yml @lukaszraczylo
# GitHub Actions and CI/CD
/.github/ @lukaszraczylo
/.github/workflows/ @lukaszraczylo
/.golangci.yml @lukaszraczylo
# Documentation
/docs/ @lukaszraczylo
README.md @lukaszraczylo
# Dependencies
go.mod @lukaszraczylo
go.sum @lukaszraczylo
+123
View File
@@ -0,0 +1,123 @@
## Description
<!-- Provide a brief description of the changes in this PR -->
## Type of Change
<!-- Mark the relevant option with an "x" -->
- [ ] Bug fix (non-breaking change which fixes an issue)
- [ ] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
- [ ] Documentation update
- [ ] Performance improvement
- [ ] Code refactoring
- [ ] Security fix
- [ ] Provider-specific fix/enhancement
## Related Issues
<!-- Link to related issues using #issue_number -->
Fixes #
Related to #
## Changes Made
<!-- List the main changes made in this PR -->
-
-
-
## Provider Impact
<!-- If this affects specific OIDC providers, list them here -->
- [ ] Google
- [ ] Azure AD
- [ ] Auth0
- [ ] Okta
- [ ] Keycloak
- [ ] AWS Cognito
- [ ] GitLab
- [ ] GitHub
- [ ] Generic OIDC
- [ ] All providers
## Testing Performed
<!-- Describe the tests you ran to verify your changes -->
- [ ] Unit tests pass locally
- [ ] Integration tests pass locally
- [ ] Race detector shows no issues
- [ ] Memory leak tests pass
- [ ] Manual testing performed
### Test Configuration
<!-- Provide details about your test configuration if applicable -->
**Provider tested:**
**Go version:**
**Traefik version:**
## Security Considerations
<!-- Describe any security implications of these changes -->
- [ ] This PR does not introduce security vulnerabilities
- [ ] Security scanning has been performed
- [ ] Credentials/secrets are properly handled
- [ ] Input validation is implemented
## Performance Impact
<!-- Describe any performance implications -->
- [ ] No performance impact expected
- [ ] Performance improved (describe how)
- [ ] Performance may be affected (describe why and mitigation)
## Breaking Changes
<!-- If this is a breaking change, describe the impact and migration path -->
**Breaking changes:**
**Migration guide:**
## Checklist
<!-- Ensure all items are checked before requesting review -->
- [ ] My code follows the project's code style
- [ ] I have performed a self-review of my code
- [ ] I have commented my code, particularly in hard-to-understand areas
- [ ] I have made corresponding changes to the documentation
- [ ] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my feature works
- [ ] New and existing unit tests pass locally with my changes
- [ ] Any dependent changes have been merged and published
## Additional Context
<!-- Add any other context, screenshots, or information about the PR here -->
## Screenshots (if applicable)
<!-- Add screenshots to help explain your changes -->
---
**For Reviewers:**
Please verify:
- [ ] Code quality and style
- [ ] Test coverage is adequate
- [ ] Security implications reviewed
- [ ] Documentation is updated
- [ ] No performance regressions
+52
View File
@@ -0,0 +1,52 @@
version: 2
updates:
# Maintain dependencies for GitHub Actions
- package-ecosystem: "github-actions"
directory: "/"
schedule:
interval: "weekly"
day: "monday"
time: "09:00"
open-pull-requests-limit: 5
commit-message:
prefix: "chore(deps)"
include: "scope"
labels:
- "dependencies"
- "github-actions"
reviewers:
- "lukaszraczylo"
# Maintain Go module dependencies
- package-ecosystem: "gomod"
directory: "/"
schedule:
interval: "weekly"
day: "monday"
time: "09:00"
open-pull-requests-limit: 10
commit-message:
prefix: "chore(deps)"
include: "scope"
labels:
- "dependencies"
- "go"
reviewers:
- "lukaszraczylo"
# Group patch updates together
groups:
patch-updates:
patterns:
- "*"
update-types:
- "patch"
minor-updates:
patterns:
- "*"
update-types:
- "minor"
# Ignore certain dependencies if needed
ignore:
# Example: ignore specific versions
# - dependency-name: "github.com/example/package"
# versions: ["1.x", "2.x"]
+9
View File
@@ -0,0 +1,9 @@
# Ensure consistent line endings
* text=auto eol=lf
# GitHub Actions files should use LF
*.yml text eol=lf
*.yaml text eol=lf
# Shell scripts should use LF
*.sh text eol=lf
+225
View File
@@ -0,0 +1,225 @@
# GitHub Actions Workflows
This directory contains CI/CD workflows for the Traefik OIDC middleware.
## Workflows
### PR Validation (`pr-validation.yml`)
A comprehensive validation workflow that runs **all checks in parallel** for maximum speed and thorough testing.
**Triggered on:**
- Pull requests to `main` branch
- Pushes to `main` branch
**Parallel Jobs (20+ concurrent checks):**
#### Code Quality
- **Quick Checks** - Format, go vet, go mod verify
- **golangci-lint** - Comprehensive linting
- **Staticcheck** - Static analysis
#### Security
- **Gosec** - Security vulnerability scanning
- **Govulncheck** - Go vulnerability database check
- **CodeQL** - GitHub's code analysis
#### Testing
- **Race Detector** - Concurrent access bug detection
- **Coverage** - Test coverage with 75% threshold
- **Memory Leaks** - Goroutine and memory leak detection
- **Integration Tests** - Full integration test suite
- **Regression Tests** - Prevent previously fixed bugs
- **Security Edge Cases** - Security-specific scenarios
- **Session Tests** - Session management validation
- **Token Tests** - Token validation scenarios
- **CSRF Tests** - CSRF protection validation
#### Provider Testing (Matrix)
Tests run in parallel for each OIDC provider:
- Google
- Azure AD
- Auth0
- Okta
- Keycloak
- AWS Cognito
- GitLab
- GitHub
- Generic OIDC
#### Performance & Compatibility
- **Benchmarks** - Performance regression detection
- **Build Matrix** - linux/darwin × amd64/arm64
- **Go Versions** - Go 1.23 and 1.24 compatibility
#### Final Validation
- **All Checks Passed** - Ensures all jobs succeeded
## Workflow Features
### 🚀 Parallel Execution
All independent checks run simultaneously for fastest feedback (~5-10 minutes for full suite).
### 📊 Coverage Reporting
- Automatic PR comments with coverage statistics
- Per-package coverage breakdown
- 75% coverage threshold enforcement
### 🔒 Security First
- Multiple security scanners (gosec, govulncheck, CodeQL)
- SARIF report uploads for GitHub Security tab
- Security edge case testing
### 🎯 Comprehensive Testing
- Race condition detection
- Memory leak detection
- Provider-specific testing
- Integration and regression tests
### 📈 Performance Tracking
- Benchmark results stored as artifacts
- Performance regression detection
### ✅ Quality Gates
All checks must pass before PR can be merged:
- Code formatting and style
- Security vulnerabilities
- Test coverage threshold
- Race conditions
- Memory leaks
- Build success on all platforms
## Local Development
### Run checks locally before pushing:
```bash
# Format code
gofmt -s -w .
# Run linter
golangci-lint run
# Run tests with race detector
go test -race -timeout=15m -count=1 ./...
# Check coverage
go test -coverprofile=coverage.out ./...
go tool cover -func=coverage.out
# Run specific test suites
go test -v -run='.*Leak.*' ./... # Memory leak tests
go test -v -run='.*Integration.*' ./... # Integration tests
go test -v -run='.*Regression.*' ./... # Regression tests
# Run benchmarks
go test -bench=. -benchmem ./...
# Security scan
gosec ./...
govulncheck ./...
```
### Required Tools
Install these tools for local development:
```bash
# golangci-lint
go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
# staticcheck
go install honnef.co/go/tools/cmd/staticcheck@latest
# gosec
go install github.com/securego/gosec/v2/cmd/gosec@latest
# govulncheck
go install golang.org/x/vuln/cmd/govulncheck@latest
```
## Troubleshooting
### Workflow Fails
1. **Check job status** - Click on failed job for details
2. **Review logs** - Expand failed steps to see error messages
3. **Run locally** - Reproduce issue with local commands above
4. **Check coverage** - Ensure test coverage meets 75% threshold
### Coverage Below Threshold
Add tests to increase coverage:
```bash
# See which lines aren't covered
go test -coverprofile=coverage.out ./...
go tool cover -html=coverage.out
```
### Race Condition Detected
Run with race detector locally:
```bash
go test -race -v ./...
```
### Provider Test Failure
Test specific provider:
```bash
go test -v -run='.*Azure.*' ./internal/providers/...
```
## Performance Optimization
The workflow is optimized for speed:
- **Parallel execution** - All independent jobs run simultaneously
- **Go caching** - Dependencies cached between runs
- **Strategic ordering** - Quick checks run first for fast feedback
- **Fail-fast disabled** - Continue running all tests even if some fail
## Workflow Monitoring
### GitHub Actions Dashboard
Monitor workflow runs at: `https://github.com/{owner}/{repo}/actions`
### Status Badges
Add to README.md:
```markdown
![PR Validation](https://github.com/{owner}/{repo}/actions/workflows/pr-validation.yml/badge.svg)
```
### Notifications
Configure in repository settings:
- Settings → Notifications
- Choose email or Slack notifications for workflow failures
## Maintenance
### Update Go Version
Edit in workflow file:
```yaml
go-version: '1.24' # Update this
```
### Adjust Coverage Threshold
Edit in workflow file:
```yaml
THRESHOLD=75 # Adjust this value
```
### Add New Provider
Add to provider matrix:
```yaml
matrix:
provider:
- new_provider # Add here
```
## Additional Resources
- [GitHub Actions Documentation](https://docs.github.com/en/actions)
- [golangci-lint Configuration](../.golangci.yml)
- [Dependabot Configuration](../dependabot.yml)
- [PR Template](../PULL_REQUEST_TEMPLATE.md)
+23
View File
@@ -0,0 +1,23 @@
name: Pull Request
on:
pull_request:
branches:
- main
push:
branches:
- "**"
- "!main"
permissions:
contents: read
pull-requests: write
security-events: write
jobs:
pr-checks:
uses: lukaszraczylo/shared-actions/.github/workflows/go-pr.yaml@main
with:
go-version: "1.24.11"
coverage-threshold: 70
secrets: inherit
+21
View File
@@ -0,0 +1,21 @@
name: Release
on:
push:
branches:
- main
paths:
- "**.go"
- "go.mod"
- "go.sum"
workflow_dispatch:
permissions:
contents: write
jobs:
release:
uses: lukaszraczylo/shared-actions/.github/workflows/go-release.yaml@main
with:
go-version: "1.24.11"
secrets: inherit
+2 -1
View File
@@ -1,2 +1,3 @@
docker/
.claude/
.claude/*.out
*.test
+192
View File
@@ -0,0 +1,192 @@
version: "2"
run:
go: "1.24"
modules-download-mode: readonly
tests: true
linters:
enable:
- bodyclose
- dupl
- goconst
- gocritic
- gocyclo
- goprintffuncname
- gosec
- misspell
- noctx
- nolintlint
- prealloc
- revive
- rowserrcheck
- sqlclosecheck
- unconvert
- unparam
- whitespace
disable:
- exhaustive
- funlen
- gocognit
- lll
- mnd
- testpackage
- wsl
settings:
dupl:
threshold: 200 # Allow intentional duplication in provider patterns and token management
errcheck:
check-type-assertions: true
check-blank: false # Allow explicit blank assignments (_ = ...) to ignore errors
exclude-functions:
- (io.Closer).Close
- (*database/sql.Rows).Close
- (*database/sql.Stmt).Close
- (io.Writer).Write
- (*net/http.ResponseWriter).Write
- fmt.Fprintf
- fmt.Fprint
- fmt.Fprintln
goconst:
min-len: 3
min-occurrences: 10 # Increased to reduce noise for standard OAuth2/OIDC strings
ignore-tests: true
gocritic:
# Using default enabled checks in v2
enabled-checks:
- appendCombine
- boolExprSimplify
- builtinShadow
- commentedOutCode
- emptyFallthrough
- equalFold
- hexLiteral
- indexAlloc
- initClause
- methodExprCall
- nestingReduce
- rangeExprCopy
- rangeValCopy
- stringXbytes
- typeAssertChain
- typeUnparen
- unlabelStmt
- yodaStyleExpr
gocyclo:
min-complexity: 30 # OAuth/OIDC flows are inherently complex; set higher for Yaegi compatibility
gosec:
excludes:
- G104
- G404
severity: medium
confidence: medium
govet:
disable:
- fieldalignment
- shadow
enable-all: true
misspell:
locale: US
ignore-rules:
- traefik
- oidc
- keycloak
nolintlint:
require-explanation: true
require-specific: true
allow-unused: false
prealloc:
simple: true
range-loops: true
for-loops: false
revive:
rules:
- name: blank-imports
- name: context-as-argument
- name: context-keys-type
- name: dot-imports
- name: error-return
- name: error-strings
- name: error-naming
- name: exported
- name: if-return
- name: increment-decrement
- name: var-naming
- name: var-declaration
- name: package-comments
- name: range
- name: receiver-naming
- name: time-naming
- name: unexported-return
- name: indent-error-flow
- name: errorf
- name: empty-block
- name: superfluous-else
- name: unused-parameter
- name: unreachable-code
- name: redefines-builtin-id
unparam:
check-exported: false
staticcheck:
checks:
- all
- -QF1001 # De Morgan's law - style preference, may affect Yaegi
- -QF1003 # Tagged switch - style preference, may affect Yaegi
- -QF1007 # Merge conditional assignment - style preference
- -QF1008 # Remove embedded field - may break Yaegi compatibility
- -QF1012 # Use fmt.Fprintf - style preference
- -ST1003 # Package name format - allowed for test packages
exclusions:
generated: lax
rules:
- linters:
- bodyclose
- dupl
- errcheck
- goconst
- gocyclo
- gosec
- noctx
- prealloc
- unparam
path: _test\.go
- linters:
- dupl
- gocyclo
path: test.*\.go
- linters:
- gocritic
- unused
path: mocks.*\.go
- linters:
- gosec
text: 'G404:'
- linters:
- all
path: vendor/
- linters:
- goconst
path: (.+)_test\.go
- linters:
- dupl
path: internal/providers/(auth0|keycloak|okta|google|azure|github|gitlab|cognito|generic)\.go
- linters:
- dupl
path: session\.go
- linters:
- dupl
path: session_chunk_manager\.go
text: "(extractJWTExpiration|extractJWTIssuedAt)"
paths:
- third_party$
- builtin$
- examples$
issues:
max-issues-per-linter: 0
max-same-issues: 0
uniq-by-line: true
formatters:
exclusions:
generated: lax
paths:
- third_party$
- builtin$
- examples$
+49
View File
@@ -0,0 +1,49 @@
version: 2
# Traefik plugins are source-only - no binary builds
# Traefik loads plugins via Yaegi interpreter at runtime
builds:
- skip: true
# Create source archive for GitHub releases
archives:
- formats: [tar.gz]
name_template: "{{ .ProjectName }}_v{{ .Version }}_source"
files:
- "*.go"
- "**/*.go"
- go.mod
- go.sum
- .traefik.yml
- LICENSE*
- README*
# Exclude test files and vendor from release archive
- "!**/*_test.go"
- "!vendor/**"
- "!docker/**"
- "!integration/**"
- "!regression/**"
- "!examples/**"
- "!docs/**"
checksum:
name_template: "{{ .ProjectName }}_v{{ .Version }}_checksums.txt"
algorithm: sha256
changelog:
sort: asc
filters:
exclude:
- "^docs:"
- "^test:"
- "^Merge"
- "^WIP"
- "^chore:"
release:
github:
owner: lukaszraczylo
name: traefikoidc
name_template: "v{{ .Version }}"
draft: false
prerelease: auto
+1218 -65
View File
File diff suppressed because it is too large Load Diff
+1213 -140
View File
File diff suppressed because it is too large Load Diff
-308
View File
@@ -1,308 +0,0 @@
# Test Execution Guide
This guide explains how to run tests efficiently with the new test categorization and optimization system.
## Quick Start
### Fast Development Testing (Default - Target: < 30 seconds)
```bash
# Run quick smoke tests only
go test ./...
# Or explicitly run in short mode
go test ./... -short
```
### Extended Testing (Target: 2-5 minutes)
```bash
# Enable extended tests with more iterations and concurrency
RUN_EXTENDED_TESTS=1 go test ./...
# Or use the flag equivalent (if using test runner that supports it)
go test ./... -extended
```
### Long-Running Performance Tests (Target: 5-15 minutes)
```bash
# Enable comprehensive performance and stress tests
RUN_LONG_TESTS=1 go test ./...
```
### Full Stress Testing (Target: 10-30 minutes)
```bash
# Enable all stress tests with maximum parameters
RUN_STRESS_TESTS=1 go test ./...
```
## Test Categories
### 1. Quick Tests (Default)
- **Purpose**: Fast feedback during development
- **Duration**: < 30 seconds total
- **Features**:
- Basic functionality verification
- Limited iterations (1-3)
- Small data sets
- Minimal concurrency
- Essential memory leak checks
**Configuration**:
- Max Iterations: 3
- Max Concurrency: 5
- Memory Threshold: 2.0 MB
- Cache Size: 50
- Timeout: 10 seconds
### 2. Extended Tests
- **Purpose**: Comprehensive testing before commits
- **Duration**: 2-5 minutes
- **Features**:
- Increased test coverage
- More iterations (5-10)
- Medium concurrency tests
- Enhanced memory leak detection
**Configuration**:
- Max Iterations: 10
- Max Concurrency: 20
- Memory Threshold: 10.0 MB
- Cache Size: 200
- Timeout: 30 seconds
### 3. Long Tests
- **Purpose**: Performance validation and stress testing
- **Duration**: 5-15 minutes
- **Features**:
- High iteration counts (50-100)
- High concurrency scenarios
- Large data sets
- Comprehensive memory testing
**Configuration**:
- Max Iterations: 100
- Max Concurrency: 50
- Memory Threshold: 50.0 MB
- Cache Size: 1000
- Timeout: 60 seconds
### 4. Stress Tests
- **Purpose**: Maximum load testing and edge case validation
- **Duration**: 10-30 minutes
- **Features**:
- Extreme iteration counts (100-500)
- Maximum concurrency (100+)
- Large memory allocations
- Edge case combinations
**Configuration**:
- Max Iterations: 500
- Max Concurrency: 100
- Memory Threshold: 100.0 MB
- Cache Size: 2000
- Timeout: 120 seconds
## Environment Variables
### Test Execution Control
```bash
# Enable specific test types
export RUN_EXTENDED_TESTS=1 # Enable extended tests
export RUN_LONG_TESTS=1 # Enable long-running tests
export RUN_STRESS_TESTS=1 # Enable stress tests
# Disable specific features
export DISABLE_LEAK_DETECTION=1 # Skip memory leak detection
```
### Parameter Customization
```bash
# Customize concurrency limits
export TEST_MAX_CONCURRENCY=10 # Override max concurrent operations
# Customize iteration limits
export TEST_MAX_ITERATIONS=50 # Override max test iterations
# Customize memory thresholds
export TEST_MEMORY_THRESHOLD_MB=25.5 # Override memory growth limit (in MB)
```
## Test-Specific Behavior
### Memory Leak Tests
- **Quick Mode**: 1-3 iterations, small data sets, strict memory limits
- **Extended Mode**: 5-10 iterations, medium data sets, relaxed limits
- **Long Mode**: 50-100 iterations, large data sets, performance focus
- **Stress Mode**: 100-500 iterations, maximum data sets, stress focus
### Concurrency Tests
- **Quick Mode**: 2-5 concurrent operations, basic race detection
- **Extended Mode**: 10-20 concurrent operations, moderate stress
- **Long Mode**: 20-50 concurrent operations, high contention
- **Stress Mode**: 50-100+ concurrent operations, maximum stress
### Cache Tests
- **Quick Mode**: Small caches (50 items), basic operations
- **Extended Mode**: Medium caches (200 items), varied operations
- **Long Mode**: Large caches (1000 items), performance testing
- **Stress Mode**: Very large caches (2000+ items), stress testing
## Integration with CI/CD
### GitHub Actions Example
```yaml
# Quick tests for every push/PR
- name: Quick Tests
run: go test ./... -short
# Extended tests for main branch
- name: Extended Tests
if: github.ref == 'refs/heads/main'
run: RUN_EXTENDED_TESTS=1 go test ./...
# Nightly comprehensive testing
- name: Nightly Stress Tests
if: github.event_name == 'schedule'
run: RUN_STRESS_TESTS=1 go test ./...
```
### Local Development Workflow
```bash
# During active development
go test ./... -short
# Before committing
RUN_EXTENDED_TESTS=1 go test ./...
# Before major releases
RUN_LONG_TESTS=1 go test ./...
# Performance validation
RUN_STRESS_TESTS=1 go test ./...
```
## Performance Optimization Features
### Dynamic Test Scaling
The test system automatically adjusts parameters based on:
- Test mode (quick/extended/long/stress)
- Available resources
- Environment variables
- Previous test performance
### Memory Management
- **Garbage Collection**: Forced GC between test iterations
- **Memory Monitoring**: Real-time memory growth tracking
- **Leak Detection**: Goroutine and memory leak prevention
- **Resource Cleanup**: Automatic cleanup of test resources
### Timeout Management
- **Adaptive Timeouts**: Timeouts scale with test complexity
- **Graceful Degradation**: Tests adapt to slower environments
- **Early Termination**: Failed tests terminate quickly
## Troubleshooting
### Tests Taking Too Long
```bash
# Check if running in extended mode accidentally
echo $RUN_EXTENDED_TESTS $RUN_LONG_TESTS
# Force quick mode
unset RUN_EXTENDED_TESTS RUN_LONG_TESTS RUN_STRESS_TESTS
go test ./... -short
```
### Memory Issues
```bash
# Reduce memory limits for constrained environments
export TEST_MEMORY_THRESHOLD_MB=5.0
export TEST_MAX_CONCURRENCY=2
go test ./...
```
### Concurrency Issues
```bash
# Reduce concurrency for slower systems
export TEST_MAX_CONCURRENCY=5
export TEST_MAX_ITERATIONS=10
go test ./...
```
### Skip Specific Test Types
```bash
# Skip memory leak detection if problematic
export DISABLE_LEAK_DETECTION=1
go test ./...
```
## Benchmarking
### Running Benchmarks
```bash
# Quick benchmarks
go test -bench=. -short
# Extended benchmarks
RUN_EXTENDED_TESTS=1 go test -bench=.
# Memory profiling
go test -bench=. -memprofile=mem.prof
go tool pprof mem.prof
```
### Benchmark Categories
- **Basic Operations**: Set/Get performance
- **Concurrency**: Multi-threaded performance
- **Memory**: Allocation and cleanup performance
- **Cache**: Eviction and cleanup performance
## Best Practices
### For Developers
1. Always run quick tests during development (`go test ./... -short`)
2. Run extended tests before committing (`RUN_EXTENDED_TESTS=1 go test ./...`)
3. Use appropriate test categories for your use case
4. Monitor test execution time and adjust if needed
### For CI/CD
1. Use quick tests for fast feedback on PRs
2. Use extended tests for main branch validation
3. Use long tests for release validation
4. Use stress tests for nightly/weekly validation
### For Performance Testing
1. Use consistent environment variables
2. Run tests multiple times for statistical significance
3. Monitor both execution time and resource usage
4. Use profiling tools for detailed analysis
## Examples
### Daily Development
```bash
# Fast tests while coding
go test ./... -short
# Before git commit
RUN_EXTENDED_TESTS=1 go test ./...
```
### Release Testing
```bash
# Comprehensive validation
RUN_LONG_TESTS=1 go test ./...
# Stress testing
RUN_STRESS_TESTS=1 go test ./...
```
### Custom Configuration
```bash
# Custom limits for specific environment
export TEST_MAX_CONCURRENCY=8
export TEST_MAX_ITERATIONS=25
export TEST_MEMORY_THRESHOLD_MB=15.0
RUN_EXTENDED_TESTS=1 go test ./...
```
This test system provides flexible, scalable test execution that adapts to your development workflow and infrastructure constraints while maintaining comprehensive test coverage.
+1518
View File
File diff suppressed because it is too large Load Diff
-360
View File
@@ -1,360 +0,0 @@
// Package auth provides authentication-related functionality for the OIDC middleware.
package auth
import (
"fmt"
"net"
"net/http"
"net/url"
"strings"
"github.com/google/uuid"
)
// AuthHandler provides core authentication functionality for OIDC flows
type AuthHandler struct {
logger Logger
enablePKCE bool
isGoogleProv func() bool
isAzureProv func() bool
clientID string
authURL string
issuerURL string
scopes []string
overrideScopes bool
}
// Logger interface for dependency injection
type Logger interface {
Debugf(format string, args ...interface{})
Errorf(format string, args ...interface{})
}
// NewAuthHandler creates a new AuthHandler instance
func NewAuthHandler(logger Logger, enablePKCE bool, isGoogleProv, isAzureProv func() bool,
clientID, authURL, issuerURL string, scopes []string, overrideScopes bool) *AuthHandler {
return &AuthHandler{
logger: logger,
enablePKCE: enablePKCE,
isGoogleProv: isGoogleProv,
isAzureProv: isAzureProv,
clientID: clientID,
authURL: authURL,
issuerURL: issuerURL,
scopes: scopes,
overrideScopes: overrideScopes,
}
}
// InitiateAuthentication initiates the OIDC authentication flow.
// It generates CSRF tokens, nonce, PKCE parameters (if enabled), clears the session,
// stores authentication state, and redirects the user to the OIDC provider.
func (h *AuthHandler) InitiateAuthentication(rw http.ResponseWriter, req *http.Request,
session SessionData, redirectURL string,
generateNonce, generateCodeVerifier, deriveCodeChallenge func() (string, error)) {
h.logger.Debugf("Initiating new OIDC authentication flow for request: %s", req.URL.RequestURI())
const maxRedirects = 5
redirectCount := session.GetRedirectCount()
if redirectCount >= maxRedirects {
h.logger.Errorf("Maximum redirect limit (%d) exceeded, possible redirect loop detected", maxRedirects)
session.ResetRedirectCount()
http.Error(rw, "Authentication failed: Too many redirects", http.StatusLoopDetected)
return
}
session.IncrementRedirectCount()
csrfToken := uuid.NewString()
nonce, err := generateNonce()
if err != nil {
h.logger.Errorf("Failed to generate nonce: %v", err)
http.Error(rw, "Failed to generate nonce", http.StatusInternalServerError)
return
}
// Generate PKCE code verifier and challenge if PKCE is enabled
var codeVerifier, codeChallenge string
if h.enablePKCE {
codeVerifier, err = generateCodeVerifier()
if err != nil {
h.logger.Errorf("Failed to generate code verifier: %v", err)
http.Error(rw, "Failed to generate code verifier", http.StatusInternalServerError)
return
}
codeChallenge, err = deriveCodeChallenge()
if err != nil {
h.logger.Errorf("Failed to generate code challenge: %v", err)
http.Error(rw, "Failed to generate code challenge", http.StatusInternalServerError)
return
}
h.logger.Debugf("PKCE enabled, generated code challenge")
}
session.SetAuthenticated(false)
session.SetEmail("")
session.SetAccessToken("")
session.SetRefreshToken("")
session.SetIDToken("")
session.SetNonce("")
session.SetCodeVerifier("")
session.SetCSRF(csrfToken)
session.SetNonce(nonce)
if h.enablePKCE {
session.SetCodeVerifier(codeVerifier)
}
session.SetIncomingPath(req.URL.RequestURI())
h.logger.Debugf("Storing incoming path: %s", req.URL.RequestURI())
session.MarkDirty()
if err := session.Save(req, rw); err != nil {
h.logger.Errorf("Failed to save session before redirecting to provider: %v", err)
http.Error(rw, "Failed to save session", http.StatusInternalServerError)
return
}
h.logger.Debugf("Session saved before redirect. CSRF: %s, Nonce: %s",
csrfToken, nonce)
authURL := h.BuildAuthURL(redirectURL, csrfToken, nonce, codeChallenge)
h.logger.Debugf("Redirecting user to OIDC provider: %s", authURL)
http.Redirect(rw, req, authURL, http.StatusFound)
}
// BuildAuthURL constructs the OIDC provider authorization URL.
// It builds the URL with all necessary parameters including client_id, scopes,
// PKCE parameters, and provider-specific parameters for Google and Azure.
func (h *AuthHandler) BuildAuthURL(redirectURL, state, nonce, codeChallenge string) string {
params := url.Values{}
params.Set("client_id", h.clientID)
params.Set("response_type", "code")
params.Set("redirect_uri", redirectURL)
params.Set("state", state)
params.Set("nonce", nonce)
if h.enablePKCE && codeChallenge != "" {
params.Set("code_challenge", codeChallenge)
params.Set("code_challenge_method", "S256")
}
scopes := make([]string, len(h.scopes))
copy(scopes, h.scopes)
if h.isGoogleProv() {
params.Set("access_type", "offline")
h.logger.Debugf("Google OIDC provider detected, added access_type=offline for refresh tokens")
params.Set("prompt", "consent")
h.logger.Debugf("Google OIDC provider detected, added prompt=consent to ensure refresh tokens")
} else if h.isAzureProv() {
params.Set("response_mode", "query")
h.logger.Debugf("Azure AD provider detected, added response_mode=query")
hasOfflineAccess := false
for _, scope := range scopes {
if scope == "offline_access" {
hasOfflineAccess = true
break
}
}
if !h.overrideScopes || (h.overrideScopes && len(h.scopes) == 0) {
if !hasOfflineAccess {
scopes = append(scopes, "offline_access")
h.logger.Debugf("Azure AD provider: Added offline_access scope (overrideScopes: %t, user scopes count: %d)", h.overrideScopes, len(h.scopes))
}
} else {
h.logger.Debugf("Azure AD provider: User is overriding scopes (count: %d), offline_access not automatically added.", len(h.scopes))
}
} else {
if !h.overrideScopes || (h.overrideScopes && len(h.scopes) == 0) {
hasOfflineAccess := false
for _, scope := range scopes {
if scope == "offline_access" {
hasOfflineAccess = true
break
}
}
if !hasOfflineAccess {
scopes = append(scopes, "offline_access")
h.logger.Debugf("Standard provider: Added offline_access scope (overrideScopes: %t, user scopes count: %d)", h.overrideScopes, len(h.scopes))
}
} else {
h.logger.Debugf("Standard provider: User is overriding scopes (count: %d), offline_access not automatically added.", len(h.scopes))
}
}
if len(scopes) > 0 {
finalScopeString := strings.Join(scopes, " ")
params.Set("scope", finalScopeString)
h.logger.Debugf("AuthHandler.BuildAuthURL: Final scope string being sent to OIDC provider: %s", finalScopeString)
}
return h.buildURLWithParams(h.authURL, params)
}
// buildURLWithParams constructs a URL by combining a base URL with query parameters.
// It handles both relative and absolute URLs, validates URL security,
// and properly encodes query parameters.
func (h *AuthHandler) buildURLWithParams(baseURL string, params url.Values) string {
if baseURL != "" {
if strings.HasPrefix(baseURL, "http://") || strings.HasPrefix(baseURL, "https://") {
if err := h.validateURL(baseURL); err != nil {
h.logger.Errorf("URL validation failed for %s: %v", baseURL, err)
return ""
}
}
}
if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
issuerURLParsed, err := url.Parse(h.issuerURL)
if err != nil {
h.logger.Errorf("Could not parse issuerURL: %s. Error: %v", h.issuerURL, err)
return ""
}
baseURLParsed, err := url.Parse(baseURL)
if err != nil {
h.logger.Errorf("Could not parse baseURL: %s. Error: %v", baseURL, err)
return ""
}
resolvedURL := issuerURLParsed.ResolveReference(baseURLParsed)
if err := h.validateURL(resolvedURL.String()); err != nil {
h.logger.Errorf("Resolved URL validation failed for %s: %v", resolvedURL.String(), err)
return ""
}
resolvedURL.RawQuery = params.Encode()
return resolvedURL.String()
}
u, err := url.Parse(baseURL)
if err != nil {
h.logger.Errorf("Could not parse absolute baseURL: %s. Error: %v", baseURL, err)
return ""
}
if err := h.validateParsedURL(u); err != nil {
h.logger.Errorf("Parsed URL validation failed for %s: %v", baseURL, err)
return ""
}
u.RawQuery = params.Encode()
return u.String()
}
// validateURL performs security validation on URLs to prevent SSRF attacks.
// It checks for allowed schemes, validates hosts, and prevents access to private networks.
func (h *AuthHandler) validateURL(urlStr string) error {
if urlStr == "" {
return fmt.Errorf("empty URL")
}
u, err := url.Parse(urlStr)
if err != nil {
return fmt.Errorf("invalid URL format: %w", err)
}
return h.validateParsedURL(u)
}
// validateParsedURL validates a parsed URL structure for security.
// It checks schemes, hosts, and paths to prevent malicious URLs.
func (h *AuthHandler) validateParsedURL(u *url.URL) error {
allowedSchemes := map[string]bool{
"https": true,
"http": true,
}
if !allowedSchemes[u.Scheme] {
return fmt.Errorf("disallowed URL scheme: %s", u.Scheme)
}
if u.Scheme == "http" {
h.logger.Debugf("Warning: Using HTTP scheme for URL: %s", u.String())
}
if u.Host == "" {
return fmt.Errorf("missing host in URL")
}
if err := h.validateHost(u.Host); err != nil {
return fmt.Errorf("invalid host: %w", err)
}
if strings.Contains(u.Path, "..") {
return fmt.Errorf("path traversal detected in URL path")
}
return nil
}
// validateHost validates a hostname for security and reachability.
// It prevents access to private networks and localhost addresses.
func (h *AuthHandler) validateHost(host string) error {
if host == "" {
return fmt.Errorf("empty host")
}
// Strip port if present
if strings.Contains(host, ":") {
var err error
host, _, err = net.SplitHostPort(host)
if err != nil {
return fmt.Errorf("invalid host:port format: %w", err)
}
}
// Check for localhost variations
localhostVariations := []string{
"localhost", "127.0.0.1", "::1", "0.0.0.0",
}
for _, localhost := range localhostVariations {
if strings.EqualFold(host, localhost) {
return fmt.Errorf("localhost access not allowed: %s", host)
}
}
// Try to parse as IP address
if ip := net.ParseIP(host); ip != nil {
if ip.IsLoopback() {
return fmt.Errorf("loopback IP not allowed: %s", host)
}
if ip.IsPrivate() {
return fmt.Errorf("private IP not allowed: %s", host)
}
if ip.IsLinkLocalUnicast() {
return fmt.Errorf("link-local IP not allowed: %s", host)
}
if ip.IsMulticast() {
return fmt.Errorf("multicast IP not allowed: %s", host)
}
}
return nil
}
// SessionData interface for dependency injection
type SessionData interface {
GetRedirectCount() int
ResetRedirectCount()
IncrementRedirectCount()
SetAuthenticated(bool)
SetEmail(string)
SetAccessToken(string)
SetRefreshToken(string)
SetIDToken(string)
SetNonce(string)
SetCodeVerifier(string)
SetCSRF(string)
SetIncomingPath(string)
MarkDirty()
Save(req *http.Request, rw http.ResponseWriter) error
}
+342
View File
@@ -0,0 +1,342 @@
package traefikoidc
import (
"fmt"
"net/http"
"strings"
"github.com/google/uuid"
)
// validateRedirectCount checks if redirect limit is exceeded and handles the error
func (t *TraefikOidc) validateRedirectCount(session *SessionData, rw http.ResponseWriter, req *http.Request) error {
const maxRedirects = 5
redirectCount := session.GetRedirectCount()
if redirectCount >= maxRedirects {
t.logger.Errorf("Maximum redirect limit (%d) exceeded, possible redirect loop detected", maxRedirects)
session.ResetRedirectCount()
t.sendErrorResponse(rw, req, "Authentication failed: Too many redirects", http.StatusLoopDetected)
return fmt.Errorf("redirect limit exceeded")
}
session.IncrementRedirectCount()
return nil
}
// generatePKCEParameters generates PKCE code verifier and challenge if PKCE is enabled
func (t *TraefikOidc) generatePKCEParameters() (string, string, error) {
if !t.enablePKCE {
return "", "", nil
}
codeVerifier, err := generateCodeVerifier()
if err != nil {
return "", "", fmt.Errorf("failed to generate code verifier: %w", err)
}
codeChallenge := deriveCodeChallenge(codeVerifier)
t.logger.Debugf("PKCE enabled, generated code challenge")
return codeVerifier, codeChallenge, nil
}
// prepareSessionForAuthentication clears existing session data and sets new authentication state
func (t *TraefikOidc) prepareSessionForAuthentication(session *SessionData, csrfToken, nonce, codeVerifier, incomingPath string) {
// Clear all existing session data
_ = session.SetAuthenticated(false) // Safe to ignore: clearing authentication state on new flow
session.SetEmail("")
session.SetAccessToken("")
session.SetRefreshToken("")
session.SetIDToken("")
session.SetNonce("")
session.SetCodeVerifier("")
// Set new authentication state
session.SetCSRF(csrfToken)
session.SetNonce(nonce)
if t.enablePKCE && codeVerifier != "" {
session.SetCodeVerifier(codeVerifier)
}
session.SetIncomingPath(incomingPath)
t.logger.Debugf("Storing incoming path: %s", incomingPath)
}
// defaultInitiateAuthentication initiates the OIDC authentication flow.
// It generates CSRF tokens, nonce, PKCE parameters (if enabled), clears the session,
// stores authentication state, and redirects the user to the OIDC provider.
// Parameters:
// - rw: The HTTP response writer.
// - req: The HTTP request initiating authentication.
// - session: The session data to prepare for authentication.
// - redirectURL: The pre-calculated callback URL (redirect_uri) for this middleware instance.
func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
t.logger.Debugf("Initiating new OIDC authentication flow for request: %s", req.URL.RequestURI())
// Check and handle redirect limits
if err := t.validateRedirectCount(session, rw, req); err != nil {
return
}
csrfToken := uuid.NewString()
nonce, err := generateNonce()
if err != nil {
t.logger.Errorf("Failed to generate nonce: %v", err)
http.Error(rw, "Failed to generate nonce", http.StatusInternalServerError)
return
}
// Generate PKCE parameters if enabled
codeVerifier, codeChallenge, err := t.generatePKCEParameters()
if err != nil {
t.logger.Errorf("Failed to generate PKCE parameters: %v", err)
http.Error(rw, "Failed to generate PKCE parameters", http.StatusInternalServerError)
return
}
// Clear existing session data and set new authentication state
t.prepareSessionForAuthentication(session, csrfToken, nonce, codeVerifier, req.URL.RequestURI())
session.MarkDirty()
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to save session before redirecting to provider: %v", err)
http.Error(rw, "Failed to save session", http.StatusInternalServerError)
return
}
t.logger.Debugf("Session saved before redirect. CSRF: %s, Nonce: %s",
csrfToken, nonce)
authURL := t.buildAuthURL(redirectURL, csrfToken, nonce, codeChallenge)
t.logger.Debugf("Redirecting user to OIDC provider: %s", authURL)
http.Redirect(rw, req, authURL, http.StatusFound)
}
// handleCallback processes the OIDC callback after user authentication.
// It validates state/CSRF tokens, exchanges authorization code for tokens,
// verifies the received tokens, extracts claims, and establishes the session.
// Parameters:
// - rw: The HTTP response writer.
// - req: The callback request containing authorization code and state.
// - redirectURL: The fully qualified callback URL (used in the token exchange request).
func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, redirectURL string) {
session, err := t.sessionManager.GetSession(req)
if err != nil {
t.logger.Errorf("Session error during callback: %v", err)
t.sendErrorResponse(rw, req, "Session error during callback", http.StatusInternalServerError)
return
}
defer session.returnToPoolSafely()
t.logger.Debugf("Handling callback, URL: %s", req.URL.String())
if req.URL.Query().Get("error") != "" {
errorDescription := req.URL.Query().Get("error_description")
if errorDescription == "" {
errorDescription = req.URL.Query().Get("error")
}
t.logger.Errorf("Authentication error from provider during callback: %s - %s", req.URL.Query().Get("error"), errorDescription)
t.sendErrorResponse(rw, req, fmt.Sprintf("Authentication error from provider: %s", errorDescription), http.StatusBadRequest)
return
}
state := req.URL.Query().Get("state")
if state == "" {
t.logger.Error("No state in callback")
t.sendErrorResponse(rw, req, "State parameter missing in callback", http.StatusBadRequest)
return
}
csrfToken := session.GetCSRF()
if csrfToken == "" {
t.logger.Errorf("CSRF token missing in session during callback. Authenticated: %v, Request URL: %s",
session.GetAuthenticated(), req.URL.String())
cookie, err := req.Cookie("_oidc_raczylo_m")
if err != nil {
t.logger.Errorf("Main session cookie not found in request: %v", err)
} else {
t.logger.Errorf("Main session cookie exists but CSRF token is empty. Cookie value length: %d", len(cookie.Value))
}
t.sendErrorResponse(rw, req, "CSRF token missing in session", http.StatusBadRequest)
return
}
if state != csrfToken {
t.logger.Error("State parameter does not match CSRF token in session during callback")
t.sendErrorResponse(rw, req, "Invalid state parameter (CSRF mismatch)", http.StatusBadRequest)
return
}
code := req.URL.Query().Get("code")
if code == "" {
t.logger.Error("No code in callback")
t.sendErrorResponse(rw, req, "No authorization code received in callback", http.StatusBadRequest)
return
}
codeVerifier := session.GetCodeVerifier()
tokenResponse, err := t.tokenExchanger.ExchangeCodeForToken(req.Context(), "authorization_code", code, redirectURL, codeVerifier)
if err != nil {
t.logger.Errorf("Failed to exchange code for token during callback: %v", err)
t.sendErrorResponse(rw, req, "Authentication failed: Could not exchange code for token", http.StatusInternalServerError)
return
}
if err = t.verifyToken(tokenResponse.IDToken); err != nil {
t.logger.Errorf("Failed to verify id_token during callback: %v", err)
t.sendErrorResponse(rw, req, "Authentication failed: Could not verify ID token", http.StatusInternalServerError)
return
}
claims, err := t.extractClaimsFunc(tokenResponse.IDToken)
if err != nil {
t.logger.Errorf("Failed to extract claims during callback: %v", err)
t.sendErrorResponse(rw, req, "Authentication failed: Could not extract claims from token", http.StatusInternalServerError)
return
}
nonceClaim, ok := claims["nonce"].(string)
if !ok || nonceClaim == "" {
t.logger.Error("Nonce claim missing in id_token during callback")
t.sendErrorResponse(rw, req, "Authentication failed: Nonce missing in token", http.StatusInternalServerError)
return
}
sessionNonce := session.GetNonce()
if sessionNonce == "" {
t.logger.Error("Nonce not found in session during callback")
t.sendErrorResponse(rw, req, "Authentication failed: Nonce missing in session", http.StatusInternalServerError)
return
}
if nonceClaim != sessionNonce {
t.logger.Error("Nonce claim does not match session nonce during callback")
t.sendErrorResponse(rw, req, "Authentication failed: Nonce mismatch", http.StatusInternalServerError)
return
}
// Extract user identifier from the configured claim (defaults to "email" for backward compatibility)
userIdentifier, _ := claims[t.userIdentifierClaim].(string)
if userIdentifier == "" {
// Try "sub" as fallback since it's required by OIDC spec
if t.userIdentifierClaim != "sub" {
userIdentifier, _ = claims["sub"].(string)
}
if userIdentifier == "" {
t.logger.Errorf("User identifier claim '%s' missing or empty in token during callback", t.userIdentifierClaim)
t.sendErrorResponse(rw, req, "Authentication failed: User identifier missing in token", http.StatusInternalServerError)
return
}
t.logger.Debugf("Configured claim '%s' not found, using 'sub' claim as fallback", t.userIdentifierClaim)
}
// Validate user authorization
if !t.isAllowedUser(userIdentifier) {
t.logger.Errorf("User not authorized during callback: %s", userIdentifier)
t.sendErrorResponse(rw, req, "Authentication failed: User not authorized", http.StatusForbidden)
return
}
if err := session.SetAuthenticated(true); err != nil {
t.logger.Errorf("Failed to set authenticated state and regenerate session ID: %v", err)
t.sendErrorResponse(rw, req, "Failed to update session", http.StatusInternalServerError)
return
}
session.SetEmail(userIdentifier) // SetEmail stores the user identifier (email or other claim)
session.SetIDToken(tokenResponse.IDToken)
session.SetAccessToken(tokenResponse.AccessToken)
session.SetRefreshToken(tokenResponse.RefreshToken)
session.SetCSRF("")
session.SetNonce("")
session.SetCodeVerifier("")
session.ResetRedirectCount()
redirectPath := "/"
if incomingPath := session.GetIncomingPath(); incomingPath != "" && incomingPath != t.redirURLPath {
redirectPath = incomingPath
}
session.SetIncomingPath("")
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to save session after callback: %v", err)
t.sendErrorResponse(rw, req, "Failed to save session after callback", http.StatusInternalServerError)
return
}
t.logger.Debugf("Callback successful, redirecting to %s", redirectPath)
http.Redirect(rw, req, redirectPath, http.StatusFound)
}
// handleExpiredToken handles requests with expired or invalid tokens.
// It clears the session data and initiates a new authentication flow.
// Parameters:
// - rw: The HTTP response writer.
// - req: The HTTP request with expired token.
// - session: The session data to clear.
// - redirectURL: The callback URL to be used in the new authentication flow.
func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
t.logger.Debug("Handling expired token: Clearing session and initiating re-authentication.")
_ = session.SetAuthenticated(false) // Safe to ignore: clearing authentication on expired token
session.SetIDToken("")
session.SetAccessToken("")
session.SetRefreshToken("")
session.SetEmail("")
// Clear CSRF tokens to prevent replay attacks
session.SetCSRF("")
session.SetNonce("")
session.SetCodeVerifier("")
// Reset redirect count to prevent loops when handling expired tokens
session.ResetRedirectCount()
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to save cleared session during expired token handling: %v", err)
}
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
}
// isUserAuthenticated determines the authentication status and refresh requirements.
// It delegates to provider-specific validation methods that handle different token types
// and expiration behaviors.
// Parameters:
// - session: The session data containing authentication tokens.
//
// Returns:
// - authenticated (bool): True if the user has valid tokens.
// - needsRefresh (bool): True if tokens are valid but nearing expiration.
// - expired (bool): True if the session is unauthenticated, the token is missing,
// or the token verification failed for reasons other than nearing/actual expiration.
func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, bool) {
if t.isAzureProvider() {
return t.validateAzureTokens(session)
} else if t.isGoogleProvider() {
return t.validateGoogleTokens(session)
}
// Auth0 and other providers can now use standard validation
// which handles opaque tokens generically
return t.validateStandardTokens(session)
}
// isAjaxRequest determines if this is an AJAX request that should receive 401 instead of redirect
func (t *TraefikOidc) isAjaxRequest(req *http.Request) bool {
xhr := req.Header.Get("X-Requested-With")
contentType := req.Header.Get("Content-Type")
accept := req.Header.Get("Accept")
return xhr == "XMLHttpRequest" ||
strings.Contains(contentType, "application/json") ||
strings.Contains(accept, "application/json")
}
// isRefreshTokenExpired checks if refresh token is likely expired (older than 6 hours)
func (t *TraefikOidc) isRefreshTokenExpired(session *SessionData) bool {
// This is a heuristic check - actual implementation would depend on
// the specific provider and token metadata
return false // Placeholder implementation
}
File diff suppressed because it is too large Load Diff
+101
View File
@@ -0,0 +1,101 @@
package traefikoidc
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestGeneratePKCEParameters tests the generatePKCEParameters method
func TestGeneratePKCEParameters(t *testing.T) {
t.Run("PKCE enabled - successful generation", func(t *testing.T) {
// Create a TraefikOidc instance with PKCE enabled
plugin := &TraefikOidc{
enablePKCE: true,
logger: NewLogger("debug"),
}
verifier, challenge, err := plugin.generatePKCEParameters()
require.NoError(t, err)
assert.NotEmpty(t, verifier, "code verifier should not be empty when PKCE is enabled")
assert.NotEmpty(t, challenge, "code challenge should not be empty when PKCE is enabled")
// Verify the challenge is derived from the verifier
expectedChallenge := deriveCodeChallenge(verifier)
assert.Equal(t, expectedChallenge, challenge, "challenge should match derived challenge from verifier")
})
t.Run("PKCE disabled - returns empty strings", func(t *testing.T) {
// Create a TraefikOidc instance with PKCE disabled
plugin := &TraefikOidc{
enablePKCE: false,
logger: NewLogger("debug"),
}
verifier, challenge, err := plugin.generatePKCEParameters()
require.NoError(t, err)
assert.Empty(t, verifier, "code verifier should be empty when PKCE is disabled")
assert.Empty(t, challenge, "code challenge should be empty when PKCE is disabled")
})
t.Run("PKCE enabled - generates different values each time", func(t *testing.T) {
plugin := &TraefikOidc{
enablePKCE: true,
logger: NewLogger("debug"),
}
verifier1, challenge1, err1 := plugin.generatePKCEParameters()
require.NoError(t, err1)
verifier2, challenge2, err2 := plugin.generatePKCEParameters()
require.NoError(t, err2)
assert.NotEqual(t, verifier1, verifier2, "verifiers should be different")
assert.NotEqual(t, challenge1, challenge2, "challenges should be different")
})
t.Run("PKCE enabled - verifier and challenge relationship", func(t *testing.T) {
plugin := &TraefikOidc{
enablePKCE: true,
logger: NewLogger("debug"),
}
verifier, challenge, err := plugin.generatePKCEParameters()
require.NoError(t, err)
// The challenge should always be derivable from the verifier
recalculatedChallenge := deriveCodeChallenge(verifier)
assert.Equal(t, challenge, recalculatedChallenge,
"challenge should always match the SHA256 hash of verifier")
})
t.Run("PKCE enabled - verifier meets RFC 7636 requirements", func(t *testing.T) {
plugin := &TraefikOidc{
enablePKCE: true,
logger: NewLogger("debug"),
}
verifier, _, err := plugin.generatePKCEParameters()
require.NoError(t, err)
// RFC 7636 requires verifier to be 43-128 characters
assert.GreaterOrEqual(t, len(verifier), 43, "verifier should be at least 43 characters")
assert.LessOrEqual(t, len(verifier), 128, "verifier should be at most 128 characters")
})
t.Run("PKCE enabled - challenge meets RFC 7636 requirements", func(t *testing.T) {
plugin := &TraefikOidc{
enablePKCE: true,
logger: NewLogger("debug"),
}
_, challenge, err := plugin.generatePKCEParameters()
require.NoError(t, err)
// SHA256 hash base64 encoded should be 43 characters
assert.Equal(t, 43, len(challenge), "S256 challenge should be exactly 43 characters")
})
}
+28 -25
View File
@@ -173,7 +173,7 @@ func (bt *BackgroundTask) run() {
if bt.logger != nil {
if !isTestMode() {
bt.logger.Info("Starting background task: %s", bt.name)
bt.logger.Debug("Starting background task: %s", bt.name)
}
}
@@ -182,7 +182,7 @@ func (bt *BackgroundTask) run() {
case <-bt.stopChan:
if bt.logger != nil {
if !isTestMode() {
bt.logger.Info("Stopping background task: %s (before initial execution)", bt.name)
bt.logger.Debug("Stopping background task: %s (before initial execution)", bt.name)
}
}
return
@@ -201,7 +201,7 @@ func (bt *BackgroundTask) run() {
case <-bt.stopChan:
if bt.logger != nil {
if !isTestMode() {
bt.logger.Info("Stopping background task: %s (during periodic execution)", bt.name)
bt.logger.Debug("Stopping background task: %s (during periodic execution)", bt.name)
}
}
return
@@ -211,7 +211,7 @@ func (bt *BackgroundTask) run() {
case <-bt.stopChan:
if bt.logger != nil {
if !isTestMode() {
bt.logger.Info("Stopping background task: %s (direct stop signal)", bt.name)
bt.logger.Debug("Stopping background task: %s (direct stop signal)", bt.name)
}
}
return
@@ -222,17 +222,16 @@ func (bt *BackgroundTask) run() {
// TaskCircuitBreaker implements circuit breaker pattern for background task creation
// It limits concurrent task execution and tracks failures to prevent system overload
type TaskCircuitBreaker struct {
state int32 // CircuitBreakerState
failureCount int32
lastFailureTime int64 // Unix timestamp
failureThreshold int32
timeout time.Duration
logger *Logger
// Concurrency limiting
concurrentTasks int32 // Current number of running tasks
maxConcurrent int32 // Maximum concurrent tasks allowed
activeTasks map[string]struct{} // Track active task names
tasksMu sync.RWMutex // Separate mutex for task tracking
activeTasks map[string]struct{}
lastFailureTime int64
timeout time.Duration
tasksMu sync.RWMutex
state int32
failureCount int32
failureThreshold int32
concurrentTasks int32
maxConcurrent int32
}
// NewTaskCircuitBreaker creates a new circuit breaker for background tasks
@@ -266,18 +265,21 @@ func (cb *TaskCircuitBreaker) CanCreateTask(taskName string) error {
max := atomic.LoadInt32(&cb.maxConcurrent)
// For cleanup tasks, be more restrictive (singleton-like behavior)
// However, allow distinct realm-specific tasks (e.g., singleton-metadata-refresh-abc123 vs singleton-metadata-refresh-def456)
if strings.Contains(taskName, "cleanup") || strings.Contains(taskName, "singleton") {
cb.tasksMu.RLock()
hasCleanupTask := false
hasSameTask := false
for activeTask := range cb.activeTasks {
if strings.Contains(activeTask, "cleanup") || strings.Contains(activeTask, "singleton") {
hasCleanupTask = true
// Only block if the EXACT same task is already running
// This allows realm-specific tasks like singleton-metadata-refresh-{hash} to run concurrently
if activeTask == taskName {
hasSameTask = true
break
}
}
cb.tasksMu.RUnlock()
if hasCleanupTask {
if hasSameTask {
return fmt.Errorf("cleanup/singleton task already running: %s", taskName)
}
}
@@ -315,7 +317,7 @@ func (cb *TaskCircuitBreaker) CanCreateTask(taskName string) error {
if time.Now().Unix()-lastFailure > int64(cb.timeout.Seconds()) {
atomic.StoreInt32(&cb.state, int32(CircuitBreakerHalfOpen))
if cb.logger != nil {
cb.logger.Info("Circuit breaker transitioning to half-open for task: %s", taskName)
cb.logger.Debug("Circuit breaker transitioning to half-open for task: %s", taskName)
}
return nil
}
@@ -377,9 +379,9 @@ func (cb *TaskCircuitBreaker) OnTaskFailure(taskName string, err error) {
// TaskRegistry maintains a registry of all active background tasks to prevent duplicates
type TaskRegistry struct {
tasks map[string]*BackgroundTask
mu sync.RWMutex
cb *TaskCircuitBreaker
logger *Logger
mu sync.RWMutex
}
// GlobalTaskRegistry is the singleton instance for managing all background tasks
@@ -467,7 +469,7 @@ func (tr *TaskRegistry) RegisterTask(name string, task *BackgroundTask) error {
tr.cb.OnTaskSuccess(name)
if tr.logger != nil {
tr.logger.Info("Registered background task: %s", name)
tr.logger.Debug("Registered background task: %s", name)
}
return nil
@@ -483,7 +485,7 @@ func (tr *TaskRegistry) UnregisterTask(name string) {
delete(tr.tasks, name)
if tr.logger != nil {
tr.logger.Info("Unregistered background task: %s", name)
tr.logger.Debug("Unregistered background task: %s", name)
}
}
}
@@ -513,7 +515,7 @@ func (tr *TaskRegistry) StopAllTasks() {
for name, task := range tasksCopy {
task.Stop()
if tr.logger != nil {
tr.logger.Info("Stopped background task during shutdown: %s", name)
tr.logger.Debug("Stopped background task during shutdown: %s", name)
}
}
}
@@ -538,7 +540,7 @@ func (tr *TaskRegistry) CreateSingletonTask(name string, interval time.Duration,
// Start the task if not already running
if !rm.IsTaskRunning(name) {
rm.StartBackgroundTask(name)
_ = rm.StartBackgroundTask(name) // Safe to ignore: task registration succeeded, start is best-effort
}
// Get the task from resource manager's internal registry
@@ -641,7 +643,7 @@ func (mm *TaskMemoryMonitor) Start(interval time.Duration) error {
mm.started = true
if mm.logger != nil && !isTestMode() {
mm.logger.Info("Started global task memory monitoring with %v interval", interval)
mm.logger.Debug("Started global task memory monitoring with %v interval", interval)
}
return nil
@@ -787,6 +789,7 @@ func (mm *TaskMemoryMonitor) ForceGC() (before, after TaskMemoryStats, err error
}
if mm.logger != nil {
// #nosec G115 -- heap allocation bytes fit in int64 for practical purposes
freed := int64(before.HeapAlloc) - int64(after.HeapAlloc)
mm.logger.Infof("Forced GC: freed %d bytes (%.2f MB)", freed, float64(freed)/(1024*1024))
}
+224
View File
@@ -0,0 +1,224 @@
package traefikoidc
import (
"errors"
"sync"
"testing"
"time"
)
// globalRegistryMutex protects only the global registry operations
var globalRegistryMutex sync.Mutex
// TestTaskCircuitBreakerOnTaskFailure tests the OnTaskFailure method
func TestTaskCircuitBreakerOnTaskFailure(t *testing.T) {
logger := NewLogger("debug") // Create a real logger
cb := NewTaskCircuitBreaker(3, time.Minute, logger)
// Test failure doesn't trigger open state before threshold
cb.OnTaskFailure("test-task", errors.New("test error"))
if err := cb.CanCreateTask("test-task"); err != nil {
t.Error("Circuit breaker should allow task creation after 1 failure (threshold: 3)")
}
// Test failure count reaches threshold and opens circuit
cb.OnTaskFailure("test-task", errors.New("test error 2"))
cb.OnTaskFailure("test-task", errors.New("test error 3"))
if err := cb.CanCreateTask("test-task"); err == nil {
t.Error("Circuit breaker should prevent task creation after reaching failure threshold")
}
}
// TestResetGlobalTaskRegistry tests the reset functionality
func TestResetGlobalTaskRegistry(t *testing.T) {
globalRegistryMutex.Lock()
defer globalRegistryMutex.Unlock()
// Get the global registry first
registry := GetGlobalTaskRegistry()
// Create and register a dummy task
logger := NewLogger("debug")
task := NewBackgroundTask("test-task", time.Second, func() {
// Do nothing
}, logger)
registry.RegisterTask("test-task", task)
// Verify task is registered
if registry.GetTaskCount() == 0 {
t.Error("Expected task to be registered")
}
// Reset the registry
ResetGlobalTaskRegistry()
// Get registry again and verify it's empty
newRegistry := GetGlobalTaskRegistry()
if newRegistry.GetTaskCount() != 0 {
t.Error("Expected registry to be empty after reset")
}
}
// TestGetTask tests the GetTask method
func TestGetTask(t *testing.T) {
globalRegistryMutex.Lock()
defer globalRegistryMutex.Unlock()
// Reset registry to ensure clean state
ResetGlobalTaskRegistry()
registry := GetGlobalTaskRegistry()
// Test getting non-existent task
task, exists := registry.GetTask("non-existent")
if task != nil || exists {
t.Error("Expected nil and false for non-existent task")
}
// Create and register a task
logger := NewLogger("debug")
newTask := NewBackgroundTask("test-task", time.Second, func() {
// Do nothing
}, logger)
registry.RegisterTask("test-task", newTask)
// Test getting existing task
retrievedTask, exists := registry.GetTask("test-task")
if retrievedTask == nil || !exists {
t.Error("Expected to retrieve registered task")
return
}
if retrievedTask.name != "test-task" {
t.Errorf("Expected task name 'test-task', got '%s'", retrievedTask.name)
}
}
// TestNewTaskMemoryMonitor tests the NewTaskMemoryMonitor function
func TestNewTaskMemoryMonitor(t *testing.T) {
// No mutex needed - this doesn't modify global state
logger := NewLogger("debug")
registry := GetGlobalTaskRegistry()
monitor := NewTaskMemoryMonitor(logger, registry)
if monitor == nil {
t.Error("Expected NewTaskMemoryMonitor to return non-nil monitor")
}
}
// TestGetCurrentStats tests the GetCurrentStats method
func TestGetCurrentStats(t *testing.T) {
// Don't hold mutex during background task execution to avoid deadlocks
logger := NewLogger("debug")
registry := GetGlobalTaskRegistry()
monitor := NewTaskMemoryMonitor(logger, registry)
// Start the monitor and let it collect at least one statistic
err := monitor.Start(50 * time.Millisecond)
if err != nil {
t.Fatalf("Failed to start monitor: %v", err)
}
// Ensure monitor is stopped even if test fails
defer func() {
monitor.Stop()
// Give extra time for cleanup
time.Sleep(50 * time.Millisecond)
}()
// Wait a bit for the monitor to collect stats
time.Sleep(150 * time.Millisecond)
stats, err := monitor.GetCurrentStats()
if err != nil {
// If no stats are available yet, that's acceptable for this test
t.Logf("No memory statistics available yet: %v", err)
return
}
// TaskMemoryStats is a struct, not a pointer, so it can't be nil
if stats.Timestamp.IsZero() {
t.Error("Expected GetCurrentStats to return valid timestamp")
}
}
// TestGetStatsHistory tests the GetStatsHistory method
func TestGetStatsHistory(t *testing.T) {
// No mutex needed - this just creates a monitor and checks its initial state
logger := NewLogger("debug")
registry := GetGlobalTaskRegistry()
monitor := NewTaskMemoryMonitor(logger, registry)
history := monitor.GetStatsHistory()
if history == nil {
t.Error("Expected GetStatsHistory to return non-nil history")
}
// A fresh monitor should have empty history
if len(history) != 0 {
t.Logf("History length: %d (may be non-empty due to shared global state)", len(history))
}
}
// TestForceGC tests the ForceGC method
func TestForceGC(t *testing.T) {
// No mutex needed - this doesn't modify global state
logger := NewLogger("debug")
registry := GetGlobalTaskRegistry()
monitor := NewTaskMemoryMonitor(logger, registry)
// This should not panic and should work
monitor.ForceGC()
// No specific verification needed, just ensuring it doesn't crash
}
// TestShutdownAllTasks tests the ShutdownAllTasks function
func TestShutdownAllTasks(t *testing.T) {
// Use a unique task name prefix to avoid conflicts with other tests
taskPrefix := "shutdown-test-"
// Create a temporary clean registry state
func() {
globalRegistryMutex.Lock()
defer globalRegistryMutex.Unlock()
ResetGlobalTaskRegistry()
}()
registry := GetGlobalTaskRegistry()
logger := NewLogger("debug")
// Create some test tasks with unique names
task1 := NewBackgroundTask(taskPrefix+"task1", time.Millisecond, func() {
time.Sleep(100 * time.Millisecond) // Simulate work
}, logger)
task2 := NewBackgroundTask(taskPrefix+"task2", time.Millisecond, func() {
time.Sleep(100 * time.Millisecond) // Simulate work
}, logger)
// Register tasks under mutex protection
func() {
globalRegistryMutex.Lock()
defer globalRegistryMutex.Unlock()
registry.RegisterTask(taskPrefix+"task1", task1)
registry.RegisterTask(taskPrefix+"task2", task2)
}()
// Start the tasks (outside mutex to avoid deadlock)
task1.Start()
task2.Start()
// Give tasks time to start
time.Sleep(50 * time.Millisecond)
// Shutdown all tasks
ShutdownAllTasks()
// Give shutdown time to complete
time.Sleep(200 * time.Millisecond)
// Note: We can't reliably verify task count due to other tests
// Just ensure shutdown doesn't panic
}
+9 -8
View File
@@ -58,12 +58,13 @@ func TestAzureOIDCRegression(t *testing.T) {
tokenURL: "https://login.microsoftonline.com/tenant-id/oauth2/v2.0/token",
jwksURL: "https://login.microsoftonline.com/tenant-id/discovery/v2.0/keys",
clientID: "test-client-id",
audience: "test-client-id",
clientSecret: "test-client-secret",
scopes: []string{"openid", "profile", "email"},
refreshGracePeriod: 60 * time.Second,
limiter: rate.NewLimiter(rate.Every(time.Second), 100), // Add rate limiter
logger: mockLogger,
httpClient: createDefaultHTTPClient(), // Add HTTP client
httpClient: CreateDefaultHTTPClient(), // Add HTTP client
jwkCache: &JWKCache{}, // Add JWK cache
tokenCache: tokenCache,
tokenBlacklist: tokenBlacklist,
@@ -78,7 +79,7 @@ func TestAzureOIDCRegression(t *testing.T) {
tOidc := &mockTraefikOidc{TraefikOidc: baseOidc}
// Initialize session manager
sessionManager, _ := NewSessionManager("test-encryption-key-32-bytes-long", false, "", mockLogger)
sessionManager, _ := NewSessionManager("test-encryption-key-32-bytes-long", false, "", "", 0, mockLogger)
tOidc.sessionManager = sessionManager
// Mock the JWT verification to avoid JWKS lookup issues
@@ -329,12 +330,12 @@ func TestValidateGoogleTokens(t *testing.T) {
ts.tOidc.refreshGracePeriod = 60 * time.Second
tests := []struct {
name string
setupSession func() *SessionData
name string
description string
expectedAuth bool
expectedRefresh bool
expectedExpired bool
description string
}{
{
name: "ValidGoogleTokens",
@@ -475,13 +476,13 @@ func TestIsUserAuthenticated(t *testing.T) {
ts.tOidc.refreshGracePeriod = 60 * time.Second
tests := []struct {
setupSession func() *SessionData
name string
providerType string
setupSession func() *SessionData
description string
expectedAuth bool
expectedRefresh bool
expectedExpired bool
description string
}{
{
name: "AzureProvider",
@@ -659,12 +660,12 @@ func TestValidateAzureTokensEdgeCases(t *testing.T) {
ts.tOidc.refreshGracePeriod = 60 * time.Second
tests := []struct {
name string
setupSession func() *SessionData
name string
description string
expectedAuth bool
expectedRefresh bool
expectedExpired bool
description string
}{
{
name: "UnauthenticatedWithRefreshToken",
+536
View File
@@ -0,0 +1,536 @@
package traefikoidc
import (
"context"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
// TestMemoryMonitorComprehensive tests memory monitor edge cases
func TestMemoryMonitorComprehensive(t *testing.T) {
t.Run("TriggerGC calls runtime GC", func(t *testing.T) {
thresholds := DefaultMemoryAlertThresholds()
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
// Should not panic
assert.NotPanics(t, func() {
monitor.TriggerGC()
})
})
t.Run("GetMemoryPressure returns pressure level", func(t *testing.T) {
thresholds := DefaultMemoryAlertThresholds()
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
// Initially should return None (no stats yet)
pressure := monitor.GetMemoryPressure()
assert.Equal(t, MemoryPressureNone, pressure)
// Collect stats to populate lastStats
monitor.GetCurrentStats()
// Now should return a valid pressure level
pressure = monitor.GetMemoryPressure()
assert.NotNil(t, pressure)
})
t.Run("StartMonitoring can be called", func(t *testing.T) {
ResetGlobalMemoryMonitor()
ResetGlobalTaskRegistry()
defer ResetGlobalMemoryMonitor()
defer ResetGlobalTaskRegistry()
thresholds := DefaultMemoryAlertThresholds()
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
// Start monitoring should not panic
assert.NotPanics(t, func() {
ctx := context.Background()
monitor.StartMonitoring(ctx, 100*time.Millisecond)
time.Sleep(GetTestDuration(50 * time.Millisecond))
})
// Clean up
monitor.StopMonitoring()
})
t.Run("StopMonitoring can be called safely", func(t *testing.T) {
ResetGlobalMemoryMonitor()
defer ResetGlobalMemoryMonitor()
thresholds := DefaultMemoryAlertThresholds()
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
// StopMonitoring should not panic even if not started
assert.NotPanics(t, func() {
monitor.StopMonitoring()
})
// Can be called multiple times safely
assert.NotPanics(t, func() {
monitor.StopMonitoring()
monitor.StopMonitoring()
})
})
t.Run("ResetGlobalMemoryMonitor resets singleton", func(t *testing.T) {
ResetGlobalMemoryMonitor()
defer ResetGlobalMemoryMonitor()
// Get initial instance
GetGlobalMemoryMonitor()
// Reset
ResetGlobalMemoryMonitor()
// Should be able to get a new instance
monitor := GetGlobalMemoryMonitor()
assert.NotNil(t, monitor)
// Clean up
monitor.StopMonitoring()
ResetGlobalMemoryMonitor()
})
t.Run("String method returns pressure name", func(t *testing.T) {
pressures := []struct {
name string
level MemoryPressureLevel
}{
{level: MemoryPressureNone, name: "None"},
{level: MemoryPressureLow, name: "Low"},
{level: MemoryPressureModerate, name: "Moderate"},
{level: MemoryPressureHigh, name: "High"},
{level: MemoryPressureCritical, name: "Critical"},
{level: MemoryPressureLevel(999), name: "Unknown"},
}
for _, p := range pressures {
assert.Equal(t, p.name, p.level.String(), "pressure level %d should return %s", p.level, p.name)
}
})
t.Run("GetCurrentStats collects statistics", func(t *testing.T) {
thresholds := DefaultMemoryAlertThresholds()
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
stats := monitor.GetCurrentStats()
assert.NotNil(t, stats)
assert.Greater(t, stats.HeapAllocBytes, uint64(0))
assert.Greater(t, stats.NumGoroutines, 0)
assert.NotZero(t, stats.Timestamp)
})
}
// TestBackgroundTaskRegistry tests background task registry edge cases
func TestBackgroundTaskRegistry(t *testing.T) {
t.Run("GetGlobalTaskRegistry returns singleton", func(t *testing.T) {
registry1 := GetGlobalTaskRegistry()
registry2 := GetGlobalTaskRegistry()
assert.Equal(t, registry1, registry2, "should return same instance")
})
t.Run("RegisterTask adds task to registry", func(t *testing.T) {
ResetGlobalTaskRegistry()
registry := GetGlobalTaskRegistry()
defer ResetGlobalTaskRegistry()
taskName := "test-register-task"
task := NewBackgroundTask(
taskName,
100*time.Millisecond,
func() {},
newNoOpLogger(),
)
err := registry.RegisterTask(taskName, task)
assert.NoError(t, err)
// Verify task was registered
_, exists := registry.GetTask(taskName)
assert.True(t, exists, "task should be registered")
// Clean up
task.Stop()
})
t.Run("CreateSingletonTask is idempotent", func(t *testing.T) {
ResetGlobalTaskRegistry()
registry := GetGlobalTaskRegistry()
defer ResetGlobalTaskRegistry()
taskName := "test-singleton-idempotent"
callCount := 0
var mu sync.Mutex
taskFunc := func() {
mu.Lock()
callCount++
mu.Unlock()
}
// First creation should succeed
task1, err1 := registry.CreateSingletonTask(
taskName,
100*time.Millisecond,
taskFunc,
newNoOpLogger(),
nil,
)
assert.NoError(t, err1)
assert.NotNil(t, task1)
// Second creation should also succeed (idempotent)
// Returns same task without error
task2, err2 := registry.CreateSingletonTask(
taskName,
100*time.Millisecond,
taskFunc,
newNoOpLogger(),
nil,
)
assert.NoError(t, err2, "CreateSingletonTask should be idempotent")
assert.NotNil(t, task2)
// Clean up
if task1 != nil {
task1.Stop()
}
})
t.Run("GetTaskCount returns active task count", func(t *testing.T) {
ResetGlobalTaskRegistry()
registry := GetGlobalTaskRegistry()
defer ResetGlobalTaskRegistry()
// Initially should be 0 or small number
initialCount := registry.GetTaskCount()
// Create a task
task := NewBackgroundTask(
"count-test-task",
100*time.Millisecond,
func() {},
newNoOpLogger(),
)
err := registry.RegisterTask("count-test-task", task)
assert.NoError(t, err)
// Count should increase
newCount := registry.GetTaskCount()
assert.Equal(t, initialCount+1, newCount)
// Clean up
task.Stop()
})
t.Run("StopAllTasks stops all tasks", func(t *testing.T) {
ResetGlobalTaskRegistry()
registry := GetGlobalTaskRegistry()
defer ResetGlobalTaskRegistry()
// Create multiple tasks
for i := 0; i < 3; i++ {
taskName := "multi-task-" + string(rune(i+'0'))
task := NewBackgroundTask(
taskName,
100*time.Millisecond,
func() {},
newNoOpLogger(),
)
registry.RegisterTask(taskName, task)
}
// Verify tasks were created
assert.GreaterOrEqual(t, registry.GetTaskCount(), 3)
// Stop all tasks
registry.StopAllTasks()
// Verify all tasks are removed
taskCount := registry.GetTaskCount()
assert.Equal(t, 0, taskCount, "all tasks should be stopped")
})
t.Run("ResetGlobalTaskRegistry clears registry", func(t *testing.T) {
ResetGlobalTaskRegistry()
registry := GetGlobalTaskRegistry()
// Create a task
task := NewBackgroundTask(
"reset-test-task",
100*time.Millisecond,
func() {},
newNoOpLogger(),
)
registry.RegisterTask("reset-test-task", task)
// Reset
ResetGlobalTaskRegistry()
// Get new registry
newRegistry := GetGlobalTaskRegistry()
assert.Equal(t, 0, newRegistry.GetTaskCount(), "new registry should be empty")
})
}
// TestBackgroundTaskLifecycle tests background task lifecycle
func TestBackgroundTaskLifecycle(t *testing.T) {
t.Run("Start begins task execution", func(t *testing.T) {
if testing.Short() {
t.Skip("Skipping background task test in short mode")
}
ResetGlobalTaskRegistry()
defer ResetGlobalTaskRegistry()
executed := false
var mu sync.Mutex
task := NewBackgroundTask(
"lifecycle-test",
50*time.Millisecond,
func() {
mu.Lock()
executed = true
mu.Unlock()
},
newNoOpLogger(),
)
// Start task
task.Start()
// Wait for execution
time.Sleep(GetTestDuration(100 * time.Millisecond))
// Stop task
task.Stop()
// Verify it executed
mu.Lock()
wasExecuted := executed
mu.Unlock()
assert.True(t, wasExecuted, "task should have executed")
})
t.Run("Stop halts task execution", func(t *testing.T) {
if testing.Short() {
t.Skip("Skipping background task test in short mode")
}
ResetGlobalTaskRegistry()
defer ResetGlobalTaskRegistry()
execCount := 0
var mu sync.Mutex
task := NewBackgroundTask(
"stop-test",
30*time.Millisecond,
func() {
mu.Lock()
execCount++
mu.Unlock()
},
newNoOpLogger(),
)
// Start task
task.Start()
// Let it run a few times
time.Sleep(GetTestDuration(100 * time.Millisecond))
// Stop task
task.Stop()
// Record count
mu.Lock()
countAfterStop := execCount
mu.Unlock()
// Wait more
time.Sleep(GetTestDuration(100 * time.Millisecond))
// Count should not increase
mu.Lock()
finalCount := execCount
mu.Unlock()
assert.Equal(t, countAfterStop, finalCount, "task should not execute after stop")
})
t.Run("Multiple Start calls are safe", func(t *testing.T) {
if testing.Short() {
t.Skip("Skipping background task test in short mode")
}
ResetGlobalTaskRegistry()
defer ResetGlobalTaskRegistry()
execCount := 0
var mu sync.Mutex
task := NewBackgroundTask(
"multi-start-test",
100*time.Millisecond,
func() {
mu.Lock()
execCount++
mu.Unlock()
},
newNoOpLogger(),
)
// Multiple starts should be safe
task.Start()
task.Start()
task.Start()
// Wait a bit
time.Sleep(GetTestDuration(50 * time.Millisecond))
// Stop task
task.Stop()
// Should have executed, but only one goroutine
mu.Lock()
count := execCount
mu.Unlock()
assert.GreaterOrEqual(t, count, 0, "task should have executed at least once")
})
t.Run("Multiple Stop calls are safe", func(t *testing.T) {
ResetGlobalTaskRegistry()
defer ResetGlobalTaskRegistry()
task := NewBackgroundTask(
"multi-stop-test",
100*time.Millisecond,
func() {},
newNoOpLogger(),
)
// Start and stop
task.Start()
time.Sleep(GetTestDuration(20 * time.Millisecond))
// Multiple stops should be safe
assert.NotPanics(t, func() {
task.Stop()
task.Stop()
task.Stop()
})
})
}
// TestMemoryMonitorIntegration tests memory monitor integration
func TestMemoryMonitorIntegration(t *testing.T) {
if testing.Short() {
t.Skip("Skipping memory monitor integration test in short mode")
}
t.Run("monitoring updates stats", func(t *testing.T) {
ResetGlobalMemoryMonitor()
ResetGlobalTaskRegistry()
defer ResetGlobalMemoryMonitor()
defer ResetGlobalTaskRegistry()
thresholds := DefaultMemoryAlertThresholds()
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
defer monitor.StopMonitoring()
// Start monitoring
ctx := context.Background()
monitor.StartMonitoring(ctx, 50*time.Millisecond)
// Wait for at least one check
time.Sleep(GetTestDuration(150 * time.Millisecond))
// Get pressure (should be a valid pressure level)
pressure := monitor.GetMemoryPressure()
assert.Contains(t, []MemoryPressureLevel{
MemoryPressureNone,
MemoryPressureLow,
MemoryPressureModerate,
MemoryPressureHigh,
MemoryPressureCritical,
}, pressure, "pressure should be a valid level")
// Stop monitoring
monitor.StopMonitoring()
})
t.Run("global memory monitor singleton", func(t *testing.T) {
ResetGlobalMemoryMonitor()
defer ResetGlobalMemoryMonitor()
monitor1 := GetGlobalMemoryMonitor()
monitor2 := GetGlobalMemoryMonitor()
assert.Equal(t, monitor1, monitor2, "should return same instance")
})
}
// TestMemoryStatsCollection tests memory statistics collection
func TestMemoryStatsCollection(t *testing.T) {
t.Run("GetCurrentStats returns valid data", func(t *testing.T) {
thresholds := DefaultMemoryAlertThresholds()
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
stats := monitor.GetCurrentStats()
assert.NotNil(t, stats)
assert.Greater(t, stats.HeapAllocBytes, uint64(0))
assert.Greater(t, stats.HeapSysBytes, uint64(0))
assert.Greater(t, stats.NumGoroutines, 0)
assert.False(t, stats.Timestamp.IsZero())
})
t.Run("Stats include memory pressure", func(t *testing.T) {
thresholds := DefaultMemoryAlertThresholds()
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
stats := monitor.GetCurrentStats()
// Should calculate and include pressure level
assert.NotNil(t, stats.MemoryPressure)
assert.Contains(t, []MemoryPressureLevel{
MemoryPressureNone,
MemoryPressureLow,
MemoryPressureModerate,
MemoryPressureHigh,
MemoryPressureCritical,
}, stats.MemoryPressure)
})
t.Run("TriggerGC reduces memory", func(t *testing.T) {
thresholds := DefaultMemoryAlertThresholds()
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
// Allocate some memory
_ = make([]byte, 1024*1024) // 1MB
// Get stats before GC
beforeStats := monitor.GetCurrentStats()
// Trigger GC
monitor.TriggerGC()
// Get stats after GC
afterStats := monitor.GetCurrentStats()
// After GC should have different stats
assert.NotEqual(t, beforeStats.LastGCTime, afterStats.LastGCTime)
})
}
+241
View File
@@ -0,0 +1,241 @@
package traefikoidc
import (
"fmt"
"sync"
"testing"
"time"
)
// =============================================================================
// UNIVERSAL CACHE BENCHMARKS
// =============================================================================
func BenchmarkCacheSet(b *testing.B) {
cache := NewUniversalCache(createTestCacheConfig())
defer cache.Close()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
cache.Set(fmt.Sprintf("key%d", i), fmt.Sprintf("value%d", i), 1*time.Hour)
i++
}
})
}
func BenchmarkCacheGet(b *testing.B) {
cache := NewUniversalCache(createTestCacheConfig())
defer cache.Close()
for i := 0; i < 1000; i++ {
cache.Set(fmt.Sprintf("key%d", i), fmt.Sprintf("value%d", i), 1*time.Hour)
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
cache.Get(fmt.Sprintf("key%d", i%1000))
i++
}
})
}
func BenchmarkCacheSetGet(b *testing.B) {
cache := NewUniversalCache(createTestCacheConfig())
defer cache.Close()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
key := fmt.Sprintf("key%d", i)
cache.Set(key, fmt.Sprintf("value%d", i), 1*time.Hour)
cache.Get(key)
i++
}
})
}
func BenchmarkCacheLRUEviction(b *testing.B) {
config := createTestCacheConfig()
config.MaxSize = 100
cache := NewUniversalCache(config)
defer cache.Close()
b.ResetTimer()
for i := 0; i < b.N; i++ {
cache.Set(fmt.Sprintf("key%d", i), fmt.Sprintf("value%d", i), 1*time.Hour)
}
}
func BenchmarkCacheConcurrent(b *testing.B) {
cache := NewUniversalCache(createTestCacheConfig())
defer cache.Close()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
switch i % 3 {
case 0:
cache.Set(fmt.Sprintf("key%d", i), fmt.Sprintf("value%d", i), 1*time.Hour)
case 1:
cache.Get(fmt.Sprintf("key%d", i))
case 2:
cache.Delete(fmt.Sprintf("key%d", i))
}
i++
}
})
}
// =============================================================================
// CACHE MANAGER BENCHMARKS
// =============================================================================
func BenchmarkCacheInterfaceWrapper_Set(b *testing.B) {
t := &testing.T{}
cm := getTestCacheManager(t)
cache := cm.GetSharedTokenBlacklist()
b.ResetTimer()
for i := 0; i < b.N; i++ {
cache.Set("benchmark-key", "benchmark-value", time.Hour)
}
}
func BenchmarkCacheInterfaceWrapper_Get(b *testing.B) {
t := &testing.T{}
cm := getTestCacheManager(t)
cache := cm.GetSharedTokenBlacklist()
cache.Set("benchmark-key", "benchmark-value", time.Hour)
b.ResetTimer()
for i := 0; i < b.N; i++ {
cache.Get("benchmark-key")
}
}
func BenchmarkCacheInterfaceWrapper_Delete(b *testing.B) {
t := &testing.T{}
cm := getTestCacheManager(t)
cache := cm.GetSharedTokenBlacklist()
b.ResetTimer()
for i := 0; i < b.N; i++ {
b.StopTimer()
key := fmt.Sprintf("benchmark-key-%d", i)
cache.Set(key, "value", time.Hour)
b.StartTimer()
cache.Delete(key)
}
}
// =============================================================================
// CACHE COMPATIBILITY BENCHMARKS
// =============================================================================
func BenchmarkNewBoundedCache(b *testing.B) {
for i := 0; i < b.N; i++ {
NewBoundedCache(1000)
}
}
func BenchmarkNewOptimizedCache(b *testing.B) {
for i := 0; i < b.N; i++ {
NewOptimizedCache()
}
}
func BenchmarkLRUStrategy_EstimateSize(b *testing.B) {
strategy := NewLRUStrategy(1000)
item := "test-item"
b.ResetTimer()
for i := 0; i < b.N; i++ {
strategy.EstimateSize(item)
}
}
// =============================================================================
// SHARDED CACHE BENCHMARKS
// =============================================================================
func BenchmarkShardedCache(b *testing.B) {
b.Run("Set", func(b *testing.B) {
cache := NewShardedCache(64, 100000)
b.ResetTimer()
for i := 0; i < b.N; i++ {
cache.Set(fmt.Sprintf("key-%d", i), i, 5*time.Minute)
}
})
b.Run("Get", func(b *testing.B) {
cache := NewShardedCache(64, 100000)
for i := 0; i < 10000; i++ {
cache.Set(fmt.Sprintf("key-%d", i), i, 5*time.Minute)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
cache.Get(fmt.Sprintf("key-%d", i%10000))
}
})
b.Run("ParallelSetGet", func(b *testing.B) {
cache := NewShardedCache(64, 100000)
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
key := fmt.Sprintf("key-%d", i)
cache.Set(key, i, 5*time.Minute)
cache.Get(key)
i++
}
})
})
}
// BenchmarkShardedVsGlobalMutex compares sharded cache with global mutex approach
func BenchmarkShardedVsGlobalMutex(b *testing.B) {
b.Run("ShardedCache64", func(b *testing.B) {
cache := NewShardedCache(64, 100000)
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
key := fmt.Sprintf("jti-%d", i%10000)
if !cache.Exists(key) {
cache.Set(key, true, 5*time.Minute)
}
i++
}
})
})
b.Run("GlobalMutexCache", func(b *testing.B) {
var mu sync.RWMutex
data := make(map[string]bool)
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
key := fmt.Sprintf("jti-%d", i%10000)
mu.RLock()
_, exists := data[key]
mu.RUnlock()
if !exists {
mu.Lock()
data[key] = true
mu.Unlock()
}
i++
}
})
})
}
+3 -3
View File
@@ -155,9 +155,9 @@ type CacheStrategy interface {
// CacheEntry for backward compatibility
type CacheEntry struct {
Key string
Value interface{}
ExpiresAt time.Time
Value interface{}
Key string
}
// Cache is an alias for backward compatibility
@@ -175,10 +175,10 @@ func NewOptimizedCacheWithConfig(config OptimizedCacheConfig) *CacheInterfaceWra
// ListNode for backward compatibility
type ListNode struct {
Key string
Value interface{}
Next *ListNode
Prev *ListNode
Key string
}
// NewFixedMetadataCache creates a metadata cache with fixed configuration
+46 -3
View File
@@ -21,10 +21,37 @@ var (
)
// GetGlobalCacheManager returns a singleton CacheManager instance
// Deprecated: Use GetGlobalCacheManagerWithConfig instead
func GetGlobalCacheManager(wg *sync.WaitGroup) *CacheManager {
return GetGlobalCacheManagerWithConfig(wg, nil)
}
// GetGlobalCacheManagerWithConfig returns a singleton CacheManager instance with optional Redis configuration
func GetGlobalCacheManagerWithConfig(wg *sync.WaitGroup, config *Config) *CacheManager {
cacheManagerInitOnce.Do(func() {
var redisConfig *RedisConfig
var logger *Logger
if config != nil {
logger = NewLogger(config.LogLevel)
// Initialize Redis config if not present
if config.Redis == nil {
config.Redis = &RedisConfig{}
}
// Apply environment variable fallbacks for fields not set in config
// This allows env vars to be used as optional overrides
config.Redis.ApplyEnvFallbacks()
// Apply defaults after env fallbacks
config.Redis.ApplyDefaults()
redisConfig = config.Redis
}
globalCacheManagerInstance = &CacheManager{
manager: GetUniversalCacheManager(nil),
manager: GetUniversalCacheManagerWithConfig(logger, redisConfig),
}
})
return globalCacheManagerInstance
@@ -61,6 +88,22 @@ func (cm *CacheManager) GetSharedJWKCache() JWKCacheInterface {
return &JWKCache{cache: cm.manager.GetJWKCache()}
}
// GetSharedIntrospectionCache returns the shared token introspection cache
// for caching OAuth 2.0 Token Introspection (RFC 7662) results
func (cm *CacheManager) GetSharedIntrospectionCache() CacheInterface {
cm.mu.RLock()
defer cm.mu.RUnlock()
return &CacheInterfaceWrapper{cache: cm.manager.GetIntrospectionCache()}
}
// GetSharedTokenTypeCache returns the shared token type cache
// for caching token type detection results to improve performance
func (cm *CacheManager) GetSharedTokenTypeCache() CacheInterface {
cm.mu.RLock()
defer cm.mu.RUnlock()
return &CacheInterfaceWrapper{cache: cm.manager.GetTokenTypeCache()}
}
// Close gracefully shuts down all cache components
func (cm *CacheManager) Close() error {
cm.mu.Lock()
@@ -83,7 +126,7 @@ type CacheInterfaceWrapper struct {
// Set stores a value
func (c *CacheInterfaceWrapper) Set(key string, value interface{}, ttl time.Duration) {
c.cache.Set(key, value, ttl)
_ = c.cache.Set(key, value, ttl) // Safe to ignore: cache set failures are non-critical
}
// Get retrieves a value
@@ -110,7 +153,7 @@ func (c *CacheInterfaceWrapper) Cleanup() {
func (c *CacheInterfaceWrapper) Close() {
// Close the underlying cache to stop goroutines
if c.cache != nil {
c.cache.Close()
_ = c.cache.Close() // Safe to ignore: closing cache is best-effort during shutdown
}
}
File diff suppressed because it is too large Load Diff
-319
View File
@@ -1,319 +0,0 @@
// Package circuit_breaker provides circuit breaker implementation for resilience
package circuit_breaker
import (
"context"
"fmt"
"sync"
"sync/atomic"
"time"
)
// CircuitBreakerState represents the current state of a circuit breaker.
// The circuit breaker pattern prevents cascading failures by monitoring
// error rates and temporarily blocking requests to failing services.
type CircuitBreakerState int
// Circuit breaker states following the standard pattern:
// Closed: Normal operation, requests flow through
// Open: Circuit is tripped, requests are blocked
// HalfOpen: Testing state, limited requests allowed to test recovery
const (
// CircuitBreakerClosed allows all requests through (normal operation)
CircuitBreakerClosed CircuitBreakerState = iota
// CircuitBreakerOpen blocks all requests (service is failing)
CircuitBreakerOpen
// CircuitBreakerHalfOpen allows limited requests to test service recovery
CircuitBreakerHalfOpen
)
// String returns a string representation of the circuit breaker state
func (s CircuitBreakerState) String() string {
switch s {
case CircuitBreakerClosed:
return "closed"
case CircuitBreakerOpen:
return "open"
case CircuitBreakerHalfOpen:
return "half-open"
default:
return "unknown"
}
}
// Logger interface for dependency injection
type Logger interface {
Infof(format string, args ...interface{})
Errorf(format string, args ...interface{})
Debugf(format string, args ...interface{})
}
// BaseRecoveryMechanism interface for common functionality
type BaseRecoveryMechanism interface {
RecordRequest()
RecordSuccess()
RecordFailure()
GetBaseMetrics() map[string]interface{}
LogInfo(format string, args ...interface{})
LogError(format string, args ...interface{})
LogDebug(format string, args ...interface{})
}
// CircuitBreaker implements the circuit breaker pattern for external service calls.
// It monitors failure rates and automatically opens the circuit when failures
// exceed the threshold, preventing further requests until the service recovers.
type CircuitBreaker struct {
// baseRecovery provides common functionality
baseRecovery BaseRecoveryMechanism
// maxFailures is the threshold for opening the circuit
maxFailures int
// timeout is how long to wait before allowing requests in half-open state
timeout time.Duration
// resetTimeout is how long to wait before transitioning from open to half-open
resetTimeout time.Duration
// state tracks the current circuit breaker state
state CircuitBreakerState
// failures counts consecutive failures
failures int64
// lastFailureTime records when the last failure occurred
lastFailureTime time.Time
// mutex protects shared state
mutex sync.RWMutex
// logger for debugging and monitoring
logger Logger
}
// CircuitBreakerConfig holds configuration parameters for circuit breakers.
// These settings control when the circuit opens and how it recovers.
type CircuitBreakerConfig struct {
// MaxFailures is the number of failures before opening the circuit
MaxFailures int `json:"max_failures"`
// Timeout is how long to wait before trying to recover (open -> half-open)
Timeout time.Duration `json:"timeout"`
// ResetTimeout is how long to wait before fully closing the circuit
ResetTimeout time.Duration `json:"reset_timeout"`
}
// DefaultCircuitBreakerConfig returns sensible default configuration for circuit breakers.
// Configured for typical web service scenarios with moderate tolerance for failures.
func DefaultCircuitBreakerConfig() CircuitBreakerConfig {
return CircuitBreakerConfig{
MaxFailures: 2,
Timeout: 60 * time.Second,
ResetTimeout: 30 * time.Second,
}
}
// NewCircuitBreaker creates a new circuit breaker with the specified configuration.
// The circuit breaker starts in the closed state, allowing all requests through.
func NewCircuitBreaker(config CircuitBreakerConfig, logger Logger, baseRecovery BaseRecoveryMechanism) *CircuitBreaker {
return &CircuitBreaker{
baseRecovery: baseRecovery,
maxFailures: config.MaxFailures,
timeout: config.Timeout,
resetTimeout: config.ResetTimeout,
state: CircuitBreakerClosed,
logger: logger,
}
}
// ExecuteWithContext executes a function through the circuit breaker with context.
// It checks if requests are allowed, executes the function, and updates the circuit state
// based on the result. Implements the ErrorRecoveryMechanism interface.
func (cb *CircuitBreaker) ExecuteWithContext(ctx context.Context, fn func() error) error {
if cb.baseRecovery != nil {
cb.baseRecovery.RecordRequest()
}
if !cb.allowRequest() {
return fmt.Errorf("circuit breaker is open")
}
err := fn()
if err != nil {
cb.recordFailure()
if cb.baseRecovery != nil {
cb.baseRecovery.RecordFailure()
}
return err
}
cb.recordSuccess()
if cb.baseRecovery != nil {
cb.baseRecovery.RecordSuccess()
}
return nil
}
// Execute executes a function through the circuit breaker without context.
// This is provided for backward compatibility with existing code.
func (cb *CircuitBreaker) Execute(fn func() error) error {
return cb.ExecuteWithContext(context.Background(), fn)
}
// allowRequest determines whether to allow a request based on the circuit state.
// Handles state transitions from open to half-open based on timeout.
func (cb *CircuitBreaker) allowRequest() bool {
cb.mutex.Lock()
defer cb.mutex.Unlock()
now := time.Now()
switch cb.state {
case CircuitBreakerClosed:
return true
case CircuitBreakerOpen:
if now.Sub(cb.lastFailureTime) > cb.timeout {
cb.state = CircuitBreakerHalfOpen
if cb.logger != nil {
cb.logger.Infof("Circuit breaker transitioning to half-open state")
}
return true
}
return false
case CircuitBreakerHalfOpen:
return true
default:
return false
}
}
// recordFailure records a failure and potentially opens the circuit.
// Updates failure count and triggers state transitions when thresholds are exceeded.
func (cb *CircuitBreaker) recordFailure() {
cb.mutex.Lock()
defer cb.mutex.Unlock()
cb.failures++
cb.lastFailureTime = time.Now()
switch cb.state {
case CircuitBreakerClosed:
if cb.failures >= int64(cb.maxFailures) {
cb.state = CircuitBreakerOpen
if cb.baseRecovery != nil {
cb.baseRecovery.LogError("Circuit breaker opened after %d failures", cb.failures)
}
}
case CircuitBreakerHalfOpen:
cb.state = CircuitBreakerOpen
if cb.baseRecovery != nil {
cb.baseRecovery.LogError("Circuit breaker returned to open state after failure in half-open")
}
}
}
// recordSuccess records a successful request and potentially closes the circuit.
// Resets failure count and transitions from half-open to closed state on success.
func (cb *CircuitBreaker) recordSuccess() {
cb.mutex.Lock()
defer cb.mutex.Unlock()
switch cb.state {
case CircuitBreakerHalfOpen:
cb.failures = 0
cb.state = CircuitBreakerClosed
if cb.baseRecovery != nil {
cb.baseRecovery.LogInfo("Circuit breaker closed after successful request in half-open state")
}
case CircuitBreakerClosed:
cb.failures = 0
}
}
// GetState returns the current state of the circuit breaker.
// Thread-safe method for monitoring circuit breaker status.
func (cb *CircuitBreaker) GetState() CircuitBreakerState {
cb.mutex.RLock()
defer cb.mutex.RUnlock()
return cb.state
}
// Reset resets the circuit breaker to its initial closed state.
// Clears failure count and state, effectively recovering from any open state.
func (cb *CircuitBreaker) Reset() {
cb.mutex.Lock()
defer cb.mutex.Unlock()
cb.state = CircuitBreakerClosed
atomic.StoreInt64(&cb.failures, 0)
if cb.baseRecovery != nil {
cb.baseRecovery.LogInfo("Circuit breaker has been reset")
}
}
// IsAvailable returns whether the circuit breaker is currently allowing requests.
// This provides a quick way to check if the service is available.
func (cb *CircuitBreaker) IsAvailable() bool {
return cb.allowRequest()
}
// GetMetrics returns comprehensive metrics about the circuit breaker.
// Includes state information, failure counts, configuration, and base metrics.
func (cb *CircuitBreaker) GetMetrics() map[string]interface{} {
cb.mutex.RLock()
state := cb.state
failures := cb.failures
lastFailureTime := cb.lastFailureTime
cb.mutex.RUnlock()
var metrics map[string]interface{}
if cb.baseRecovery != nil {
metrics = cb.baseRecovery.GetBaseMetrics()
} else {
metrics = make(map[string]interface{})
}
metrics["state"] = state.String()
metrics["current_failures"] = failures
metrics["max_failures"] = cb.maxFailures
metrics["timeout"] = cb.timeout.String()
metrics["reset_timeout"] = cb.resetTimeout.String()
if !lastFailureTime.IsZero() {
metrics["last_failure_time"] = lastFailureTime
metrics["time_since_last_failure"] = time.Since(lastFailureTime).String()
}
return metrics
}
// GetFailureCount returns the current failure count
func (cb *CircuitBreaker) GetFailureCount() int64 {
cb.mutex.RLock()
defer cb.mutex.RUnlock()
return cb.failures
}
// GetLastFailureTime returns the time of the last failure
func (cb *CircuitBreaker) GetLastFailureTime() time.Time {
cb.mutex.RLock()
defer cb.mutex.RUnlock()
return cb.lastFailureTime
}
// IsOpen returns true if the circuit breaker is in open state
func (cb *CircuitBreaker) IsOpen() bool {
cb.mutex.RLock()
defer cb.mutex.RUnlock()
return cb.state == CircuitBreakerOpen
}
// IsClosed returns true if the circuit breaker is in closed state
func (cb *CircuitBreaker) IsClosed() bool {
cb.mutex.RLock()
defer cb.mutex.RUnlock()
return cb.state == CircuitBreakerClosed
}
// IsHalfOpen returns true if the circuit breaker is in half-open state
func (cb *CircuitBreaker) IsHalfOpen() bool {
cb.mutex.RLock()
defer cb.mutex.RUnlock()
return cb.state == CircuitBreakerHalfOpen
}
-981
View File
@@ -1,981 +0,0 @@
package circuit_breaker
import (
"context"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
)
// Mock implementations for testing
type mockLogger struct {
infoLogs []string
errorLogs []string
debugLogs []string
mu sync.RWMutex
}
func (m *mockLogger) Infof(format string, args ...interface{}) {
m.mu.Lock()
defer m.mu.Unlock()
m.infoLogs = append(m.infoLogs, fmt.Sprintf(format, args...))
}
func (m *mockLogger) Errorf(format string, args ...interface{}) {
m.mu.Lock()
defer m.mu.Unlock()
m.errorLogs = append(m.errorLogs, fmt.Sprintf(format, args...))
}
func (m *mockLogger) Debugf(format string, args ...interface{}) {
m.mu.Lock()
defer m.mu.Unlock()
m.debugLogs = append(m.debugLogs, fmt.Sprintf(format, args...))
}
func (m *mockLogger) getInfoLogs() []string {
m.mu.RLock()
defer m.mu.RUnlock()
result := make([]string, len(m.infoLogs))
copy(result, m.infoLogs)
return result
}
//lint:ignore U1000 May be needed for future error log verification tests
func (m *mockLogger) getErrorLogs() []string {
m.mu.RLock()
defer m.mu.RUnlock()
result := make([]string, len(m.errorLogs))
copy(result, m.errorLogs)
return result
}
//lint:ignore U1000 May be needed for future test isolation
func (m *mockLogger) reset() {
m.mu.Lock()
defer m.mu.Unlock()
m.infoLogs = nil
m.errorLogs = nil
m.debugLogs = nil
}
type mockBaseRecoveryMechanism struct {
requestCount int64
successCount int64
failureCount int64
infoLogs []string
errorLogs []string
debugLogs []string
baseMetrics map[string]interface{}
mu sync.RWMutex
}
func newMockBaseRecovery() *mockBaseRecoveryMechanism {
return &mockBaseRecoveryMechanism{
baseMetrics: make(map[string]interface{}),
}
}
func (m *mockBaseRecoveryMechanism) RecordRequest() {
atomic.AddInt64(&m.requestCount, 1)
}
func (m *mockBaseRecoveryMechanism) RecordSuccess() {
atomic.AddInt64(&m.successCount, 1)
}
func (m *mockBaseRecoveryMechanism) RecordFailure() {
atomic.AddInt64(&m.failureCount, 1)
}
func (m *mockBaseRecoveryMechanism) GetBaseMetrics() map[string]interface{} {
m.mu.RLock()
defer m.mu.RUnlock()
result := make(map[string]interface{})
for k, v := range m.baseMetrics {
result[k] = v
}
result["total_requests"] = atomic.LoadInt64(&m.requestCount)
result["total_successes"] = atomic.LoadInt64(&m.successCount)
result["total_failures"] = atomic.LoadInt64(&m.failureCount)
return result
}
func (m *mockBaseRecoveryMechanism) LogInfo(format string, args ...interface{}) {
m.mu.Lock()
defer m.mu.Unlock()
m.infoLogs = append(m.infoLogs, fmt.Sprintf(format, args...))
}
func (m *mockBaseRecoveryMechanism) LogError(format string, args ...interface{}) {
m.mu.Lock()
defer m.mu.Unlock()
m.errorLogs = append(m.errorLogs, fmt.Sprintf(format, args...))
}
func (m *mockBaseRecoveryMechanism) LogDebug(format string, args ...interface{}) {
m.mu.Lock()
defer m.mu.Unlock()
m.debugLogs = append(m.debugLogs, fmt.Sprintf(format, args...))
}
func (m *mockBaseRecoveryMechanism) getRequestCount() int64 {
return atomic.LoadInt64(&m.requestCount)
}
func (m *mockBaseRecoveryMechanism) getSuccessCount() int64 {
return atomic.LoadInt64(&m.successCount)
}
func (m *mockBaseRecoveryMechanism) getFailureCount() int64 {
return atomic.LoadInt64(&m.failureCount)
}
func (m *mockBaseRecoveryMechanism) getInfoLogs() []string {
m.mu.RLock()
defer m.mu.RUnlock()
result := make([]string, len(m.infoLogs))
copy(result, m.infoLogs)
return result
}
func (m *mockBaseRecoveryMechanism) getErrorLogs() []string {
m.mu.RLock()
defer m.mu.RUnlock()
result := make([]string, len(m.errorLogs))
copy(result, m.errorLogs)
return result
}
func TestCircuitBreakerState_String(t *testing.T) {
tests := []struct {
state CircuitBreakerState
expected string
}{
{CircuitBreakerClosed, "closed"},
{CircuitBreakerOpen, "open"},
{CircuitBreakerHalfOpen, "half-open"},
{CircuitBreakerState(999), "unknown"},
}
for _, tt := range tests {
t.Run(tt.expected, func(t *testing.T) {
result := tt.state.String()
if result != tt.expected {
t.Errorf("Expected %s, got %s", tt.expected, result)
}
})
}
}
func TestDefaultCircuitBreakerConfig(t *testing.T) {
config := DefaultCircuitBreakerConfig()
if config.MaxFailures != 2 {
t.Errorf("Expected MaxFailures to be 2, got %d", config.MaxFailures)
}
if config.Timeout != 60*time.Second {
t.Errorf("Expected Timeout to be 60s, got %v", config.Timeout)
}
if config.ResetTimeout != 30*time.Second {
t.Errorf("Expected ResetTimeout to be 30s, got %v", config.ResetTimeout)
}
}
func TestNewCircuitBreaker(t *testing.T) {
config := CircuitBreakerConfig{
MaxFailures: 3,
Timeout: 30 * time.Second,
ResetTimeout: 15 * time.Second,
}
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
if cb == nil {
t.Fatal("NewCircuitBreaker returned nil")
}
if cb.maxFailures != 3 {
t.Errorf("Expected maxFailures to be 3, got %d", cb.maxFailures)
}
if cb.timeout != 30*time.Second {
t.Errorf("Expected timeout to be 30s, got %v", cb.timeout)
}
if cb.resetTimeout != 15*time.Second {
t.Errorf("Expected resetTimeout to be 15s, got %v", cb.resetTimeout)
}
if cb.state != CircuitBreakerClosed {
t.Errorf("Expected initial state to be Closed, got %v", cb.state)
}
if cb.logger != logger {
t.Error("Expected logger to be set")
}
if cb.baseRecovery != baseRecovery {
t.Error("Expected baseRecovery to be set")
}
}
func TestCircuitBreaker_ExecuteWithContext_Success(t *testing.T) {
config := CircuitBreakerConfig{
MaxFailures: 2,
Timeout: time.Second,
ResetTimeout: time.Second,
}
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
callCount := 0
testFunc := func() error {
callCount++
return nil
}
ctx := context.Background()
err := cb.ExecuteWithContext(ctx, testFunc)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if callCount != 1 {
t.Errorf("Expected function to be called once, got %d", callCount)
}
if cb.GetState() != CircuitBreakerClosed {
t.Errorf("Expected state to remain Closed, got %v", cb.GetState())
}
if baseRecovery.getRequestCount() != 1 {
t.Errorf("Expected 1 request recorded, got %d", baseRecovery.getRequestCount())
}
if baseRecovery.getSuccessCount() != 1 {
t.Errorf("Expected 1 success recorded, got %d", baseRecovery.getSuccessCount())
}
}
func TestCircuitBreaker_ExecuteWithContext_Failure(t *testing.T) {
config := CircuitBreakerConfig{
MaxFailures: 2,
Timeout: time.Second,
ResetTimeout: time.Second,
}
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
testError := fmt.Errorf("test error")
testFunc := func() error {
return testError
}
ctx := context.Background()
err := cb.ExecuteWithContext(ctx, testFunc)
if err != testError {
t.Errorf("Expected test error, got %v", err)
}
if cb.GetState() != CircuitBreakerClosed {
t.Errorf("Expected state to remain Closed after single failure, got %v", cb.GetState())
}
if baseRecovery.getRequestCount() != 1 {
t.Errorf("Expected 1 request recorded, got %d", baseRecovery.getRequestCount())
}
if baseRecovery.getFailureCount() != 1 {
t.Errorf("Expected 1 failure recorded, got %d", baseRecovery.getFailureCount())
}
}
func TestCircuitBreaker_Execute(t *testing.T) {
config := CircuitBreakerConfig{
MaxFailures: 1,
Timeout: time.Second,
ResetTimeout: time.Second,
}
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
callCount := 0
testFunc := func() error {
callCount++
return nil
}
err := cb.Execute(testFunc)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if callCount != 1 {
t.Errorf("Expected function to be called once, got %d", callCount)
}
}
func TestCircuitBreaker_OpenAfterMaxFailures(t *testing.T) {
config := CircuitBreakerConfig{
MaxFailures: 2,
Timeout: time.Second,
ResetTimeout: time.Second,
}
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
testError := fmt.Errorf("test error")
testFunc := func() error {
return testError
}
ctx := context.Background()
// First failure
err := cb.ExecuteWithContext(ctx, testFunc)
if err != testError {
t.Errorf("Expected test error on first failure, got %v", err)
}
if cb.GetState() != CircuitBreakerClosed {
t.Errorf("Expected state to remain Closed after first failure, got %v", cb.GetState())
}
// Second failure - should open circuit
err = cb.ExecuteWithContext(ctx, testFunc)
if err != testError {
t.Errorf("Expected test error on second failure, got %v", err)
}
if cb.GetState() != CircuitBreakerOpen {
t.Errorf("Expected state to be Open after max failures, got %v", cb.GetState())
}
// Third attempt - should be blocked
callCount := 0
blockedFunc := func() error {
callCount++
return nil
}
err = cb.ExecuteWithContext(ctx, blockedFunc)
if err == nil {
t.Error("Expected error when circuit is open")
}
if callCount != 0 {
t.Errorf("Expected function not to be called when circuit is open, got %d calls", callCount)
}
}
func TestCircuitBreaker_HalfOpenTransition(t *testing.T) {
config := CircuitBreakerConfig{
MaxFailures: 1,
Timeout: 10 * time.Millisecond, // Very short for testing
ResetTimeout: time.Second,
}
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
// Trigger circuit opening
testError := fmt.Errorf("test error")
err := cb.ExecuteWithContext(context.Background(), func() error {
return testError
})
if err != testError {
t.Errorf("Expected test error, got %v", err)
}
if cb.GetState() != CircuitBreakerOpen {
t.Errorf("Expected state to be Open, got %v", cb.GetState())
}
// Wait for timeout
time.Sleep(15 * time.Millisecond)
// Next request should transition to half-open
callCount := 0
testFunc := func() error {
callCount++
return nil
}
err = cb.ExecuteWithContext(context.Background(), testFunc)
if err != nil {
t.Errorf("Expected no error in half-open state, got %v", err)
}
if callCount != 1 {
t.Errorf("Expected function to be called in half-open state, got %d calls", callCount)
}
if cb.GetState() != CircuitBreakerClosed {
t.Errorf("Expected state to be Closed after successful half-open request, got %v", cb.GetState())
}
}
func TestCircuitBreaker_HalfOpenFailureReturnsToOpen(t *testing.T) {
config := CircuitBreakerConfig{
MaxFailures: 1,
Timeout: 10 * time.Millisecond,
ResetTimeout: time.Second,
}
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
// Trigger circuit opening
testError := fmt.Errorf("test error")
_ = cb.ExecuteWithContext(context.Background(), func() error {
return testError
})
if cb.GetState() != CircuitBreakerOpen {
t.Errorf("Expected state to be Open, got %v", cb.GetState())
}
// Wait for timeout to allow half-open transition
time.Sleep(15 * time.Millisecond)
// First call should transition to half-open, but we'll force it by checking allowRequest
if !cb.allowRequest() {
t.Error("Expected allowRequest to return true after timeout")
}
if cb.GetState() != CircuitBreakerHalfOpen {
t.Errorf("Expected state to be HalfOpen, got %v", cb.GetState())
}
// Failure in half-open should return to open
err := cb.ExecuteWithContext(context.Background(), func() error {
return testError
})
if err != testError {
t.Errorf("Expected test error, got %v", err)
}
if cb.GetState() != CircuitBreakerOpen {
t.Errorf("Expected state to return to Open after half-open failure, got %v", cb.GetState())
}
}
func TestCircuitBreaker_Reset(t *testing.T) {
config := CircuitBreakerConfig{
MaxFailures: 1,
Timeout: time.Second,
ResetTimeout: time.Second,
}
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
// Trigger circuit opening
testError := fmt.Errorf("test error")
_ = cb.ExecuteWithContext(context.Background(), func() error {
return testError
})
if cb.GetState() != CircuitBreakerOpen {
t.Errorf("Expected state to be Open, got %v", cb.GetState())
}
// Reset circuit
cb.Reset()
if cb.GetState() != CircuitBreakerClosed {
t.Errorf("Expected state to be Closed after reset, got %v", cb.GetState())
}
if cb.GetFailureCount() != 0 {
t.Errorf("Expected failure count to be 0 after reset, got %d", cb.GetFailureCount())
}
// Should allow requests again
callCount := 0
err := cb.ExecuteWithContext(context.Background(), func() error {
callCount++
return nil
})
if err != nil {
t.Errorf("Expected no error after reset, got %v", err)
}
if callCount != 1 {
t.Errorf("Expected function to be called after reset, got %d calls", callCount)
}
}
func TestCircuitBreaker_IsAvailable(t *testing.T) {
config := CircuitBreakerConfig{
MaxFailures: 1,
Timeout: 10 * time.Millisecond,
ResetTimeout: time.Second,
}
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
// Initially available
if !cb.IsAvailable() {
t.Error("Expected circuit breaker to be available initially")
}
// Trigger opening
testError := fmt.Errorf("test error")
cb.ExecuteWithContext(context.Background(), func() error {
return testError
})
// Should not be available when open
if cb.IsAvailable() {
t.Error("Expected circuit breaker to be unavailable when open")
}
// Wait for timeout
time.Sleep(15 * time.Millisecond)
// Should be available again after timeout (half-open)
if !cb.IsAvailable() {
t.Error("Expected circuit breaker to be available after timeout")
}
}
func TestCircuitBreaker_StateCheckers(t *testing.T) {
config := CircuitBreakerConfig{
MaxFailures: 1,
Timeout: 10 * time.Millisecond,
ResetTimeout: time.Second,
}
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
// Initially closed
if !cb.IsClosed() {
t.Error("Expected circuit breaker to be closed initially")
}
if cb.IsOpen() {
t.Error("Expected circuit breaker not to be open initially")
}
if cb.IsHalfOpen() {
t.Error("Expected circuit breaker not to be half-open initially")
}
// Trigger opening
testError := fmt.Errorf("test error")
cb.ExecuteWithContext(context.Background(), func() error {
return testError
})
// Should be open
if cb.IsClosed() {
t.Error("Expected circuit breaker not to be closed when open")
}
if !cb.IsOpen() {
t.Error("Expected circuit breaker to be open")
}
if cb.IsHalfOpen() {
t.Error("Expected circuit breaker not to be half-open when open")
}
// Wait for timeout and trigger half-open
time.Sleep(15 * time.Millisecond)
cb.allowRequest() // This will transition to half-open
// Should be half-open
if cb.IsClosed() {
t.Error("Expected circuit breaker not to be closed when half-open")
}
if cb.IsOpen() {
t.Error("Expected circuit breaker not to be open when half-open")
}
if !cb.IsHalfOpen() {
t.Error("Expected circuit breaker to be half-open")
}
}
func TestCircuitBreaker_GetMetrics(t *testing.T) {
config := CircuitBreakerConfig{
MaxFailures: 2,
Timeout: 30 * time.Second,
ResetTimeout: 15 * time.Second,
}
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
baseRecovery.baseMetrics["custom_metric"] = "custom_value"
cb := NewCircuitBreaker(config, logger, baseRecovery)
// Record some activity
testError := fmt.Errorf("test error")
cb.ExecuteWithContext(context.Background(), func() error {
return testError
})
metrics := cb.GetMetrics()
// Check circuit breaker specific metrics
if metrics["state"] != "closed" {
t.Errorf("Expected state to be 'closed', got %v", metrics["state"])
}
if metrics["current_failures"] != int64(1) {
t.Errorf("Expected current_failures to be 1, got %v", metrics["current_failures"])
}
if metrics["max_failures"] != 2 {
t.Errorf("Expected max_failures to be 2, got %v", metrics["max_failures"])
}
if metrics["timeout"] != "30s" {
t.Errorf("Expected timeout to be '30s', got %v", metrics["timeout"])
}
if metrics["reset_timeout"] != "15s" {
t.Errorf("Expected reset_timeout to be '15s', got %v", metrics["reset_timeout"])
}
// Check base metrics are included
if metrics["total_requests"] != int64(1) {
t.Errorf("Expected total_requests to be 1, got %v", metrics["total_requests"])
}
if metrics["custom_metric"] != "custom_value" {
t.Errorf("Expected custom_metric to be 'custom_value', got %v", metrics["custom_metric"])
}
// Check failure time metrics
if _, exists := metrics["last_failure_time"]; !exists {
t.Error("Expected last_failure_time to exist")
}
if _, exists := metrics["time_since_last_failure"]; !exists {
t.Error("Expected time_since_last_failure to exist")
}
}
func TestCircuitBreaker_GetMetrics_NoBaseRecovery(t *testing.T) {
config := DefaultCircuitBreakerConfig()
logger := &mockLogger{}
cb := NewCircuitBreaker(config, logger, nil)
metrics := cb.GetMetrics()
// Should still have circuit breaker metrics
if metrics["state"] != "closed" {
t.Errorf("Expected state to be 'closed', got %v", metrics["state"])
}
if metrics["max_failures"] != 2 {
t.Errorf("Expected max_failures to be 2, got %v", metrics["max_failures"])
}
// Should not have base metrics
if _, exists := metrics["total_requests"]; exists {
t.Error("Expected total_requests not to exist without base recovery")
}
}
func TestCircuitBreaker_GetLastFailureTime(t *testing.T) {
config := DefaultCircuitBreakerConfig()
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
// Initially should be zero
if !cb.GetLastFailureTime().IsZero() {
t.Error("Expected last failure time to be zero initially")
}
// Record a failure
before := time.Now()
testError := fmt.Errorf("test error")
cb.ExecuteWithContext(context.Background(), func() error {
return testError
})
after := time.Now()
lastFailure := cb.GetLastFailureTime()
if lastFailure.IsZero() {
t.Error("Expected last failure time to be set after failure")
}
if lastFailure.Before(before) || lastFailure.After(after) {
t.Errorf("Expected last failure time to be between %v and %v, got %v",
before, after, lastFailure)
}
}
func TestCircuitBreaker_ExecuteWithoutBaseRecovery(t *testing.T) {
config := DefaultCircuitBreakerConfig()
logger := &mockLogger{}
cb := NewCircuitBreaker(config, logger, nil)
callCount := 0
testFunc := func() error {
callCount++
return nil
}
err := cb.ExecuteWithContext(context.Background(), testFunc)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if callCount != 1 {
t.Errorf("Expected function to be called once, got %d", callCount)
}
// Should work fine without base recovery
if cb.GetState() != CircuitBreakerClosed {
t.Errorf("Expected state to be Closed, got %v", cb.GetState())
}
}
func TestCircuitBreaker_ConcurrentAccess(t *testing.T) {
config := CircuitBreakerConfig{
MaxFailures: 10, // Higher threshold for concurrent test
Timeout: 100 * time.Millisecond,
ResetTimeout: 50 * time.Millisecond,
}
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
const numGoroutines = 10
const numOperations = 50
var wg sync.WaitGroup
successCount := int64(0)
errorCount := int64(0)
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
for j := 0; j < numOperations; j++ {
err := cb.ExecuteWithContext(context.Background(), func() error {
// Simulate some failures
if j%10 == 9 { // Every 10th operation fails
return fmt.Errorf("simulated error")
}
return nil
})
if err != nil {
atomic.AddInt64(&errorCount, 1)
} else {
atomic.AddInt64(&successCount, 1)
}
// Intermittently check state and metrics
if j%5 == 0 {
cb.GetState()
cb.GetMetrics()
cb.IsAvailable()
}
}
}(i)
}
wg.Wait()
// Verify we got both successes and errors
finalSuccessCount := atomic.LoadInt64(&successCount)
finalErrorCount := atomic.LoadInt64(&errorCount)
if finalSuccessCount == 0 {
t.Error("Expected some successful operations")
}
if finalErrorCount == 0 {
t.Error("Expected some failed operations")
}
totalOperations := finalSuccessCount + finalErrorCount
expectedMax := int64(numGoroutines * numOperations)
if totalOperations > expectedMax {
t.Errorf("Expected at most %d operations, got %d", expectedMax, totalOperations)
}
t.Logf("Concurrent test completed: %d successes, %d errors, final state: %v",
finalSuccessCount, finalErrorCount, cb.GetState())
}
func TestCircuitBreaker_StateTransitionLogging(t *testing.T) {
config := CircuitBreakerConfig{
MaxFailures: 1,
Timeout: 10 * time.Millisecond,
ResetTimeout: time.Second,
}
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
// Trigger circuit opening
testError := fmt.Errorf("test error")
cb.ExecuteWithContext(context.Background(), func() error {
return testError
})
// Check that error was logged when circuit opened
errorLogs := baseRecovery.getErrorLogs()
if len(errorLogs) == 0 {
t.Error("Expected error log when circuit breaker opened")
} else {
if !contains(errorLogs, "Circuit breaker opened after") {
t.Errorf("Expected circuit opening log, got %v", errorLogs)
}
}
// Wait and trigger half-open
time.Sleep(15 * time.Millisecond)
// Successful request should close circuit and log
cb.ExecuteWithContext(context.Background(), func() error {
return nil
})
// Check that success was logged when circuit closed
infoLogs := baseRecovery.getInfoLogs()
if len(infoLogs) == 0 {
t.Error("Expected info log when circuit breaker closed")
} else {
if !contains(infoLogs, "Circuit breaker closed after successful request") {
t.Errorf("Expected circuit closing log, got %v", infoLogs)
}
}
// Reset should also be logged
cb.Reset()
infoLogs = baseRecovery.getInfoLogs()
if !contains(infoLogs, "Circuit breaker has been reset") {
t.Errorf("Expected reset log, got %v", infoLogs)
}
}
func TestCircuitBreaker_LoggerTransitionLogging(t *testing.T) {
config := CircuitBreakerConfig{
MaxFailures: 1,
Timeout: 10 * time.Millisecond,
ResetTimeout: time.Second,
}
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
// Wait for timeout and check half-open transition logging
testError := fmt.Errorf("test error")
cb.ExecuteWithContext(context.Background(), func() error {
return testError
})
// Wait for timeout
time.Sleep(15 * time.Millisecond)
// Next allowRequest call should log transition to half-open
cb.allowRequest()
infoLogs := logger.getInfoLogs()
if len(infoLogs) == 0 {
t.Error("Expected info log for half-open transition")
} else {
if !contains(infoLogs, "Circuit breaker transitioning to half-open state") {
t.Errorf("Expected half-open transition log, got %v", infoLogs)
}
}
}
// Helper function to check if a slice contains a string with substring
func contains(slice []string, substr string) bool {
for _, s := range slice {
if len(s) >= len(substr) && s[:len(substr)] == substr {
return true
}
}
return false
}
// Benchmark tests
func BenchmarkCircuitBreaker_ExecuteWithContext_Success(b *testing.B) {
config := DefaultCircuitBreakerConfig()
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
testFunc := func() error {
return nil
}
ctx := context.Background()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
cb.ExecuteWithContext(ctx, testFunc)
}
})
}
func BenchmarkCircuitBreaker_ExecuteWithContext_Failure(b *testing.B) {
config := CircuitBreakerConfig{
MaxFailures: 1000, // High threshold to avoid opening during benchmark
Timeout: time.Second,
ResetTimeout: time.Second,
}
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
testError := fmt.Errorf("test error")
testFunc := func() error {
return testError
}
ctx := context.Background()
b.ResetTimer()
for i := 0; i < b.N; i++ {
cb.ExecuteWithContext(ctx, testFunc)
}
}
func BenchmarkCircuitBreaker_GetState(b *testing.B) {
config := DefaultCircuitBreakerConfig()
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
cb.GetState()
}
})
}
func BenchmarkCircuitBreaker_GetMetrics(b *testing.B) {
config := DefaultCircuitBreakerConfig()
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
// Add some activity
for i := 0; i < 100; i++ {
if i%2 == 0 {
cb.ExecuteWithContext(context.Background(), func() error { return nil })
} else {
cb.ExecuteWithContext(context.Background(), func() error { return fmt.Errorf("error") })
}
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
cb.GetMetrics()
}
}
File diff suppressed because it is too large Load Diff
-211
View File
@@ -1,211 +0,0 @@
// Package config provides configuration management for the OIDC middleware
package config
import (
"context"
"fmt"
"net/http"
"strings"
"sync"
"time"
)
const (
minEncryptionKeyLength = 16
ConstSessionTimeout = 86400
)
//lint:ignore U1000 May be referenced for default exclusion patterns
var defaultExcludedURLs = map[string]struct{}{
"/favicon.ico": {},
"/robots.txt": {},
"/health": {},
"/.well-known/": {},
"/metrics": {},
"/ping": {},
"/api/": {},
"/static/": {},
"/assets/": {},
"/js/": {},
"/css/": {},
"/images/": {},
"/fonts/": {},
}
// Settings manages configuration and initialization for the OIDC middleware
type Settings struct {
logger Logger
}
// Logger interface for dependency injection
type Logger interface {
Debug(msg string)
Debugf(format string, args ...interface{})
Info(msg string)
Infof(format string, args ...interface{})
Error(msg string)
Errorf(format string, args ...interface{})
}
// Config represents the configuration for the OIDC middleware
type Config struct {
ProviderURL string `json:"providerUrl"`
ClientID string `json:"clientId"`
ClientSecret string `json:"clientSecret"`
CallbackURL string `json:"callbackUrl"`
LogoutURL string `json:"logoutUrl"`
PostLogoutRedirectURI string `json:"postLogoutRedirectUri"`
SessionEncryptionKey string `json:"sessionEncryptionKey"`
ForceHTTPS bool `json:"forceHttps"`
LogLevel string `json:"logLevel"`
Scopes []string `json:"scopes"`
OverrideScopes bool `json:"overrideScopes"`
AllowedUsers []string `json:"allowedUsers"`
AllowedUserDomains []string `json:"allowedUserDomains"`
AllowedRolesAndGroups []string `json:"allowedRolesAndGroups"`
ExcludedURLs []string `json:"excludedUrls"`
EnablePKCE bool `json:"enablePkce"`
RateLimit int `json:"rateLimit"`
RefreshGracePeriodSeconds int `json:"refreshGracePeriodSeconds"`
Headers []HeaderConfig `json:"headers"`
HTTPClient *http.Client `json:"-"`
CookieDomain string `json:"cookieDomain"`
}
// HeaderConfig represents header template configuration
type HeaderConfig struct {
Name string `json:"name"`
Value string `json:"value"`
}
// NewSettings creates a new Settings instance
func NewSettings(logger Logger) *Settings {
return &Settings{
logger: logger,
}
}
// CreateConfig creates a default configuration
func CreateConfig() *Config {
return &Config{
LogLevel: "INFO",
ForceHTTPS: true,
EnablePKCE: true,
RateLimit: 10,
RefreshGracePeriodSeconds: 60,
Scopes: []string{"openid", "profile", "email"},
Headers: []HeaderConfig{},
}
}
// InitializeTraefikOidc would initialize and configure a new TraefikOidc instance
// This functionality has been moved to the main New function in main.go
// This function is kept for compatibility but should not be used
func (s *Settings) InitializeTraefikOidc(ctx context.Context, next http.Handler, config *Config, name string) (interface{}, error) {
return nil, fmt.Errorf("InitializeTraefikOidc is deprecated - use New function from main package instead")
}
//lint:ignore U1000 Kept for backward compatibility
func (s *Settings) setupHeaderTemplates(t interface{}, config *Config, logger Logger) error {
logger.Debug("setupHeaderTemplates is deprecated")
return nil
}
//lint:ignore U1000 May be needed for future background service management
func (s *Settings) startBackgroundServices(ctx context.Context, logger Logger) {
startReplayCacheCleanup(ctx, logger)
// Start memory monitoring for leak detection and performance insights
memoryMonitor := GetGlobalMemoryMonitor()
memoryMonitor.StartMonitoring(ctx, 60*time.Second) // Monitor every minute
logger.Debug("Started global memory monitoring")
}
// Utility functions
//lint:ignore U1000 May be needed for future scope processing
func deduplicateScopes(scopes []string) []string {
seen := make(map[string]bool)
result := []string{}
for _, scope := range scopes {
if !seen[scope] {
seen[scope] = true
result = append(result, scope)
}
}
return result
}
//lint:ignore U1000 May be needed for future scope merging operations
func mergeScopes(defaultScopes, userScopes []string) []string {
result := make([]string, len(defaultScopes))
copy(result, defaultScopes)
return append(result, userScopes...)
}
//lint:ignore U1000 May be needed for future utility operations
func createStringMap(items []string) map[string]struct{} {
result := make(map[string]struct{})
for _, item := range items {
result[item] = struct{}{}
}
return result
}
//lint:ignore U1000 May be needed for future case-insensitive operations
func createCaseInsensitiveStringMap(items []string) map[string]struct{} {
result := make(map[string]struct{})
for _, item := range items {
result[strings.ToLower(item)] = struct{}{}
}
return result
}
//lint:ignore U1000 May be needed for future test environment detection
func isTestMode() bool {
// This function should be implemented based on environment detection logic
return false
}
// External dependencies that need to be provided
// TraefikOidc struct is defined in types.go
// These functions need to be provided by external packages
func NewLogger(level string) Logger { return nil }
func CreateDefaultHTTPClient() *http.Client { return nil }
func CreateTokenHTTPClient() *http.Client { return nil }
func GetGlobalCacheManager(*sync.WaitGroup) CacheManager { return nil }
func NewSessionManager(string, bool, string, Logger) (SessionManager, error) { return nil, nil }
func NewErrorRecoveryManager(Logger) ErrorRecoveryManager { return nil }
//lint:ignore U1000 May be needed for future token claim extraction
func extractClaims(string) (map[string]interface{}, error) { return nil, nil }
//lint:ignore U1000 May be needed for future replay attack prevention
func startReplayCacheCleanup(context.Context, Logger) {}
func GetGlobalMemoryMonitor() MemoryMonitor { return nil }
// Interfaces for external dependencies
type CacheManager interface {
GetSharedTokenBlacklist() CacheInterface
GetSharedTokenCache() *TokenCache
GetSharedMetadataCache() *MetadataCache
GetSharedJWKCache() JWKCacheInterface
Close() error
}
type SessionManager interface{}
type ErrorRecoveryManager interface{}
type MemoryMonitor interface {
StartMonitoring(ctx context.Context, interval time.Duration)
}
type CacheInterface interface {
Set(key string, value interface{}, ttl time.Duration)
Get(key string) (interface{}, bool)
Delete(key string)
SetMaxSize(size int)
Cleanup()
Close()
}
type TokenCache struct{}
type MetadataCache struct{}
type JWKCacheInterface interface{}
+116
View File
@@ -0,0 +1,116 @@
package traefikoidc
import (
"encoding/json"
)
// REDACTED is the placeholder value for sensitive information
const REDACTED = "[REDACTED]"
// MarshalJSON implements custom JSON marshalling to redact sensitive fields
// Rewritten without type aliases for yaegi compatibility
func (c Config) MarshalJSON() ([]byte, error) {
// Build a map manually to avoid type alias issues with yaegi
result := make(map[string]interface{})
// Copy public fields
result["providerURL"] = c.ProviderURL
result["clientID"] = c.ClientID
result["callbackURL"] = c.CallbackURL
result["logoutURL"] = c.LogoutURL
result["postLogoutRedirectURI"] = c.PostLogoutRedirectURI
result["scopes"] = c.Scopes
result["forceHTTPS"] = c.ForceHTTPS
result["logLevel"] = c.LogLevel
result["rateLimit"] = c.RateLimit
result["excludedURLs"] = c.ExcludedURLs
result["allowedUserDomains"] = c.AllowedUserDomains
result["allowedUsers"] = c.AllowedUsers
result["allowedRolesAndGroups"] = c.AllowedRolesAndGroups
// Redact sensitive fields
result["clientSecret"] = REDACTED
result["sessionEncryptionKey"] = REDACTED
// Handle Redis config
if c.Redis != nil {
redisMap := make(map[string]interface{})
redisMap["enabled"] = c.Redis.Enabled
redisMap["address"] = c.Redis.Address
redisMap["password"] = REDACTED
redisMap["db"] = c.Redis.DB
redisMap["poolSize"] = c.Redis.PoolSize
redisMap["cacheMode"] = c.Redis.CacheMode
result["redis"] = redisMap
}
return json.Marshal(result)
}
// MarshalYAML implements custom YAML marshalling to redact sensitive fields
// Rewritten without type aliases for yaegi compatibility
func (c Config) MarshalYAML() (interface{}, error) {
// Build a map manually to avoid type alias issues with yaegi
result := make(map[string]interface{})
// Copy public fields
result["providerURL"] = c.ProviderURL
result["clientID"] = c.ClientID
result["callbackURL"] = c.CallbackURL
result["logoutURL"] = c.LogoutURL
result["postLogoutRedirectURI"] = c.PostLogoutRedirectURI
result["scopes"] = c.Scopes
result["forceHTTPS"] = c.ForceHTTPS
result["logLevel"] = c.LogLevel
result["rateLimit"] = c.RateLimit
result["excludedURLs"] = c.ExcludedURLs
result["allowedUserDomains"] = c.AllowedUserDomains
result["allowedUsers"] = c.AllowedUsers
result["allowedRolesAndGroups"] = c.AllowedRolesAndGroups
// Redact sensitive fields
result["clientSecret"] = REDACTED
result["sessionEncryptionKey"] = REDACTED
// Handle Redis config
if c.Redis != nil {
redisMap := make(map[string]interface{})
redisMap["enabled"] = c.Redis.Enabled
redisMap["address"] = c.Redis.Address
redisMap["password"] = REDACTED
redisMap["db"] = c.Redis.DB
redisMap["poolSize"] = c.Redis.PoolSize
redisMap["cacheMode"] = c.Redis.CacheMode
result["redis"] = redisMap
}
return result, nil
}
// MarshalJSON for RedisConfig to redact sensitive fields
// Rewritten without type aliases for yaegi compatibility
func (r RedisConfig) MarshalJSON() ([]byte, error) {
result := make(map[string]interface{})
result["enabled"] = r.Enabled
result["address"] = r.Address
result["password"] = REDACTED
result["db"] = r.DB
result["poolSize"] = r.PoolSize
result["cacheMode"] = r.CacheMode
return json.Marshal(result)
}
// MarshalYAML for RedisConfig to redact sensitive fields
// Rewritten without type aliases for yaegi compatibility
func (r RedisConfig) MarshalYAML() (interface{}, error) {
result := make(map[string]interface{})
result["enabled"] = r.Enabled
result["address"] = r.Address
result["password"] = REDACTED
result["db"] = r.DB
result["poolSize"] = r.PoolSize
result["cacheMode"] = r.CacheMode
return result, nil
}
File diff suppressed because it is too large Load Diff
+8 -8
View File
@@ -18,7 +18,7 @@ func TestCSRFTokenSessionManagement(t *testing.T) {
// Test that CSRF tokens persist through the authentication flow
t.Run("CSRF_Token_Persists_After_Selective_Clear", func(t *testing.T) {
// Create a session manager
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", NewLogger("debug"))
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", "", 0, NewLogger("debug"))
require.NoError(t, err)
// Create initial request
@@ -90,7 +90,7 @@ func TestCSRFTokenSessionManagement(t *testing.T) {
// Test that marking session as dirty forces save
t.Run("Mark_Dirty_Forces_Session_Save", func(t *testing.T) {
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", NewLogger("debug"))
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", "", 0, NewLogger("debug"))
require.NoError(t, err)
req := httptest.NewRequest("GET", "http://example.com/test", nil)
@@ -126,7 +126,7 @@ func TestCSRFTokenSessionManagement(t *testing.T) {
// Test Azure-specific session handling
t.Run("Azure_Session_Cookie_Configuration", func(t *testing.T) {
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", NewLogger("debug"))
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", "", 0, NewLogger("debug"))
require.NoError(t, err)
// Simulate Azure callback scenario
@@ -158,7 +158,7 @@ func TestCSRFTokenSessionManagement(t *testing.T) {
// Test session continuity through auth flow
t.Run("Session_Continuity_Through_Auth_Flow", func(t *testing.T) {
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", NewLogger("debug"))
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", "", 0, NewLogger("debug"))
require.NoError(t, err)
// Step 1: Initial request
@@ -199,7 +199,7 @@ func TestCSRFTokenSessionManagement(t *testing.T) {
// Test large token handling doesn't affect CSRF
t.Run("Large_Tokens_Dont_Affect_CSRF", func(t *testing.T) {
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", NewLogger("debug"))
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", "", 0, NewLogger("debug"))
require.NoError(t, err)
req := httptest.NewRequest("GET", "http://example.com/test", nil)
@@ -262,7 +262,7 @@ func TestAuthFlowWithoutExternalDependencies(t *testing.T) {
// We can't fully initialize TraefikOidc without network access,
// but we can test the session management directly
sessionManager, err := NewSessionManager(plugin.SessionEncryptionKey, plugin.ForceHTTPS, "", NewLogger(plugin.LogLevel))
sessionManager, err := NewSessionManager(plugin.SessionEncryptionKey, plugin.ForceHTTPS, "", "", 0, NewLogger(plugin.LogLevel))
require.NoError(t, err)
t.Run("Session_Created_On_Protected_Request", func(t *testing.T) {
@@ -291,7 +291,7 @@ func TestAuthFlowWithoutExternalDependencies(t *testing.T) {
// TestRegressionLoginLoop specifically tests the fix for issue #53
func TestRegressionLoginLoop(t *testing.T) {
// This test verifies that the specific changes made to fix the login loop work correctly
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", NewLogger("debug"))
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", "", 0, NewLogger("debug"))
require.NoError(t, err)
// Simulate the exact flow that was causing the login loop
@@ -392,7 +392,7 @@ func TestRegressionLoginLoop(t *testing.T) {
// TestCSRFValidationTiming tests timing-sensitive CSRF validation scenarios
func TestCSRFValidationTiming(t *testing.T) {
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", NewLogger("debug"))
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", "", 0, NewLogger("debug"))
require.NoError(t, err)
t.Run("Rapid_Redirect_Maintains_CSRF", func(t *testing.T) {
+364
View File
@@ -0,0 +1,364 @@
//go:build !yaegi
package traefikoidc
import (
"testing"
)
// TestCustomClaimNames_DefaultBehavior tests backward compatibility with default claim names
func TestCustomClaimNames_DefaultBehavior(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
// Explicitly set defaults to test backward compatibility
ts.tOidc.roleClaimName = "roles"
ts.tOidc.groupClaimName = "groups"
// Test that when no custom claim names are configured, it uses defaults "roles" and "groups"
claims := map[string]interface{}{
"groups": []interface{}{"admin", "users"},
"roles": []interface{}{"editor", "viewer"},
}
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
if err != nil {
t.Fatalf("Failed to create test token: %v", err)
}
groups, roles, err := ts.tOidc.extractGroupsAndRoles(token)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if !stringSliceEqual(groups, []string{"admin", "users"}) {
t.Errorf("Expected groups [admin users], got %v", groups)
}
if !stringSliceEqual(roles, []string{"editor", "viewer"}) {
t.Errorf("Expected roles [editor viewer], got %v", roles)
}
}
// TestCustomClaimNames_Auth0Namespaced tests Auth0-style namespaced claims
func TestCustomClaimNames_Auth0Namespaced(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
// Configure custom claim names for Auth0
ts.tOidc.roleClaimName = "https://myapp.com/roles"
ts.tOidc.groupClaimName = "https://myapp.com/groups"
// Create token with Auth0-style namespaced claims
claims := map[string]interface{}{
"https://myapp.com/groups": []interface{}{"admin", "users"},
"https://myapp.com/roles": []interface{}{"editor", "viewer"},
}
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
if err != nil {
t.Fatalf("Failed to create test token: %v", err)
}
groups, roles, err := ts.tOidc.extractGroupsAndRoles(token)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if !stringSliceEqual(groups, []string{"admin", "users"}) {
t.Errorf("Expected groups [admin users], got %v", groups)
}
if !stringSliceEqual(roles, []string{"editor", "viewer"}) {
t.Errorf("Expected roles [editor viewer], got %v", roles)
}
}
// TestCustomClaimNames_CustomSimpleNames tests custom simple claim names
func TestCustomClaimNames_CustomSimpleNames(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
// Configure custom simple claim names
ts.tOidc.roleClaimName = "user_roles"
ts.tOidc.groupClaimName = "user_groups"
// Create token with custom claim names
claims := map[string]interface{}{
"user_groups": []interface{}{"engineering", "product"},
"user_roles": []interface{}{"developer", "manager"},
}
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
if err != nil {
t.Fatalf("Failed to create test token: %v", err)
}
groups, roles, err := ts.tOidc.extractGroupsAndRoles(token)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if !stringSliceEqual(groups, []string{"engineering", "product"}) {
t.Errorf("Expected groups [engineering product], got %v", groups)
}
if !stringSliceEqual(roles, []string{"developer", "manager"}) {
t.Errorf("Expected roles [developer manager], got %v", roles)
}
}
// TestCustomClaimNames_MissingClaims tests behavior when custom claims are missing
func TestCustomClaimNames_MissingClaims(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
// Configure custom claim names
ts.tOidc.roleClaimName = "custom_roles"
ts.tOidc.groupClaimName = "custom_groups"
// Create token WITHOUT the custom claims
claims := map[string]interface{}{
"sub": "user123",
"email": "user@example.com",
}
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
if err != nil {
t.Fatalf("Failed to create test token: %v", err)
}
groups, roles, err := ts.tOidc.extractGroupsAndRoles(token)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
// Should return empty slices, not error
if len(groups) != 0 {
t.Errorf("Expected empty groups, got %v", groups)
}
if len(roles) != 0 {
t.Errorf("Expected empty roles, got %v", roles)
}
}
// TestCustomClaimNames_MalformedClaims tests error handling for malformed claims
func TestCustomClaimNames_MalformedRoleClaim(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
// Configure custom claim names
ts.tOidc.roleClaimName = "custom_roles"
// Create token with malformed role claim (not an array)
claims := map[string]interface{}{
"custom_roles": "this-should-be-an-array",
}
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
if err != nil {
t.Fatalf("Failed to create test token: %v", err)
}
_, _, err = ts.tOidc.extractGroupsAndRoles(token)
if err == nil {
t.Error("Expected error for malformed role claim, got nil")
}
// Check error message contains the custom claim name
expectedError := "custom_roles claim is not an array"
if err.Error() != expectedError {
t.Errorf("Expected error '%s', got '%s'", expectedError, err.Error())
}
}
// TestCustomClaimNames_MalformedGroupClaim tests error handling for malformed group claims
func TestCustomClaimNames_MalformedGroupClaim(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
// Configure custom claim names
ts.tOidc.groupClaimName = "custom_groups"
// Create token with malformed group claim (not an array)
claims := map[string]interface{}{
"custom_groups": 12345, // Not an array
}
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
if err != nil {
t.Fatalf("Failed to create test token: %v", err)
}
_, _, err = ts.tOidc.extractGroupsAndRoles(token)
if err == nil {
t.Error("Expected error for malformed group claim, got nil")
}
// Check error message contains the custom claim name
expectedError := "custom_groups claim is not an array"
if err.Error() != expectedError {
t.Errorf("Expected error '%s', got '%s'", expectedError, err.Error())
}
}
// TestCustomClaimNames_PartialConfiguration tests when only one claim name is customized
func TestCustomClaimNames_OnlyRoleCustomized(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
// Configure only role claim name (group uses default)
ts.tOidc.roleClaimName = "https://myapp.com/roles"
ts.tOidc.groupClaimName = "groups" // default
// Create token with mixed claim names
claims := map[string]interface{}{
"groups": []interface{}{"admin"},
"https://myapp.com/roles": []interface{}{"editor"},
}
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
if err != nil {
t.Fatalf("Failed to create test token: %v", err)
}
groups, roles, err := ts.tOidc.extractGroupsAndRoles(token)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if !stringSliceEqual(groups, []string{"admin"}) {
t.Errorf("Expected groups [admin], got %v", groups)
}
if !stringSliceEqual(roles, []string{"editor"}) {
t.Errorf("Expected roles [editor], got %v", roles)
}
}
// TestCustomClaimNames_OnlyGroupCustomized tests when only group claim name is customized
func TestCustomClaimNames_OnlyGroupCustomized(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
// Configure only group claim name (role uses default)
ts.tOidc.roleClaimName = "roles" // default
ts.tOidc.groupClaimName = "https://myapp.com/groups"
// Create token with mixed claim names
claims := map[string]interface{}{
"roles": []interface{}{"viewer"},
"https://myapp.com/groups": []interface{}{"users"},
}
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
if err != nil {
t.Fatalf("Failed to create test token: %v", err)
}
groups, roles, err := ts.tOidc.extractGroupsAndRoles(token)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if !stringSliceEqual(groups, []string{"users"}) {
t.Errorf("Expected groups [users], got %v", groups)
}
if !stringSliceEqual(roles, []string{"viewer"}) {
t.Errorf("Expected roles [viewer], got %v", roles)
}
}
// TestCustomClaimNames_EmptyArrays tests extraction with empty claim arrays
func TestCustomClaimNames_EmptyArrays(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
// Configure custom claim names
ts.tOidc.roleClaimName = "https://myapp.com/roles"
ts.tOidc.groupClaimName = "https://myapp.com/groups"
// Create token with empty arrays
claims := map[string]interface{}{
"https://myapp.com/groups": []interface{}{},
"https://myapp.com/roles": []interface{}{},
}
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
if err != nil {
t.Fatalf("Failed to create test token: %v", err)
}
groups, roles, err := ts.tOidc.extractGroupsAndRoles(token)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if len(groups) != 0 {
t.Errorf("Expected empty groups, got %v", groups)
}
if len(roles) != 0 {
t.Errorf("Expected empty roles, got %v", roles)
}
}
// TestCustomClaimNames_NonStringElements tests handling of non-string elements in claim arrays
func TestCustomClaimNames_NonStringInRoleArray(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
// Configure custom claim names
ts.tOidc.roleClaimName = "custom_roles"
// Create token with mixed-type array (should skip non-string elements)
claims := map[string]interface{}{
"custom_roles": []interface{}{"role1", 12345, "role2", true},
}
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
if err != nil {
t.Fatalf("Failed to create test token: %v", err)
}
_, roles, err := ts.tOidc.extractGroupsAndRoles(token)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
// Should only extract string elements
if !stringSliceEqual(roles, []string{"role1", "role2"}) {
t.Errorf("Expected roles [role1 role2], got %v", roles)
}
}
// TestCustomClaimNames_NonStringInGroupArray tests handling of non-string elements in group arrays
func TestCustomClaimNames_NonStringInGroupArray(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
// Configure custom claim names
ts.tOidc.groupClaimName = "custom_groups"
// Create token with mixed-type array (should skip non-string elements)
claims := map[string]interface{}{
"custom_groups": []interface{}{"group1", nil, "group2", 3.14},
}
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
if err != nil {
t.Fatalf("Failed to create test token: %v", err)
}
groups, _, err := ts.tOidc.extractGroupsAndRoles(token)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
// Should only extract string elements
if !stringSliceEqual(groups, []string{"group1", "group2"}) {
t.Errorf("Expected groups [group1 group2], got %v", groups)
}
}
+424
View File
@@ -0,0 +1,424 @@
# Auth0 Audience Validation Guide
## Overview
This guide explains how to configure audience validation for Auth0 and other OIDC providers that support custom API audiences. It covers three common Auth0 scenarios and how to configure the middleware for maximum security.
## Table of Contents
1. [Understanding Audiences](#understanding-audiences)
2. [The Three Auth0 Scenarios](#the-three-auth0-scenarios)
3. [Configuration Options](#configuration-options)
4. [Security Recommendations](#security-recommendations)
5. [Troubleshooting](#troubleshooting)
---
## Understanding Audiences
### What is an Audience?
The **audience** (`aud`) claim in a JWT identifies the intended recipient of the token. Per OAuth 2.0 and OIDC specifications:
- **ID Tokens**: MUST have `aud = client_id` (per OIDC Core 1.0 spec)
- **Access Tokens**: Can have custom audiences (e.g., API identifiers)
### Why Does This Matter?
Proper audience validation prevents **token confusion attacks** where a token intended for one API is used to access another API.
---
## The Three Auth0 Scenarios
### Scenario 1: Custom API Audience ✅ **RECOMMENDED**
**Configuration:**
```yaml
audience: "https://my-api.example.com" # Your API identifier from Auth0
```
**What Happens:**
1. Authorization request includes `audience` parameter
2. Auth0 issues:
- **ID Token**: `aud = client_id`
- **Access Token**: `aud = ["https://issuer/userinfo", "https://my-api.example.com"]`
3. Middleware validates:
- ID tokens against `client_id`
- Access tokens against custom audience
**Result:** ✅ Fully secure, OIDC compliant
---
### Scenario 2: Default Audience (No Custom API) ⚠️ **USE WITH CAUTION**
**Configuration:**
```yaml
# audience not specified (defaults to client_id)
```
**What Happens:**
1. Authorization request WITHOUT `audience` parameter
2. Auth0 issues:
- **ID Token**: `aud = client_id`
- **Access Token**: `aud = ["https://issuer/userinfo", "default_api"]` (no `client_id`)
3. Access token validation fails (audience mismatch)
4. Middleware falls back to ID token validation
**Security Warning:**
```
⚠️⚠️⚠️ SECURITY WARNING: Falling back to ID token validation despite access token audience mismatch!
⚠️ This could allow tokens intended for different APIs to grant access
⚠️ Set strictAudienceValidation=true to enforce proper audience validation
⚠️ See: https://github.com/lukaszraczylo/traefikoidc/issues/74
```
**Recommended Fix:**
```yaml
strictAudienceValidation: true # Reject sessions with audience mismatch
```
**Result:**
- Default: ⚠️ Works but logs security warnings
- With strict mode: ✅ Secure (rejects mismatched tokens)
---
### Scenario 3: Opaque Access Tokens ✅ **SUPPORTED**
**Configuration:**
```yaml
allowOpaqueTokens: true # Enable opaque token support
requireTokenIntrospection: true # Require introspection (recommended)
```
**What Happens:**
1. Auth0 issues opaque (non-JWT) access token
2. Middleware detects opaque token (not 3 parts separated by dots)
3. Uses OAuth 2.0 Token Introspection (RFC 7662) to validate
4. Falls back to ID token if introspection unavailable (unless `requireTokenIntrospection=true`)
**Requirements:**
- Provider must support `introspection_endpoint` in OIDC discovery
- Client must have introspection permissions
**Result:** ✅ Secure with introspection, ⚠️ risky without
---
## Configuration Options
### Audience Settings
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `audience` | string | `client_id` | Expected audience for access tokens |
**Example:**
```yaml
# .traefik.yml
http:
middlewares:
oidc-auth:
plugin:
traefikoidc:
audience: "https://my-api.example.com"
```
---
### Security Mode Settings
#### `strictAudienceValidation`
**Type:** boolean
**Default:** `false`
**Recommended:** `true` for production
**What it does:**
- When `true`: Rejects sessions if access token audience doesn't match (prevents Scenario 2)
- When `false`: Logs warnings but allows fallback to ID token (backward compatible)
**Example:**
```yaml
strictAudienceValidation: true
```
**When to use:**
- ✅ Always use in production environments
- ✅ When you have custom API audiences configured in Auth0
- ⚠️ May break existing deployments relying on Scenario 2 behavior
---
#### `allowOpaqueTokens`
**Type:** boolean
**Default:** `false`
**What it does:**
- When `true`: Accepts opaque (non-JWT) access tokens
- When `false`: Only accepts JWT access tokens
**Example:**
```yaml
allowOpaqueTokens: true
```
**When to use:**
- ✅ When Auth0 issues opaque tokens (no default API configured)
- ✅ When using Auth0 Management API tokens
- ⚠️ Requires introspection endpoint for security
---
#### `requireTokenIntrospection`
**Type:** boolean
**Default:** `false`
**Recommended:** `true` when `allowOpaqueTokens=true`
**What it does:**
- When `true`: Rejects opaque tokens if introspection fails or endpoint unavailable
- When `false`: Falls back to ID token validation for opaque tokens
**Example:**
```yaml
allowOpaqueTokens: true
requireTokenIntrospection: true
```
**When to use:**
- ✅ Always use when `allowOpaqueTokens=true` for maximum security
- ⚠️ Requires provider to expose introspection endpoint
---
## Security Recommendations
### Recommended Configuration for Auth0
**For APIs with custom audiences (Scenario 1):**
```yaml
audience: "https://my-api.example.com"
strictAudienceValidation: true
allowOpaqueTokens: false
```
**For default Auth0 setup (Scenario 2):**
```yaml
# Don't set audience (defaults to client_id)
strictAudienceValidation: true # Enforce proper configuration
```
**For opaque tokens (Scenario 3):**
```yaml
allowOpaqueTokens: true
requireTokenIntrospection: true
strictAudienceValidation: true
```
### Security Best Practices
1.**Always set `strictAudienceValidation: true` in production**
2.**Configure custom API audiences in Auth0 dashboard**
3.**Use `requireTokenIntrospection: true` if accepting opaque tokens**
4.**Monitor logs for security warnings**
5.**Don't rely on Scenario 2 fallback behavior**
---
## Troubleshooting
### "Access token validation failed due to audience mismatch"
**Symptom:**
```
⚠️ SCENARIO 2 DETECTED: Access token validation failed due to audience mismatch
```
**Cause:** Access token audience doesn't match configured audience
**Solutions:**
1. **Configure correct audience:**
```yaml
audience: "https://your-api-identifier" # From Auth0 API settings
```
2. **Update Auth0 authorization request:**
- Ensure `audience` parameter is included in authorize URL
- Middleware automatically adds this when `audience != client_id`
3. **Accept the behavior (not recommended):**
```yaml
strictAudienceValidation: false # Logs warnings but allows
```
---
### "Opaque token detected but allowOpaqueTokens=false"
**Symptom:**
```
⚠️ Opaque access token detected but allowOpaqueTokens=false
```
**Cause:** Auth0 issued non-JWT access token but middleware not configured to accept them
**Solutions:**
1. **Enable opaque tokens:**
```yaml
allowOpaqueTokens: true
requireTokenIntrospection: true
```
2. **Configure Auth0 to issue JWT access tokens:**
- Create an API in Auth0 dashboard
- Set API identifier as `audience` in configuration
---
### "Introspection endpoint not available"
**Symptom:**
```
⚠️ Opaque tokens enabled but no introspection endpoint available from provider
```
**Cause:** Auth0 provider metadata doesn't include `introspection_endpoint`
**Solutions:**
1. **Check provider discovery:**
```bash
curl https://YOUR_DOMAIN/.well-known/openid-configuration
```
Look for `introspection_endpoint`
2. **Disable required introspection (less secure):**
```yaml
allowOpaqueTokens: true
requireTokenIntrospection: false # Falls back to ID token
```
3. **Use JWT access tokens instead** (recommended)
---
### "Token introspection required but endpoint not available"
**Symptom:**
```
❌ SECURITY: Opaque token rejected (introspection required but failed)
```
**Cause:** `requireTokenIntrospection=true` but provider doesn't support it
**Solutions:**
1. **Disable required introspection:**
```yaml
requireTokenIntrospection: false
```
2. **Configure Auth0 to issue JWT tokens** (better solution)
---
## Advanced Topics
### Token Type Detection
The middleware uses a sophisticated 6-step detection algorithm:
1. **RFC 9068 `typ` header**: `at+jwt` → Access Token
2. **Explicit type claims**: `token_use`, `token_type`
3. **`scope` claim**: Present → Access Token
4. **`nonce` claim**: Present → ID Token (OIDC spec)
5. **Audience check**: `aud == client_id` only → ID Token
6. **Default**: Access Token
### OAuth 2.0 Token Introspection (RFC 7662)
When opaque tokens are detected:
1. Middleware calls provider's `introspection_endpoint`
2. Authenticates using client credentials
3. Receives response with `active` status and claims
4. Caches result for 5 minutes (configurable via TTL)
5. Validates expiration, not-before, and audience if present
**Cache behavior:**
- Cache key: Token hash
- TTL: 5 minutes or token expiry (whichever is shorter)
- Reduces introspection requests for frequently used tokens
---
## Reference Links
- [GitHub Issue #74](https://github.com/lukaszraczylo/traefikoidc/issues/74) - Original Auth0 audience discussion
- [OIDC Core 1.0 Spec](https://openid.net/specs/openid-connect-core-1_0.html) - ID Token requirements
- [OAuth 2.0 RFC 6749](https://datatracker.ietf.org/doc/html/rfc6749) - OAuth 2.0 specification
- [RFC 7662](https://datatracker.ietf.org/doc/html/rfc7662) - OAuth 2.0 Token Introspection
- [RFC 9068](https://datatracker.ietf.org/doc/html/rfc9068) - JWT Access Token Profile
- [Auth0 API Authorization](https://auth0.com/docs/secure/tokens/access-tokens) - Auth0 audience documentation
---
## Migration Guide
### From Previous Versions
**If you're upgrading from a version without these features:**
1. **No action required for default behavior** - backward compatible
2. **Recommended: Enable strict mode gradually**
```yaml
# Step 1: Enable and monitor logs
strictAudienceValidation: false # Default
# Step 2: After confirming no warnings, enable
strictAudienceValidation: true
```
3. **For opaque tokens: Enable explicitly**
```yaml
allowOpaqueTokens: true
requireTokenIntrospection: true
```
### Testing Your Configuration
1. **Check logs for warnings:**
```bash
# Look for Scenario 2 warnings
grep "SCENARIO 2 DETECTED" /var/log/traefik.log
# Look for opaque token warnings
grep "Opaque" /var/log/traefik.log
```
2. **Test with curl:**
```bash
# Get token from Auth0
ACCESS_TOKEN="your_access_token"
# Test request
curl -H "Authorization: Bearer $ACCESS_TOKEN" \
https://your-app.example.com/api
```
3. **Monitor for security warnings in production logs**
---
## Support
For issues or questions:
- GitHub Issues: https://github.com/lukaszraczylo/traefikoidc/issues
- Security issues: See SECURITY.md for responsible disclosure
---
**Last Updated:** 2025-01-09
**Version:** 0.7.8+
+1
View File
@@ -0,0 +1 @@
traefikoidc.raczylo.com
+456
View File
@@ -0,0 +1,456 @@
# Configuration Reference
Complete reference for all Traefik OIDC middleware configuration options.
## Table of Contents
- [Required Parameters](#required-parameters)
- [Optional Parameters](#optional-parameters)
- [Security Options](#security-options)
- [Session Management](#session-management)
- [Access Control](#access-control)
- [Headers Configuration](#headers-configuration)
- [Security Headers](#security-headers)
- [Scope Configuration](#scope-configuration)
- [Advanced Options](#advanced-options)
---
## Required Parameters
| Parameter | Type | Description | Example |
|-----------|------|-------------|---------|
| `providerURL` | string | Base URL of the OIDC provider | `https://accounts.google.com` |
| `clientID` | string | OAuth 2.0 client identifier | `1234567890.apps.googleusercontent.com` |
| `clientSecret` | string | OAuth 2.0 client secret | `your-client-secret` |
| `sessionEncryptionKey` | string | Key for encrypting session data (min 32 bytes) | `your-32-byte-encryption-key-here` |
| `callbackURL` | string | Path where provider redirects after authentication | `/oauth2/callback` |
### Basic Configuration Example
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-auth
spec:
plugin:
traefikoidc:
providerURL: https://accounts.google.com
clientID: your-client-id.apps.googleusercontent.com
clientSecret: your-client-secret
sessionEncryptionKey: your-32-byte-encryption-key-here
callbackURL: /oauth2/callback
```
---
## Optional Parameters
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `logoutURL` | string | `callbackURL + "/logout"` | Path for logout requests |
| `postLogoutRedirectURI` | string | `/` | Redirect URL after logout |
| `logLevel` | string | `info` | Logging verbosity (`debug`, `info`, `error`) |
| `forceHTTPS` | bool | `false` | Force HTTPS for redirect URIs |
| `rateLimit` | int | `100` | Maximum requests per second |
| `excludedURLs` | []string | none | Paths that bypass authentication |
| `revocationURL` | string | auto-discovered | Token revocation endpoint |
| `oidcEndSessionURL` | string | auto-discovered | Provider's end session endpoint |
| `enablePKCE` | bool | `false` | Enable PKCE for authorization code flow |
| `minimalHeaders` | bool | `false` | Reduce forwarded headers |
### TLS Termination at Load Balancer
If running Traefik behind a load balancer (AWS ALB, Google Cloud LB, Azure App Gateway) that terminates TLS:
```yaml
forceHTTPS: true # Required for correct redirect URIs
```
Without this setting, redirect URIs will use `http://` instead of `https://`, causing OAuth callback failures.
---
## Security Options
### Audience Validation
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `audience` | string | `clientID` | Expected audience for access token validation |
| `strictAudienceValidation` | bool | `false` | Reject sessions with audience mismatch |
| `allowOpaqueTokens` | bool | `false` | Enable opaque token support via RFC 7662 |
| `requireTokenIntrospection` | bool | `false` | Require introspection for opaque tokens |
#### Production Security Configuration
```yaml
audience: "https://my-api.example.com"
strictAudienceValidation: true
```
#### Opaque Token Support
```yaml
allowOpaqueTokens: true
requireTokenIntrospection: true
strictAudienceValidation: true
```
### Other Security Options
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `disableReplayDetection` | bool | `false` | Disable JTI-based replay attack detection |
| `allowPrivateIPAddresses` | bool | `false` | Allow private IPs in provider URLs |
---
## Session Management
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `sessionMaxAge` | int | `86400` (24h) | Maximum session age in seconds |
| `refreshGracePeriodSeconds` | int | `60` | Seconds before expiry to attempt refresh |
| `cookieDomain` | string | auto-detected | Domain for session cookies |
| `cookiePrefix` | string | `_oidc_raczylo_` | Prefix for cookie names |
### Multi-Subdomain Setup
```yaml
cookieDomain: .example.com # Share cookies across subdomains
```
### Multiple Middleware Instances
When running multiple middleware instances with different authorization requirements, use unique prefixes:
```yaml
# User authentication middleware
---
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-userauth
spec:
plugin:
traefikoidc:
cookiePrefix: "_oidc_userauth_"
sessionEncryptionKey: user-encryption-key-min-32-bytes
# ... other config
---
# Admin authentication middleware
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-adminauth
spec:
plugin:
traefikoidc:
cookiePrefix: "_oidc_adminauth_"
sessionEncryptionKey: admin-encryption-key-min-32-bytes
allowedUsers:
- admin@example.com
# ... other config
```
### Extended Session Duration
```yaml
sessionMaxAge: 604800 # 7 days
# Common values:
# 3600 - 1 hour (high security)
# 86400 - 1 day (default)
# 259200 - 3 days
# 604800 - 7 days
# 2592000 - 30 days
```
---
## Access Control
### User Restrictions
| Parameter | Type | Description |
|-----------|------|-------------|
| `allowedUserDomains` | []string | Restrict to specific email domains |
| `allowedUsers` | []string | Specific email addresses allowed |
| `allowedRolesAndGroups` | []string | Required roles or groups |
| `roleClaimName` | string | JWT claim for roles (default: `roles`) |
| `groupClaimName` | string | JWT claim for groups (default: `groups`) |
| `userIdentifierClaim` | string | Claim for user ID (default: `email`) |
### Domain Restriction
```yaml
allowedUserDomains:
- company.com
- subsidiary.com
```
### Specific User Access
```yaml
allowedUsers:
- user@example.com
- contractor@external.org
```
### Role-Based Access Control
```yaml
allowedRolesAndGroups:
- admin
- developer
roleClaimName: "https://myapp.com/roles" # For namespaced claims (Auth0)
```
### Access Control Logic
- If only `allowedUsers` is set: Only specified emails can access
- If only `allowedUserDomains` is set: Only specified domains can access
- If both are set: Access granted if email is in `allowedUsers` OR domain is in `allowedUserDomains`
- If neither is set: Any authenticated user can access
### Users Without Email (Azure AD)
For Azure AD service accounts or users without email:
```yaml
userIdentifierClaim: sub # Options: sub, oid, upn, preferred_username
allowedUsers:
- "abc12345-6789-0abc-def0-123456789abc" # User object ID
```
---
## Headers Configuration
### Default Headers
The middleware sets these headers for downstream services:
| Header | Description |
|--------|-------------|
| `X-Forwarded-User` | User's email address |
| `X-User-Groups` | Comma-separated user groups |
| `X-User-Roles` | Comma-separated user roles |
| `X-Auth-Request-Redirect` | Original request URI |
| `X-Auth-Request-User` | User's email address |
| `X-Auth-Request-Token` | User's ID token |
### Minimal Headers Mode
For "431 Request Header Fields Too Large" errors:
```yaml
minimalHeaders: true # Only forwards X-Forwarded-User
```
### Custom Templated Headers
```yaml
headers:
- name: "X-User-Email"
value: "{{{{.Claims.email}}}}"
- name: "X-User-ID"
value: "{{{{.Claims.sub}}}}"
- name: "Authorization"
value: "Bearer {{{{.AccessToken}}}}"
- name: "X-User-Roles"
value: "{{{{range $i, $e := .Claims.roles}}}}{{{{if $i}}}},{{{{end}}}}{{{{$e}}}}{{{{end}}}}"
```
**Template Variables:**
- `{{.Claims.field}}` - ID token claims
- `{{.AccessToken}}` - Raw access token
- `{{.IdToken}}` - Raw ID token
- `{{.RefreshToken}}` - Raw refresh token
**Important:** Use double curly braces (`{{{{` and `}}}}`) to escape templates in YAML.
---
## Security Headers
### Security Profiles
| Profile | Use Case | Security Level |
|---------|----------|----------------|
| `default` | Standard web apps | High |
| `strict` | Maximum security | Very High |
| `development` | Local development | Medium |
| `api` | API endpoints | High |
| `custom` | Custom requirements | Configurable |
### Basic Configuration
```yaml
securityHeaders:
enabled: true
profile: "default"
```
### API with CORS
```yaml
securityHeaders:
enabled: true
profile: "api"
corsEnabled: true
corsAllowedOrigins:
- "https://your-frontend.com"
- "https://*.example.com"
corsAllowCredentials: true
```
### Custom Security Configuration
```yaml
securityHeaders:
enabled: true
profile: "custom"
# Content Security Policy
contentSecurityPolicy: "default-src 'self'; script-src 'self'"
# HSTS
strictTransportSecurity: true
strictTransportSecurityMaxAge: 31536000
strictTransportSecuritySubdomains: true
strictTransportSecurityPreload: true
# Frame and Content Protection
frameOptions: "DENY"
contentTypeOptions: "nosniff"
xssProtection: "1; mode=block"
referrerPolicy: "strict-origin-when-cross-origin"
# CORS
corsEnabled: true
corsAllowedOrigins: ["https://app.example.com"]
corsAllowedMethods: ["GET", "POST", "PUT", "DELETE", "OPTIONS"]
corsAllowedHeaders: ["Authorization", "Content-Type"]
corsAllowCredentials: true
corsMaxAge: 86400
# Custom Headers
customHeaders:
X-Custom-Header: "value"
# Server Identification
disableServerHeader: true
disablePoweredByHeader: true
```
### CORS Origin Patterns
```yaml
corsAllowedOrigins:
- "https://example.com" # Exact match
- "https://*.example.com" # Subdomain wildcard
- "http://localhost:*" # Port wildcard (development)
```
---
## Scope Configuration
### Default Behavior (Append Mode)
```yaml
scopes:
- roles
- custom_scope
# Result: ["openid", "profile", "email", "roles", "custom_scope"]
```
### Override Mode
```yaml
overrideScopes: true
scopes:
- openid
- profile
- custom_scope
# Result: ["openid", "profile", "custom_scope"]
```
---
## Advanced Options
### Dynamic Client Registration (RFC 7591)
```yaml
dynamicClientRegistration:
enabled: true
initialAccessToken: "your-token" # Optional
persistCredentials: true
credentialsFile: "/tmp/oidc-credentials.json"
clientMetadata:
redirect_uris:
- "https://your-app.com/oauth2/callback"
client_name: "My Application"
application_type: "web"
grant_types:
- "authorization_code"
- "refresh_token"
```
### Multi-Replica Deployment
Without Redis, disable replay detection:
```yaml
disableReplayDetection: true
```
With Redis (recommended):
```yaml
redis:
enabled: true
address: "redis:6379"
cacheMode: "hybrid"
```
See [REDIS.md](REDIS.md) for complete Redis configuration.
---
## Kubernetes Secrets
Reference secrets instead of hardcoding sensitive values:
```yaml
providerURL: urn:k8s:secret:oidc-secret:ISSUER
clientID: urn:k8s:secret:oidc-secret:CLIENT_ID
clientSecret: urn:k8s:secret:oidc-secret:SECRET
```
Create the secret:
```bash
kubectl create secret generic oidc-secret \
--from-literal=ISSUER=https://accounts.google.com \
--from-literal=CLIENT_ID=your-client-id \
--from-literal=SECRET=your-client-secret \
-n traefik
```
---
## Environment Variable Naming
**Important:** Avoid using "API" as a substring in environment variable names when using `${VAR}` syntax in Traefik configuration. Traefik reserves `TRAEFIK_API_*` variables and the substring may cause conflicts.
```yaml
# Bad - may cause issues
sessionEncryptionKey: ${OIDC_SECRET_API}
# Good
sessionEncryptionKey: ${OIDC_SECRET_SVC}
```
+455
View File
@@ -0,0 +1,455 @@
# Development Guide
Guide for local development, testing, and contributing to the Traefik OIDC middleware.
## Table of Contents
- [Prerequisites](#prerequisites)
- [Local Development Setup](#local-development-setup)
- [Running Tests](#running-tests)
- [Test Categories](#test-categories)
- [CI/CD Pipeline](#cicd-pipeline)
- [Code Quality](#code-quality)
- [Contributing](#contributing)
---
## Prerequisites
- **Go 1.23+** for plugin compilation
- **Docker & Docker Compose** for local testing
- **OIDC Provider** credentials (Google, Azure, etc.)
### Required Development Tools
```bash
# golangci-lint (comprehensive linting)
go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
# staticcheck (static analysis)
go install honnef.co/go/tools/cmd/staticcheck@latest
# gosec (security scanning)
go install github.com/securego/gosec/v2/cmd/gosec@latest
# govulncheck (vulnerability scanning)
go install golang.org/x/vuln/cmd/govulncheck@latest
```
---
## Local Development Setup
### Docker Compose Environment
The repository includes a Docker Compose setup for testing the plugin locally.
#### 1. Host Configuration
Add to `/etc/hosts`:
```bash
127.0.0.1 hello.localhost
127.0.0.1 traefik.localhost
```
#### 2. Plugin Configuration
The plugin is loaded using Traefik's **local plugins mode**:
- Plugin source: Parent directory (`../`)
- Mount path: `/plugins-local/src/github.com/lukaszraczylo/traefikoidc`
- Configuration: `experimental.localPlugins` in `traefik.yml`
#### 3. OIDC Provider Setup
Edit `docker/dynamic.yml` with your provider details:
**Google:**
```yaml
http:
middlewares:
oidc-auth:
plugin:
traefikoidc:
providerURL: "https://accounts.google.com"
clientID: "your-client-id.apps.googleusercontent.com"
clientSecret: "your-google-client-secret"
sessionEncryptionKey: "your-32-character-encryption-key"
callbackURL: "/oauth2/callback"
logoutURL: "/oauth2/logout"
scopes:
- "openid"
- "email"
- "profile"
```
**Azure AD:**
```yaml
http:
middlewares:
oidc-auth:
plugin:
traefikoidc:
providerURL: "https://login.microsoftonline.com/your-tenant-id/v2.0"
clientID: "your-azure-client-id"
clientSecret: "your-azure-client-secret"
sessionEncryptionKey: "your-32-character-encryption-key"
callbackURL: "/oauth2/callback"
scopes:
- "openid"
- "email"
- "profile"
```
#### 4. Start Environment
```bash
cd docker
docker-compose up -d
```
#### 5. Test Plugin
- **Protected App**: http://hello.localhost (redirects to OIDC)
- **Traefik Dashboard**: http://traefik.localhost:8080
### Development Workflow
1. **Edit plugin code** in the project root
2. **Build and test** (optional syntax check):
```bash
go mod tidy
go build .
go test ./...
```
3. **Restart Traefik** to reload plugin:
```bash
docker-compose restart traefik
```
4. **Test changes** at http://hello.localhost
### Debugging
**View plugin logs:**
```bash
docker-compose logs -f traefik | grep traefikoidc
```
**Check plugin loading:**
```bash
docker-compose logs traefik | grep -i plugin
```
**Verify plugin directory:**
```bash
docker-compose exec traefik ls -la /plugins-local/src/github.com/lukaszraczylo/traefikoidc/
```
---
## Running Tests
### Quick Start
```bash
# Fast development testing (< 30 seconds)
go test ./... -short
# Standard tests with race detector
go test -race -timeout=15m ./...
# With coverage report
go test -coverprofile=coverage.out ./...
go tool cover -func=coverage.out
```
### Test Modes
| Mode | Command | Duration | Use Case |
|------|---------|----------|----------|
| Quick | `go test ./... -short` | < 30s | During development |
| Extended | `RUN_EXTENDED_TESTS=1 go test ./...` | 2-5 min | Before commits |
| Long | `RUN_LONG_TESTS=1 go test ./...` | 5-15 min | Release validation |
| Stress | `RUN_STRESS_TESTS=1 go test ./...` | 10-30 min | Performance testing |
### Environment Variables
```bash
# Enable specific test types
export RUN_EXTENDED_TESTS=1
export RUN_LONG_TESTS=1
export RUN_STRESS_TESTS=1
# Disable specific features
export DISABLE_LEAK_DETECTION=1
# Customize test parameters
export TEST_MAX_CONCURRENCY=10
export TEST_MAX_ITERATIONS=50
export TEST_MEMORY_THRESHOLD_MB=25.5
```
---
## Test Categories
### Quick Tests (Default)
- Basic functionality verification
- Limited iterations (1-3)
- Small data sets
- Essential memory leak checks
**Configuration:**
- Max Iterations: 3
- Max Concurrency: 5
- Memory Threshold: 2.0 MB
- Timeout: 10 seconds
### Extended Tests
- Comprehensive testing before commits
- More iterations (5-10)
- Enhanced memory leak detection
**Configuration:**
- Max Iterations: 10
- Max Concurrency: 20
- Memory Threshold: 10.0 MB
- Timeout: 30 seconds
### Long Tests
- Performance validation
- High iteration counts (50-100)
- Large data sets
**Configuration:**
- Max Iterations: 100
- Max Concurrency: 50
- Memory Threshold: 50.0 MB
- Timeout: 60 seconds
### Stress Tests
- Maximum load testing
- Edge case validation
- Extreme parameters
**Configuration:**
- Max Iterations: 500
- Max Concurrency: 100
- Memory Threshold: 100.0 MB
- Timeout: 120 seconds
### Running Specific Test Suites
```bash
# Memory leak tests
go test -v -run='.*Leak.*' ./...
# Integration tests
go test -v -run='.*Integration.*' ./...
# Regression tests
go test -v -run='.*Regression.*' ./...
# Provider-specific tests
go test -v -run='.*Azure.*' ./...
go test -v -run='.*Google.*' ./...
```
### Benchmarks
```bash
# Quick benchmarks
go test -bench=. -short
# Extended benchmarks
RUN_EXTENDED_TESTS=1 go test -bench=.
# Memory profiling
go test -bench=. -memprofile=mem.prof
go tool pprof mem.prof
```
---
## CI/CD Pipeline
The repository uses GitHub Actions for comprehensive validation with 20+ parallel checks.
### Triggered On
- Pull requests to `main` branch
- Pushes to `main` branch
### Parallel Jobs
#### Code Quality (3 checks)
- **Format & Basic Checks** - gofmt, go vet, go mod
- **golangci-lint** - 30+ linters
- **Staticcheck** - Advanced static analysis
#### Security (3 checks)
- **Gosec** - Security vulnerability scanning
- **Govulncheck** - Go vulnerability database
- **CodeQL** - GitHub's semantic code analysis
#### Testing (9 suites)
- Race Detector
- Coverage (75% threshold)
- Memory Leaks
- Integration Tests
- Regression Tests
- Security Edge Cases
- Session Tests
- Token Tests
- CSRF Tests
#### Provider Testing (9 providers)
Tests run in parallel for:
- Google
- Azure AD
- Auth0
- Okta
- Keycloak
- AWS Cognito
- GitLab
- GitHub
- Generic OIDC
#### Performance & Build (3 checks)
- Benchmarks
- Multi-platform Build (linux/darwin x amd64/arm64)
- Go Version Compatibility (Go 1.23 & 1.24)
### Quality Gates
All PRs must pass:
- All parallel checks
- 75% test coverage minimum
- Zero security vulnerabilities
- No race conditions
- No memory leaks
- All providers tested
- Builds on all platforms
---
## Code Quality
### Pre-Commit Checklist
```bash
# Run before every commit
gofmt -s -w . && \
go mod tidy && \
golangci-lint run && \
go test -race -short ./... && \
echo "Ready to commit!"
```
### Local Validation
```bash
# Format code
gofmt -s -w .
# Run linter
golangci-lint run
# Static analysis
staticcheck ./...
# Security scan
gosec ./...
# Vulnerability check
govulncheck ./...
# Tests with race detector
go test -race -timeout=15m -count=1 ./...
# Coverage report
go test -coverprofile=coverage.out ./...
go tool cover -func=coverage.out
# View coverage in browser
go tool cover -html=coverage.out
```
### Troubleshooting
**Coverage Below Threshold:**
```bash
go test -coverprofile=coverage.out ./...
go tool cover -html=coverage.out # See uncovered lines
```
**Race Condition Found:**
```bash
go test -race -v -run=TestName ./...
```
**Linter Errors:**
```bash
golangci-lint run -v
golangci-lint run --fix # Auto-fix some issues
```
**Provider Test Fails:**
```bash
go test -v -run='.*Azure.*' ./...
```
---
## Contributing
### Development Guidelines
1. **Memory Management**: Ensure all goroutines can be cancelled and resources are bounded
2. **Testing**: Add tests for new features, including memory leak tests where appropriate
3. **Race Conditions**: Run tests with `-race` flag to detect race conditions
4. **Documentation**: Update README and configuration files for new options
### Pull Request Template
PRs should include:
- Description of changes
- Type of change (bug fix, feature, breaking change, etc.)
- Related issues
- Provider impact (which providers are affected)
- Testing performed
- Security considerations
- Performance impact
- Breaking changes (if any)
### Checklist
Before submitting:
- [ ] Code follows project style
- [ ] Self-review completed
- [ ] Tests added for new functionality
- [ ] All tests pass locally
- [ ] Documentation updated
- [ ] No new warnings generated
### Code Owners
The repository uses CODEOWNERS for automatic PR reviewer assignment based on file paths.
### Dependabot
Automated dependency updates run weekly (Mondays 9 AM) with security updates prioritized.
---
## Additional Resources
- [golangci-lint Rules](.golangci.yml)
- [PR Template](.github/PULL_REQUEST_TEMPLATE.md)
- [Workflow Documentation](.github/workflows/README.md)
- [GitHub Actions Documentation](https://docs.github.com/en/actions)
+580
View File
@@ -0,0 +1,580 @@
# OIDC Provider Configuration Guide
Configuration reference for each supported OIDC provider.
## Table of Contents
- [Provider Support Matrix](#provider-support-matrix)
- [Google](#google)
- [Microsoft Azure AD](#microsoft-azure-ad)
- [Auth0](#auth0)
- [Okta](#okta)
- [Keycloak](#keycloak)
- [AWS Cognito](#aws-cognito)
- [GitLab](#gitlab)
- [GitHub](#github)
- [Generic OIDC](#generic-oidc)
- [Automatic Scope Filtering](#automatic-scope-filtering)
---
## Provider Support Matrix
| Provider | OIDC Support | Refresh Tokens | Auto-Detection | ID Tokens |
|----------|-------------|----------------|----------------|-----------|
| Google | Full | Yes | `accounts.google.com` | Yes |
| Azure AD | Full | Yes | `login.microsoftonline.com` | Yes |
| Auth0 | Full | Yes | `*.auth0.com` | Yes |
| Okta | Full | Yes | `*.okta.com` | Yes |
| Keycloak | Full | Yes | `/auth/realms/` path | Yes |
| AWS Cognito | Full | Yes | `cognito-idp.*.amazonaws.com` | Yes |
| GitLab | Full | Yes | `gitlab.com` | Yes |
| GitHub | OAuth 2.0 Only | No | `github.com` | No |
| Generic | Full | Yes | Any OIDC endpoint | Yes |
---
## Google
### Provider URL
```yaml
providerURL: "https://accounts.google.com"
```
### Configuration
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-google
spec:
plugin:
traefikoidc:
providerURL: "https://accounts.google.com"
clientID: "your-id.apps.googleusercontent.com"
clientSecret: "your-client-secret"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "your-32-char-encryption-key-here"
scopes:
- openid
- email
- profile
allowedUserDomains:
- "your-gsuite-domain.com" # Optional: Workspace restriction
forceHttps: true
enablePkce: true
```
### Google-Specific Features
- **Automatic offline access**: Middleware adds `access_type=offline` and `prompt=consent`
- **Scope filtering**: Automatically removes unsupported `offline_access` scope
- **Workspace domains**: Restrict to specific Google Workspace domains via `hd` claim
### Google Cloud Console Setup
1. Go to [Google Cloud Console](https://console.cloud.google.com/)
2. Create or select a project
3. Navigate to APIs & Services > Credentials
4. Create OAuth 2.0 Client ID (Web application)
5. Add authorized redirect URI: `https://your-domain.com/oauth2/callback`
6. Configure OAuth consent screen (must be "Published" for production)
---
## Microsoft Azure AD
### Provider URL
```yaml
# Single tenant
providerURL: "https://login.microsoftonline.com/{tenant-id}/v2.0"
# Multi-tenant
providerURL: "https://login.microsoftonline.com/common/v2.0"
```
### Basic Configuration
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-azure
spec:
plugin:
traefikoidc:
providerURL: "https://login.microsoftonline.com/common/v2.0"
clientID: "your-azure-client-id"
clientSecret: "your-azure-client-secret"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "your-32-char-encryption-key-here"
scopes:
- openid
- profile
- email
- offline_access
allowedRolesAndGroups:
- "App.Users"
- "Admin.Group"
forceHttps: true
```
### With Application ID URI (API Access)
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-azure-api
spec:
plugin:
traefikoidc:
providerURL: "https://login.microsoftonline.com/common/v2.0"
clientID: "your-azure-client-id"
clientSecret: "your-azure-client-secret"
audience: "api://your-azure-client-id" # Application ID URI
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "your-32-char-encryption-key-here"
forceHttps: true
```
### Users Without Email
```yaml
userIdentifierClaim: sub # Options: sub, oid, upn, preferred_username
allowedUsers:
- "user-object-id-1"
- "user-object-id-2"
```
### Azure AD Setup
1. Go to [Azure Portal](https://portal.azure.com/)
2. Navigate to Azure Active Directory > App registrations
3. Create new registration
4. Add redirect URI: `https://your-domain.com/oauth2/callback`
5. Create client secret in Certificates & secrets
6. Configure Token Configuration for group claims
---
## Auth0
### Provider URL
```yaml
providerURL: "https://your-domain.auth0.com"
```
### Basic Configuration
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-auth0
spec:
plugin:
traefikoidc:
providerURL: "https://your-domain.auth0.com"
clientID: "your-auth0-client-id"
clientSecret: "your-auth0-client-secret"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "your-32-char-encryption-key-here"
scopes:
- openid
- profile
- email
- offline_access
postLogoutRedirectUri: "https://your-app.com"
forceHttps: true
enablePkce: true
```
### With Custom API Audience
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-auth0-api
spec:
plugin:
traefikoidc:
providerURL: "https://your-domain.auth0.com"
clientID: "your-auth0-client-id"
clientSecret: "your-auth0-client-secret"
audience: "https://api.your-domain.com" # API identifier
strictAudienceValidation: true
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "your-32-char-encryption-key-here"
roleClaimName: "https://your-app.com/roles" # Namespaced claim
groupClaimName: "https://your-app.com/groups"
allowedRolesAndGroups:
- admin
- editor
```
### Auth0 Action for Custom Claims
```javascript
exports.onExecutePostLogin = async (event, api) => {
const namespace = 'https://your-app.com/';
if (event.authorization) {
api.idToken.setCustomClaim(namespace + 'roles', event.authorization.roles);
api.idToken.setCustomClaim('email', event.user.email);
}
};
```
### Auth0 Setup
1. Go to [Auth0 Dashboard](https://manage.auth0.com/)
2. Create Regular Web Application
3. Configure Allowed Callback URLs: `https://your-domain.com/oauth2/callback`
4. Configure Allowed Logout URLs: `https://your-domain.com/oauth2/logout`
5. Enable OIDC Conformant in Advanced Settings
6. Create API in APIs section for custom audiences
See [AUTH0_AUDIENCE_GUIDE.md](AUTH0_AUDIENCE_GUIDE.md) for detailed audience configuration.
---
## Okta
### Provider URL
```yaml
providerURL: "https://your-domain.okta.com"
# Or with custom authorization server:
providerURL: "https://your-domain.okta.com/oauth2/default"
```
### Configuration
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-okta
spec:
plugin:
traefikoidc:
providerURL: "https://your-domain.okta.com"
clientID: "your-okta-client-id"
clientSecret: "your-okta-client-secret"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "your-32-char-encryption-key-here"
scopes:
- openid
- profile
- email
- groups
- offline_access
allowedRolesAndGroups:
- admin
- "Everyone"
forceHttps: true
enablePkce: true
```
### Okta Setup
1. Access Okta Admin Console
2. Go to Applications > Create App Integration
3. Select OIDC - OpenID Connect > Web Application
4. Configure Sign-in redirect URIs: `https://your-domain.com/oauth2/callback`
5. Configure Sign-out redirect URIs: `https://your-domain.com/oauth2/logout`
6. Enable Authorization Code and Refresh Token grant types
7. Configure Groups claim in authorization server
---
## Keycloak
### Provider URL
```yaml
providerURL: "https://keycloak.your-domain.com/realms/{realm-name}"
```
### Configuration
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-keycloak
spec:
plugin:
traefikoidc:
providerURL: "https://keycloak.company.com/realms/your-realm"
clientID: "your-keycloak-client-id"
clientSecret: "your-keycloak-client-secret"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "your-32-char-encryption-key-here"
scopes:
- openid
- profile
- email
- roles
- groups
- offline_access
allowedRolesAndGroups:
- admin
- editor
forceHttps: true
enablePkce: true
```
### Internal Network Deployment
For private IP addresses (Docker networks, Kubernetes):
```yaml
providerURL: "https://192.168.1.100:8443/realms/your-realm"
allowPrivateIPAddresses: true # Required for private IPs
```
### Keycloak Client Setup
1. Access Keycloak Admin Console
2. Select your realm
3. Go to Clients > Create client
4. Set Client Protocol: openid-connect
5. Set Access Type: confidential
6. Add Valid Redirect URIs: `https://your-domain.com/oauth2/callback`
7. Generate client secret in Credentials tab
8. Configure mappers to add claims to ID Token:
- Email: User Property mapper with "Add to ID token" enabled
- Roles: User Client Role mapper with "Add to ID token" enabled
- Groups: Group Membership mapper with "Add to ID token" enabled
---
## AWS Cognito
### Provider URL
```yaml
providerURL: "https://cognito-idp.{region}.amazonaws.com/{user-pool-id}"
```
### Configuration
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-cognito
spec:
plugin:
traefikoidc:
providerURL: "https://cognito-idp.us-east-1.amazonaws.com/us-east-1_ABCDEF123"
clientID: "your-cognito-client-id"
clientSecret: "your-cognito-client-secret"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "your-32-char-encryption-key-here"
scopes:
- openid
- profile
- email
- aws.cognito.signin.user.admin
allowedRolesAndGroups:
- admin
- users
forceHttps: true
```
### AWS Cognito Setup
1. Create Cognito User Pool
2. Create App Client with OIDC scopes
3. Configure App Client settings:
- Callback URLs: `https://your-domain.com/oauth2/callback`
- Sign out URLs: `https://your-domain.com/oauth2/logout`
- OAuth flows: Authorization code grant
4. Configure hosted UI domain (optional)
5. Set up groups for role-based access
---
## GitLab
### Provider URL
```yaml
# GitLab.com
providerURL: "https://gitlab.com"
# Self-hosted
providerURL: "https://gitlab.your-company.com"
```
### Configuration
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-gitlab
spec:
plugin:
traefikoidc:
providerURL: "https://gitlab.com"
clientID: "your-gitlab-application-id"
clientSecret: "your-gitlab-application-secret"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "your-32-char-encryption-key-here"
scopes:
- openid
- profile
- email
# Note: GitLab doesn't require offline_access scope
# Refresh tokens are issued automatically with openid
allowedRolesAndGroups:
- developers
- maintainers
forceHttps: true
enablePkce: true
```
### GitLab Setup
1. Go to GitLab Settings > Applications
2. Create new application
3. Add scopes: `openid`, `profile`, `email`
4. Set redirect URI: `https://your-domain.com/oauth2/callback`
5. Save and note Application ID and Secret
---
## GitHub
### Provider URL
```yaml
providerURL: "https://github.com"
```
### Configuration
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oauth-github
spec:
plugin:
traefikoidc:
providerURL: "https://github.com/login/oauth"
clientID: "your-github-client-id"
clientSecret: "your-github-client-secret"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "your-32-char-encryption-key-here"
scopes:
- user:email
- read:user
allowedUsers:
- "github-username"
forceHttps: true
```
### Limitations
- **OAuth 2.0 only** - Not OpenID Connect
- **No ID tokens** - Only access tokens for API calls
- **No refresh tokens** - Users must re-authenticate on expiry
- **No standard claims** - User info requires API calls
Use GitHub only for API access, not for user authentication with claims.
### GitHub Setup
1. Go to GitHub Settings > Developer settings > OAuth Apps
2. Create new OAuth App
3. Set Authorization callback URL: `https://your-domain.com/oauth2/callback`
4. Note Client ID and generate Client Secret
---
## Generic OIDC
For any OIDC-compliant provider not listed above.
### Configuration
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-generic
spec:
plugin:
traefikoidc:
providerURL: "https://oidc.your-provider.com"
clientID: "your-client-id"
clientSecret: "your-client-secret"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "your-32-char-encryption-key-here"
scopes:
- openid
- profile
- email
forceHttps: true
enablePkce: true
```
### Requirements
- Provider must expose `.well-known/openid-configuration` endpoint
- Must support authorization code flow
- ID tokens must contain required claims (email, sub, etc.)
---
## Automatic Scope Filtering
The middleware automatically filters OAuth scopes based on the provider's declared capabilities.
### How It Works
1. Fetches provider's `.well-known/openid-configuration`
2. Extracts `scopes_supported` field
3. Filters requested scopes to only include supported ones
4. Falls back to all requested scopes if provider doesn't declare supported scopes
### Example: Self-Hosted GitLab
Self-hosted GitLab may reject `offline_access` scope:
```yaml
scopes:
- openid
- profile
- email
- offline_access # Will be automatically filtered out if unsupported
```
The middleware will:
1. Read GitLab's discovery document
2. Detect `offline_access` is NOT in `scopes_supported`
3. Filter it out automatically
4. Authentication succeeds
### Logging
```
INFO: ScopeFilter: Filtered unsupported scopes: [offline_access]
DEBUG: ScopeFilter: Final filtered scopes: [openid profile email]
```
### Troubleshooting
If a provider rejects scopes even after filtering:
1. Check the provider's discovery document: `curl https://provider/.well-known/openid-configuration`
2. Use `overrideScopes: true` with only supported scopes
3. Review middleware debug logs for filtering decisions
+546
View File
@@ -0,0 +1,546 @@
# Redis Cache for Distributed Deployments
Redis cache support for multi-replica Traefik deployments with shared state.
## Table of Contents
- [Overview](#overview)
- [Why Use Redis Cache?](#why-use-redis-cache)
- [Configuration](#configuration)
- [Cache Modes](#cache-modes)
- [Deployment Examples](#deployment-examples)
- [Performance Tuning](#performance-tuning)
- [Monitoring](#monitoring)
- [Troubleshooting](#troubleshooting)
- [Migration Guide](#migration-guide)
---
## Overview
The Redis cache feature provides distributed caching for the Traefik OIDC plugin, enabling seamless operation across multiple Traefik instances.
### Key Features
- **Distributed JTI Replay Detection**: Prevents token replay attacks across all instances
- **Shared Session Management**: Consistent user sessions across replicas
- **Circuit Breaker**: Automatic fallback to memory cache during Redis outages
- **Health Checking**: Continuous monitoring of Redis connectivity
- **Flexible Cache Modes**: Memory, Redis, or hybrid caching strategies
- **Pure-Go Implementation**: Yaegi-compatible, works with dynamic plugin loading
### Architecture
```
┌──────────────┐ ┌──────────────┐ ┌──────────────┐
│ Traefik #1 │ │ Traefik #2 │ │ Traefik #3 │
│ (Plugin) │ │ (Plugin) │ │ (Plugin) │
└──────┬───────┘ └──────┬───────┘ └──────┬───────┘
│ │ │
└────────────────────┼────────────────────┘
┌──────▼──────┐
│ Redis │
│ (Shared │
│ Cache) │
└─────────────┘
```
---
## Why Use Redis Cache?
### The Problem
When running multiple Traefik instances without shared cache:
1. **False Positive Replay Detection**
- User authenticates → Token stored in Instance A's JTI cache
- Next request → Load balancer routes to Instance B
- Instance B doesn't have the JTI → Falsely detects replay attack
2. **Session Inconsistency**
- User session created on Instance A
- Subsequent request routed to Instance B
- Instance B has no knowledge of the session
3. **Token Metadata Fragmentation**
- Token refresh happens on Instance A
- Other instances continue using old tokens
### The Solution
Redis provides centralized cache that all instances share, ensuring:
- **Consistent Authentication**: All instances share authentication state
- **True Replay Detection**: JTI cache shared across all instances
- **Seamless Scaling**: Add/remove instances without affecting sessions
- **High Availability**: Circuit breaker with automatic fallback
---
## Configuration
### Basic Configuration
```yaml
redis:
enabled: true
address: "redis:6379"
password: "your-password" # Optional
db: 0
keyPrefix: "traefikoidc:"
cacheMode: "hybrid"
```
### All Configuration Options
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `enabled` | bool | `false` | Enable Redis caching |
| `address` | string | - | Redis server address (`host:port`) |
| `password` | string | - | Redis password (optional) |
| `db` | int | `0` | Redis database number (0-15) |
| `keyPrefix` | string | `traefikoidc:` | Prefix for all Redis keys |
| `cacheMode` | string | `redis` | Cache mode: `memory`, `redis`, `hybrid` |
| `poolSize` | int | `10` | Connection pool size |
| `connectTimeout` | int | `5` | Connection timeout (seconds) |
| `readTimeout` | int | `3` | Read timeout (seconds) |
| `writeTimeout` | int | `3` | Write timeout (seconds) |
| `enableTLS` | bool | `false` | Enable TLS for connections |
| `tlsSkipVerify` | bool | `false` | Skip TLS certificate verification |
| `enableCircuitBreaker` | bool | `true` | Enable circuit breaker |
| `circuitBreakerThreshold` | int | `5` | Failures before circuit opens |
| `circuitBreakerTimeout` | int | `60` | Circuit reset timeout (seconds) |
| `enableHealthCheck` | bool | `true` | Enable periodic health checks |
| `healthCheckInterval` | int | `30` | Health check interval (seconds) |
| `hybridL1Size` | int | `500` | Max items in L1 cache (hybrid mode) |
| `hybridL1MemoryMB` | int64 | `10` | Max memory for L1 cache in MB |
### Environment Variables (Fallback)
If not configured through Traefik, these environment variables are used:
```bash
REDIS_ENABLED=true
REDIS_ADDRESS=redis:6379
REDIS_PASSWORD=your-password
REDIS_DB=0
REDIS_KEY_PREFIX=traefikoidc:
REDIS_CACHE_MODE=hybrid
REDIS_POOL_SIZE=10
REDIS_CONNECT_TIMEOUT=5
REDIS_READ_TIMEOUT=3
REDIS_WRITE_TIMEOUT=3
REDIS_ENABLE_TLS=false
REDIS_TLS_SKIP_VERIFY=false
```
---
## Cache Modes
### Memory Mode (Default without Redis)
```yaml
redis:
cacheMode: "memory"
```
- Uses only in-memory cache
- Suitable for single-instance deployments
- No Redis dependency
- Fastest performance
### Redis Mode
```yaml
redis:
enabled: true
address: "redis:6379"
cacheMode: "redis"
```
- All operations go directly to Redis
- Ensures consistency across replicas
- Slightly higher latency
### Hybrid Mode (Recommended)
```yaml
redis:
enabled: true
address: "redis:6379"
cacheMode: "hybrid"
```
Two-tier caching strategy:
```
┌─────────────────────────────────────────┐
│ Client Request │
└────────────────┬────────────────────────┘
┌────────────────┐
│ Local Cache │ ← L1 Cache (Fast)
│ (Memory) │
└────────┬───────┘
│ Miss
┌────────────────┐
│ Remote Cache │ ← L2 Cache (Shared)
│ (Redis) │
└────────────────┘
```
**Read Path:**
1. Check local memory cache (L1)
2. On miss, check Redis (L2)
3. On hit in Redis, populate L1
4. Return value
**Write Path:**
1. Write to Redis (L2) for durability
2. Write to local cache (L1) for speed
### Performance Comparison
| Operation | Memory Mode | Redis Mode | Hybrid Mode |
|-----------|------------|------------|-------------|
| Read (p50) | 0.1ms | 2ms | 0.2ms |
| Read (p99) | 0.5ms | 10ms | 5ms |
| Write (p50) | 0.2ms | 3ms | 3ms |
| Throughput | 100k/s | 20k/s | 80k/s |
---
## Deployment Examples
### Docker Compose
```yaml
version: '3.8'
services:
redis:
image: redis:7-alpine
command: redis-server --requirepass ${REDIS_PASSWORD}
volumes:
- redis-data:/data
healthcheck:
test: ["CMD", "redis-cli", "--raw", "incr", "ping"]
interval: 30s
timeout: 3s
retries: 3
traefik:
image: traefik:v3.2
deploy:
replicas: 3
labels:
- "traefik.http.middlewares.oidc.plugin.traefikoidc.redis.enabled=true"
- "traefik.http.middlewares.oidc.plugin.traefikoidc.redis.address=redis:6379"
- "traefik.http.middlewares.oidc.plugin.traefikoidc.redis.password=${REDIS_PASSWORD}"
- "traefik.http.middlewares.oidc.plugin.traefikoidc.redis.cacheMode=hybrid"
depends_on:
redis:
condition: service_healthy
volumes:
redis-data:
```
### Kubernetes
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-with-redis
spec:
plugin:
traefikoidc:
providerURL: https://accounts.google.com
clientID: your-client-id
clientSecret: your-client-secret
sessionEncryptionKey: your-encryption-key
callbackURL: /oauth2/callback
redis:
enabled: true
address: "redis-service.redis-namespace:6379"
password: "urn:k8s:secret:redis-secret:password"
db: 0
keyPrefix: "traefikoidc:"
cacheMode: "hybrid"
poolSize: 20
enableCircuitBreaker: true
circuitBreakerThreshold: 5
```
### AWS ElastiCache
```yaml
redis:
enabled: true
address: "your-cache.abc123.cache.amazonaws.com:6379"
cacheMode: "hybrid"
enableTLS: true
password: "your-elasticache-auth-token"
```
---
## Performance Tuning
### Connection Pool Sizing
```yaml
redis:
poolSize: 20 # Formula: 2 * CPU cores * replicas
# For 4 cores, 3 replicas: poolSize = 24
```
### TTL Strategy
The plugin automatically sets TTLs based on token lifetimes:
- **JTI Cache**: Matches token lifetime (typically 1 hour)
- **Session**: Matches `sessionMaxAge` configuration
- **Token Metadata**: 5 minutes (short-lived)
### Redis Server Configuration
```bash
# Recommended Redis settings for cache
maxmemory 512mb
maxmemory-policy allkeys-lru # Evict least recently used
# For cache data, disable persistence for better performance
save ""
appendonly no
```
### Hybrid Mode Tuning
```yaml
redis:
cacheMode: "hybrid"
hybridL1Size: 500 # Max items in local cache
hybridL1MemoryMB: 10 # Max memory for local cache
```
---
## Monitoring
### Key Metrics
- **Cache hit rate** (target: >90% for hybrid mode)
- **Redis latency** (target: <10ms p99)
- **Circuit breaker state**
- **Connection pool utilization
### Redis Commands for Monitoring
```bash
# Monitor commands in real-time
redis-cli MONITOR
# Check slow queries
redis-cli SLOWLOG GET 10
# Memory usage
redis-cli INFO memory
# Key statistics
redis-cli DBSIZE
# List keys with prefix
redis-cli --scan --pattern "traefikoidc:*"
# Check key TTL
redis-cli TTL "traefikoidc:session:abc123"
```
### Health Check Endpoint
The plugin provides health information including:
```json
{
"status": "healthy",
"cache": {
"mode": "hybrid",
"redis": {
"connected": true,
"latency": "2ms"
},
"circuit_breaker": {
"state": "closed",
"failures": 0
}
}
}
```
---
## Troubleshooting
### Connection Refused
**Symptoms:** `dial tcp: connection refused`
**Solutions:**
1. Verify Redis is running: `redis-cli ping`
2. Check network connectivity: `telnet redis-host 6379`
3. Verify address configuration
### Authentication Failure
**Symptoms:** `NOAUTH Authentication required`
**Solutions:**
1. Set Redis password in configuration
2. Verify password is correct
### Circuit Breaker Open
**Symptoms:** `Circuit breaker is open`, falling back to memory
**Solutions:**
1. Check Redis health: `redis-cli INFO server`
2. Review network latency: `redis-cli --latency`
3. Adjust circuit breaker thresholds if needed
### High Memory Usage
**Symptoms:** Redis memory constantly growing, OOM errors
**Solutions:**
1. Configure eviction policy:
```bash
CONFIG SET maxmemory 512mb
CONFIG SET maxmemory-policy allkeys-lru
```
2. Review key count: `redis-cli DBSIZE`
3. Check for large keys: `redis-cli --bigkeys`
### Inconsistent Cache State
**Symptoms:** Different responses from different replicas
**Solutions:**
1. Verify all instances use the same Redis address
2. Check cache mode consistency across instances
3. Verify time synchronization on all hosts
---
## Migration Guide
### From Memory-Only to Redis
#### Phase 1: Preparation
1. Deploy Redis infrastructure
2. Test Redis connectivity
3. Configure monitoring
#### Phase 2: Gradual Rollout
1. Enable Redis on one instance:
```yaml
redis:
enabled: true
address: "redis:6379"
cacheMode: "hybrid"
```
2. Monitor for errors
3. Gradually enable on more instances
#### Phase 3: Full Migration
1. Enable Redis on all instances
2. Remove `disableReplayDetection: true` if set
3. Monitor for issues
### Rollback Plan
If issues occur:
1. Set `redis.enabled: false`
2. Plugin falls back to memory cache automatically
3. Investigate and resolve issues
### Migration Checklist
- [ ] Redis deployed and accessible
- [ ] Redis password configured
- [ ] Network connectivity verified
- [ ] Monitoring configured
- [ ] Backup plan prepared
- [ ] Test environment validated
- [ ] Gradual rollout planned
---
## Best Practices
### Security
- Always use Redis password authentication
- Enable TLS for production deployments
- Use network segmentation (private subnets)
- Rotate Redis passwords regularly
### High Availability
- Use Redis Sentinel or Cluster for HA
- Configure appropriate circuit breaker thresholds
- Implement proper health checks
- Use connection pooling
### Performance
- Use hybrid cache mode for best performance
- Monitor cache hit rates
- Size Redis memory appropriately
- Disable persistence for cache-only usage
### Operations
- Implement comprehensive monitoring
- Set up alerting for circuit breaker state
- Document Redis configuration
- Test failover scenarios
---
## FAQ
### Is Redis required?
No, Redis is optional. The plugin works with in-memory cache for single-instance deployments.
### What happens if Redis goes down?
The circuit breaker opens after threshold failures, and the plugin falls back to in-memory cache. It periodically attempts to reconnect.
### Which cache mode should I use?
For production multi-replica deployments, use `hybrid` mode for best performance and consistency.
### How much memory does Redis need?
Depends on active sessions and token sizes:
- Small (1-1000 users): 128MB
- Medium (1000-10000 users): 256-512MB
- Large (10000+ users): 1GB+
### Can I use managed Redis services?
Yes, the plugin works with AWS ElastiCache, Azure Cache for Redis, Google Cloud Memorystore, and Redis Enterprise Cloud.
### Is data encrypted in Redis?
Session data is encrypted before storing using `sessionEncryptionKey`. Additionally, you can enable TLS for Redis connections.
+390
View File
@@ -0,0 +1,390 @@
# Testing Guide
Comprehensive testing infrastructure for traefikoidc.
## Overview
| Metric | Value |
|--------|-------|
| Test files | 99 |
| Lines of test code | ~65,500 |
| Code coverage | 71.0% |
| Race conditions | None (all pass with `-race`) |
## Running Tests
```bash
# Run all tests
go test ./...
# Run with race detection
go test -race ./...
# Run with coverage
go test -cover ./...
# Run specific test suite
go test -v -run "TokenValidationSuite" .
# Run edge case tests
go test -v -run "ClockSkewEdgeCasesSuite|UnicodeClaimsSuite" .
```
## Test Infrastructure
### Directory Structure
```
internal/testutil/
├── compat.go # Re-exports for main package access
├── mocks/
│ ├── interfaces.go # JWKCache, TokenExchanger, TokenVerifier, etc.
│ ├── session.go # SessionManager, SessionData
│ ├── cache.go # Cache, TokenCache, Blacklist
│ └── interfaces_test.go # Mock verification tests
├── fixtures/
│ └── tokens.go # JWT token generation fixtures
└── servers/
├── oidc.go # Mock OIDC server factory
└── oidc_test.go # Server tests
```
### Test Suites
| Suite | File | Description |
|-------|------|-------------|
| TokenValidationSuite | `token_validation_suite_test.go` | Token validation happy path and error cases |
| JWKCacheTestSuite | `token_validation_suite_test.go` | JWK cache behavior tests |
| TokenExchangerTestSuite | `token_validation_suite_test.go` | Token exchange scenarios |
| ClockSkewEdgeCasesSuite | `edge_cases_suite_test.go` | Expiry boundary testing |
| UnicodeClaimsSuite | `edge_cases_suite_test.go` | Unicode/emoji handling in claims |
| LargeClaimsSuite | `edge_cases_suite_test.go` | Large data handling (100s of claims) |
| URLPathEdgeCasesSuite | `edge_cases_suite_test.go` | URL parsing edge cases |
| ConcurrencyEdgeCasesSuite | `edge_cases_suite_test.go` | Concurrent token validation |
| ExampleTestSuite | `testutil_example_test.go` | Example demonstrating patterns |
| AuthFlowBehaviourSuite | `auth_flow_behaviour_test.go` | Authentication flow behavior tests |
| SessionBehaviourSuite | `session_behaviour_test.go` | Session management behavior tests |
| EnhancedMocksSuite | `enhanced_mocks_suite_test.go` | Enhanced mock usage demonstration |
## Mock Types
The project provides two mocking patterns:
### State-Based Mocks (Basic)
Located in `main_test.go`, `mocks_test.go`. Simple mocks that store data in struct fields.
| Mock | Interface | Description |
|------|-----------|-------------|
| `MockJWKCache` | `JWKCacheInterface` | Simple state-based mock with JWKS/Err fields |
| `MockTokenVerifier` | `TokenVerifier` | Function-based mock for token verification |
| `MockTokenExchanger` | `TokenExchanger` | Function-based mock for token exchange |
| `MockOAuthProvider` | `http.Handler` | Full HTTP handler mock for OAuth provider simulation |
| `MockSessionManager` | `SessionManager` | State-based mock for session management |
| `MockHTTPClient` | N/A | Mock HTTP client with customizable responses |
**Usage:**
```go
mock := &MockJWKCache{
JWKS: &JWKSet{Keys: []JWK{jwk}},
Err: nil,
}
tOidc := &TraefikOidc{
jwkCache: mock,
// ...
}
```
### Enhanced State-Based Mocks (with Call Tracking)
Located in `enhanced_mocks_test.go`. State-based mocks with built-in call tracking and assertion helpers.
| Mock | Interface | Description |
|------|-----------|-------------|
| `EnhancedMockJWKCache` | `JWKCacheInterface` | State-based with call tracking |
| `EnhancedMockTokenVerifier` | `TokenVerifier` | State-based with call tracking |
| `EnhancedMockTokenExchanger` | `TokenExchanger` | State-based with call tracking |
| `EnhancedMockCacheInterface` | `CacheInterface` | Functional cache with call tracking |
**Usage:**
```go
mock := &EnhancedMockJWKCache{
JWKS: &JWKSet{Keys: []JWK{jwk}},
}
// Make calls
result, err := mock.GetJWKS(ctx, "https://example.com/jwks", nil)
// Verify calls were made
mock.AssertGetJWKSCalled(t)
mock.AssertGetJWKSCalledWith(t, "https://example.com/jwks")
mock.AssertGetJWKSCallCount(t, 1)
// Access call details
s.Equal(1, mock.GetJWKSCallCount())
```
**Features:**
- Track all calls with parameters and timestamps
- Built-in assertion helpers using testify
- Thread-safe for concurrent tests
- `Reset()` method to clear state between tests
- `LastCall()` to inspect most recent call
### Testify-Based Mocks
Located in `testify_mocks_test.go`. Mocks using testify's `.On()/.Return()` pattern for behavior verification.
| Mock | Interface | Description |
|------|-----------|-------------|
| `TestifyJWKCache` | `JWKCacheInterface` | Testify mock with `.On()/.Return()` |
| `TestifyTokenVerifier` | `TokenVerifier` | Testify mock for token verification |
| `TestifyTokenExchanger` | `TokenExchanger` | Testify mock for token exchange |
| `TestifyCacheInterface` | `CacheInterface` | Testify mock for cache operations |
| `TestifyHTTPClient` | N/A | Testify mock for HTTP client |
| `TestifyRoundTripper` | `http.RoundTripper` | Testify mock for HTTP transport |
**Usage:**
```go
mock := &TestifyJWKCache{}
mock.On("GetJWKS", mock.Anything, "https://example.com/jwks", mock.Anything).
Return(&JWKSet{Keys: []JWK{jwk}}, nil)
// After test
mock.AssertExpectations(t)
```
### Testutil Package Mocks
Located in `internal/testutil/mocks/`. Generic mocks for testing the test infrastructure itself.
```go
import "github.com/lukaszraczylo/traefikoidc/internal/testutil"
mock := testutil.NewJWKCacheMock()
mock.On("GetJWKS", mock.Anything, mock.Anything, mock.Anything).
Return(&mocks.JWKSet{Keys: []mocks.JWK{{Kty: "RSA"}}}, nil)
```
### Choosing the Right Mock
| Use Case | Recommended Mock |
|----------|-----------------|
| Simple return values only | Basic state-based (`MockJWKCache`) |
| Return values + verify calls made | Enhanced state-based (`EnhancedMockJWKCache`) |
| Complex call expectations | Testify-based (`TestifyJWKCache`) |
| Verify call order/sequence | Testify-based |
| HTTP endpoint simulation | `MockOAuthProvider` |
| New testify suite tests | Enhanced or Testify-based |
**Decision Guide:**
1. **Basic State-Based**: Use when you only need to control return values and don't care about verifying interactions.
2. **Enhanced State-Based**: Use when you want to verify calls were made with specific parameters, but prefer simpler setup than testify's `.On()/.Return()` pattern.
3. **Testify-Based**: Use when you need complex behavior like different returns per call, strict call ordering, or detailed expectation matching.
## Token Fixtures
The `testutil.TokenFixture` generates JWT tokens for testing:
```go
fixture, err := testutil.NewTokenFixture()
// Valid token with default claims
token, _ := fixture.ValidToken(nil)
// Token with custom claims
token, _ := fixture.ValidToken(map[string]interface{}{
"email": "test@example.com",
"roles": []string{"admin"},
})
// Expired token
token, _ := fixture.ExpiredToken()
// Token with specific roles/groups
token, _ := fixture.TokenWithRoles([]string{"admin", "user"})
token, _ := fixture.TokenWithGroups([]string{"developers"})
// Token with clock skew
token, _ := fixture.TokenWithSkew(-2 * time.Minute) // expired 2 min ago
token, _ := fixture.TokenWithSkew(5 * time.Minute) // expires in 5 min
// Token missing specific claims
token, _ := fixture.TokenMissingClaim("email", "sub")
// Malformed token
token := fixture.MalformedToken() // "not.a.valid.jwt"
// Get JWKS for verification
jwks := fixture.GetJWKS()
```
## Mock OIDC Server
The `testutil.OIDCServer` provides a fully functional mock OIDC provider:
```go
// Default configuration
server := testutil.NewOIDCServer(nil)
defer server.Close()
// Custom configuration
config := testutil.DefaultServerConfig()
config.Issuer = "https://custom-issuer.com"
config.TokenError = &testutil.OIDCError{
Error: "invalid_grant",
Description: "Authorization code expired",
}
server := testutil.NewOIDCServer(config)
// Provider-specific configurations
googleConfig := testutil.GoogleServerConfig()
azureConfig := testutil.AzureServerConfig()
auth0Config := testutil.Auth0ServerConfig()
keycloakConfig := testutil.KeycloakServerConfig()
// Behavior configurations
slowConfig := testutil.SlowServerConfig(100 * time.Millisecond)
rateLimitedConfig := testutil.RateLimitedServerConfig(5) // Limit after 5 requests
```
### Server Endpoints
| Endpoint | Description |
|----------|-------------|
| `/.well-known/openid-configuration` | OIDC discovery document |
| `/authorize` | Authorization endpoint |
| `/token` | Token exchange endpoint |
| `/jwks` | JSON Web Key Set |
| `/userinfo` | User information endpoint |
| `/introspect` | Token introspection |
| `/revoke` | Token revocation |
| `/logout` | End session endpoint |
### Request Tracking
```go
server := testutil.NewOIDCServer(nil)
// Make requests...
count := server.GetRequestCount()
requests := server.GetRequests()
server.Reset() // Clear tracking
```
## Writing Test Suites
### Basic Suite Structure
```go
type MyTestSuite struct {
suite.Suite
fixture *testutil.TokenFixture
tOidc *TraefikOidc
}
func (s *MyTestSuite) SetupSuite() {
var err error
s.fixture, err = testutil.NewTokenFixture()
s.Require().NoError(err)
}
func (s *MyTestSuite) SetupTest() {
// Per-test setup
s.tOidc = &TraefikOidc{
issuerURL: s.fixture.Issuer,
// ...
}
}
func (s *MyTestSuite) TearDownTest() {
// Per-test cleanup
}
func (s *MyTestSuite) TestSomething() {
token, err := s.fixture.ValidToken(nil)
s.Require().NoError(err)
err = s.tOidc.VerifyToken(token)
s.NoError(err)
}
func TestMyTestSuite(t *testing.T) {
suite.Run(t, new(MyTestSuite))
}
```
### Table-Driven Tests
```go
func (s *MyTestSuite) TestClockSkewEdgeCases() {
testCases := []struct {
name string
skew time.Duration
shouldPass bool
}{
{"valid_token", 5 * time.Minute, true},
{"expired_within_tolerance", -1 * time.Minute, true},
{"expired_beyond_tolerance", -10 * time.Minute, false},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
token, err := s.fixture.TokenWithSkew(tc.skew)
s.Require().NoError(err)
err = s.tOidc.VerifyToken(token)
if tc.shouldPass {
s.NoError(err)
} else {
s.Error(err)
}
})
}
}
```
## Test Categories
### Happy Path Tests
Test the expected successful scenarios:
- Valid token verification
- Successful token exchange
- Session creation and retrieval
- Cache operations
### Error Case Tests
Test failure scenarios:
- Expired tokens
- Invalid signatures
- Wrong issuer/audience
- Network failures
- Rate limiting
### Edge Case Tests
Test boundary conditions:
- Clock skew tolerance boundaries
- Unicode/emoji in claims
- Very large claim values
- Concurrent access
- Special characters in URLs
## Best Practices
1. **Use fixtures for token generation** - Don't manually construct JWTs
2. **Use mock servers for integration tests** - Test against realistic OIDC behavior
3. **Always run with `-race`** - Catch concurrency issues early
4. **Use testify assertions** - Better error messages and cleaner code
5. **Clean up resources** - Use `t.Cleanup()` or `TearDownTest()`
6. **Test edge cases systematically** - Use table-driven tests
-163
View File
@@ -1,163 +0,0 @@
# Google OAuth Integration Fix
## Problem Overview
The Traefik OIDC plugin encountered an authentication issue when using Google as an OAuth provider. Authentication would fail with the following error:
```
Some requested scopes were invalid. {valid=[openid, https://www.googleapis.com/auth/userinfo.email, https://www.googleapis.com/auth/userinfo.profile], invalid=[offline_access]}
```
This occurred because Google's OAuth implementation differs from the standard OIDC specification in how it handles refresh tokens and offline access.
## Technical Details of the Issue
### Standard OIDC Provider Behavior
Most OpenID Connect (OIDC) providers follow the standard specification, where:
- To obtain a refresh token, clients include the `offline_access` scope in their authorization request
- This allows authenticated sessions to persist beyond the initial access token expiration
### Google's Non-Standard Approach
Google's OAuth implementation deviates from the standard by:
1. Not supporting the `offline_access` scope, instead rejecting it as an invalid scope
2. Requiring the `access_type=offline` query parameter for requesting refresh tokens
3. Needing the `prompt=consent` parameter to consistently issue refresh tokens (especially for repeat authentications)
This difference caused the plugin to fail when configured for Google OAuth, as it was using a standard approach that didn't work with Google's implementation.
## Solution Implementation
The fix involved modifying the authentication flow to specifically handle Google providers:
1. **Google Provider Detection**: Added code to detect if the OIDC provider is Google based on the issuer URL:
```go
// Check if we're dealing with a Google OIDC provider
isGoogleProvider := strings.Contains(t.issuerURL, "google") ||
strings.Contains(t.issuerURL, "accounts.google.com")
```
2. **Provider-Specific Auth URL Building**: Modified the `buildAuthURL` function to handle Google and non-Google providers differently:
```go
// Handle offline access differently for Google vs other providers
if isGoogleProvider {
// For Google, use access_type=offline parameter instead of offline_access scope
params.Set("access_type", "offline")
t.logger.Debug("Google OIDC provider detected, added access_type=offline for refresh tokens")
// Add prompt=consent for Google to ensure refresh token is issued
params.Set("prompt", "consent")
t.logger.Debug("Google OIDC provider detected, added prompt=consent to ensure refresh tokens")
} else {
// For non-Google providers, use the offline_access scope
hasOfflineAccess := false
for _, scope := range scopes {
if scope == "offline_access" {
hasOfflineAccess = true
break
}
}
if !hasOfflineAccess {
scopes = append(scopes, "offline_access")
}
}
```
3. **Token Refresh Enhancement**: Improved the token refresh logic to better handle Google's behavior, particularly when refresh tokens aren't returned in refresh responses (as Google often uses the same refresh token for multiple requests).
## Why This Approach Works
This solution aligns with Google's OAuth 2.0 documentation which specifies:
1. **Access Type Parameter**: Google's [OAuth 2.0 documentation](https://developers.google.com/identity/protocols/oauth2/web-server#offline) states that to request a refresh token, applications must include `access_type=offline` in the authorization request.
2. **Prompt Parameter**: The [`prompt=consent`](https://developers.google.com/identity/protocols/oauth2/web-server#forceapprovalprompt) parameter forces the consent screen to appear, ensuring a refresh token is issued even if the user has previously granted access.
3. **Scope Validation**: Google strictly validates scopes and rejects non-standard ones like `offline_access`, instead relying on the `access_type` parameter to indicate whether a refresh token should be issued.
By adapting to these Google-specific requirements, the OIDC plugin can now seamlessly work with both standard OIDC providers and Google's OAuth implementation.
## Testing and Verification
Comprehensive tests were implemented to verify the solution:
1. **Provider Detection Test**: Ensures the code correctly identifies Google providers and applies the appropriate parameters.
2. **Auth URL Parameter Tests**: Verifies that:
- For Google providers: `access_type=offline` and `prompt=consent` are included; `offline_access` scope is NOT included
- For non-Google providers: `offline_access` scope IS included; `access_type` parameter is NOT added
3. **Token Refresh Tests**: Validates that Google's token refresh process works correctly, including the preservation of refresh tokens when Google doesn't return a new one.
4. **Integration Test**: Tests the complete authentication flow with a mocked Google provider to ensure all components work together seamlessly.
Sample test case (simplified):
```go
t.Run("Google provider detection adds required parameters", func(t *testing.T) {
// Test buildAuthURL to ensure it adds access_type=offline and prompt=consent for Google
authURL := tOidc.buildAuthURL("https://example.com/callback", "state123", "nonce123", "")
// Check that access_type=offline was added (not offline_access scope for Google)
if !strings.Contains(authURL, "access_type=offline") {
t.Errorf("access_type=offline not added to Google auth URL: %s", authURL)
}
// Verify offline_access scope is NOT included for Google providers
if strings.Contains(authURL, "offline_access") {
t.Errorf("offline_access scope incorrectly added to Google auth URL: %s", authURL)
}
// Check that prompt=consent was added
if !strings.Contains(authURL, "prompt=consent") {
t.Errorf("prompt=consent not added to Google auth URL: %s", authURL)
}
})
```
## Usage Guidance for Developers
When configuring the Traefik OIDC middleware for Google:
1. **Provider URL**: Use `https://accounts.google.com` as the `providerURL` value
2. **Client Configuration**: Create OAuth 2.0 credentials in the Google Cloud Console:
- Configure the authorized redirect URI to match your `callbackURL` setting
- Ensure your OAuth consent screen is properly configured (especially if you want long-lived refresh tokens)
3. **Configuration Example**:
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-google
namespace: traefik
spec:
plugin:
traefikoidc:
providerURL: https://accounts.google.com
clientID: your-google-client-id.apps.googleusercontent.com
clientSecret: your-google-client-secret
sessionEncryptionKey: your-secure-encryption-key-min-32-chars
callbackURL: /oauth2/callback
scopes:
- openid
- email
- profile
# Note: DO NOT manually add offline_access scope for Google
# The middleware handles this automatically and correctly
```
4. **Troubleshooting**: If sessions still expire prematurely with Google (typically after 1 hour):
- Ensure your Google Cloud OAuth consent screen is set to "External" and "Production" mode (not "Testing" mode, which limits refresh token validity)
- Review your application logs with `logLevel: debug` to check for refresh token errors
- Verify you're using a version of the middleware that includes this fix
## Conclusion
This fix ensures that the Traefik OIDC plugin works seamlessly with Google's OAuth implementation without requiring users to make provider-specific configuration changes. The middleware now intelligently adapts to the provider's requirements, making it more robust and user-friendly while maintaining compatibility with the standard OIDC specification for other providers.
+1373
View File
File diff suppressed because it is too large Load Diff
+540
View File
@@ -0,0 +1,540 @@
// Package traefikoidc provides OIDC authentication middleware for Traefik
package traefikoidc
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"strings"
"sync"
"time"
)
// ClientRegistrationResponse represents the response from a successful client registration (RFC 7591)
type ClientRegistrationResponse struct {
SubjectType string `json:"subject_type,omitempty"`
LogoURI string `json:"logo_uri,omitempty"`
RegistrationAccessToken string `json:"registration_access_token,omitempty"`
RegistrationClientURI string `json:"registration_client_uri,omitempty"`
Scope string `json:"scope,omitempty"`
TokenEndpointAuthMethod string `json:"token_endpoint_auth_method,omitempty"`
TOSURI string `json:"tos_uri,omitempty"`
PolicyURI string `json:"policy_uri,omitempty"`
ClientSecret string `json:"client_secret,omitempty"`
ApplicationType string `json:"application_type,omitempty"`
ClientID string `json:"client_id"`
ClientName string `json:"client_name,omitempty"`
JWKSURI string `json:"jwks_uri,omitempty"`
ClientURI string `json:"client_uri,omitempty"`
Contacts []string `json:"contacts,omitempty"`
GrantTypes []string `json:"grant_types,omitempty"`
ResponseTypes []string `json:"response_types,omitempty"`
RedirectURIs []string `json:"redirect_uris,omitempty"`
ClientSecretExpiresAt int64 `json:"client_secret_expires_at,omitempty"`
ClientIDIssuedAt int64 `json:"client_id_issued_at,omitempty"`
}
// ClientRegistrationError represents an error response from client registration (RFC 7591)
type ClientRegistrationError struct {
Error string `json:"error"`
ErrorDescription string `json:"error_description,omitempty"`
}
// DynamicClientRegistrar handles OIDC Dynamic Client Registration (RFC 7591)
type DynamicClientRegistrar struct {
httpClient *http.Client
logger *Logger
config *DynamicClientRegistrationConfig
registrationResponse *ClientRegistrationResponse
providerURL string
mu sync.RWMutex
}
// NewDynamicClientRegistrar creates a new dynamic client registrar
func NewDynamicClientRegistrar(
httpClient *http.Client,
logger *Logger,
dcrConfig *DynamicClientRegistrationConfig,
providerURL string,
) *DynamicClientRegistrar {
if logger == nil {
logger = GetSingletonNoOpLogger()
}
return &DynamicClientRegistrar{
httpClient: httpClient,
logger: logger,
config: dcrConfig,
providerURL: providerURL,
}
}
// RegisterClient performs dynamic client registration with the OIDC provider
// It first attempts to load existing credentials from a file if persistence is enabled,
// then registers a new client if no valid credentials exist.
func (r *DynamicClientRegistrar) RegisterClient(ctx context.Context, registrationEndpoint string) (*ClientRegistrationResponse, error) {
if r.config == nil || !r.config.Enabled {
return nil, fmt.Errorf("dynamic client registration is not enabled")
}
// Try to load existing credentials if persistence is enabled
if r.config.PersistCredentials {
if resp, err := r.loadCredentials(); err == nil && resp != nil {
// Check if credentials are still valid (not expired)
if r.areCredentialsValid(resp) {
r.logger.Info("Loaded existing client credentials from file")
r.mu.Lock()
r.registrationResponse = resp
r.mu.Unlock()
return resp, nil
}
r.logger.Info("Existing credentials expired or invalid, registering new client")
}
}
// Determine registration endpoint
endpoint := registrationEndpoint
if r.config.RegistrationEndpoint != "" {
endpoint = r.config.RegistrationEndpoint
}
if endpoint == "" {
return nil, fmt.Errorf("no registration endpoint available: provider does not support dynamic client registration or endpoint not configured")
}
// Validate the endpoint URL
if !strings.HasPrefix(endpoint, "https://") {
// Allow http only for localhost/development
if !strings.HasPrefix(endpoint, "http://localhost") && !strings.HasPrefix(endpoint, "http://127.0.0.1") {
return nil, fmt.Errorf("registration endpoint must use HTTPS for security")
}
r.logger.Infof("Warning: using insecure HTTP for registration endpoint (development only): %s", endpoint)
}
// Build registration request
reqBody, err := r.buildRegistrationRequest()
if err != nil {
return nil, fmt.Errorf("failed to build registration request: %w", err)
}
r.logger.Debugf("Registering client at endpoint: %s", endpoint)
// Create HTTP request
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(reqBody))
if err != nil {
return nil, fmt.Errorf("failed to create registration request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
// Add Initial Access Token if provided
if r.config.InitialAccessToken != "" {
req.Header.Set("Authorization", "Bearer "+r.config.InitialAccessToken)
}
// Execute request
resp, err := r.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("registration request failed: %w", err)
}
defer resp.Body.Close()
// Read response body
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) // 1MB limit
if err != nil {
return nil, fmt.Errorf("failed to read registration response: %w", err)
}
// Handle error responses
if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK {
var regError ClientRegistrationError
if jsonErr := json.Unmarshal(body, &regError); jsonErr == nil && regError.Error != "" {
return nil, fmt.Errorf("registration failed: %s - %s", regError.Error, regError.ErrorDescription)
}
return nil, fmt.Errorf("registration failed with status %d: %s", resp.StatusCode, string(body))
}
// Parse successful response
var regResp ClientRegistrationResponse
if err := json.Unmarshal(body, &regResp); err != nil {
return nil, fmt.Errorf("failed to parse registration response: %w", err)
}
// Validate response
if regResp.ClientID == "" {
return nil, fmt.Errorf("registration response missing client_id")
}
r.logger.Infof("Successfully registered client with ID: %s", regResp.ClientID)
// Cache the response
r.mu.Lock()
r.registrationResponse = &regResp
r.mu.Unlock()
// Persist credentials if enabled
if r.config.PersistCredentials {
if err := r.saveCredentials(&regResp); err != nil {
r.logger.Errorf("Failed to persist client credentials: %v", err)
// Don't fail registration if persistence fails
}
}
return &regResp, nil
}
// buildRegistrationRequest creates the JSON request body for client registration
func (r *DynamicClientRegistrar) buildRegistrationRequest() ([]byte, error) {
metadata := r.config.ClientMetadata
if metadata == nil {
metadata = &ClientRegistrationMetadata{}
}
// Build request object
reqData := make(map[string]interface{})
// Required: redirect_uris
if len(metadata.RedirectURIs) > 0 {
reqData["redirect_uris"] = metadata.RedirectURIs
} else {
return nil, fmt.Errorf("redirect_uris is required for client registration")
}
// Optional fields - only include if set
if len(metadata.ResponseTypes) > 0 {
reqData["response_types"] = metadata.ResponseTypes
} else {
// Default to authorization code flow
reqData["response_types"] = []string{"code"}
}
if len(metadata.GrantTypes) > 0 {
reqData["grant_types"] = metadata.GrantTypes
} else {
// Default grant types for authorization code flow
reqData["grant_types"] = []string{"authorization_code", "refresh_token"}
}
if metadata.ApplicationType != "" {
reqData["application_type"] = metadata.ApplicationType
}
if len(metadata.Contacts) > 0 {
reqData["contacts"] = metadata.Contacts
}
if metadata.ClientName != "" {
reqData["client_name"] = metadata.ClientName
}
if metadata.LogoURI != "" {
reqData["logo_uri"] = metadata.LogoURI
}
if metadata.ClientURI != "" {
reqData["client_uri"] = metadata.ClientURI
}
if metadata.PolicyURI != "" {
reqData["policy_uri"] = metadata.PolicyURI
}
if metadata.TOSURI != "" {
reqData["tos_uri"] = metadata.TOSURI
}
if metadata.JWKSURI != "" {
reqData["jwks_uri"] = metadata.JWKSURI
}
if metadata.SubjectType != "" {
reqData["subject_type"] = metadata.SubjectType
}
if metadata.TokenEndpointAuthMethod != "" {
reqData["token_endpoint_auth_method"] = metadata.TokenEndpointAuthMethod
} else {
// Default to client_secret_basic for confidential clients
reqData["token_endpoint_auth_method"] = "client_secret_basic"
}
if metadata.DefaultMaxAge > 0 {
reqData["default_max_age"] = metadata.DefaultMaxAge
}
if metadata.RequireAuthTime {
reqData["require_auth_time"] = metadata.RequireAuthTime
}
if len(metadata.DefaultACRValues) > 0 {
reqData["default_acr_values"] = metadata.DefaultACRValues
}
if metadata.Scope != "" {
reqData["scope"] = metadata.Scope
}
return json.Marshal(reqData)
}
// GetCachedResponse returns the cached registration response
func (r *DynamicClientRegistrar) GetCachedResponse() *ClientRegistrationResponse {
r.mu.RLock()
defer r.mu.RUnlock()
return r.registrationResponse
}
// areCredentialsValid checks if the cached credentials are still valid
func (r *DynamicClientRegistrar) areCredentialsValid(resp *ClientRegistrationResponse) bool {
if resp == nil || resp.ClientID == "" {
return false
}
// Check if secret has expired
if resp.ClientSecretExpiresAt > 0 {
expiresAt := time.Unix(resp.ClientSecretExpiresAt, 0)
// Add 5 minute buffer before expiration
if time.Now().Add(5 * time.Minute).After(expiresAt) {
return false
}
}
return true
}
// credentialsFilePath returns the path for storing credentials
func (r *DynamicClientRegistrar) credentialsFilePath() string {
if r.config.CredentialsFile != "" {
return r.config.CredentialsFile
}
return "/tmp/oidc-client-credentials.json"
}
// saveCredentials persists client credentials to a file
func (r *DynamicClientRegistrar) saveCredentials(resp *ClientRegistrationResponse) error {
filePath := r.credentialsFilePath()
data, err := json.MarshalIndent(resp, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal credentials: %w", err)
}
// Write with restrictive permissions (owner read/write only)
if err := os.WriteFile(filePath, data, 0600); err != nil {
return fmt.Errorf("failed to write credentials file: %w", err)
}
r.logger.Debugf("Saved client credentials to %s", filePath)
return nil
}
// loadCredentials loads client credentials from a file
func (r *DynamicClientRegistrar) loadCredentials() (*ClientRegistrationResponse, error) {
filePath := r.credentialsFilePath()
// #nosec G304 -- path is constructed from trusted config values via credentialsFilePath()
data, err := os.ReadFile(filePath)
if err != nil {
if os.IsNotExist(err) {
return nil, nil // No credentials file exists
}
return nil, fmt.Errorf("failed to read credentials file: %w", err)
}
var resp ClientRegistrationResponse
if err := json.Unmarshal(data, &resp); err != nil {
return nil, fmt.Errorf("failed to parse credentials file: %w", err)
}
return &resp, nil
}
// UpdateClientRegistration updates an existing client registration using RFC 7592
// This requires the registration_client_uri and registration_access_token from the original registration
func (r *DynamicClientRegistrar) UpdateClientRegistration(ctx context.Context) (*ClientRegistrationResponse, error) {
r.mu.RLock()
cachedResp := r.registrationResponse
r.mu.RUnlock()
if cachedResp == nil {
return nil, fmt.Errorf("no existing registration to update")
}
if cachedResp.RegistrationClientURI == "" || cachedResp.RegistrationAccessToken == "" {
return nil, fmt.Errorf("registration management not supported: missing registration_client_uri or registration_access_token")
}
// Build update request
reqBody, err := r.buildRegistrationRequest()
if err != nil {
return nil, fmt.Errorf("failed to build update request: %w", err)
}
// Create HTTP request
req, err := http.NewRequestWithContext(ctx, http.MethodPut, cachedResp.RegistrationClientURI, bytes.NewReader(reqBody))
if err != nil {
return nil, fmt.Errorf("failed to create update request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Bearer "+cachedResp.RegistrationAccessToken)
// Execute request
resp, err := r.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("update request failed: %w", err)
}
defer resp.Body.Close()
// Read response body
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil {
return nil, fmt.Errorf("failed to read update response: %w", err)
}
// Handle error responses
if resp.StatusCode != http.StatusOK {
var regError ClientRegistrationError
if jsonErr := json.Unmarshal(body, &regError); jsonErr == nil && regError.Error != "" {
return nil, fmt.Errorf("update failed: %s - %s", regError.Error, regError.ErrorDescription)
}
return nil, fmt.Errorf("update failed with status %d: %s", resp.StatusCode, string(body))
}
// Parse successful response
var regResp ClientRegistrationResponse
if err := json.Unmarshal(body, &regResp); err != nil {
return nil, fmt.Errorf("failed to parse update response: %w", err)
}
// Update cache
r.mu.Lock()
r.registrationResponse = &regResp
r.mu.Unlock()
// Persist updated credentials if enabled
if r.config.PersistCredentials {
if err := r.saveCredentials(&regResp); err != nil {
r.logger.Errorf("Failed to persist updated credentials: %v", err)
}
}
r.logger.Infof("Successfully updated client registration for client ID: %s", regResp.ClientID)
return &regResp, nil
}
// ReadClientRegistration reads the current client registration using RFC 7592
func (r *DynamicClientRegistrar) ReadClientRegistration(ctx context.Context) (*ClientRegistrationResponse, error) {
r.mu.RLock()
cachedResp := r.registrationResponse
r.mu.RUnlock()
if cachedResp == nil {
return nil, fmt.Errorf("no existing registration to read")
}
if cachedResp.RegistrationClientURI == "" || cachedResp.RegistrationAccessToken == "" {
return nil, fmt.Errorf("registration management not supported: missing registration_client_uri or registration_access_token")
}
// Create HTTP request
req, err := http.NewRequestWithContext(ctx, http.MethodGet, cachedResp.RegistrationClientURI, nil)
if err != nil {
return nil, fmt.Errorf("failed to create read request: %w", err)
}
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Bearer "+cachedResp.RegistrationAccessToken)
// Execute request
resp, err := r.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("read request failed: %w", err)
}
defer resp.Body.Close()
// Read response body
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
// Handle error responses
if resp.StatusCode != http.StatusOK {
var regError ClientRegistrationError
if jsonErr := json.Unmarshal(body, &regError); jsonErr == nil && regError.Error != "" {
return nil, fmt.Errorf("read failed: %s - %s", regError.Error, regError.ErrorDescription)
}
return nil, fmt.Errorf("read failed with status %d: %s", resp.StatusCode, string(body))
}
// Parse successful response
var regResp ClientRegistrationResponse
if err := json.Unmarshal(body, &regResp); err != nil {
return nil, fmt.Errorf("failed to parse read response: %w", err)
}
return &regResp, nil
}
// DeleteClientRegistration deletes the client registration using RFC 7592
func (r *DynamicClientRegistrar) DeleteClientRegistration(ctx context.Context) error {
r.mu.RLock()
cachedResp := r.registrationResponse
r.mu.RUnlock()
if cachedResp == nil {
return fmt.Errorf("no existing registration to delete")
}
if cachedResp.RegistrationClientURI == "" || cachedResp.RegistrationAccessToken == "" {
return fmt.Errorf("registration management not supported: missing registration_client_uri or registration_access_token")
}
// Create HTTP request
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, cachedResp.RegistrationClientURI, nil)
if err != nil {
return fmt.Errorf("failed to create delete request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+cachedResp.RegistrationAccessToken)
// Execute request
resp, err := r.httpClient.Do(req)
if err != nil {
return fmt.Errorf("delete request failed: %w", err)
}
defer resp.Body.Close()
// Handle error responses (204 No Content is success)
if resp.StatusCode != http.StatusNoContent && resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
var regError ClientRegistrationError
if jsonErr := json.Unmarshal(body, &regError); jsonErr == nil && regError.Error != "" {
return fmt.Errorf("delete failed: %s - %s", regError.Error, regError.ErrorDescription)
}
return fmt.Errorf("delete failed with status %d: %s", resp.StatusCode, string(body))
}
// Clear cache
r.mu.Lock()
r.registrationResponse = nil
r.mu.Unlock()
// Remove credentials file if persistence is enabled
if r.config.PersistCredentials {
filePath := r.credentialsFilePath()
if err := os.Remove(filePath); err != nil && !os.IsNotExist(err) {
r.logger.Errorf("Failed to remove credentials file: %v", err)
}
}
r.logger.Info("Successfully deleted client registration")
return nil
}
File diff suppressed because it is too large Load Diff
+620
View File
@@ -0,0 +1,620 @@
package traefikoidc
import (
"context"
"encoding/base64"
"math/big"
"net/http"
"net/http/httptest"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/lukaszraczylo/traefikoidc/internal/testutil"
"github.com/stretchr/testify/suite"
"golang.org/x/time/rate"
)
// ClockSkewEdgeCasesSuite tests clock skew tolerance scenarios
type ClockSkewEdgeCasesSuite struct {
suite.Suite
fixture *testutil.TokenFixture
tOidc *TraefikOidc
}
func (s *ClockSkewEdgeCasesSuite) SetupSuite() {
var err error
s.fixture, err = testutil.NewTokenFixture()
s.Require().NoError(err)
}
func (s *ClockSkewEdgeCasesSuite) SetupTest() {
// Create JWK for the test key
jwk := JWK{
Kty: "RSA",
Kid: s.fixture.KeyID,
Alg: "RS256",
N: base64.RawURLEncoding.EncodeToString(s.fixture.RSAPublicKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString(bigIntToBytes(big.NewInt(int64(s.fixture.RSAPublicKey.E)))),
}
jwkCache := &MockJWKCache{
JWKS: &JWKSet{Keys: []JWK{jwk}},
Err: nil,
}
tokenBlacklist := NewCache()
tokenCacheInternal := NewCache()
tokenCache := &TokenCache{}
if tokenCache.cache == nil {
if wrapper, ok := tokenCacheInternal.(*CacheInterfaceWrapper); ok {
tokenCache.cache = wrapper.cache
}
}
logger := NewLogger("error") // Reduce noise
s.tOidc = &TraefikOidc{
issuerURL: s.fixture.Issuer,
clientID: s.fixture.Audience,
audience: s.fixture.Audience,
clientSecret: "test-client-secret",
roleClaimName: "roles",
groupClaimName: "groups",
userIdentifierClaim: "email",
jwkCache: jwkCache,
jwksURL: "https://test-jwks-url.com",
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
tokenBlacklist: tokenBlacklist,
tokenCache: tokenCache,
logger: logger,
httpClient: &http.Client{Timeout: 10 * time.Second},
extractClaimsFunc: extractClaims,
initComplete: make(chan struct{}),
goroutineWG: &sync.WaitGroup{},
ctx: context.Background(),
}
close(s.tOidc.initComplete)
s.tOidc.tokenVerifier = s.tOidc
s.tOidc.jwtVerifier = s.tOidc
s.T().Cleanup(func() {
if s.tOidc.tokenBlacklist != nil {
s.tOidc.tokenBlacklist.Close()
}
if s.tOidc.tokenCache != nil && s.tOidc.tokenCache.cache != nil {
s.tOidc.tokenCache.cache.Close()
}
})
}
func (s *ClockSkewEdgeCasesSuite) TestExactlyAtExpiry() {
token, err := s.fixture.TokenWithSkew(0)
s.Require().NoError(err)
// Token at exact expiry - behavior is implementation-defined
err = s.tOidc.VerifyToken(token)
s.T().Logf("Exact expiry result: %v", err)
}
func (s *ClockSkewEdgeCasesSuite) TestOneSecondBeforeExpiry() {
token, err := s.fixture.TokenWithSkew(1 * time.Second)
s.Require().NoError(err)
err = s.tOidc.VerifyToken(token)
s.NoError(err, "Token should be valid 1 second before expiry")
}
func (s *ClockSkewEdgeCasesSuite) TestOneSecondAfterExpiry() {
token, err := s.fixture.TokenWithSkew(-1 * time.Second)
s.Require().NoError(err)
err = s.tOidc.VerifyToken(token)
// With default 2-minute clock skew tolerance, 1 second past expiry should still be valid
s.NoError(err, "Token 1 second past expiry should be valid within clock skew tolerance")
}
func (s *ClockSkewEdgeCasesSuite) TestWithinSkewTolerance() {
// Most implementations allow 5-minute clock skew
token, err := s.fixture.TokenWithSkew(-4 * time.Minute)
s.Require().NoError(err)
err = s.tOidc.VerifyToken(token)
// May pass or fail depending on implementation
s.T().Logf("4-minute expired token result: %v", err)
}
func (s *ClockSkewEdgeCasesSuite) TestBeyondSkewTolerance() {
token, err := s.fixture.TokenWithSkew(-10 * time.Minute)
s.Require().NoError(err)
err = s.tOidc.VerifyToken(token)
s.Error(err, "Token should be invalid 10 minutes after expiry")
}
func TestClockSkewEdgeCasesSuite(t *testing.T) {
suite.Run(t, new(ClockSkewEdgeCasesSuite))
}
// UnicodeClaimsSuite tests Unicode handling in JWT claims
type UnicodeClaimsSuite struct {
suite.Suite
fixture *testutil.TokenFixture
tOidc *TraefikOidc
}
func (s *UnicodeClaimsSuite) SetupSuite() {
var err error
s.fixture, err = testutil.NewTokenFixture()
s.Require().NoError(err)
}
func (s *UnicodeClaimsSuite) SetupTest() {
jwk := JWK{
Kty: "RSA",
Kid: s.fixture.KeyID,
Alg: "RS256",
N: base64.RawURLEncoding.EncodeToString(s.fixture.RSAPublicKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString(bigIntToBytes(big.NewInt(int64(s.fixture.RSAPublicKey.E)))),
}
jwkCache := &MockJWKCache{
JWKS: &JWKSet{Keys: []JWK{jwk}},
Err: nil,
}
tokenBlacklist := NewCache()
tokenCacheInternal := NewCache()
tokenCache := &TokenCache{}
if tokenCache.cache == nil {
if wrapper, ok := tokenCacheInternal.(*CacheInterfaceWrapper); ok {
tokenCache.cache = wrapper.cache
}
}
logger := NewLogger("error")
s.tOidc = &TraefikOidc{
issuerURL: s.fixture.Issuer,
clientID: s.fixture.Audience,
audience: s.fixture.Audience,
clientSecret: "test-client-secret",
roleClaimName: "roles",
groupClaimName: "groups",
userIdentifierClaim: "email",
jwkCache: jwkCache,
jwksURL: "https://test-jwks-url.com",
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
tokenBlacklist: tokenBlacklist,
tokenCache: tokenCache,
logger: logger,
httpClient: &http.Client{Timeout: 10 * time.Second},
extractClaimsFunc: extractClaims,
initComplete: make(chan struct{}),
goroutineWG: &sync.WaitGroup{},
ctx: context.Background(),
}
close(s.tOidc.initComplete)
s.tOidc.tokenVerifier = s.tOidc
s.tOidc.jwtVerifier = s.tOidc
s.T().Cleanup(func() {
if s.tOidc.tokenBlacklist != nil {
s.tOidc.tokenBlacklist.Close()
}
if s.tOidc.tokenCache != nil && s.tOidc.tokenCache.cache != nil {
s.tOidc.tokenCache.cache.Close()
}
})
}
func (s *UnicodeClaimsSuite) TestUnicodeEmail() {
token, err := s.fixture.TokenWithEmail("用户@example.com")
s.Require().NoError(err)
err = s.tOidc.VerifyToken(token)
s.NoError(err, "Unicode email should be handled correctly")
}
func (s *UnicodeClaimsSuite) TestUnicodeName() {
token, err := s.fixture.TokenWithCustomClaims(map[string]interface{}{
"name": "田中太郎",
})
s.Require().NoError(err)
err = s.tOidc.VerifyToken(token)
s.NoError(err, "Unicode name should be handled correctly")
}
func (s *UnicodeClaimsSuite) TestEmojiInClaims() {
token, err := s.fixture.TokenWithCustomClaims(map[string]interface{}{
"name": "Test User 😀",
})
s.Require().NoError(err)
err = s.tOidc.VerifyToken(token)
s.NoError(err, "Emoji in claims should be handled correctly")
}
func (s *UnicodeClaimsSuite) TestRTLText() {
token, err := s.fixture.TokenWithCustomClaims(map[string]interface{}{
"name": "مستخدم اختبار",
})
s.Require().NoError(err)
err = s.tOidc.VerifyToken(token)
s.NoError(err, "RTL text should be handled correctly")
}
func (s *UnicodeClaimsSuite) TestMixedScripts() {
token, err := s.fixture.TokenWithCustomClaims(map[string]interface{}{
"name": "Test 测试 テスト",
"roles": []string{"admin", "管理者", "管理员"},
})
s.Require().NoError(err)
err = s.tOidc.VerifyToken(token)
s.NoError(err, "Mixed scripts should be handled correctly")
}
func TestUnicodeClaimsSuite(t *testing.T) {
suite.Run(t, new(UnicodeClaimsSuite))
}
// LargeClaimsSuite tests large claim values
type LargeClaimsSuite struct {
suite.Suite
fixture *testutil.TokenFixture
tOidc *TraefikOidc
}
func (s *LargeClaimsSuite) SetupSuite() {
var err error
s.fixture, err = testutil.NewTokenFixture()
s.Require().NoError(err)
}
func (s *LargeClaimsSuite) SetupTest() {
jwk := JWK{
Kty: "RSA",
Kid: s.fixture.KeyID,
Alg: "RS256",
N: base64.RawURLEncoding.EncodeToString(s.fixture.RSAPublicKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString(bigIntToBytes(big.NewInt(int64(s.fixture.RSAPublicKey.E)))),
}
jwkCache := &MockJWKCache{
JWKS: &JWKSet{Keys: []JWK{jwk}},
Err: nil,
}
tokenBlacklist := NewCache()
tokenCacheInternal := NewCache()
tokenCache := &TokenCache{}
if tokenCache.cache == nil {
if wrapper, ok := tokenCacheInternal.(*CacheInterfaceWrapper); ok {
tokenCache.cache = wrapper.cache
}
}
logger := NewLogger("error")
s.tOidc = &TraefikOidc{
issuerURL: s.fixture.Issuer,
clientID: s.fixture.Audience,
audience: s.fixture.Audience,
clientSecret: "test-client-secret",
roleClaimName: "roles",
groupClaimName: "groups",
userIdentifierClaim: "email",
jwkCache: jwkCache,
jwksURL: "https://test-jwks-url.com",
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
tokenBlacklist: tokenBlacklist,
tokenCache: tokenCache,
logger: logger,
httpClient: &http.Client{Timeout: 10 * time.Second},
extractClaimsFunc: extractClaims,
initComplete: make(chan struct{}),
goroutineWG: &sync.WaitGroup{},
ctx: context.Background(),
}
close(s.tOidc.initComplete)
s.tOidc.tokenVerifier = s.tOidc
s.tOidc.jwtVerifier = s.tOidc
s.T().Cleanup(func() {
if s.tOidc.tokenBlacklist != nil {
s.tOidc.tokenBlacklist.Close()
}
if s.tOidc.tokenCache != nil && s.tOidc.tokenCache.cache != nil {
s.tOidc.tokenCache.cache.Close()
}
})
}
func (s *LargeClaimsSuite) TestManyRoles() {
roles := make([]string, 100)
for i := 0; i < 100; i++ {
roles[i] = strings.Repeat("role", 10) + string(rune('A'+i%26))
}
token, err := s.fixture.TokenWithRoles(roles)
s.Require().NoError(err)
err = s.tOidc.VerifyToken(token)
s.NoError(err, "Token with 100 roles should be handled")
}
func (s *LargeClaimsSuite) TestManyGroups() {
groups := make([]string, 50)
for i := 0; i < 50; i++ {
groups[i] = strings.Repeat("group", 5) + string(rune('A'+i%26))
}
token, err := s.fixture.TokenWithGroups(groups)
s.Require().NoError(err)
err = s.tOidc.VerifyToken(token)
s.NoError(err, "Token with 50 groups should be handled")
}
func (s *LargeClaimsSuite) TestLongEmail() {
// RFC 5321 allows up to 254 characters
localPart := strings.Repeat("a", 64)
domain := strings.Repeat("b", 63) + ".com"
email := localPart + "@" + domain
token, err := s.fixture.TokenWithEmail(email)
s.Require().NoError(err)
err = s.tOidc.VerifyToken(token)
s.NoError(err, "Token with long email should be handled")
}
func (s *LargeClaimsSuite) TestLongSubject() {
longSub := strings.Repeat("subject", 100)
token, err := s.fixture.TokenWithCustomClaims(map[string]interface{}{
"sub": longSub,
})
s.Require().NoError(err)
err = s.tOidc.VerifyToken(token)
s.NoError(err, "Token with long subject should be handled")
}
func TestLargeClaimsSuite(t *testing.T) {
suite.Run(t, new(LargeClaimsSuite))
}
// URLPathEdgeCasesSuite tests URL handling edge cases
type URLPathEdgeCasesSuite struct {
suite.Suite
}
func (s *URLPathEdgeCasesSuite) TestVeryLongPath() {
longPath := "/" + strings.Repeat("segment/", 100)
req := httptest.NewRequest("GET", longPath, nil)
s.NotNil(req)
s.Contains(req.URL.Path, "segment")
}
func (s *URLPathEdgeCasesSuite) TestSpecialCharactersInPath() {
paths := []string{
"/path%20with%20spaces",
"/path/with/日本語",
"/path?query=value&another=test",
"/path#fragment",
"/path/../traversal",
"/path/./current",
}
for _, path := range paths {
s.Run(path, func() {
req := httptest.NewRequest("GET", path, nil)
s.NotNil(req)
})
}
}
func (s *URLPathEdgeCasesSuite) TestEmptyPath() {
req := httptest.NewRequest("GET", "/", nil)
s.Equal("/", req.URL.Path)
}
func (s *URLPathEdgeCasesSuite) TestDoubleSlashes() {
req := httptest.NewRequest("GET", "//double//slashes//", nil)
s.NotNil(req)
}
func TestURLPathEdgeCasesSuite(t *testing.T) {
suite.Run(t, new(URLPathEdgeCasesSuite))
}
// ConcurrencyEdgeCasesSuite tests concurrency scenarios
type ConcurrencyEdgeCasesSuite struct {
suite.Suite
fixture *testutil.TokenFixture
tOidc *TraefikOidc
}
func (s *ConcurrencyEdgeCasesSuite) SetupSuite() {
var err error
s.fixture, err = testutil.NewTokenFixture()
s.Require().NoError(err)
}
func (s *ConcurrencyEdgeCasesSuite) SetupTest() {
jwk := JWK{
Kty: "RSA",
Kid: s.fixture.KeyID,
Alg: "RS256",
N: base64.RawURLEncoding.EncodeToString(s.fixture.RSAPublicKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString(bigIntToBytes(big.NewInt(int64(s.fixture.RSAPublicKey.E)))),
}
jwkCache := &MockJWKCache{
JWKS: &JWKSet{Keys: []JWK{jwk}},
Err: nil,
}
tokenBlacklist := NewCache()
tokenCacheInternal := NewCache()
tokenCache := &TokenCache{}
if tokenCache.cache == nil {
if wrapper, ok := tokenCacheInternal.(*CacheInterfaceWrapper); ok {
tokenCache.cache = wrapper.cache
}
}
logger := NewLogger("error")
s.tOidc = &TraefikOidc{
issuerURL: s.fixture.Issuer,
clientID: s.fixture.Audience,
audience: s.fixture.Audience,
clientSecret: "test-client-secret",
roleClaimName: "roles",
groupClaimName: "groups",
userIdentifierClaim: "email",
jwkCache: jwkCache,
jwksURL: "https://test-jwks-url.com",
limiter: rate.NewLimiter(rate.Every(time.Second), 100), // Higher limit for concurrency tests
tokenBlacklist: tokenBlacklist,
tokenCache: tokenCache,
logger: logger,
httpClient: &http.Client{Timeout: 10 * time.Second},
extractClaimsFunc: extractClaims,
initComplete: make(chan struct{}),
goroutineWG: &sync.WaitGroup{},
ctx: context.Background(),
}
close(s.tOidc.initComplete)
s.tOidc.tokenVerifier = s.tOidc
s.tOidc.jwtVerifier = s.tOidc
s.T().Cleanup(func() {
if s.tOidc.tokenBlacklist != nil {
s.tOidc.tokenBlacklist.Close()
}
if s.tOidc.tokenCache != nil && s.tOidc.tokenCache.cache != nil {
s.tOidc.tokenCache.cache.Close()
}
})
}
func (s *ConcurrencyEdgeCasesSuite) TestConcurrentTokenValidation() {
token, err := s.fixture.ValidToken(nil)
s.Require().NoError(err)
const goroutines = 50
var wg sync.WaitGroup
errors := make(chan error, goroutines)
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
if err := s.tOidc.VerifyToken(token); err != nil {
errors <- err
}
}()
}
wg.Wait()
close(errors)
var errCount int
for err := range errors {
s.T().Logf("Concurrent error: %v", err)
errCount++
}
s.Equal(0, errCount, "All concurrent validations should succeed")
}
func (s *ConcurrencyEdgeCasesSuite) TestConcurrentDifferentTokens() {
const goroutines = 20
var wg sync.WaitGroup
errors := make(chan error, goroutines)
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
token, err := s.fixture.TokenWithCustomClaims(map[string]interface{}{
"custom": idx,
})
if err != nil {
errors <- err
return
}
if err := s.tOidc.VerifyToken(token); err != nil {
errors <- err
}
}(i)
}
wg.Wait()
close(errors)
var errCount int
for err := range errors {
s.T().Logf("Concurrent different token error: %v", err)
errCount++
}
s.Equal(0, errCount, "All concurrent different token validations should succeed")
}
func (s *ConcurrencyEdgeCasesSuite) TestConcurrentMixedValidInvalid() {
validToken, err := s.fixture.ValidToken(nil)
s.Require().NoError(err)
expiredToken, err := s.fixture.ExpiredToken()
s.Require().NoError(err)
const goroutines = 40
var wg sync.WaitGroup
validCount := int32(0)
expiredCount := int32(0)
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
var token string
if idx%2 == 0 {
token = validToken
} else {
token = expiredToken
}
err := s.tOidc.VerifyToken(token)
if idx%2 == 0 {
if err == nil {
atomic.AddInt32(&validCount, 1)
}
} else {
if err != nil {
atomic.AddInt32(&expiredCount, 1)
}
}
}(i)
}
wg.Wait()
s.T().Logf("Valid passed: %d, Expired rejected: %d", validCount, expiredCount)
}
func TestConcurrencyEdgeCasesSuite(t *testing.T) {
suite.Run(t, new(ConcurrencyEdgeCasesSuite))
}
+258
View File
@@ -0,0 +1,258 @@
package traefikoidc
import (
"context"
"errors"
"testing"
"time"
"github.com/stretchr/testify/suite"
)
// EnhancedMocksSuite demonstrates improved state-based mocks with call tracking
type EnhancedMocksSuite struct {
suite.Suite
}
func (s *EnhancedMocksSuite) TestEnhancedJWKCacheCallTracking() {
mock := &EnhancedMockJWKCache{
JWKS: &JWKSet{Keys: []JWK{{Kid: "test-key"}}},
}
// Make some calls
result, err := mock.GetJWKS(context.Background(), "https://example.com/jwks", nil)
s.NoError(err)
s.NotNil(result)
// Another call with different URL
_, _ = mock.GetJWKS(context.Background(), "https://other.com/jwks", nil)
// Verify calls were tracked
s.Equal(2, mock.GetJWKSCallCount())
mock.AssertGetJWKSCalled(s.T())
mock.AssertGetJWKSCalledWith(s.T(), "https://example.com/jwks")
mock.AssertGetJWKSCallCount(s.T(), 2)
}
func (s *EnhancedMocksSuite) TestEnhancedJWKCacheWithError() {
expectedErr := errors.New("network error")
mock := &EnhancedMockJWKCache{
Err: expectedErr,
}
result, err := mock.GetJWKS(context.Background(), "https://example.com/jwks", nil)
s.Nil(result)
s.Equal(expectedErr, err)
mock.AssertGetJWKSCalled(s.T())
}
func (s *EnhancedMocksSuite) TestEnhancedJWKCacheReset() {
mock := &EnhancedMockJWKCache{
JWKS: &JWKSet{Keys: []JWK{{Kid: "test-key"}}},
}
_, _ = mock.GetJWKS(context.Background(), "https://example.com/jwks", nil)
s.Equal(1, mock.GetJWKSCallCount())
mock.Reset()
s.Equal(0, mock.GetJWKSCallCount())
s.Nil(mock.JWKS)
}
func (s *EnhancedMocksSuite) TestEnhancedTokenVerifierCallTracking() {
mock := &EnhancedMockTokenVerifier{
Err: nil, // Valid tokens
}
// Verify a token
err := mock.VerifyToken("test-token-1")
s.NoError(err)
// Verify another token
err = mock.VerifyToken("test-token-2")
s.NoError(err)
// Check tracking
s.Equal(2, mock.GetVerifyTokenCallCount())
mock.AssertVerifyTokenCalled(s.T())
mock.AssertVerifyTokenCalledWith(s.T(), "test-token-1")
// Check last call
lastCall := mock.LastCall()
s.NotNil(lastCall)
s.Equal("test-token-2", lastCall.Token)
}
func (s *EnhancedMocksSuite) TestEnhancedTokenVerifierWithDynamicFunc() {
callCount := 0
mock := &EnhancedMockTokenVerifier{
VerifyFunc: func(token string) error {
callCount++
if token == "invalid" {
return errors.New("invalid token")
}
return nil
},
}
// Valid token
err := mock.VerifyToken("valid-token")
s.NoError(err)
// Invalid token
err = mock.VerifyToken("invalid")
s.Error(err)
s.Equal(2, callCount)
s.Equal(2, mock.GetVerifyTokenCallCount())
}
func (s *EnhancedMocksSuite) TestEnhancedTokenExchangerCallTracking() {
mock := &EnhancedMockTokenExchanger{
ExchangeResponse: &TokenResponse{
AccessToken: "access-token",
RefreshToken: "refresh-token",
ExpiresIn: 3600,
},
RefreshResponse: &TokenResponse{
AccessToken: "new-access-token",
ExpiresIn: 3600,
},
}
// Exchange code
resp, err := mock.ExchangeCodeForToken(context.Background(), "authorization_code", "auth-code", "https://redirect.com", "verifier")
s.NoError(err)
s.Equal("access-token", resp.AccessToken)
// Refresh token
resp, err = mock.GetNewTokenWithRefreshToken("refresh-token")
s.NoError(err)
s.Equal("new-access-token", resp.AccessToken)
// Revoke token
err = mock.RevokeTokenWithProvider("access-token", "access_token")
s.NoError(err)
// Check tracking
mock.AssertExchangeCalled(s.T())
mock.AssertExchangeCalledWith(s.T(), "authorization_code")
mock.AssertRefreshCalled(s.T())
mock.AssertRevokeCalled(s.T())
s.Equal(1, mock.GetExchangeCallCount())
s.Equal(1, mock.GetRefreshCallCount())
s.Equal(1, mock.GetRevokeCallCount())
// Check last exchange call details
lastExchange := mock.LastExchangeCall()
s.NotNil(lastExchange)
s.Equal("authorization_code", lastExchange.GrantType)
s.Equal("auth-code", lastExchange.CodeOrToken)
s.Equal("https://redirect.com", lastExchange.RedirectURL)
}
func (s *EnhancedMocksSuite) TestEnhancedTokenExchangerWithErrors() {
mock := &EnhancedMockTokenExchanger{
ExchangeErr: errors.New("invalid_grant"),
RefreshErr: errors.New("refresh_expired"),
RevokeErr: errors.New("revoke_failed"),
}
_, err := mock.ExchangeCodeForToken(context.Background(), "authorization_code", "code", "", "")
s.Error(err)
s.Contains(err.Error(), "invalid_grant")
_, err = mock.GetNewTokenWithRefreshToken("token")
s.Error(err)
s.Contains(err.Error(), "refresh_expired")
err = mock.RevokeTokenWithProvider("token", "access_token")
s.Error(err)
s.Contains(err.Error(), "revoke_failed")
}
func (s *EnhancedMocksSuite) TestEnhancedCacheCallTracking() {
mock := NewEnhancedMockCache()
// Set some values
mock.Set("key1", "value1", 5*time.Minute)
mock.Set("key2", "value2", 10*time.Minute)
// Get values
val, found := mock.Get("key1")
s.True(found)
s.Equal("value1", val)
_, found = mock.Get("nonexistent")
s.False(found)
// Delete
mock.Delete("key1")
// Verify tracking
mock.AssertSetCalled(s.T(), "key1")
mock.AssertSetCalled(s.T(), "key2")
mock.AssertGetCalled(s.T(), "key1")
mock.AssertGetCalled(s.T(), "nonexistent")
mock.AssertDeleteCalled(s.T(), "key1")
s.Equal(2, mock.SetCallCount())
s.Equal(2, mock.GetCallCount())
}
func (s *EnhancedMocksSuite) TestEnhancedCacheActualStorage() {
mock := NewEnhancedMockCache()
// The enhanced mock actually stores data
mock.Set("key", "value", time.Hour)
s.Equal(1, mock.Size())
val, found := mock.Get("key")
s.True(found)
s.Equal("value", val)
mock.Delete("key")
s.Equal(0, mock.Size())
_, found = mock.Get("key")
s.False(found)
}
func (s *EnhancedMocksSuite) TestEnhancedCacheClear() {
mock := NewEnhancedMockCache()
mock.Set("key1", "value1", time.Hour)
mock.Set("key2", "value2", time.Hour)
s.Equal(2, mock.Size())
mock.Clear()
s.Equal(0, mock.Size())
}
func (s *EnhancedMocksSuite) TestConcurrentAccess() {
mock := &EnhancedMockJWKCache{
JWKS: &JWKSet{Keys: []JWK{{Kid: "test-key"}}},
}
// Concurrent calls should be safe
done := make(chan bool)
for i := 0; i < 10; i++ {
go func() {
_, _ = mock.GetJWKS(context.Background(), "https://example.com/jwks", nil)
done <- true
}()
}
for i := 0; i < 10; i++ {
<-done
}
s.Equal(10, mock.GetJWKSCallCount())
}
func TestEnhancedMocksSuite(t *testing.T) {
suite.Run(t, new(EnhancedMocksSuite))
}
+577
View File
@@ -0,0 +1,577 @@
package traefikoidc
import (
"context"
"net/http"
"sync"
"sync/atomic"
"time"
"github.com/stretchr/testify/assert"
)
// EnhancedMockJWKCache is an improved state-based mock with call tracking
type EnhancedMockJWKCache struct {
Err error
JWKS *JWKSet
GetJWKSCalls []JWKSCall
mu sync.RWMutex
getJWKSCallsMu sync.Mutex
CleanupCalls int32
CloseCalls int32
}
// JWKSCall records parameters from a GetJWKS call
type JWKSCall struct {
Timestamp time.Time
URL string
}
func (m *EnhancedMockJWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error) {
m.getJWKSCallsMu.Lock()
m.GetJWKSCalls = append(m.GetJWKSCalls, JWKSCall{
URL: jwksURL,
Timestamp: time.Now(),
})
m.getJWKSCallsMu.Unlock()
m.mu.RLock()
defer m.mu.RUnlock()
return m.JWKS, m.Err
}
func (m *EnhancedMockJWKCache) Cleanup() {
atomic.AddInt32(&m.CleanupCalls, 1)
m.mu.Lock()
defer m.mu.Unlock()
m.JWKS = nil
m.Err = nil
}
func (m *EnhancedMockJWKCache) Close() {
atomic.AddInt32(&m.CloseCalls, 1)
}
// Assertion helpers
// AssertGetJWKSCalled verifies GetJWKS was called
func (m *EnhancedMockJWKCache) AssertGetJWKSCalled(t assert.TestingT) bool {
m.getJWKSCallsMu.Lock()
defer m.getJWKSCallsMu.Unlock()
return assert.NotEmpty(t, m.GetJWKSCalls, "GetJWKS should have been called")
}
// AssertGetJWKSCalledWith verifies GetJWKS was called with specific URL
func (m *EnhancedMockJWKCache) AssertGetJWKSCalledWith(t assert.TestingT, expectedURL string) bool {
m.getJWKSCallsMu.Lock()
defer m.getJWKSCallsMu.Unlock()
for _, call := range m.GetJWKSCalls {
if call.URL == expectedURL {
return true
}
}
return assert.Fail(t, "GetJWKS was not called with URL: "+expectedURL)
}
// AssertGetJWKSCallCount verifies the number of GetJWKS calls
func (m *EnhancedMockJWKCache) AssertGetJWKSCallCount(t assert.TestingT, expected int) bool {
m.getJWKSCallsMu.Lock()
defer m.getJWKSCallsMu.Unlock()
return assert.Equal(t, expected, len(m.GetJWKSCalls), "GetJWKS call count mismatch")
}
// GetJWKSCallCount returns the number of GetJWKS calls
func (m *EnhancedMockJWKCache) GetJWKSCallCount() int {
m.getJWKSCallsMu.Lock()
defer m.getJWKSCallsMu.Unlock()
return len(m.GetJWKSCalls)
}
// Reset clears all state and call tracking
func (m *EnhancedMockJWKCache) Reset() {
m.mu.Lock()
m.JWKS = nil
m.Err = nil
m.mu.Unlock()
m.getJWKSCallsMu.Lock()
m.GetJWKSCalls = nil
m.getJWKSCallsMu.Unlock()
atomic.StoreInt32(&m.CleanupCalls, 0)
atomic.StoreInt32(&m.CloseCalls, 0)
}
// EnhancedMockTokenVerifier is an improved state-based mock with call tracking
type EnhancedMockTokenVerifier struct {
Err error
VerifyFunc func(token string) error
VerifyCalls []TokenVerifyCall
mu sync.RWMutex
verifyCallsMu sync.Mutex
}
// TokenVerifyCall records parameters from a VerifyToken call
type TokenVerifyCall struct {
Timestamp time.Time
Result error
Token string
}
func (m *EnhancedMockTokenVerifier) VerifyToken(token string) error {
var result error
m.mu.RLock()
if m.VerifyFunc != nil {
result = m.VerifyFunc(token)
} else {
result = m.Err
}
m.mu.RUnlock()
m.verifyCallsMu.Lock()
m.VerifyCalls = append(m.VerifyCalls, TokenVerifyCall{
Token: token,
Timestamp: time.Now(),
Result: result,
})
m.verifyCallsMu.Unlock()
return result
}
// Assertion helpers
// AssertVerifyTokenCalled verifies VerifyToken was called
func (m *EnhancedMockTokenVerifier) AssertVerifyTokenCalled(t assert.TestingT) bool {
m.verifyCallsMu.Lock()
defer m.verifyCallsMu.Unlock()
return assert.NotEmpty(t, m.VerifyCalls, "VerifyToken should have been called")
}
// AssertVerifyTokenCalledWith verifies VerifyToken was called with specific token
func (m *EnhancedMockTokenVerifier) AssertVerifyTokenCalledWith(t assert.TestingT, expectedToken string) bool {
m.verifyCallsMu.Lock()
defer m.verifyCallsMu.Unlock()
for _, call := range m.VerifyCalls {
if call.Token == expectedToken {
return true
}
}
return assert.Fail(t, "VerifyToken was not called with expected token")
}
// AssertVerifyTokenCallCount verifies the number of VerifyToken calls
func (m *EnhancedMockTokenVerifier) AssertVerifyTokenCallCount(t assert.TestingT, expected int) bool {
m.verifyCallsMu.Lock()
defer m.verifyCallsMu.Unlock()
return assert.Equal(t, expected, len(m.VerifyCalls), "VerifyToken call count mismatch")
}
// GetVerifyTokenCallCount returns the number of VerifyToken calls
func (m *EnhancedMockTokenVerifier) GetVerifyTokenCallCount() int {
m.verifyCallsMu.Lock()
defer m.verifyCallsMu.Unlock()
return len(m.VerifyCalls)
}
// LastCall returns the most recent VerifyToken call
func (m *EnhancedMockTokenVerifier) LastCall() *TokenVerifyCall {
m.verifyCallsMu.Lock()
defer m.verifyCallsMu.Unlock()
if len(m.VerifyCalls) == 0 {
return nil
}
return &m.VerifyCalls[len(m.VerifyCalls)-1]
}
// Reset clears all state and call tracking
func (m *EnhancedMockTokenVerifier) Reset() {
m.mu.Lock()
m.Err = nil
m.VerifyFunc = nil
m.mu.Unlock()
m.verifyCallsMu.Lock()
m.VerifyCalls = nil
m.verifyCallsMu.Unlock()
}
// EnhancedMockTokenExchanger is an improved state-based mock with call tracking
type EnhancedMockTokenExchanger struct {
RefreshErr error
RevokeErr error
ExchangeErr error
ExchangeCodeFunc func(ctx context.Context, grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error)
RefreshResponse *TokenResponse
ExchangeResponse *TokenResponse
RefreshTokenFunc func(refreshToken string) (*TokenResponse, error)
RevokeTokenFunc func(token, tokenType string) error
ExchangeCalls []ExchangeCall
RefreshCalls []RefreshCall
RevokeCalls []RevokeCall
mu sync.RWMutex
exchangeCallsMu sync.Mutex
refreshCallsMu sync.Mutex
revokeCallsMu sync.Mutex
}
// ExchangeCall records parameters from an ExchangeCodeForToken call
type ExchangeCall struct {
Timestamp time.Time
GrantType string
CodeOrToken string
RedirectURL string
CodeVerifier string
}
// RefreshCall records parameters from a GetNewTokenWithRefreshToken call
type RefreshCall struct {
Timestamp time.Time
RefreshToken string
}
// RevokeCall records parameters from a RevokeTokenWithProvider call
type RevokeCall struct {
Timestamp time.Time
Token string
TokenType string
}
func (m *EnhancedMockTokenExchanger) ExchangeCodeForToken(ctx context.Context, grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error) {
m.exchangeCallsMu.Lock()
m.ExchangeCalls = append(m.ExchangeCalls, ExchangeCall{
GrantType: grantType,
CodeOrToken: codeOrToken,
RedirectURL: redirectURL,
CodeVerifier: codeVerifier,
Timestamp: time.Now(),
})
m.exchangeCallsMu.Unlock()
m.mu.RLock()
defer m.mu.RUnlock()
if m.ExchangeCodeFunc != nil {
return m.ExchangeCodeFunc(ctx, grantType, codeOrToken, redirectURL, codeVerifier)
}
return m.ExchangeResponse, m.ExchangeErr
}
func (m *EnhancedMockTokenExchanger) GetNewTokenWithRefreshToken(refreshToken string) (*TokenResponse, error) {
m.refreshCallsMu.Lock()
m.RefreshCalls = append(m.RefreshCalls, RefreshCall{
RefreshToken: refreshToken,
Timestamp: time.Now(),
})
m.refreshCallsMu.Unlock()
m.mu.RLock()
defer m.mu.RUnlock()
if m.RefreshTokenFunc != nil {
return m.RefreshTokenFunc(refreshToken)
}
return m.RefreshResponse, m.RefreshErr
}
func (m *EnhancedMockTokenExchanger) RevokeTokenWithProvider(token, tokenType string) error {
m.revokeCallsMu.Lock()
m.RevokeCalls = append(m.RevokeCalls, RevokeCall{
Token: token,
TokenType: tokenType,
Timestamp: time.Now(),
})
m.revokeCallsMu.Unlock()
m.mu.RLock()
defer m.mu.RUnlock()
if m.RevokeTokenFunc != nil {
return m.RevokeTokenFunc(token, tokenType)
}
return m.RevokeErr
}
// Assertion helpers
// AssertExchangeCalled verifies ExchangeCodeForToken was called
func (m *EnhancedMockTokenExchanger) AssertExchangeCalled(t assert.TestingT) bool {
m.exchangeCallsMu.Lock()
defer m.exchangeCallsMu.Unlock()
return assert.NotEmpty(t, m.ExchangeCalls, "ExchangeCodeForToken should have been called")
}
// AssertExchangeCalledWith verifies ExchangeCodeForToken was called with specific grant type
func (m *EnhancedMockTokenExchanger) AssertExchangeCalledWith(t assert.TestingT, grantType string) bool {
m.exchangeCallsMu.Lock()
defer m.exchangeCallsMu.Unlock()
for _, call := range m.ExchangeCalls {
if call.GrantType == grantType {
return true
}
}
return assert.Fail(t, "ExchangeCodeForToken was not called with grant type: "+grantType)
}
// AssertRefreshCalled verifies GetNewTokenWithRefreshToken was called
func (m *EnhancedMockTokenExchanger) AssertRefreshCalled(t assert.TestingT) bool {
m.refreshCallsMu.Lock()
defer m.refreshCallsMu.Unlock()
return assert.NotEmpty(t, m.RefreshCalls, "GetNewTokenWithRefreshToken should have been called")
}
// AssertRevokeCalled verifies RevokeTokenWithProvider was called
func (m *EnhancedMockTokenExchanger) AssertRevokeCalled(t assert.TestingT) bool {
m.revokeCallsMu.Lock()
defer m.revokeCallsMu.Unlock()
return assert.NotEmpty(t, m.RevokeCalls, "RevokeTokenWithProvider should have been called")
}
// GetExchangeCallCount returns the number of ExchangeCodeForToken calls
func (m *EnhancedMockTokenExchanger) GetExchangeCallCount() int {
m.exchangeCallsMu.Lock()
defer m.exchangeCallsMu.Unlock()
return len(m.ExchangeCalls)
}
// GetRefreshCallCount returns the number of GetNewTokenWithRefreshToken calls
func (m *EnhancedMockTokenExchanger) GetRefreshCallCount() int {
m.refreshCallsMu.Lock()
defer m.refreshCallsMu.Unlock()
return len(m.RefreshCalls)
}
// GetRevokeCallCount returns the number of RevokeTokenWithProvider calls
func (m *EnhancedMockTokenExchanger) GetRevokeCallCount() int {
m.revokeCallsMu.Lock()
defer m.revokeCallsMu.Unlock()
return len(m.RevokeCalls)
}
// LastExchangeCall returns the most recent ExchangeCodeForToken call
func (m *EnhancedMockTokenExchanger) LastExchangeCall() *ExchangeCall {
m.exchangeCallsMu.Lock()
defer m.exchangeCallsMu.Unlock()
if len(m.ExchangeCalls) == 0 {
return nil
}
return &m.ExchangeCalls[len(m.ExchangeCalls)-1]
}
// Reset clears all state and call tracking
func (m *EnhancedMockTokenExchanger) Reset() {
m.mu.Lock()
m.ExchangeResponse = nil
m.ExchangeErr = nil
m.RefreshResponse = nil
m.RefreshErr = nil
m.RevokeErr = nil
m.ExchangeCodeFunc = nil
m.RefreshTokenFunc = nil
m.RevokeTokenFunc = nil
m.mu.Unlock()
m.exchangeCallsMu.Lock()
m.ExchangeCalls = nil
m.exchangeCallsMu.Unlock()
m.refreshCallsMu.Lock()
m.RefreshCalls = nil
m.refreshCallsMu.Unlock()
m.revokeCallsMu.Lock()
m.RevokeCalls = nil
m.revokeCallsMu.Unlock()
}
// EnhancedMockCacheInterface is an improved state-based mock for CacheInterface
type EnhancedMockCacheInterface struct {
data map[string]cacheEntry
GetCalls []CacheGetCall
SetCalls []CacheSetCall
DeleteCalls []string
maxSize int
mu sync.RWMutex
getCalls sync.Mutex
setCalls sync.Mutex
deleteCalls sync.Mutex
}
type cacheEntry struct {
value any
ttl time.Duration
}
// CacheGetCall records parameters from a Get call
type CacheGetCall struct {
Timestamp time.Time
Key string
Found bool
}
// CacheSetCall records parameters from a Set call
type CacheSetCall struct {
Timestamp time.Time
Value any
Key string
TTL time.Duration
}
// NewEnhancedMockCache creates a new enhanced cache mock
func NewEnhancedMockCache() *EnhancedMockCacheInterface {
return &EnhancedMockCacheInterface{
data: make(map[string]cacheEntry),
maxSize: 1000,
}
}
func (m *EnhancedMockCacheInterface) Set(key string, value any, ttl time.Duration) {
m.setCalls.Lock()
m.SetCalls = append(m.SetCalls, CacheSetCall{
Key: key,
Value: value,
TTL: ttl,
Timestamp: time.Now(),
})
m.setCalls.Unlock()
m.mu.Lock()
m.data[key] = cacheEntry{value: value, ttl: ttl}
m.mu.Unlock()
}
func (m *EnhancedMockCacheInterface) Get(key string) (any, bool) {
m.mu.RLock()
entry, found := m.data[key]
m.mu.RUnlock()
m.getCalls.Lock()
m.GetCalls = append(m.GetCalls, CacheGetCall{
Key: key,
Found: found,
Timestamp: time.Now(),
})
m.getCalls.Unlock()
if found {
return entry.value, true
}
return nil, false
}
func (m *EnhancedMockCacheInterface) Delete(key string) {
m.deleteCalls.Lock()
m.DeleteCalls = append(m.DeleteCalls, key)
m.deleteCalls.Unlock()
m.mu.Lock()
delete(m.data, key)
m.mu.Unlock()
}
func (m *EnhancedMockCacheInterface) SetMaxSize(size int) {
m.mu.Lock()
m.maxSize = size
m.mu.Unlock()
}
func (m *EnhancedMockCacheInterface) Size() int {
m.mu.RLock()
defer m.mu.RUnlock()
return len(m.data)
}
func (m *EnhancedMockCacheInterface) Clear() {
m.mu.Lock()
m.data = make(map[string]cacheEntry)
m.mu.Unlock()
}
func (m *EnhancedMockCacheInterface) Cleanup() {
// No-op for mock
}
func (m *EnhancedMockCacheInterface) Close() {
// No-op for mock
}
func (m *EnhancedMockCacheInterface) GetStats() map[string]any {
m.mu.RLock()
defer m.mu.RUnlock()
return map[string]any{
"size": len(m.data),
"max_size": m.maxSize,
}
}
// Assertion helpers
// AssertGetCalled verifies Get was called with specific key
func (m *EnhancedMockCacheInterface) AssertGetCalled(t assert.TestingT, key string) bool {
m.getCalls.Lock()
defer m.getCalls.Unlock()
for _, call := range m.GetCalls {
if call.Key == key {
return true
}
}
return assert.Fail(t, "Get was not called with key: "+key)
}
// AssertSetCalled verifies Set was called with specific key
func (m *EnhancedMockCacheInterface) AssertSetCalled(t assert.TestingT, key string) bool {
m.setCalls.Lock()
defer m.setCalls.Unlock()
for _, call := range m.SetCalls {
if call.Key == key {
return true
}
}
return assert.Fail(t, "Set was not called with key: "+key)
}
// AssertDeleteCalled verifies Delete was called with specific key
func (m *EnhancedMockCacheInterface) AssertDeleteCalled(t assert.TestingT, key string) bool {
m.deleteCalls.Lock()
defer m.deleteCalls.Unlock()
for _, k := range m.DeleteCalls {
if k == key {
return true
}
}
return assert.Fail(t, "Delete was not called with key: "+key)
}
// GetCallCount returns the number of Get calls
func (m *EnhancedMockCacheInterface) GetCallCount() int {
m.getCalls.Lock()
defer m.getCalls.Unlock()
return len(m.GetCalls)
}
// SetCallCount returns the number of Set calls
func (m *EnhancedMockCacheInterface) SetCallCount() int {
m.setCalls.Lock()
defer m.setCalls.Unlock()
return len(m.SetCalls)
}
// Reset clears all state and call tracking
func (m *EnhancedMockCacheInterface) Reset() {
m.mu.Lock()
m.data = make(map[string]cacheEntry)
m.mu.Unlock()
m.getCalls.Lock()
m.GetCalls = nil
m.getCalls.Unlock()
m.setCalls.Lock()
m.SetCalls = nil
m.setCalls.Unlock()
m.deleteCalls.Lock()
m.DeleteCalls = nil
m.deleteCalls.Unlock()
}
+151 -39
View File
@@ -2,10 +2,14 @@ package traefikoidc
import (
"context"
"crypto/x509"
"errors"
"fmt"
"io"
"math"
"math/rand/v2"
"net"
"strings"
"sync"
"sync/atomic"
"time"
@@ -123,8 +127,10 @@ func (b *BaseRecoveryMechanism) GetBaseMetrics() map[string]interface{} {
metrics["seconds_since_last_success"] = time.Since(b.lastSuccessTime).Seconds()
}
if metrics["total_requests"].(int64) > 0 {
successRate := float64(metrics["total_successes"].(int64)) / float64(metrics["total_requests"].(int64))
totalReq, _ := metrics["total_requests"].(int64) // Safe to ignore: type assertion with fallback
totalSucc, _ := metrics["total_successes"].(int64) // Safe to ignore: type assertion with fallback
if totalReq > 0 {
successRate := float64(totalSucc) / float64(totalReq)
metrics["success_rate"] = successRate
} else {
metrics["success_rate"] = 1.0
@@ -409,6 +415,31 @@ func DefaultRetryConfig() RetryConfig {
}
}
// MetadataFetchRetryConfig returns retry configuration optimized for OIDC metadata
// fetching during startup. Uses more aggressive retry settings to handle the race
// condition where Traefik initializes the plugin before routes are fully established,
// or before TLS certificates are properly loaded.
// See: https://github.com/lukaszraczylo/traefikoidc/issues/90
func MetadataFetchRetryConfig() RetryConfig {
return RetryConfig{
MaxAttempts: 10, // More attempts for startup scenarios
InitialDelay: 1 * time.Second, // 1 second between attempts as suggested
MaxDelay: 10 * time.Second, // Cap at 10 seconds
BackoffFactor: 1.5, // Gentler backoff for startup
EnableJitter: true, // Prevent thundering herd
RetryableErrors: []string{
"connection refused",
"timeout",
"temporary failure",
"network unreachable",
"EOF",
"certificate",
"x509",
"tls",
},
}
}
// RetryExecutor implements retry logic with exponential backoff and jitter.
// It automatically retries failed operations based on configurable error patterns
// and uses exponential backoff to avoid overwhelming failing services.
@@ -485,11 +516,29 @@ func (re *RetryExecutor) Execute(ctx context.Context, fn func() error) error {
// isRetryableError checks if an error should trigger a retry
// isRetryableError determines if an error should trigger a retry attempt.
// Checks error message against configured retryable error patterns.
// Also handles startup-specific errors like Traefik default certificate errors
// and EOF errors that occur during service initialization.
func (re *RetryExecutor) isRetryableError(err error) bool {
if err == nil {
return false
}
// Check for Traefik default certificate error (startup race condition)
// See: https://github.com/lukaszraczylo/traefikoidc/issues/90
if isTraefikDefaultCertError(err) {
return true
}
// Check for EOF errors (common during startup when services aren't ready)
if isEOFError(err) {
return true
}
// Check for certificate errors (transient during startup)
if isCertificateError(err) {
return true
}
errStr := err.Error()
for _, retryableErr := range re.config.RetryableErrors {
@@ -536,6 +585,7 @@ func (re *RetryExecutor) calculateDelay(attempt int) time.Duration {
delay = float64(re.config.MaxDelay)
}
// #nosec G404 -- math/rand is acceptable for jitter timing, not security-sensitive
if re.config.EnableJitter {
jitter := delay * 0.1 * (2.0*rand.Float64() - 1.0)
delay += jitter
@@ -592,14 +642,10 @@ func (e *HTTPError) Error() string {
// OIDCError represents OIDC-specific errors with context information.
// It provides structured error reporting for authentication and authorization failures.
type OIDCError struct {
// Code identifies the specific error type
Code string
// Message provides a human-readable description
Message string
// Context contains additional error context (e.g., provider, session details)
Cause error
Context map[string]interface{}
// Cause is the underlying error that caused this error
Cause error
Code string
Message string
}
// Error returns the string representation of the OIDC error.
@@ -619,14 +665,10 @@ func (e *OIDCError) Unwrap() error {
// SessionError represents session-related errors with context.
// Used for session management, validation, and storage errors.
type SessionError struct {
// Operation describes what session operation failed
Cause error
Operation string
// Message provides a human-readable description
Message string
// SessionID identifies the session (if available)
Message string
SessionID string
// Cause is the underlying error that caused this error
Cause error
}
// Error returns the string representation of the session error.
@@ -646,14 +688,10 @@ func (e *SessionError) Unwrap() error {
// TokenError represents token-related errors with validation context.
// Used for JWT validation, token refresh, and token format errors.
type TokenError struct {
// TokenType identifies the type of token (id_token, access_token, refresh_token)
Cause error
TokenType string
// Reason describes why the token is invalid
Reason string
// Message provides a human-readable description
Message string
// Cause is the underlying error that caused this error
Cause error
Reason string
Message string
}
// Error returns the string representation of the token error.
@@ -715,24 +753,15 @@ func NewTokenError(tokenType, reason, message string, cause error) *TokenError {
// It provides fallback mechanisms when primary services are unavailable and monitors
// service health to automatically recover when services become available again.
type GracefulDegradation struct {
// BaseRecoveryMechanism provides common functionality
*BaseRecoveryMechanism
// fallbacks stores service-specific fallback implementations
fallbacks map[string]func() (interface{}, error)
// healthChecks stores service health check functions
healthChecks map[string]func() bool
// degradedServices tracks which services are currently degraded
fallbacks map[string]func() (interface{}, error)
healthChecks map[string]func() bool
degradedServices map[string]time.Time
// config contains graceful degradation configuration
config GracefulDegradationConfig
// mutex protects shared state
mutex sync.RWMutex
// healthCheckTask manages background health checking
healthCheckTask *BackgroundTask
// stopChan signals shutdown
stopChan chan struct{}
// shutdownOnce ensures shutdown happens only once
shutdownOnce sync.Once
healthCheckTask *BackgroundTask
stopChan chan struct{}
config GracefulDegradationConfig
mutex sync.RWMutex
shutdownOnce sync.Once
}
// GracefulDegradationConfig holds configuration for graceful degradation behavior.
@@ -963,7 +992,7 @@ func (gd *GracefulDegradation) Close() {
// Don't set to nil to avoid race conditions
}
gd.logger.Info("GracefulDegradation shut down successfully")
gd.logger.Debug("GracefulDegradation shut down successfully")
})
}
@@ -1085,3 +1114,86 @@ func containsSubstring(s, substr string) bool {
}
return false
}
// isTraefikDefaultCertError detects when Traefik is serving its default self-signed
// certificate during cold-start, before the real certificates are loaded.
// This manifests as an x509.HostnameError where one of the certificate's DNS names
// ends with "traefik.default" (the default Traefik certificate pattern).
// See: https://github.com/lukaszraczylo/traefikoidc/issues/90
func isTraefikDefaultCertError(err error) bool {
if err == nil {
return false
}
var hostnameErr x509.HostnameError
if errors.As(err, &hostnameErr) {
if hostnameErr.Certificate != nil {
for _, name := range hostnameErr.Certificate.DNSNames {
if strings.HasSuffix(name, "traefik.default") {
return true
}
}
}
}
return false
}
// isEOFError checks if an error is an EOF error, which can occur during
// connection establishment when the remote end closes unexpectedly.
// This is common during service startup when endpoints aren't fully ready.
func isEOFError(err error) bool {
if err == nil {
return false
}
// Check for direct EOF
if errors.Is(err, io.EOF) {
return true
}
// Check for unexpected EOF
if errors.Is(err, io.ErrUnexpectedEOF) {
return true
}
// Check error message for EOF patterns (wrapped errors)
errStr := err.Error()
return strings.Contains(errStr, "EOF") || strings.Contains(errStr, "unexpected EOF")
}
// isCertificateError checks if an error is related to TLS certificate validation.
// These errors are often transient during startup when services are still initializing.
func isCertificateError(err error) bool {
if err == nil {
return false
}
// Check for x509 certificate errors
var certInvalidErr x509.CertificateInvalidError
var hostnameErr x509.HostnameError
var unknownAuthErr x509.UnknownAuthorityError
if errors.As(err, &certInvalidErr) ||
errors.As(err, &hostnameErr) ||
errors.As(err, &unknownAuthErr) {
return true
}
// Check error message for certificate patterns
errStr := strings.ToLower(err.Error())
certPatterns := []string{
"certificate",
"x509",
"tls",
"ssl",
}
for _, pattern := range certPatterns {
if strings.Contains(errStr, pattern) {
return true
}
}
return false
}
+29
View File
@@ -0,0 +1,29 @@
package traefikoidc
import "testing"
func BenchmarkDefaultCircuitBreakerConfig(b *testing.B) {
for i := 0; i < b.N; i++ {
DefaultCircuitBreakerConfig()
}
}
func BenchmarkBaseRecoveryMechanism_GetBaseMetrics(b *testing.B) {
logger := GetSingletonNoOpLogger()
base := NewBaseRecoveryMechanism("test-mechanism", logger)
b.ResetTimer()
for i := 0; i < b.N; i++ {
base.GetBaseMetrics()
}
}
func BenchmarkBaseRecoveryMechanism_RecordRequest(b *testing.B) {
logger := GetSingletonNoOpLogger()
base := NewBaseRecoveryMechanism("test-mechanism", logger)
b.ResetTimer()
for i := 0; i < b.N; i++ {
base.RecordRequest()
}
}
File diff suppressed because it is too large Load Diff
+486
View File
@@ -0,0 +1,486 @@
# ============================================================================
# Complete Traefik Configuration Example with TraefikOIDC Plugin + Redis
# ============================================================================
#
# This example shows a complete, production-ready configuration for using
# the TraefikOIDC plugin with Redis caching in a multi-replica deployment.
#
# ============================================================================
# Part 1: Traefik Static Configuration (traefik.yml)
# ============================================================================
# This file configures Traefik itself and enables the plugin.
# Place this in /etc/traefik/traefik.yml or mount it in your container.
---
# Static Configuration
api:
dashboard: true
insecure: false # Set to true only for local development
entryPoints:
web:
address: ":80"
http:
redirections:
entryPoint:
to: websecure
scheme: https
websecure:
address: ":443"
http:
tls:
certResolver: letsencrypt
certificatesResolvers:
letsencrypt:
acme:
email: admin@example.com
storage: /letsencrypt/acme.json
httpChallenge:
entryPoint: web
providers:
file:
filename: /etc/traefik/dynamic.yml
watch: true
# Enable the TraefikOIDC plugin
experimental:
plugins:
traefikoidc:
moduleName: github.com/lukaszraczylo/traefikoidc
version: v0.8.0
log:
level: INFO
format: json
accessLog:
format: json
# ============================================================================
# Part 2: Traefik Dynamic Configuration (dynamic.yml)
# ============================================================================
# This file defines your routes, services, and middleware.
# Place this in /etc/traefik/dynamic.yml
---
http:
# -------------------------------------------------------------------------
# Middleware Definitions
# -------------------------------------------------------------------------
middlewares:
# Example 1: Minimal Redis Configuration
# Perfect for getting started quickly
oidc-minimal:
plugin:
traefikoidc:
# Required OIDC settings
clientID: "your-application-client-id"
clientSecret: "your-client-secret-from-provider"
providerURL: "https://auth.example.com"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "your-secure-64-character-encryption-key-must-be-kept-secret"
# Minimal Redis configuration
redis:
enabled: true
address: "redis:6379"
# Example 2: Production Redis Configuration
# Recommended for production deployments with multiple Traefik replicas
oidc-production:
plugin:
traefikoidc:
# OIDC Provider Configuration
clientID: "prod-client-id"
clientSecret: "prod-client-secret"
providerURL: "https://auth.example.com"
callbackURL: "/oauth2/callback"
# Session Configuration
sessionEncryptionKey: "prod-encryption-key-64-chars-long-keep-it-secret-and-safe"
sessionMaxAge: 28800 # 8 hours
# Security Settings
forceHTTPS: true
strictAudienceValidation: true
# Redis Configuration for Multi-Replica Deployment
redis:
enabled: true
address: "redis-master.redis-namespace.svc.cluster.local:6379"
password: "REPLACE_WITH_YOUR_REDIS_PASSWORD"
db: 0
keyPrefix: "traefikoidc:prod:"
# Cache Strategy
cacheMode: "hybrid" # Fast local cache + shared Redis
# Connection Pooling
poolSize: 20
connectTimeout: 5
readTimeout: 3
writeTimeout: 3
# Resilience Features
enableCircuitBreaker: true
circuitBreakerThreshold: 5
circuitBreakerTimeout: 60
enableHealthCheck: true
healthCheckInterval: 30
# Example 3: Redis with TLS (for production security)
oidc-secure:
plugin:
traefikoidc:
clientID: "secure-client-id"
clientSecret: "secure-client-secret"
providerURL: "https://auth.example.com"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "secure-64-character-encryption-key-for-production-use-only"
redis:
enabled: true
address: "redis.example.com:6380"
password: "REPLACE_WITH_YOUR_REDIS_PASSWORD"
enableTLS: true
tlsSkipVerify: false # Verify certificates in production
cacheMode: "redis"
# Example 4: Hybrid Mode (Best Performance + Consistency)
# Local cache for hot data, Redis for consistency across replicas
oidc-hybrid:
plugin:
traefikoidc:
clientID: "app-client-id"
clientSecret: "app-client-secret"
providerURL: "https://auth.example.com"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "hybrid-mode-encryption-key-64-characters-long-and-secure"
redis:
enabled: true
address: "redis:6379"
password: "REPLACE_WITH_YOUR_REDIS_PASSWORD"
cacheMode: "hybrid"
# Hybrid mode L1 cache settings
hybridL1Size: 1000 # Number of items in local cache
hybridL1MemoryMB: 20 # MB of memory for local cache
# -------------------------------------------------------------------------
# Router Definitions
# -------------------------------------------------------------------------
routers:
# Protected application using OIDC authentication
my-app:
rule: "Host(`app.example.com`)"
entryPoints:
- websecure
middlewares:
- oidc-production # Use the OIDC middleware
service: my-app-service
tls:
certResolver: letsencrypt
# Another app with minimal OIDC config
simple-app:
rule: "Host(`simple.example.com`)"
entryPoints:
- websecure
middlewares:
- oidc-minimal
service: simple-app-service
tls:
certResolver: letsencrypt
# -------------------------------------------------------------------------
# Service Definitions
# -------------------------------------------------------------------------
services:
my-app-service:
loadBalancer:
servers:
- url: "http://my-app:8080"
healthCheck:
path: /health
interval: 30s
timeout: 5s
simple-app-service:
loadBalancer:
servers:
- url: "http://simple-app:3000"
# ============================================================================
# Part 3: Docker Compose Example
# ============================================================================
---
# docker-compose.yml
version: '3.8'
services:
# Redis service for shared caching
redis:
image: redis:7-alpine
command: redis-server --requirepass yourredispassword --maxmemory 256mb --maxmemory-policy allkeys-lru
ports:
- "6379:6379"
volumes:
- redis-data:/data
healthcheck:
test: ["CMD", "redis-cli", "--raw", "incr", "ping"]
interval: 10s
timeout: 3s
retries: 5
networks:
- traefik-network
# Traefik with TraefikOIDC plugin
traefik:
image: traefik:v3.2
command:
- "--api.dashboard=true"
- "--providers.docker=true"
- "--providers.docker.exposedbydefault=false"
- "--providers.file.filename=/etc/traefik/dynamic.yml"
- "--entrypoints.web.address=:80"
- "--entrypoints.websecure.address=:443"
- "--experimental.plugins.traefikoidc.modulename=github.com/lukaszraczylo/traefikoidc"
- "--experimental.plugins.traefikoidc.version=v0.8.0"
ports:
- "80:80"
- "443:443"
- "8080:8080" # Dashboard
volumes:
- /var/run/docker.sock:/var/run/docker.sock:ro
- ./traefik-dynamic.yml:/etc/traefik/dynamic.yml:ro
- ./letsencrypt:/letsencrypt
depends_on:
- redis
networks:
- traefik-network
# Your application
my-app:
image: my-app:latest
labels:
- "traefik.enable=true"
- "traefik.http.routers.my-app.rule=Host(`app.example.com`)"
- "traefik.http.routers.my-app.entrypoints=websecure"
- "traefik.http.routers.my-app.tls.certresolver=letsencrypt"
# OIDC Middleware Configuration with Redis (using labels)
- "traefik.http.routers.my-app.middlewares=my-oidc@docker"
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.clientID=your-client-id"
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.clientSecret=your-client-secret"
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.providerURL=https://auth.example.com"
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.callbackURL=/oauth2/callback"
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.sessionEncryptionKey=your-64-character-encryption-key-here"
# Redis configuration
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.enabled=true"
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.address=redis:6379"
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.password=yourredispassword"
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.db=0"
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.keyPrefix=traefikoidc:"
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.cacheMode=hybrid"
networks:
- traefik-network
deploy:
replicas: 3 # Multiple replicas sharing Redis cache
volumes:
redis-data:
networks:
traefik-network:
driver: bridge
# ============================================================================
# Part 4: Kubernetes Example
# ============================================================================
---
# kubernetes-example.yaml
# Redis Deployment
apiVersion: apps/v1
kind: Deployment
metadata:
name: redis
namespace: traefik
spec:
replicas: 1
selector:
matchLabels:
app: redis
template:
metadata:
labels:
app: redis
spec:
containers:
- name: redis
image: redis:7-alpine
args:
- redis-server
- --requirepass
- $(REDIS_PASSWORD)
- --maxmemory
- 512mb
- --maxmemory-policy
- allkeys-lru
env:
- name: REDIS_PASSWORD
valueFrom:
secretKeyRef:
name: redis-secret
key: password
ports:
- containerPort: 6379
resources:
requests:
memory: "256Mi"
cpu: "100m"
limits:
memory: "512Mi"
cpu: "500m"
---
# Redis Service
apiVersion: v1
kind: Service
metadata:
name: redis
namespace: traefik
spec:
selector:
app: redis
ports:
- port: 6379
targetPort: 6379
---
# Redis Secret
apiVersion: v1
kind: Secret
metadata:
name: redis-secret
namespace: traefik
type: Opaque
stringData:
password: "REPLACE_WITH_YOUR_REDIS_PASSWORD"
---
# OIDC Middleware with Redis
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-auth
namespace: traefik
spec:
plugin:
traefikoidc:
# OIDC Configuration
clientID: "kubernetes-client-id"
clientSecret: "kubernetes-client-secret"
providerURL: "https://auth.example.com"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "kubernetes-64-character-session-encryption-key-keep-secret"
# Redis Configuration
redis:
enabled: true
address: "redis.traefik.svc.cluster.local:6379"
password: "REPLACE_WITH_YOUR_REDIS_PASSWORD"
db: 0
keyPrefix: "traefikoidc:k8s:"
cacheMode: "hybrid"
poolSize: 20
enableCircuitBreaker: true
enableHealthCheck: true
---
# IngressRoute using the middleware
apiVersion: traefik.io/v1alpha1
kind: IngressRoute
metadata:
name: my-app
namespace: default
spec:
entryPoints:
- websecure
routes:
- match: Host(`app.example.com`)
kind: Rule
middlewares:
- name: oidc-auth
namespace: traefik
services:
- name: my-app
port: 80
tls:
certResolver: letsencrypt
# ============================================================================
# Part 5: Environment Variables (Optional Fallback)
# ============================================================================
# If you prefer environment variables as fallback (not recommended for production),
# you can set these. NOTE: Plugin configuration takes precedence!
# Docker Compose env file (.env)
---
# OIDC Configuration
OIDC_CLIENT_ID=your-client-id
OIDC_CLIENT_SECRET=your-client-secret
OIDC_PROVIDER_URL=https://auth.example.com
# Redis Configuration (fallback)
REDIS_ENABLED=true
REDIS_ADDRESS=redis:6379
REDIS_PASSWORD=yourredispassword
REDIS_DB=0
REDIS_KEY_PREFIX=traefikoidc:
REDIS_CACHE_MODE=hybrid
REDIS_POOL_SIZE=20
REDIS_ENABLE_CIRCUIT_BREAKER=true
REDIS_ENABLE_HEALTH_CHECK=true
# ============================================================================
# Configuration Cheat Sheet
# ============================================================================
# Minimal Setup (Quick Start):
# redis:
# enabled: true
# address: "redis:6379"
# Production Setup (Recommended):
# redis:
# enabled: true
# address: "redis-master:6379"
# password: "strong-password"
# cacheMode: "hybrid"
# enableCircuitBreaker: true
# enableHealthCheck: true
# High Security Setup:
# redis:
# enabled: true
# address: "redis.example.com:6380"
# password: "strong-password"
# enableTLS: true
# tlsSkipVerify: false
# cacheMode: "redis"
# Cache Modes:
# - "memory": Local cache only (default, no Redis needed)
# - "redis": Redis only (consistent, shared across replicas)
# - "hybrid": Local L1 + Redis L2 (best performance + consistency)
+149
View File
@@ -0,0 +1,149 @@
# Example Traefik configuration for TraefikOIDC plugin with Redis caching
# This example shows how to configure Redis through Traefik's dynamic configuration
# Static configuration (traefik.yml)
experimental:
plugins:
traefikoidc:
moduleName: github.com/lukaszraczylo/traefikoidc
version: v0.8.0
# Dynamic configuration (dynamic.yml or labels)
http:
middlewares:
# Example 1: Basic Redis configuration
oidc-redis-basic:
plugin:
traefikoidc:
# Required OIDC settings
clientID: "your-client-id"
clientSecret: "your-client-secret"
providerURL: "https://auth.example.com"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "your-64-character-encryption-key-here-keep-it-secret"
# Redis configuration
redis:
enabled: true
address: "redis:6379"
# password: "your-redis-password" # Optional
db: 0
keyPrefix: "traefikoidc:"
# Example 2: Redis with resilience features
oidc-redis-resilient:
plugin:
traefikoidc:
# Required OIDC settings
clientID: "your-client-id"
clientSecret: "your-client-secret"
providerURL: "https://auth.example.com"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "your-64-character-encryption-key-here-keep-it-secret"
# Redis with full resilience configuration
redis:
enabled: true
address: "redis:6379"
password: "REPLACE_WITH_YOUR_REDIS_PASSWORD" # Example placeholder - use your actual password
db: 1
keyPrefix: "myapp:"
poolSize: 20
connectTimeout: 10
readTimeout: 5
writeTimeout: 5
cacheMode: "redis" # Options: "redis", "hybrid", "memory"
# Circuit breaker settings
enableCircuitBreaker: true
circuitBreakerThreshold: 5
circuitBreakerTimeout: 60
# Health check settings
enableHealthCheck: true
healthCheckInterval: 30
# Example 3: Redis with TLS
oidc-redis-tls:
plugin:
traefikoidc:
# Required OIDC settings
clientID: "your-client-id"
clientSecret: "your-client-secret"
providerURL: "https://auth.example.com"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "your-64-character-encryption-key-here-keep-it-secret"
# Redis with TLS configuration
redis:
enabled: true
address: "redis.example.com:6380"
password: "REPLACE_WITH_YOUR_REDIS_PASSWORD" # Example placeholder
enableTLS: true
tlsSkipVerify: false # Set to true only for testing
cacheMode: "redis"
routers:
my-app:
rule: "Host(`app.example.com`)"
middlewares:
- oidc-redis-basic
service: my-app-service
services:
my-app-service:
loadBalancer:
servers:
- url: "http://localhost:8080"
# Docker Compose labels example
# version: '3.8'
# services:
# traefik:
# image: traefik:v3.0
# # ... other config ...
#
# my-app:
# image: my-app:latest
# labels:
# - "traefik.enable=true"
# - "traefik.http.routers.my-app.rule=Host(`app.example.com`)"
# - "traefik.http.routers.my-app.middlewares=my-oidc"
# # OIDC middleware configuration with Redis
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.clientID=your-client-id"
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.clientSecret=your-secret"
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.providerURL=https://auth.example.com"
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.callbackURL=/oauth2/callback"
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.sessionEncryptionKey=your-64-char-key"
# # Redis configuration via labels
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.enabled=true"
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.address=redis:6379"
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.password=redis-password"
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.db=0"
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.keyPrefix=traefikoidc:"
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.cacheMode=redis"
#
# redis:
# image: redis:7-alpine
# command: redis-server --requirepass redis-password
# # ... other config ...
# Environment variable fallback (optional)
# If Redis configuration is not provided in Traefik config, these environment variables
# can be used as a fallback (but Traefik config takes precedence):
#
# REDIS_ENABLED=true
# REDIS_ADDRESS=redis:6379
# REDIS_PASSWORD=secret
# REDIS_DB=0
# REDIS_KEY_PREFIX=traefikoidc:
# REDIS_CACHE_MODE=redis
# REDIS_POOL_SIZE=10
# REDIS_CONNECT_TIMEOUT=5
# REDIS_READ_TIMEOUT=3
# REDIS_WRITE_TIMEOUT=3
# REDIS_ENABLE_TLS=false
# REDIS_TLS_SKIP_VERIFY=false
# REDIS_ENABLE_CIRCUIT_BREAKER=true
# REDIS_CIRCUIT_BREAKER_THRESHOLD=5
# REDIS_CIRCUIT_BREAKER_TIMEOUT=60
# REDIS_ENABLE_HEALTH_CHECK=true
# REDIS_HEALTH_CHECK_INTERVAL=30
-797
View File
@@ -1,797 +0,0 @@
package features
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"text/template"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// Mock types for testing
type TemplatedHeader struct {
Name string `json:"name"`
Value string `json:"value"`
}
type MockConfig struct {
ProviderURL string `json:"providerURL"`
ClientID string `json:"clientID"`
ClientSecret string `json:"clientSecret"`
CallbackURL string `json:"callbackURL"`
SessionEncryptionKey string `json:"sessionEncryptionKey"`
Headers []TemplatedHeader `json:"headers"`
}
// TestTemplateHeaderFeatures consolidates all template header-related tests
func TestTemplateHeaderFeatures(t *testing.T) {
t.Run("Issue55_TemplateExecutionWithWrongTypes", testIssue55TemplateExecutionWithWrongTypes)
t.Run("Template_Parsing_Validation", testTemplateParsingValidation)
t.Run("Middleware_Header_Templating", testMiddlewareHeaderTemplating)
t.Run("JSON_Config_Parsing", testJSONConfigParsing)
t.Run("Template_Double_Processing", testTemplateDoubleProcessing)
t.Run("Template_Execution_Context", testTemplateExecutionContext)
t.Run("Template_Integration_With_Plugin", testTemplateIntegrationWithPlugin)
t.Run("Template_Syntax_Validation", testTemplateSyntaxValidation)
t.Run("Missing_Field_Handling", testMissingFieldHandling)
t.Run("Complex_Template_Expressions", testComplexTemplateExpressions)
t.Run("Traefik_Configuration_Parsing", testTraefikConfigurationParsing)
}
// testIssue55TemplateExecutionWithWrongTypes tests what happens when templates
// receive wrong data types during execution - reproduces GitHub issue #55
func testIssue55TemplateExecutionWithWrongTypes(t *testing.T) {
testCases := []struct {
name string
templateText string
templateData interface{}
errorContains string
expectError bool
}{
{
name: "correct map data",
templateText: "Bearer {{.AccessToken}}",
templateData: map[string]interface{}{
"AccessToken": "valid-token",
},
expectError: false,
},
{
name: "boolean as root context - reproduces issue #55",
templateText: "Bearer {{.AccessToken}}",
templateData: true,
expectError: true,
errorContains: "can't evaluate field AccessToken in type bool",
},
{
name: "string as root context",
templateText: "Bearer {{.AccessToken}}",
templateData: "just a string",
expectError: true,
errorContains: "can't evaluate field AccessToken in type string",
},
{
name: "nested claims access with correct data",
templateText: "User: {{.Claims.email}}",
templateData: map[string]interface{}{
"Claims": map[string]interface{}{
"email": "user@example.com",
},
},
expectError: false,
},
{
name: "nested claims with wrong structure",
templateText: "User: {{.Claims.email}}",
templateData: map[string]interface{}{
"Claims": "not a map",
},
expectError: true,
errorContains: "can't evaluate field email in type",
},
{
name: "complex nested structure",
templateText: "{{.Claims.sub}} - {{.Claims.groups}} - {{.AccessToken}}",
templateData: map[string]interface{}{
"AccessToken": "token123",
"Claims": map[string]interface{}{
"sub": "user-id",
"groups": "admin,users",
},
},
expectError: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tmpl, err := template.New("test").Parse(tc.templateText)
require.NoError(t, err)
var buf bytes.Buffer
err = tmpl.Execute(&buf, tc.templateData)
if tc.expectError {
require.Error(t, err)
if tc.errorContains != "" {
assert.Contains(t, err.Error(), tc.errorContains)
}
} else {
require.NoError(t, err)
}
})
}
}
// testTemplateParsingValidation ensures templates are parsed correctly
func testTemplateParsingValidation(t *testing.T) {
testCases := []struct {
name string
headerTemplates []TemplatedHeader
shouldError bool
}{
{
name: "valid bearer token template",
headerTemplates: []TemplatedHeader{
{Name: "Authorization", Value: "Bearer {{.AccessToken}}"},
},
shouldError: false,
},
{
name: "multiple valid templates",
headerTemplates: []TemplatedHeader{
{Name: "Authorization", Value: "Bearer {{.AccessToken}}"},
{Name: "X-User-Email", Value: "{{.Claims.email}}"},
{Name: "X-User-ID", Value: "{{.Claims.sub}}"},
},
shouldError: false,
},
{
name: "template with conditional logic",
headerTemplates: []TemplatedHeader{
{Name: "X-Auth-Info", Value: "{{if .AccessToken}}Bearer {{.AccessToken}}{{else}}No Token{{end}}"},
},
shouldError: false,
},
{
name: "invalid template syntax",
headerTemplates: []TemplatedHeader{
{Name: "Bad-Template", Value: "{{.AccessToken"},
},
shouldError: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
for _, header := range tc.headerTemplates {
_, err := template.New(header.Name).Parse(header.Value)
if tc.shouldError {
require.Error(t, err)
} else {
require.NoError(t, err)
}
}
})
}
}
// testMiddlewareHeaderTemplating simulates the actual middleware flow
func testMiddlewareHeaderTemplating(t *testing.T) {
testCases := []struct {
name string
headers []TemplatedHeader
accessToken string
idToken string
claims map[string]interface{}
expectedValues map[string]string
}{
{
name: "authorization header with access token",
headers: []TemplatedHeader{
{Name: "Authorization", Value: "Bearer {{.AccessToken}}"},
},
accessToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9",
expectedValues: map[string]string{
"Authorization": "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9",
},
},
{
name: "multiple headers with claims",
headers: []TemplatedHeader{
{Name: "X-User-Email", Value: "{{.Claims.email}}"},
{Name: "X-User-Groups", Value: "{{.Claims.groups}}"},
{Name: "X-Auth-Token", Value: "{{.AccessToken}}"},
},
accessToken: "token123",
claims: map[string]interface{}{
"email": "user@example.com",
"groups": "admin,developers",
},
expectedValues: map[string]string{
"X-User-Email": "user@example.com",
"X-User-Groups": "admin,developers",
"X-Auth-Token": "token123",
},
},
{
name: "complex template expressions",
headers: []TemplatedHeader{
{Name: "X-User-Info", Value: "{{.Claims.sub}} ({{.Claims.email}})"},
{Name: "X-Auth-Header", Value: "Bearer {{.AccessToken}} | ID: {{.IDToken}}"},
},
accessToken: "access-token",
idToken: "id-token",
claims: map[string]interface{}{
"sub": "user-12345",
"email": "john@example.com",
},
expectedValues: map[string]string{
"X-User-Info": "user-12345 (john@example.com)",
"X-Auth-Header": "Bearer access-token | ID: id-token",
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Parse all templates
headerTemplates := make(map[string]*template.Template)
for _, header := range tc.headers {
tmpl, err := template.New(header.Name).Parse(header.Value)
require.NoError(t, err)
headerTemplates[header.Name] = tmpl
}
// Create template data
templateData := map[string]interface{}{
"AccessToken": tc.accessToken,
"IDToken": tc.idToken,
"Claims": tc.claims,
}
// Create a test request
req := httptest.NewRequest("GET", "/test", nil)
// Execute templates and set headers
for headerName, tmpl := range headerTemplates {
var buf bytes.Buffer
err := tmpl.Execute(&buf, templateData)
require.NoError(t, err)
req.Header.Set(headerName, buf.String())
}
// Verify all expected headers are set correctly
for headerName, expectedValue := range tc.expectedValues {
actualValue := req.Header.Get(headerName)
assert.Equal(t, expectedValue, actualValue)
}
})
}
}
// testJSONConfigParsing tests that JSON configuration is properly parsed
func testJSONConfigParsing(t *testing.T) {
testCases := []struct {
name string
jsonConfig string
expectedError bool
description string
}{
{
name: "valid JSON configuration",
jsonConfig: `{
"headers": [
{
"name": "Authorization",
"value": "Bearer {{.AccessToken}}"
}
]
}`,
expectedError: false,
description: "Properly formatted JSON with string values",
},
{
name: "JSON with boolean value",
jsonConfig: `{
"headers": [
{
"name": "Authorization",
"value": true
}
]
}`,
expectedError: true,
description: "Boolean value instead of string template",
},
{
name: "JSON with number value",
jsonConfig: `{
"headers": [
{
"name": "Authorization",
"value": 123
}
]
}`,
expectedError: true,
description: "Number value instead of string template",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var config struct {
Headers []TemplatedHeader `json:"headers"`
}
err := json.Unmarshal([]byte(tc.jsonConfig), &config)
if tc.expectedError {
require.Error(t, err, tc.description)
} else {
require.NoError(t, err, tc.description)
}
})
}
}
// testTemplateDoubleProcessing tests if template strings are being double-processed
func testTemplateDoubleProcessing(t *testing.T) {
// Simulate how Traefik passes config to the plugin
config := &MockConfig{
Headers: []TemplatedHeader{
{Name: "X-User-Email", Value: "{{.Claims.email}}"},
{Name: "X-User-Role", Value: "{{.Claims.internal_role}}"},
},
}
// Verify that template strings are still raw (not processed)
assert.Equal(t, "{{.Claims.email}}", config.Headers[0].Value)
assert.Equal(t, "{{.Claims.internal_role}}", config.Headers[1].Value)
// Simulate template parsing during initialization
headerTemplates := make(map[string]*template.Template)
funcMap := template.FuncMap{
"default": func(defaultVal interface{}, val interface{}) interface{} {
if val == nil || val == "" || val == "<no value>" {
return defaultVal
}
return val
},
"get": func(m interface{}, key string) interface{} {
if mapVal, ok := m.(map[string]interface{}); ok {
if val, exists := mapVal[key]; exists {
return val
}
}
return ""
},
}
for _, header := range config.Headers {
tmpl := template.New(header.Name).Funcs(funcMap).Option("missingkey=zero")
parsedTmpl, err := tmpl.Parse(header.Value)
require.NoError(t, err)
headerTemplates[header.Name] = parsedTmpl
}
// Test execution with actual claims
claims := map[string]interface{}{
"email": "user@example.com",
// Note: internal_role is missing
}
templateData := map[string]interface{}{
"Claims": claims,
}
// Execute templates
for headerName, tmpl := range headerTemplates {
var buf bytes.Buffer
err := tmpl.Execute(&buf, templateData)
require.NoError(t, err)
result := buf.String()
if headerName == "X-User-Email" {
assert.Equal(t, "user@example.com", result)
} else if headerName == "X-User-Role" {
// With missingkey=zero, missing fields return "<no value>"
assert.Equal(t, "<no value>", result)
}
}
}
// testTemplateExecutionContext tests the specific template data context
func testTemplateExecutionContext(t *testing.T) {
testCases := []struct {
name string
templateText string
data map[string]interface{}
expectedValue string
}{
{
name: "Access and ID token distinction",
templateText: "Access: {{.AccessToken}} ID: {{.IDToken}}",
data: map[string]interface{}{
"AccessToken": "access-token-value",
"IDToken": "id-token-value",
"Claims": map[string]interface{}{},
},
expectedValue: "Access: access-token-value ID: id-token-value",
},
{
name: "Combining tokens and claims",
templateText: "User: {{.Claims.sub}} Token: {{.AccessToken}}",
data: map[string]interface{}{
"AccessToken": "access-token",
"IDToken": "id-token",
"Claims": map[string]interface{}{
"sub": "user123",
},
},
expectedValue: "User: user123 Token: access-token",
},
{
name: "Custom non-standard claims",
templateText: "X-User-Role: {{.Claims.role}}, X-User-Permissions: {{.Claims.permissions}}",
data: map[string]interface{}{
"AccessToken": "access-token-value",
"Claims": map[string]interface{}{
"role": "admin",
"permissions": "read:all,write:own",
},
},
expectedValue: "X-User-Role: admin, X-User-Permissions: read:all,write:own",
},
{
name: "Deeply nested custom claims",
templateText: "X-Organization: {{.Claims.app_metadata.organization.name}}, X-Team: {{.Claims.app_metadata.team}}",
data: map[string]interface{}{
"Claims": map[string]interface{}{
"app_metadata": map[string]interface{}{
"organization": map[string]interface{}{
"name": "acme-corp",
},
"team": "platform",
},
},
},
expectedValue: "X-Organization: acme-corp, X-Team: platform",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tmpl, err := template.New("test").Parse(tc.templateText)
require.NoError(t, err)
var buf bytes.Buffer
err = tmpl.Execute(&buf, tc.data)
require.NoError(t, err)
assert.Equal(t, tc.expectedValue, buf.String())
})
}
}
// testTemplateIntegrationWithPlugin tests template processing in the actual plugin
func testTemplateIntegrationWithPlugin(t *testing.T) {
// Test template integration using mock plugin components
// Set up test OIDC server
var testServerURL string
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/.well-known/openid-configuration":
json.NewEncoder(w).Encode(map[string]interface{}{
"issuer": testServerURL,
"authorization_endpoint": testServerURL + "/auth",
"token_endpoint": testServerURL + "/token",
"jwks_uri": testServerURL + "/jwks",
"userinfo_endpoint": testServerURL + "/userinfo",
})
case "/jwks":
json.NewEncoder(w).Encode(map[string]interface{}{
"keys": []interface{}{},
})
default:
http.NotFound(w, r)
}
}))
defer testServer.Close()
testServerURL = testServer.URL
// Create config with templates that reference potentially missing fields
config := &MockConfig{
ProviderURL: testServer.URL,
ClientID: "test-client",
ClientSecret: "test-secret",
CallbackURL: "/callback",
SessionEncryptionKey: "test-encryption-key-32-characters",
Headers: []TemplatedHeader{
{Name: "X-User-Email", Value: "{{.Claims.email}}"},
{Name: "X-User-Role", Value: "{{.Claims.internal_role}}"},
},
}
// Initialize plugin would be done here
ctx := context.Background()
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
// Test would create plugin handler here
_ = ctx
_ = next
_ = config
}
// testTemplateSyntaxValidation tests that template syntax is properly validated
func testTemplateSyntaxValidation(t *testing.T) {
validTemplates := []string{
"{{.Claims.email}}",
"{{.Claims.internal_role}}",
"{{.AccessToken}}",
"{{.IdToken}}",
"{{.RefreshToken}}",
}
for _, tmplStr := range validTemplates {
err := validateTemplateSecure(tmplStr)
assert.NoError(t, err, "Template should be valid: %s", tmplStr)
}
// Test invalid templates
invalidTemplates := []struct {
template string
reason string
}{
{"{{call .SomeFunc}}", "function calls not allowed"},
{"{{range .Items}}{{.}}{{end}}", "range not allowed"},
{"{{with .Data}}{{.Field}}{{end}}", "with statements blocked"},
{"{{index .Array 0}}", "index access blocked"},
{"{{printf \"%s\" .Data}}", "printf blocked"},
}
for _, tc := range invalidTemplates {
err := validateTemplateSecure(tc.template)
assert.Error(t, err, "Template should be invalid: %s (%s)", tc.template, tc.reason)
assert.Contains(t, strings.ToLower(err.Error()), "dangerous")
}
// Test safe custom functions
safeTemplates := []string{
"{{get .Claims \"internal_role\"}}",
"{{default \"guest\" .Claims.role}}",
}
for _, tmplStr := range safeTemplates {
err := validateTemplateSecure(tmplStr)
assert.NoError(t, err, "Safe custom functions should be allowed: %s", tmplStr)
}
}
// Mock validation function for template security
func validateTemplateSecure(templateStr string) error {
// List of potentially dangerous template actions
dangerousFunctions := []string{
"call", "range", "with", "index", "printf", "println", "print",
"js", "html", "urlquery", "base64", "exec",
}
for _, dangerous := range dangerousFunctions {
if strings.Contains(templateStr, dangerous) {
return fmt.Errorf("dangerous template function detected: %s", dangerous)
}
}
// Define safe custom functions
funcMap := template.FuncMap{
"get": func(data map[string]interface{}, key string) interface{} {
return data[key]
},
"default": func(defaultVal interface{}, val interface{}) interface{} {
if val == nil || val == "" {
return defaultVal
}
return val
},
}
// Try to parse the template with custom functions to check for syntax errors
_, err := template.New("test").Funcs(funcMap).Parse(templateStr)
return err
}
// testMissingFieldHandling tests handling of missing fields in templates
func testMissingFieldHandling(t *testing.T) {
testCases := []struct {
name string
templateText string
data map[string]interface{}
expected string
}{
{
name: "missing claim field",
templateText: "{{.Claims.missing}}",
data: map[string]interface{}{
"Claims": map[string]interface{}{},
},
expected: "<no value>",
},
{
name: "missing nested field",
templateText: "{{.Claims.user.missing}}",
data: map[string]interface{}{
"Claims": map[string]interface{}{
"user": map[string]interface{}{},
},
},
expected: "<no value>",
},
{
name: "missing entire path",
templateText: "{{.Missing.Path.Field}}",
data: map[string]interface{}{},
expected: "<no value>",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tmpl, err := template.New("test").Parse(tc.templateText)
require.NoError(t, err)
var buf bytes.Buffer
err = tmpl.Execute(&buf, tc.data)
require.NoError(t, err)
assert.Equal(t, tc.expected, buf.String())
})
}
}
// testComplexTemplateExpressions tests complex template expressions
func testComplexTemplateExpressions(t *testing.T) {
testCases := []struct {
name string
templateText string
data map[string]interface{}
expected string
}{
{
name: "conditional template",
templateText: "{{if .Claims.admin}}Admin User{{else}}Regular User{{end}}",
data: map[string]interface{}{
"Claims": map[string]interface{}{
"admin": true,
},
},
expected: "Admin User",
},
{
name: "multiple claims concatenation",
templateText: "{{.Claims.firstName}} {{.Claims.lastName}} <{{.Claims.email}}>",
data: map[string]interface{}{
"Claims": map[string]interface{}{
"firstName": "John",
"lastName": "Doe",
"email": "john.doe@example.com",
},
},
expected: "John Doe <john.doe@example.com>",
},
{
name: "array access",
templateText: "{{index .Claims.roles 0}}",
data: map[string]interface{}{
"Claims": map[string]interface{}{
"roles": []string{"admin", "user"},
},
},
expected: "admin",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tmpl, err := template.New("test").Parse(tc.templateText)
require.NoError(t, err)
var buf bytes.Buffer
err = tmpl.Execute(&buf, tc.data)
require.NoError(t, err)
assert.Equal(t, tc.expected, buf.String())
})
}
}
// testTraefikConfigurationParsing tests various ways Traefik might pass configuration
func testTraefikConfigurationParsing(t *testing.T) {
testCases := []struct {
name string
config *MockConfig
expectError bool
description string
}{
{
name: "valid configuration with templated headers",
config: &MockConfig{
ProviderURL: "https://accounts.google.com",
ClientID: "test-client",
ClientSecret: "test-secret",
SessionEncryptionKey: "test-encryption-key-32-bytes-long",
CallbackURL: "/oauth2/callback",
Headers: []TemplatedHeader{
{Name: "Authorization", Value: "Bearer {{.AccessToken}}"},
},
},
expectError: false,
description: "Standard configuration should work",
},
{
name: "configuration with multiple headers",
config: &MockConfig{
ProviderURL: "https://accounts.google.com",
ClientID: "test-client",
ClientSecret: "test-secret",
SessionEncryptionKey: "test-encryption-key-32-bytes-long",
CallbackURL: "/oauth2/callback",
Headers: []TemplatedHeader{
{Name: "Authorization", Value: "Bearer {{.AccessToken}}"},
{Name: "X-User-Email", Value: "{{.Claims.email}}"},
{Name: "X-User-ID", Value: "{{.Claims.sub}}"},
},
},
expectError: false,
description: "Multiple headers should work",
},
{
name: "empty headers configuration",
config: &MockConfig{
ProviderURL: "https://accounts.google.com",
ClientID: "test-client",
ClientSecret: "test-secret",
SessionEncryptionKey: "test-encryption-key-32-bytes-long",
CallbackURL: "/oauth2/callback",
Headers: []TemplatedHeader{},
},
expectError: false,
description: "Empty headers should not cause issues",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Create a simple next handler
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
// Try to create the middleware would be done here
ctx := context.Background()
// Test would create middleware handler here
_ = ctx
_ = next
_ = tc.config
// For now, we just validate the configuration is well-formed
if !tc.expectError {
require.NotNil(t, tc.config, tc.description)
require.NotEmpty(t, tc.config.ClientID, tc.description)
}
})
}
}
+8 -2
View File
@@ -3,15 +3,21 @@ module github.com/lukaszraczylo/traefikoidc
go 1.24.0
require (
github.com/alicebob/miniredis/v2 v2.35.0
github.com/google/uuid v1.6.0
github.com/gorilla/sessions v1.3.0
github.com/redis/go-redis/v9 v9.17.2
github.com/stretchr/testify v1.10.0
golang.org/x/time v0.13.0
golang.org/x/time v0.14.0
gopkg.in/yaml.v3 v3.0.1
)
require (
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/gorilla/securecookie v1.1.2 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
github.com/stretchr/objx v0.5.2 // indirect
github.com/yuin/gopher-lua v1.1.1 // indirect
)
+18 -2
View File
@@ -1,5 +1,15 @@
github.com/alicebob/miniredis/v2 v2.35.0 h1:QwLphYqCEAo1eu1TqPRN2jgVMPBweeQcR21jeqDCONI=
github.com/alicebob/miniredis/v2 v2.35.0/go.mod h1:TcL7YfarKPGDAthEtl5NBeHZfeUQj6OXMm/+iu5cLMM=
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0=
github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
@@ -10,10 +20,16 @@ github.com/gorilla/sessions v1.3.0 h1:XYlkq7KcpOB2ZhHBPv5WpjMIxrQosiZanfoy1HLZFz
github.com/gorilla/sessions v1.3.0/go.mod h1:ePLdVu+jbEgHH+KWw8I1z2wqd0BAdAQh/8LRvBeoNcQ=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/redis/go-redis/v9 v9.17.2 h1:P2EGsA4qVIM3Pp+aPocCJ7DguDHhqrXNhVcEp4ViluI=
github.com/redis/go-redis/v9 v9.17.2/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
golang.org/x/time v0.13.0 h1:eUlYslOIt32DgYD6utsuUeHs4d7AsEYLuIAdg7FlYgI=
golang.org/x/time v0.13.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M=
github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw=
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
+7 -7
View File
@@ -10,16 +10,16 @@ import (
type GoroutineManager struct {
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
mu sync.RWMutex
goroutines map[string]*managedGoroutine
logger *Logger
wg sync.WaitGroup
mu sync.RWMutex
}
type managedGoroutine struct {
name string
cancel context.CancelFunc
startTime time.Time
cancel context.CancelFunc
name string
running bool
}
@@ -86,7 +86,7 @@ func (m *GoroutineManager) StartPeriodicTask(name string, interval time.Duration
for {
select {
case <-ctx.Done():
m.logger.Debugf("Periodic task %s cancelled", name)
m.logger.Debugf("Periodic task %s canceled", name)
return
case <-ticker.C:
task()
@@ -149,10 +149,10 @@ func (m *GoroutineManager) GetStatus() map[string]GoroutineStatus {
// GoroutineStatus represents the status of a managed goroutine
type GoroutineStatus struct {
Name string
Running bool
StartTime time.Time
Name string
Runtime time.Duration
Running bool
}
// ErrShutdownTimeout is returned when shutdown times out
+625
View File
@@ -0,0 +1,625 @@
package traefikoidc
import (
"context"
"sync/atomic"
"testing"
"time"
)
// Test GoroutineManager Creation
func TestNewGoroutineManager(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
if gm == nil {
t.Fatal("Expected non-nil goroutine manager")
}
if gm.ctx == nil {
t.Error("Expected context to be initialized")
}
if gm.cancel == nil {
t.Error("Expected cancel function to be initialized")
}
if gm.goroutines == nil {
t.Error("Expected goroutines map to be initialized")
}
if gm.logger != logger {
t.Error("Expected logger to be set")
}
}
// Test Starting Goroutines
func TestStartGoroutine(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
defer gm.Shutdown(time.Second)
executed := atomic.Bool{}
gm.StartGoroutine("test-goroutine", func(ctx context.Context) {
executed.Store(true)
})
// Give goroutine time to execute
time.Sleep(50 * time.Millisecond)
if !executed.Load() {
t.Error("Expected goroutine to execute")
}
status := gm.GetStatus()
if len(status) != 1 {
t.Errorf("Expected 1 goroutine in status, got %d", len(status))
}
if _, exists := status["test-goroutine"]; !exists {
t.Error("Expected goroutine 'test-goroutine' in status")
}
}
func TestStartGoroutineDuplicate(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
defer gm.Shutdown(time.Second)
counter := atomic.Int32{}
// Start a long-running goroutine
gm.StartGoroutine("duplicate-test", func(ctx context.Context) {
counter.Add(1)
<-ctx.Done()
})
// Give first goroutine time to start
time.Sleep(50 * time.Millisecond)
// Try to start another with same name (should be skipped)
gm.StartGoroutine("duplicate-test", func(ctx context.Context) {
counter.Add(1)
})
time.Sleep(50 * time.Millisecond)
// Should only have executed once
if counter.Load() != 1 {
t.Errorf("Expected counter to be 1 (duplicate should be skipped), got %d", counter.Load())
}
}
func TestStartGoroutineContextCancellation(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
started := atomic.Bool{}
canceled := atomic.Bool{}
gm.StartGoroutine("cancel-test", func(ctx context.Context) {
started.Store(true)
<-ctx.Done()
canceled.Store(true)
})
// Wait for goroutine to start
time.Sleep(50 * time.Millisecond)
if !started.Load() {
t.Error("Expected goroutine to start")
}
// Stop the goroutine
gm.StopGoroutine("cancel-test")
// Wait for cancellation
time.Sleep(50 * time.Millisecond)
if !canceled.Load() {
t.Error("Expected goroutine to be canceled")
}
}
func TestStartGoroutineWithPanic(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
defer gm.Shutdown(time.Second)
executed := atomic.Bool{}
gm.StartGoroutine("panic-test", func(ctx context.Context) {
executed.Store(true)
panic("test panic")
})
// Give goroutine time to panic and recover
time.Sleep(100 * time.Millisecond)
if !executed.Load() {
t.Error("Expected goroutine to execute before panic")
}
// Check that goroutine is marked as not running after panic
status := gm.GetStatus()
if goroutineStatus, exists := status["panic-test"]; exists {
if goroutineStatus.Running {
t.Error("Expected goroutine to be marked as not running after panic")
}
}
// Manager should still be functional
counter := atomic.Int32{}
gm.StartGoroutine("after-panic", func(ctx context.Context) {
counter.Add(1)
})
time.Sleep(50 * time.Millisecond)
if counter.Load() != 1 {
t.Error("Expected manager to still be functional after panic recovery")
}
}
// Test Periodic Tasks
func TestStartPeriodicTask(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
defer gm.Shutdown(time.Second)
counter := atomic.Int32{}
gm.StartPeriodicTask("periodic-test", 50*time.Millisecond, func() {
counter.Add(1)
})
// Wait for multiple executions
time.Sleep(160 * time.Millisecond)
// Should have executed at least 2-3 times
count := counter.Load()
if count < 2 {
t.Errorf("Expected periodic task to execute at least 2 times, got %d", count)
}
}
func TestStartPeriodicTaskCancellation(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
defer gm.Shutdown(time.Second)
counter := atomic.Int32{}
gm.StartPeriodicTask("cancel-periodic", 50*time.Millisecond, func() {
counter.Add(1)
})
// Wait for some executions
time.Sleep(120 * time.Millisecond)
// Stop the task
gm.StopGoroutine("cancel-periodic")
countBeforeStop := counter.Load()
// Wait and verify no more executions
time.Sleep(120 * time.Millisecond)
countAfterStop := counter.Load()
// Allow 1 additional execution (could be in progress when stopped)
if countAfterStop > countBeforeStop+1 {
t.Errorf("Expected periodic task to stop executing, before: %d, after: %d",
countBeforeStop, countAfterStop)
}
}
// Test Stopping Goroutines
func TestStopGoroutine(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
defer gm.Shutdown(time.Second)
stopped := atomic.Bool{}
gm.StartGoroutine("stop-test", func(ctx context.Context) {
<-ctx.Done()
stopped.Store(true)
})
// Wait for goroutine to start
time.Sleep(50 * time.Millisecond)
gm.StopGoroutine("stop-test")
// Wait for goroutine to stop
time.Sleep(50 * time.Millisecond)
if !stopped.Load() {
t.Error("Expected goroutine to be stopped")
}
status := gm.GetStatus()
if goroutineStatus, exists := status["stop-test"]; exists {
if goroutineStatus.Running {
t.Error("Expected goroutine to be marked as not running")
}
}
}
func TestStopGoroutineNonExistent(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
defer gm.Shutdown(time.Second)
// Should not panic or error when stopping non-existent goroutine
gm.StopGoroutine("non-existent")
}
func TestStopGoroutineAlreadyStopped(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
defer gm.Shutdown(time.Second)
gm.StartGoroutine("already-stopped", func(ctx context.Context) {
// Exit immediately
})
// Wait for goroutine to finish
time.Sleep(50 * time.Millisecond)
// Try to stop already-stopped goroutine (should be safe)
gm.StopGoroutine("already-stopped")
}
// Test Shutdown
func TestShutdownGraceful(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
counter := atomic.Int32{}
// Start multiple goroutines
for i := 0; i < 5; i++ {
name := "goroutine-" + string(rune('0'+i))
gm.StartGoroutine(name, func(ctx context.Context) {
counter.Add(1)
<-ctx.Done()
counter.Add(-1)
})
}
// Wait for all to start
time.Sleep(100 * time.Millisecond)
if counter.Load() != 5 {
t.Errorf("Expected 5 goroutines running, got %d", counter.Load())
}
// Shutdown with generous timeout
err := gm.Shutdown(time.Second)
if err != nil {
t.Errorf("Expected graceful shutdown, got error: %v", err)
}
if counter.Load() != 0 {
t.Errorf("Expected all goroutines to complete cleanup, got %d still running", counter.Load())
}
}
func TestShutdownWithTimeout(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
// Start a goroutine that ignores cancellation (bad behavior, but testing timeout)
gm.StartGoroutine("stubborn", func(ctx context.Context) {
// Simulate a goroutine that takes too long to stop
time.Sleep(500 * time.Millisecond)
})
time.Sleep(50 * time.Millisecond)
// Shutdown with very short timeout
err := gm.Shutdown(10 * time.Millisecond)
if err == nil {
t.Error("Expected timeout error")
}
if err != ErrShutdownTimeout {
t.Errorf("Expected ErrShutdownTimeout, got %v", err)
}
}
func TestShutdownEmpty(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
// Shutdown with no goroutines should succeed immediately
err := gm.Shutdown(time.Second)
if err != nil {
t.Errorf("Expected no error for empty shutdown, got: %v", err)
}
}
// Test Status
func TestGetStatus(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
defer gm.Shutdown(time.Second)
// Start multiple goroutines with different states
gm.StartGoroutine("running", func(ctx context.Context) {
<-ctx.Done()
})
gm.StartGoroutine("quick", func(ctx context.Context) {
// Exits immediately
})
time.Sleep(50 * time.Millisecond)
status := gm.GetStatus()
if len(status) != 2 {
t.Errorf("Expected 2 goroutines in status, got %d", len(status))
}
if runningStatus, exists := status["running"]; exists {
if !runningStatus.Running {
t.Error("Expected 'running' goroutine to be marked as running")
}
if runningStatus.Name != "running" {
t.Errorf("Expected name 'running', got %s", runningStatus.Name)
}
if runningStatus.StartTime.IsZero() {
t.Error("Expected non-zero start time")
}
if runningStatus.Runtime <= 0 {
t.Error("Expected positive runtime")
}
} else {
t.Error("Expected 'running' goroutine in status")
}
if quickStatus, exists := status["quick"]; exists {
if quickStatus.Running {
t.Error("Expected 'quick' goroutine to be marked as not running")
}
} else {
t.Error("Expected 'quick' goroutine in status")
}
}
func TestGetStatusEmpty(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
defer gm.Shutdown(time.Second)
status := gm.GetStatus()
if status == nil {
t.Fatal("Expected non-nil status map")
}
if len(status) != 0 {
t.Errorf("Expected empty status, got %d entries", len(status))
}
}
// Test Concurrent Operations
func TestConcurrentStartGoroutine(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
defer gm.Shutdown(2 * time.Second)
counter := atomic.Int32{}
const numGoroutines = 50
// Start many goroutines concurrently
for i := 0; i < numGoroutines; i++ {
go func(id int) {
name := "concurrent-" + string(rune('0'+id%10)) + string(rune('0'+id/10))
gm.StartGoroutine(name, func(ctx context.Context) {
counter.Add(1)
time.Sleep(50 * time.Millisecond)
counter.Add(-1)
})
}(i)
}
// Wait for all to start
time.Sleep(150 * time.Millisecond)
// Verify goroutines are tracked
status := gm.GetStatus()
if len(status) < numGoroutines/2 {
t.Errorf("Expected at least %d goroutines, got %d", numGoroutines/2, len(status))
}
}
func TestConcurrentStopGoroutine(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
defer gm.Shutdown(time.Second)
const numGoroutines = 20
// Start goroutines
for i := 0; i < numGoroutines; i++ {
name := "stop-concurrent-" + string(rune('0'+i%10))
gm.StartGoroutine(name, func(ctx context.Context) {
<-ctx.Done()
})
}
time.Sleep(50 * time.Millisecond)
// Stop all concurrently
for i := 0; i < numGoroutines; i++ {
go func(id int) {
name := "stop-concurrent-" + string(rune('0'+id%10))
gm.StopGoroutine(name)
}(i)
}
time.Sleep(100 * time.Millisecond)
// Verify all stopped
status := gm.GetStatus()
for _, s := range status {
if s.Running {
t.Errorf("Expected goroutine %s to be stopped", s.Name)
}
}
}
func TestConcurrentGetStatus(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
defer gm.Shutdown(time.Second)
// Start some goroutines
for i := 0; i < 10; i++ {
name := "status-test-" + string(rune('0'+i))
gm.StartGoroutine(name, func(ctx context.Context) {
<-ctx.Done()
})
}
// Concurrently read status many times (should not race)
done := make(chan struct{})
for i := 0; i < 20; i++ {
go func() {
for j := 0; j < 100; j++ {
_ = gm.GetStatus()
}
done <- struct{}{}
}()
}
// Wait for all concurrent reads
for i := 0; i < 20; i++ {
<-done
}
}
// Test Error Cases
func TestShutdownTimeoutError(t *testing.T) {
err := ErrShutdownTimeout
if err.Error() != "shutdown timeout: some goroutines did not stop in time" {
t.Errorf("Unexpected error message: %s", err.Error())
}
}
// Test Edge Cases
func TestStartGoroutineAfterShutdown(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
// Shutdown immediately
_ = gm.Shutdown(time.Second)
executed := atomic.Bool{}
// Try to start goroutine after shutdown
gm.StartGoroutine("after-shutdown", func(ctx context.Context) {
executed.Store(true)
<-ctx.Done()
})
time.Sleep(50 * time.Millisecond)
// Goroutine should have started but context already canceled
// It may or may not execute depending on timing, but shouldn't panic
status := gm.GetStatus()
if _, exists := status["after-shutdown"]; exists {
// If it's in status, it was tracked (acceptable)
t.Log("Goroutine was tracked even after shutdown")
}
}
func TestMultipleShutdowns(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
// First shutdown
err1 := gm.Shutdown(time.Second)
if err1 != nil {
t.Errorf("Expected first shutdown to succeed, got: %v", err1)
}
// Second shutdown (should not panic or error)
err2 := gm.Shutdown(time.Second)
if err2 != nil {
t.Errorf("Expected second shutdown to succeed, got: %v", err2)
}
}
func TestGoroutineWithImmediateReturn(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
defer gm.Shutdown(time.Second)
executed := atomic.Bool{}
gm.StartGoroutine("immediate", func(ctx context.Context) {
executed.Store(true)
// Return immediately
})
time.Sleep(50 * time.Millisecond)
if !executed.Load() {
t.Error("Expected goroutine to execute")
}
status := gm.GetStatus()
if goroutineStatus, exists := status["immediate"]; exists {
if goroutineStatus.Running {
t.Error("Expected immediately-returning goroutine to be marked as not running")
}
}
}
func TestPeriodicTaskPanicRecovery(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
defer gm.Shutdown(time.Second)
counter := atomic.Int32{}
gm.StartPeriodicTask("panic-periodic", 50*time.Millisecond, func() {
counter.Add(1)
if counter.Load() == 2 {
panic("periodic panic")
}
})
// Wait for panic to occur
time.Sleep(200 * time.Millisecond)
// After panic, the goroutine should have stopped
status := gm.GetStatus()
if goroutineStatus, exists := status["panic-periodic"]; exists {
if goroutineStatus.Running {
t.Error("Expected panicked periodic task to stop")
}
}
}
-764
View File
@@ -1,764 +0,0 @@
package handlers
import (
"errors"
"net/http"
"sync"
"testing"
"time"
)
// ============================================================================
// OAuth Handler Tests
// ============================================================================
func TestOAuthHandler(t *testing.T) {
t.Run("HandleAuthorizationRequest", func(t *testing.T) {
// Test authorization request handling logic
logger := &MockLogger{}
tests := []struct {
name string
requestURL string
expectedStatus int
checkLocation bool
}{
{
name: "Valid authorization request",
requestURL: "/auth/login",
expectedStatus: http.StatusFound,
checkLocation: true,
},
{
name: "With return URL",
requestURL: "/auth/login?return=/dashboard",
expectedStatus: http.StatusFound,
checkLocation: true,
},
}
// Test the test case structure
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
// Verify test case parameters
if test.requestURL == "" {
t.Error("Request URL should not be empty")
}
if test.expectedStatus == 0 {
t.Error("Expected status should be set")
}
// In a real implementation, this would test the actual handler
t.Logf("Testing %s with URL %s expecting status %d", test.name, test.requestURL, test.expectedStatus)
})
}
// Verify logger doesn't cause issues
logger.Debugf("Authorization request test completed")
})
t.Run("HandleCallbackRequest", func(t *testing.T) {
// Test callback request handling with existing mocks
sessionManager := NewMockSessionManager()
logger := &MockLogger{}
tests := []struct {
name string
queryParams string
expectedStatus int
expectError bool
}{
{
name: "Valid callback with code",
queryParams: "code=test-code&state=test-state",
expectedStatus: http.StatusFound,
expectError: false,
},
{
name: "Callback with error",
queryParams: "error=access_denied&error_description=User denied access",
expectedStatus: http.StatusBadRequest,
expectError: true,
},
{
name: "Missing code",
queryParams: "state=test-state",
expectedStatus: http.StatusBadRequest,
expectError: true,
},
{
name: "Missing state",
queryParams: "code=test-code",
expectedStatus: http.StatusBadRequest,
expectError: true,
},
}
// Test the callback scenarios
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
// Verify test case parameters
if test.queryParams == "" && !test.expectError {
t.Error("Query params should not be empty for successful cases")
}
if test.expectedStatus == 0 {
t.Error("Expected status should be set")
}
// Test session manager functionality
if sessionManager != nil {
t.Logf("Session manager available for test %s", test.name)
}
t.Logf("Testing %s with params %s expecting status %d", test.name, test.queryParams, test.expectedStatus)
})
}
// Verify logger doesn't cause issues
logger.Debugf("Callback request test completed")
})
t.Run("HandleLogout", func(t *testing.T) {
// Test logout functionality with mock implementations
sessionManager := NewMockSessionManager()
logger := &MockLogger{}
// Test session clearing
mockReq := &http.Request{}
session, err := sessionManager.GetSession(mockReq)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
// Set up authenticated session
err = session.SetAuthenticated(true)
if err != nil {
t.Fatalf("Failed to set authentication: %v", err)
}
session.SetIDToken("test-token")
// Verify session is authenticated
if !session.GetAuthenticated() {
t.Error("Session should be authenticated before logout")
}
// Test logout by clearing session
// session.Clear() // Method not implemented in SessionData
// Additional logout verification would go here
// Verify logger doesn't cause issues
logger.Debugf("Logout test completed")
t.Log("Logout test completed successfully")
})
}
// ============================================================================
// Auth Handler Tests
// ============================================================================
func TestAuthHandler(t *testing.T) {
t.Run("HandleAuthentication", func(t *testing.T) {
// Test authentication handling with mock types
// validator := &MockTokenValidator{valid: true} // Currently unused
/*
handler := &MockAuthHandler{
logger: &MockLogger{},
sessionManager: NewMockSessionManager(),
}
*/
tests := []struct {
name string
setupSession func(*MockSession)
expectedStatus int
expectNext bool
}{
{
name: "Authenticated user",
setupSession: func(s *MockSession) {
s.SetAuthenticated(true)
s.SetIDToken("valid-token")
},
expectedStatus: http.StatusOK,
expectNext: true,
},
{
name: "Unauthenticated user",
setupSession: func(s *MockSession) {
s.SetAuthenticated(false)
},
expectedStatus: http.StatusUnauthorized,
expectNext: false,
},
{
name: "Expired token",
setupSession: func(s *MockSession) {
s.SetAuthenticated(true)
s.SetIDToken("expired-token")
},
expectedStatus: http.StatusUnauthorized,
expectNext: false,
},
}
// Test the authentication test cases
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
// Test with mock session
mockSession := &MockSession{values: make(map[string]interface{})}
// Use mock session to avoid unused variable error
_ = mockSession
t.Logf("Testing %s", test.name)
})
}
})
t.Run("HandleRefreshToken", func(t *testing.T) {
// Test authentication handling with mock types
// validator := &MockTokenValidator{valid: true} // Currently unused
tests := []struct {
name string
refreshToken string
mockResponse *MockTokenResponse
mockError error
expectSuccess bool
}{
{
name: "Successful refresh",
refreshToken: "valid-refresh-token",
mockResponse: &MockTokenResponse{
AccessToken: "new-access-token",
IDToken: "new-id-token",
RefreshToken: "new-refresh-token",
},
expectSuccess: true,
},
{
name: "Failed refresh",
refreshToken: "invalid-refresh-token",
mockError: errors.New("invalid_grant"),
expectSuccess: false,
},
{
name: "Empty refresh token",
refreshToken: "",
expectSuccess: false,
},
}
// Test the authentication test cases
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
// Test with mock session
mockSession := &MockSession{values: make(map[string]interface{})}
// Use mock session to avoid unused variable error
_ = mockSession
t.Logf("Testing %s", test.name)
})
}
})
}
// ============================================================================
// Error Handler Tests
// ============================================================================
func TestErrorHandler(t *testing.T) {
t.Run("HandleHTTPErrors", func(t *testing.T) {
// Test with mock implementations
/*
handler := &MockErrorHandler{
logger: &MockLogger{},
}
*/
tests := []struct {
name string
errorCode int
errorMessage string
isAjax bool
expectedStatus int
expectedBody string
}{
{
name: "401 Unauthorized",
errorCode: http.StatusUnauthorized,
errorMessage: "Authentication required",
isAjax: false,
expectedStatus: http.StatusUnauthorized,
expectedBody: "Authentication required",
},
{
name: "403 Forbidden",
errorCode: http.StatusForbidden,
errorMessage: "Access denied",
isAjax: false,
expectedStatus: http.StatusForbidden,
expectedBody: "Access denied",
},
{
name: "500 Internal Server Error",
errorCode: http.StatusInternalServerError,
errorMessage: "Internal server error",
isAjax: false,
expectedStatus: http.StatusInternalServerError,
expectedBody: "Internal server error",
},
{
name: "Ajax 401",
errorCode: http.StatusUnauthorized,
errorMessage: "Token expired",
isAjax: true,
expectedStatus: http.StatusUnauthorized,
expectedBody: `{"error":"unauthorized","message":"Token expired"}`,
},
}
// Test the authentication test cases
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
// Test with mock session
mockSession := &MockSession{values: make(map[string]interface{})}
// Use mock session to avoid unused variable error
_ = mockSession
t.Logf("Testing %s", test.name)
})
}
})
t.Run("RecoverFromPanic", func(t *testing.T) {
// Test with mock implementations
/*
handler := &MockErrorHandler{
logger: &MockLogger{},
}
*/
tests := []struct {
name string
panicValue interface{}
expectError bool
}{
{
name: "String panic",
panicValue: "something went wrong",
expectError: true,
},
{
name: "Error panic",
panicValue: errors.New("critical error"),
expectError: true,
},
{
name: "Nil panic",
panicValue: nil,
expectError: false,
},
}
// Test the authentication test cases
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
// Test with mock session
mockSession := &MockSession{values: make(map[string]interface{})}
// Use mock session to avoid unused variable error
_ = mockSession
t.Logf("Testing %s", test.name)
})
}
})
}
// ============================================================================
// Azure OAuth Callback Tests
// ============================================================================
func TestAzureOAuthCallback(t *testing.T) {
t.Run("AzureSpecificClaims", func(t *testing.T) {
// Test with mock configuration
/*
handler := &OAuthHandler{
logger: &MockLogger{},
sessionManager: NewMockSessionManager(),
}
*/
azureClaims := map[string]interface{}{
"oid": "object-id",
"tid": "tenant-id",
"preferred_username": "user@example.com",
"name": "Test User",
"email": "user@example.com",
"groups": []string{"group1", "group2"},
}
// Test would go here when properly implemented
_ = azureClaims
})
t.Run("AzureTokenValidation", func(t *testing.T) {
// Test with mock validator types
/*
validator := &MockAzureTokenValidator{
tenantID: "test-tenant",
clientID: "test-client",
}
*/
tests := []struct {
name string
token string
claims map[string]interface{}
expectValid bool
}{
{
name: "Valid Azure token",
token: "valid-azure-token",
claims: map[string]interface{}{
"aud": "test-client",
"tid": "test-tenant",
"exp": float64(time.Now().Add(time.Hour).Unix()),
},
expectValid: true,
},
{
name: "Wrong tenant",
token: "wrong-tenant-token",
claims: map[string]interface{}{
"aud": "test-client",
"tid": "wrong-tenant",
"exp": float64(time.Now().Add(time.Hour).Unix()),
},
expectValid: false,
},
{
name: "Wrong audience",
token: "wrong-audience-token",
claims: map[string]interface{}{
"aud": "wrong-client",
"tid": "test-tenant",
"exp": float64(time.Now().Add(time.Hour).Unix()),
},
expectValid: false,
},
}
// Test the authentication test cases
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
// Test with mock session
mockSession := &MockSession{values: make(map[string]interface{})}
// Use mock session to avoid unused variable error
_ = mockSession
t.Logf("Testing %s", test.name)
})
}
})
}
// ============================================================================
// Concurrent Handler Tests
// ============================================================================
func TestConcurrentHandlers(t *testing.T) {
t.Run("ConcurrentCallbacks", func(t *testing.T) {
// Test with mock configuration
/*
handler := &OAuthHandler{
logger: &MockLogger{},
sessionManager: NewMockSessionManager(),
}
*/
var wg sync.WaitGroup
successCount := int32(0)
errorCount := int32(0)
// Test would go here when properly implemented
wg.Wait() // Proper usage instead of assignment
_ = successCount
_ = errorCount
})
t.Run("ConcurrentLogouts", func(t *testing.T) {
// Test with mock configuration
/*
handler := &OAuthHandler{
logger: &MockLogger{},
sessionManager: NewMockSessionManager(),
}
*/
var wg sync.WaitGroup
logoutCount := int32(0)
// Test would go here when properly implemented
wg.Wait() // Proper usage instead of assignment
_ = logoutCount
})
}
// ============================================================================
// Mock Implementations
// ============================================================================
type MockSessionManager struct {
sessions map[string]*MockSession
mu sync.RWMutex
}
func NewMockSessionManager() *MockSessionManager {
return &MockSessionManager{
sessions: make(map[string]*MockSession),
}
}
func (m *MockSessionManager) GetSession(r *http.Request) (SessionData, error) {
m.mu.Lock()
defer m.mu.Unlock()
sessionID := "test-session"
if session, exists := m.sessions[sessionID]; exists {
return session, nil
}
session := &MockSession{
values: make(map[string]interface{}),
}
m.sessions[sessionID] = session
return session, nil
}
type MockSession struct {
values map[string]interface{}
mu sync.RWMutex
}
func (s *MockSession) SetAuthenticated(auth bool) error {
s.mu.Lock()
defer s.mu.Unlock()
s.values["authenticated"] = auth
return nil
}
func (s *MockSession) GetAuthenticated() bool {
s.mu.RLock()
defer s.mu.RUnlock()
auth, ok := s.values["authenticated"].(bool)
return ok && auth
}
func (s *MockSession) SetIDToken(token string) {
s.mu.Lock()
defer s.mu.Unlock()
s.values["id_token"] = token
}
func (s *MockSession) GetIDToken() string {
s.mu.RLock()
defer s.mu.RUnlock()
token, _ := s.values["id_token"].(string)
return token
}
func (s *MockSession) SetAccessToken(token string) {
s.mu.Lock()
defer s.mu.Unlock()
s.values["access_token"] = token
}
func (s *MockSession) GetAccessToken() string {
s.mu.RLock()
defer s.mu.RUnlock()
token, _ := s.values["access_token"].(string)
return token
}
func (s *MockSession) SetRefreshToken(token string) {
s.mu.Lock()
defer s.mu.Unlock()
s.values["refresh_token"] = token
}
func (s *MockSession) GetRefreshToken() string {
s.mu.RLock()
defer s.mu.RUnlock()
token, _ := s.values["refresh_token"].(string)
return token
}
func (s *MockSession) SetState(state string) {
s.mu.Lock()
defer s.mu.Unlock()
s.values["state"] = state
}
func (s *MockSession) GetState() string {
s.mu.RLock()
defer s.mu.RUnlock()
state, _ := s.values["state"].(string)
return state
}
func (s *MockSession) SetClaims(claims map[string]interface{}) {
s.mu.Lock()
defer s.mu.Unlock()
s.values["claims"] = claims
}
func (s *MockSession) GetClaims() map[string]interface{} {
s.mu.RLock()
defer s.mu.RUnlock()
claims, _ := s.values["claims"].(map[string]interface{})
return claims
}
// Additional SessionData interface methods to match real interface
func (s *MockSession) GetCSRF() string {
s.mu.RLock()
defer s.mu.RUnlock()
csrf, _ := s.values["csrf"].(string)
return csrf
}
func (s *MockSession) GetNonce() string {
s.mu.RLock()
defer s.mu.RUnlock()
nonce, _ := s.values["nonce"].(string)
return nonce
}
func (s *MockSession) GetCodeVerifier() string {
s.mu.RLock()
defer s.mu.RUnlock()
verifier, _ := s.values["code_verifier"].(string)
return verifier
}
func (s *MockSession) GetIncomingPath() string {
s.mu.RLock()
defer s.mu.RUnlock()
path, _ := s.values["incoming_path"].(string)
return path
}
func (s *MockSession) SetEmail(email string) {
s.mu.Lock()
defer s.mu.Unlock()
s.values["email"] = email
}
func (s *MockSession) GetEmail() string {
s.mu.RLock()
defer s.mu.RUnlock()
email, _ := s.values["email"].(string)
return email
}
func (s *MockSession) SetCSRF(csrf string) {
s.mu.Lock()
defer s.mu.Unlock()
s.values["csrf"] = csrf
}
func (s *MockSession) SetNonce(nonce string) {
s.mu.Lock()
defer s.mu.Unlock()
s.values["nonce"] = nonce
}
func (s *MockSession) SetCodeVerifier(verifier string) {
s.mu.Lock()
defer s.mu.Unlock()
s.values["code_verifier"] = verifier
}
func (s *MockSession) SetIncomingPath(path string) {
s.mu.Lock()
defer s.mu.Unlock()
s.values["incoming_path"] = path
}
func (s *MockSession) ResetRedirectCount() {
s.mu.Lock()
defer s.mu.Unlock()
s.values["redirect_count"] = 0
}
func (s *MockSession) Save(r *http.Request, w http.ResponseWriter) error {
return nil
}
func (s *MockSession) Clear() {
s.mu.Lock()
defer s.mu.Unlock()
s.values = make(map[string]interface{})
}
func (s *MockSession) returnToPoolSafely() {
// No-op for mock
}
type MockTokenValidator struct {
valid bool
}
func (v *MockTokenValidator) Validate(token string) bool {
if token == "expired-token" {
return false
}
return v.valid
}
// ============================================================================
// Mock Handler Type Definitions (for testing)
// ============================================================================
// These mock handlers are simplified versions for testing purposes
// They don't match the actual handler implementations
type MockAuthHandler struct{}
type MockErrorHandler struct{}
type MockAzureTokenValidator struct {
tenantID string
clientID string
}
func (v *MockAzureTokenValidator) ValidateAzureToken(token string, claims map[string]interface{}) bool {
// Validate tenant ID
if tid, ok := claims["tid"].(string); !ok || tid != v.tenantID {
return false
}
// Validate audience
if aud, ok := claims["aud"].(string); !ok || aud != v.clientID {
return false
}
// Validate expiration
if exp, ok := claims["exp"].(float64); ok {
if time.Now().Unix() > int64(exp) {
return false
}
}
return true
}
// ============================================================================
// Helper Types and Mock Logger
// ============================================================================
type MockLogger struct{}
func (l *MockLogger) Debugf(format string, args ...interface{}) {}
func (l *MockLogger) Errorf(format string, args ...interface{}) {}
func (l *MockLogger) Error(msg string) {}
type MockTokenResponse struct {
AccessToken string `json:"access_token"`
IDToken string `json:"id_token"`
RefreshToken string `json:"refresh_token"`
}
-308
View File
@@ -1,308 +0,0 @@
// Package handlers provides HTTP request handlers for the OIDC middleware.
package handlers
import (
"context"
"fmt"
"net/http"
"strings"
)
// OAuthHandler handles OAuth callback requests
type OAuthHandler struct {
logger Logger
sessionManager SessionManager
tokenExchanger TokenExchanger
tokenVerifier TokenVerifier
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
isAllowedDomainFunc func(email string) bool
redirURLPath string
sendErrorResponseFunc func(rw http.ResponseWriter, req *http.Request, message string, code int)
}
// Logger interface for dependency injection
type Logger interface {
Debugf(format string, args ...interface{})
Errorf(format string, args ...interface{})
Error(msg string)
}
// SessionManager interface for session operations
type SessionManager interface {
GetSession(req *http.Request) (SessionData, error)
}
// SessionData interface for session data operations
type SessionData interface {
GetCSRF() string
GetNonce() string
GetCodeVerifier() string
GetIncomingPath() string
GetAuthenticated() bool
GetAccessToken() string
GetRefreshToken() string
GetIDToken() string
GetEmail() string
SetAuthenticated(bool) error
SetEmail(string)
SetIDToken(string)
SetAccessToken(string)
SetRefreshToken(string)
SetCSRF(string)
SetNonce(string)
SetCodeVerifier(string)
SetIncomingPath(string)
ResetRedirectCount()
Save(req *http.Request, rw http.ResponseWriter) error
returnToPoolSafely()
}
// TokenExchanger interface for token operations
type TokenExchanger interface {
ExchangeCodeForToken(ctx context.Context, grantType string, codeOrToken string, redirectURL string, codeVerifier string) (*TokenResponse, error)
}
// TokenVerifier interface for token verification
type TokenVerifier interface {
VerifyToken(token string) error
}
// TokenResponse represents the response from token exchange
type TokenResponse struct {
IDToken string
AccessToken string
RefreshToken string
}
// NewOAuthHandler creates a new OAuth handler
func NewOAuthHandler(logger Logger, sessionManager SessionManager, tokenExchanger TokenExchanger,
tokenVerifier TokenVerifier, extractClaimsFunc func(string) (map[string]interface{}, error),
isAllowedDomainFunc func(string) bool, redirURLPath string,
sendErrorResponseFunc func(http.ResponseWriter, *http.Request, string, int)) *OAuthHandler {
return &OAuthHandler{
logger: logger,
sessionManager: sessionManager,
tokenExchanger: tokenExchanger,
tokenVerifier: tokenVerifier,
extractClaimsFunc: extractClaimsFunc,
isAllowedDomainFunc: isAllowedDomainFunc,
redirURLPath: redirURLPath,
sendErrorResponseFunc: sendErrorResponseFunc,
}
}
// HandleCallback handles OAuth callback requests
func (h *OAuthHandler) HandleCallback(rw http.ResponseWriter, req *http.Request, redirectURL string) {
session, err := h.sessionManager.GetSession(req)
if err != nil {
h.logger.Errorf("Session error during callback: %v", err)
h.sendErrorResponseFunc(rw, req, "Session error during callback", http.StatusInternalServerError)
return
}
defer session.returnToPoolSafely()
h.logger.Debugf("Handling callback, URL: %s", req.URL.String())
// Debug logging for cookie configuration
h.logger.Debugf("Callback request headers - Host: %s, X-Forwarded-Host: %s, X-Forwarded-Proto: %s",
req.Host, req.Header.Get("X-Forwarded-Host"), req.Header.Get("X-Forwarded-Proto"))
// Log all cookies in the request for debugging
cookies := req.Cookies()
h.logger.Debugf("Total cookies in callback request: %d", len(cookies))
for _, cookie := range cookies {
if strings.HasPrefix(cookie.Name, "_oidc_") {
h.logger.Debugf("Cookie found - Name: %s, Domain: %s, Path: %s, SameSite: %v, Secure: %v, HttpOnly: %v, Value length: %d",
cookie.Name, cookie.Domain, cookie.Path, cookie.SameSite, cookie.Secure, cookie.HttpOnly, len(cookie.Value))
}
}
if req.URL.Query().Get("error") != "" {
errorDescription := req.URL.Query().Get("error_description")
if errorDescription == "" {
errorDescription = req.URL.Query().Get("error")
}
h.logger.Errorf("Authentication error from provider during callback: %s - %s", req.URL.Query().Get("error"), errorDescription)
h.sendErrorResponseFunc(rw, req, fmt.Sprintf("Authentication error from provider: %s", errorDescription), http.StatusBadRequest)
return
}
state := req.URL.Query().Get("state")
if state == "" {
h.logger.Error("No state in callback")
h.sendErrorResponseFunc(rw, req, "State parameter missing in callback", http.StatusBadRequest)
return
}
// Debug log the state parameter received
h.logger.Debugf("State parameter received in callback: %s (length: %d)", state, len(state))
csrfToken := session.GetCSRF()
if csrfToken == "" {
h.logger.Errorf("CSRF token missing in session during callback. Authenticated: %v, Request URL: %s",
session.GetAuthenticated(), req.URL.String())
// Enhanced debugging for missing CSRF token
cookie, err := req.Cookie("_oidc_raczylo_m")
if err != nil {
h.logger.Errorf("Main session cookie not found in request: %v", err)
h.logger.Debugf("Available cookies: %v", req.Header.Get("Cookie"))
} else {
h.logger.Errorf("Main session cookie exists but CSRF token is empty. Cookie value length: %d", len(cookie.Value))
h.logger.Debugf("Cookie details - Domain: %s, Path: %s, Secure: %v, HttpOnly: %v, SameSite: %v",
cookie.Domain, cookie.Path, cookie.Secure, cookie.HttpOnly, cookie.SameSite)
}
// Log session state for debugging
h.logger.Debugf("Session state during CSRF check - Authenticated: %v, Has AccessToken: %v",
session.GetAuthenticated(), session.GetAccessToken() != "")
h.sendErrorResponseFunc(rw, req, "CSRF token missing in session", http.StatusBadRequest)
return
}
// Debug log successful CSRF token retrieval
h.logger.Debugf("CSRF token retrieved from session: %s (length: %d)", csrfToken, len(csrfToken))
if state != csrfToken {
h.logger.Error("State parameter does not match CSRF token in session during callback")
h.sendErrorResponseFunc(rw, req, "Invalid state parameter (CSRF mismatch)", http.StatusBadRequest)
return
}
code := req.URL.Query().Get("code")
if code == "" {
h.logger.Error("No code in callback")
h.sendErrorResponseFunc(rw, req, "No authorization code received in callback", http.StatusBadRequest)
return
}
codeVerifier := session.GetCodeVerifier()
tokenResponse, err := h.tokenExchanger.ExchangeCodeForToken(req.Context(), "authorization_code", code, redirectURL, codeVerifier)
if err != nil {
h.logger.Errorf("Failed to exchange code for token during callback: %v", err)
h.sendErrorResponseFunc(rw, req, "Authentication failed: Could not exchange code for token", http.StatusInternalServerError)
return
}
if err = h.tokenVerifier.VerifyToken(tokenResponse.IDToken); err != nil {
h.logger.Errorf("Failed to verify id_token during callback: %v", err)
h.sendErrorResponseFunc(rw, req, "Authentication failed: Could not verify ID token", http.StatusInternalServerError)
return
}
claims, err := h.extractClaimsFunc(tokenResponse.IDToken)
if err != nil {
h.logger.Errorf("Failed to extract claims during callback: %v", err)
h.sendErrorResponseFunc(rw, req, "Authentication failed: Could not extract claims from token", http.StatusInternalServerError)
return
}
nonceClaim, ok := claims["nonce"].(string)
if !ok || nonceClaim == "" {
h.logger.Error("Nonce claim missing in id_token during callback")
h.sendErrorResponseFunc(rw, req, "Authentication failed: Nonce missing in token", http.StatusInternalServerError)
return
}
sessionNonce := session.GetNonce()
if sessionNonce == "" {
h.logger.Error("Nonce not found in session during callback")
h.sendErrorResponseFunc(rw, req, "Authentication failed: Nonce missing in session", http.StatusInternalServerError)
return
}
if nonceClaim != sessionNonce {
h.logger.Error("Nonce claim does not match session nonce during callback")
h.sendErrorResponseFunc(rw, req, "Authentication failed: Nonce mismatch", http.StatusInternalServerError)
return
}
email, _ := claims["email"].(string)
if email == "" {
h.logger.Errorf("Email claim missing or empty in token during callback")
h.sendErrorResponseFunc(rw, req, "Authentication failed: Email missing in token", http.StatusInternalServerError)
return
}
if !h.isAllowedDomainFunc(email) {
h.logger.Errorf("Disallowed email domain during callback: %s", email)
h.sendErrorResponseFunc(rw, req, "Authentication failed: Email domain not allowed", http.StatusForbidden)
return
}
if err := session.SetAuthenticated(true); err != nil {
h.logger.Errorf("Failed to set authenticated state and regenerate session ID: %v", err)
h.sendErrorResponseFunc(rw, req, "Failed to update session", http.StatusInternalServerError)
return
}
session.SetEmail(email)
session.SetIDToken(tokenResponse.IDToken)
session.SetAccessToken(tokenResponse.AccessToken)
session.SetRefreshToken(tokenResponse.RefreshToken)
session.SetCSRF("")
session.SetNonce("")
session.SetCodeVerifier("")
session.ResetRedirectCount()
redirectPath := "/"
if incomingPath := session.GetIncomingPath(); incomingPath != "" && incomingPath != h.redirURLPath {
redirectPath = incomingPath
}
session.SetIncomingPath("")
if err := session.Save(req, rw); err != nil {
h.logger.Errorf("Failed to save session after callback: %v", err)
h.sendErrorResponseFunc(rw, req, "Failed to save session after callback", http.StatusInternalServerError)
return
}
h.logger.Debugf("Callback successful, redirecting to %s", redirectPath)
http.Redirect(rw, req, redirectPath, http.StatusFound)
}
// URLHelper provides utility methods for URL operations
type URLHelper struct {
logger Logger
}
// NewURLHelper creates a new URL helper
func NewURLHelper(logger Logger) *URLHelper {
return &URLHelper{logger: logger}
}
// DetermineExcludedURL checks if a URL path should bypass OIDC authentication.
// It compares the request path against configured excluded URL prefixes.
func (h *URLHelper) DetermineExcludedURL(currentRequest string, excludedURLs map[string]struct{}) bool {
for excludedURL := range excludedURLs {
if strings.HasPrefix(currentRequest, excludedURL) {
h.logger.Debugf("URL is excluded - got %s / excluded hit: %s", currentRequest, excludedURL)
return true
}
}
return false
}
// DetermineScheme determines the URL scheme for building redirect URLs.
// It checks X-Forwarded-Proto header first, then TLS presence.
func (h *URLHelper) DetermineScheme(req *http.Request) string {
if scheme := req.Header.Get("X-Forwarded-Proto"); scheme != "" {
return scheme
}
if req.TLS != nil {
return "https"
}
return "http"
}
// DetermineHost determines the host for building redirect URLs.
// It checks X-Forwarded-Host header first, then falls back to req.Host.
func (h *URLHelper) DetermineHost(req *http.Request) string {
if host := req.Header.Get("X-Forwarded-Host"); host != "" {
return host
}
return req.Host
}
+22 -10
View File
@@ -13,6 +13,8 @@ import (
"net/url"
"strings"
"time"
"github.com/lukaszraczylo/traefikoidc/internal/utils"
)
// generateNonce creates a cryptographically secure random nonce for OIDC flows.
@@ -109,7 +111,7 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType string, code
client := t.tokenHTTPClient
if client == nil {
// Use shared transport pool to prevent memory leaks
jar, _ := cookiejar.New(nil)
jar, _ := cookiejar.New(nil) // Safe to ignore: cookiejar creation with nil options rarely fails
pooledClient := CreateTokenHTTPClient()
client = &http.Client{
Transport: pooledClient.Transport,
@@ -124,7 +126,12 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType string, code
}
}
req, err := http.NewRequestWithContext(ctx, "POST", t.tokenURL, strings.NewReader(data.Encode()))
// Read tokenURL with RLock
t.metadataMu.RLock()
tokenURL := t.tokenURL
t.metadataMu.RUnlock()
req, err := http.NewRequestWithContext(ctx, "POST", tokenURL, strings.NewReader(data.Encode()))
if err != nil {
return nil, fmt.Errorf("failed to create token request: %w", err)
}
@@ -135,13 +142,13 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType string, code
return nil, fmt.Errorf("failed to exchange tokens: %w", err)
}
defer func() {
io.Copy(io.Discard, resp.Body)
resp.Body.Close()
_, _ = io.Copy(io.Discard, resp.Body) // Safe to ignore: draining response body on defer
_ = resp.Body.Close() // Safe to ignore: closing body on defer
}()
if resp.StatusCode != http.StatusOK {
limitReader := io.LimitReader(resp.Body, 1024*10)
bodyBytes, _ := io.ReadAll(limitReader)
bodyBytes, _ := io.ReadAll(limitReader) // Safe to ignore: reading error body for diagnostics
return nil, fmt.Errorf("token endpoint returned status %d: %s", resp.StatusCode, string(bodyBytes))
}
@@ -232,7 +239,7 @@ func NewTokenCache() *TokenCache {
// - expiration: The duration for which the cache entry should be valid
func (tc *TokenCache) Set(token string, claims map[string]interface{}, expiration time.Duration) {
token = "t-" + token
tc.cache.Set(token, claims, expiration)
_ = tc.cache.Set(token, claims, expiration) // Safe to ignore: cache failures are non-critical
}
// Get retrieves cached claims for a token.
@@ -344,8 +351,8 @@ func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
return
}
host := t.determineHost(req)
scheme := t.determineScheme(req)
host := utils.DetermineHost(req)
scheme := utils.DetermineScheme(req, t.forceHTTPS)
baseURL := fmt.Sprintf("%s://%s", scheme, host)
postLogoutRedirectURI := t.postLogoutRedirectURI
@@ -355,8 +362,13 @@ func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
postLogoutRedirectURI = fmt.Sprintf("%s%s", baseURL, postLogoutRedirectURI)
}
if t.endSessionURL != "" && idToken != "" {
logoutURL, err := BuildLogoutURL(t.endSessionURL, idToken, postLogoutRedirectURI)
// Read endSessionURL with RLock
t.metadataMu.RLock()
endSessionURL := t.endSessionURL
t.metadataMu.RUnlock()
if endSessionURL != "" && idToken != "" {
logoutURL, err := BuildLogoutURL(endSessionURL, idToken, postLogoutRedirectURI)
if err != nil {
t.logger.Errorf("Failed to build logout URL: %v", err)
http.Error(rw, "Logout error", http.StatusInternalServerError)
+29 -24
View File
@@ -12,30 +12,23 @@ import (
// HTTPClientConfig provides configuration for creating HTTP clients
type HTTPClientConfig struct {
// Timeout for the entire request
Timeout time.Duration
// MaxRedirects allowed (0 means follow Go's default of 10)
MaxRedirects int
// UseCookieJar enables cookie jar for the client
UseCookieJar bool
// Connection settings
IdleConnTimeout time.Duration
MaxIdleConns int
ReadBufferSize int
DialTimeout time.Duration
KeepAlive time.Duration
TLSHandshakeTimeout time.Duration
ResponseHeaderTimeout time.Duration
ExpectContinueTimeout time.Duration
IdleConnTimeout time.Duration
// Connection pool settings
MaxIdleConns int
MaxIdleConnsPerHost int
MaxConnsPerHost int
// Buffer settings
WriteBufferSize int
ReadBufferSize int
// Feature flags
ForceHTTP2 bool
DisableKeepAlives bool
DisableCompression bool
MaxRedirects int
MaxIdleConnsPerHost int
Timeout time.Duration
MaxConnsPerHost int
WriteBufferSize int
UseCookieJar bool
ForceHTTP2 bool
DisableKeepAlives bool
DisableCompression bool
}
// DefaultHTTPClientConfig returns the default configuration for general use
@@ -49,10 +42,10 @@ func DefaultHTTPClientConfig() HTTPClientConfig {
TLSHandshakeTimeout: 2 * time.Second,
ResponseHeaderTimeout: 3 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
IdleConnTimeout: 5 * time.Second,
MaxIdleConns: 20, // SECURITY FIX: Reduced from 100 to limit resource usage
MaxIdleConnsPerHost: 2, // SECURITY FIX: Reduced from 10 to prevent connection exhaustion
MaxConnsPerHost: 5, // SECURITY FIX: Reduced from 10 to limit concurrent connections
IdleConnTimeout: 30 * time.Second, // OPTIMIZATION: Increased for better connection reuse
MaxIdleConns: 50, // OPTIMIZATION: Increased from 20 for better connection pooling
MaxIdleConnsPerHost: 10, // OPTIMIZATION: Increased from 2 for better connection reuse
MaxConnsPerHost: 20, // OPTIMIZATION: Increased from 5 while maintaining security
WriteBufferSize: 4096,
ReadBufferSize: 4096,
ForceHTTP2: true,
@@ -70,6 +63,18 @@ func TokenHTTPClientConfig() HTTPClientConfig {
return config
}
// OIDCProviderHTTPClientConfig returns configuration optimized for OIDC provider calls
func OIDCProviderHTTPClientConfig() HTTPClientConfig {
config := DefaultHTTPClientConfig()
config.Timeout = 15 * time.Second // Slightly longer for OIDC operations
config.MaxIdleConns = 100 // Higher pool for frequent OIDC calls
config.MaxIdleConnsPerHost = 25 // More connections per OIDC provider
config.MaxConnsPerHost = 50 // Allow more concurrent requests to OIDC provider
config.IdleConnTimeout = 90 * time.Second // Keep connections alive longer for reuse
config.UseCookieJar = true // Enable cookie jar for session management
return config
}
// HTTPClientFactory provides methods for creating configured HTTP clients
type HTTPClientFactory struct{}
@@ -233,7 +238,7 @@ func (f *HTTPClientFactory) CreateHTTPClient(config HTTPClientConfig) *http.Clie
// Add cookie jar if requested
if config.UseCookieJar {
jar, _ := cookiejar.New(nil)
jar, _ := cookiejar.New(nil) // Safe to ignore: cookiejar creation with nil options rarely fails
client.Jar = jar
}
+210
View File
@@ -0,0 +1,210 @@
package traefikoidc
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestOIDCProviderHTTPClientConfigUnit tests OIDCProviderHTTPClientConfig function
func TestOIDCProviderHTTPClientConfigUnit(t *testing.T) {
config := OIDCProviderHTTPClientConfig()
// Verify OIDC-specific settings
assert.Equal(t, 15*time.Second, config.Timeout, "OIDC provider should have 15s timeout")
assert.Equal(t, 100, config.MaxIdleConns, "OIDC provider should have 100 max idle conns")
assert.Equal(t, 25, config.MaxIdleConnsPerHost, "OIDC provider should have 25 max idle conns per host")
assert.Equal(t, 50, config.MaxConnsPerHost, "OIDC provider should have 50 max conns per host")
assert.Equal(t, 90*time.Second, config.IdleConnTimeout, "OIDC provider should have 90s idle conn timeout")
assert.True(t, config.UseCookieJar, "OIDC provider should have cookie jar enabled")
}
// TestCreateDefaultClientUnit tests CreateDefaultClient function
func TestCreateDefaultClientUnit(t *testing.T) {
factory := NewHTTPClientFactory()
client := factory.CreateDefaultClient()
require.NotNil(t, client)
assert.NotNil(t, client.Transport, "client should have transport")
assert.Equal(t, 10*time.Second, client.Timeout, "default client should have 10s timeout")
}
// TestCreateTokenClientUnit tests CreateTokenClient function
func TestCreateTokenClientUnit(t *testing.T) {
factory := NewHTTPClientFactory()
client := factory.CreateTokenClient()
require.NotNil(t, client)
assert.NotNil(t, client.Transport, "client should have transport")
assert.NotNil(t, client.Jar, "token client should have cookie jar")
assert.Equal(t, 10*time.Second, client.Timeout, "token client should have 10s timeout")
}
// TestCreateHTTPClientWithConfigUnit tests CreateHTTPClientWithConfig function
func TestCreateHTTPClientWithConfigUnit(t *testing.T) {
config := HTTPClientConfig{
Timeout: 5 * time.Second,
MaxIdleConns: 20,
MaxIdleConnsPerHost: 5,
UseCookieJar: true,
}
client := CreateHTTPClientWithConfig(config)
require.NotNil(t, client)
assert.Equal(t, 5*time.Second, client.Timeout)
assert.NotNil(t, client.Jar, "client should have cookie jar when configured")
}
// TestHTTPClientFactoryCreateHTTPClientValidation tests validation in CreateHTTPClient
func TestHTTPClientFactoryCreateHTTPClientValidation(t *testing.T) {
factory := NewHTTPClientFactory()
t.Run("zero values get defaults", func(t *testing.T) {
config := HTTPClientConfig{
// All zero values
}
client := factory.CreateHTTPClient(config)
require.NotNil(t, client)
// Verify defaults were applied
assert.Equal(t, 30*time.Second, client.Timeout)
})
t.Run("custom values preserved", func(t *testing.T) {
config := HTTPClientConfig{
Timeout: 15 * time.Second,
MaxIdleConns: 50,
MaxRedirects: 3,
UseCookieJar: true,
ForceHTTP2: true,
DisableKeepAlives: true,
}
client := factory.CreateHTTPClient(config)
require.NotNil(t, client)
assert.Equal(t, 15*time.Second, client.Timeout)
assert.NotNil(t, client.Jar)
})
t.Run("invalid timeout gets default", func(t *testing.T) {
config := HTTPClientConfig{
Timeout: -1 * time.Second, // Invalid
}
client := factory.CreateHTTPClient(config)
require.NotNil(t, client)
// Should get default due to validation failure
assert.Equal(t, 30*time.Second, client.Timeout)
})
}
// TestHTTPClientFactoryValidateHTTPClientConfig tests ValidateHTTPClientConfig
func TestHTTPClientFactoryValidateHTTPClientConfig(t *testing.T) {
factory := NewHTTPClientFactory()
tests := []struct {
name string
errorMsg string
config HTTPClientConfig
wantError bool
}{
{
name: "valid config",
config: HTTPClientConfig{
Timeout: 10 * time.Second,
DialTimeout: 5 * time.Second,
TLSHandshakeTimeout: 2 * time.Second,
MaxIdleConns: 50,
MaxIdleConnsPerHost: 10,
MaxConnsPerHost: 20,
},
wantError: false,
},
{
name: "negative MaxIdleConns",
config: HTTPClientConfig{
Timeout: 10 * time.Second,
DialTimeout: 5 * time.Second,
TLSHandshakeTimeout: 2 * time.Second,
MaxIdleConns: -1,
},
wantError: true,
errorMsg: "MaxIdleConns cannot be negative",
},
{
name: "MaxIdleConns too high",
config: HTTPClientConfig{
Timeout: 10 * time.Second,
DialTimeout: 5 * time.Second,
TLSHandshakeTimeout: 2 * time.Second,
MaxIdleConns: 1500,
},
wantError: true,
errorMsg: "MaxIdleConns too high",
},
{
name: "negative MaxIdleConnsPerHost",
config: HTTPClientConfig{
Timeout: 10 * time.Second,
DialTimeout: 5 * time.Second,
TLSHandshakeTimeout: 2 * time.Second,
MaxIdleConnsPerHost: -1,
},
wantError: true,
errorMsg: "MaxIdleConnsPerHost cannot be negative",
},
{
name: "timeout too high",
config: HTTPClientConfig{
Timeout: 10 * time.Minute,
DialTimeout: 5 * time.Second,
TLSHandshakeTimeout: 2 * time.Second,
},
wantError: true,
errorMsg: "timeout too high",
},
{
name: "negative timeout",
config: HTTPClientConfig{
Timeout: -1 * time.Second,
DialTimeout: 5 * time.Second,
TLSHandshakeTimeout: 2 * time.Second,
},
wantError: true,
errorMsg: "timeout must be positive",
},
{
name: "MaxIdleConnsPerHost exceeds MaxConnsPerHost",
config: HTTPClientConfig{
Timeout: 10 * time.Second,
DialTimeout: 5 * time.Second,
TLSHandshakeTimeout: 2 * time.Second,
MaxIdleConnsPerHost: 50,
MaxConnsPerHost: 10,
},
wantError: true,
errorMsg: "MaxIdleConnsPerHost (50) cannot exceed MaxConnsPerHost (10)",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := factory.ValidateHTTPClientConfig(&tt.config)
if tt.wantError {
assert.Error(t, err)
if tt.errorMsg != "" {
assert.Contains(t, err.Error(), tt.errorMsg)
}
} else {
assert.NoError(t, err)
}
})
}
}
+48 -16
View File
@@ -12,19 +12,19 @@ import (
// SharedTransportPool manages a pool of shared HTTP transports to prevent connection exhaustion
type SharedTransportPool struct {
mu sync.RWMutex
transports map[string]*sharedTransport
maxConns int
ctx context.Context
transports map[string]*sharedTransport
cancel context.CancelFunc
clientCount int32 // SECURITY FIX: Track total HTTP clients
maxClients int32 // SECURITY FIX: Limit total clients to 5
maxConns int
mu sync.RWMutex
clientCount int32
maxClients int32
}
type sharedTransport struct {
lastUsed time.Time
transport *http.Transport
refCount int
lastUsed time.Time
}
var (
@@ -146,6 +146,9 @@ func (p *SharedTransportPool) ReleaseTransport(transport *http.Transport) {
}
// cleanupIdleTransports periodically cleans up unused transports
// Uses two-phase cleanup to minimize lock contention:
// 1. Find candidates while holding read lock
// 2. Remove and close transports with minimal lock duration
func (p *SharedTransportPool) cleanupIdleTransports(ctx context.Context) {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
@@ -155,17 +158,46 @@ func (p *SharedTransportPool) cleanupIdleTransports(ctx context.Context) {
case <-ctx.Done():
return
case <-ticker.C:
p.mu.Lock()
now := time.Now()
for transportKey, shared := range p.transports {
// Clean up transports not used for 2 minutes with no references
if shared.refCount <= 0 && now.Sub(shared.lastUsed) > 2*time.Minute {
shared.transport.CloseIdleConnections()
delete(p.transports, transportKey)
// SECURITY FIX: Decrement client count when removing transport
atomic.AddInt32(&p.clientCount, -1)
}
p.performCleanup()
}
}
}
// performCleanup does the actual cleanup with optimized locking
func (p *SharedTransportPool) performCleanup() {
now := time.Now()
// Phase 1: Find candidates while holding read lock (fast)
p.mu.RLock()
candidates := make([]string, 0)
for transportKey, shared := range p.transports {
// Clean up transports not used for 2 minutes with no references
if shared.refCount <= 0 && now.Sub(shared.lastUsed) > 2*time.Minute {
candidates = append(candidates, transportKey)
}
}
p.mu.RUnlock()
if len(candidates) == 0 {
return
}
// Phase 2: Remove and close each candidate individually
// This minimizes lock contention and allows concurrent access
for _, key := range candidates {
p.mu.Lock()
shared, exists := p.transports[key]
if exists && shared.refCount <= 0 && now.Sub(shared.lastUsed) > 2*time.Minute {
// Remove from map first (releases memory)
delete(p.transports, key)
atomic.AddInt32(&p.clientCount, -1)
p.mu.Unlock()
// Close idle connections outside the lock (can be slow)
if shared.transport != nil {
shared.transport.CloseIdleConnections()
}
} else {
p.mu.Unlock()
}
}
+691
View File
@@ -0,0 +1,691 @@
package traefikoidc
import (
"context"
"net/http"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestSharedTransportPoolGetOrCreateTransport tests transport creation and reuse
func TestSharedTransportPoolGetOrCreateTransport(t *testing.T) {
t.Run("create new transport", func(t *testing.T) {
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
clientCount: 0,
maxClients: 5,
}
config := DefaultHTTPClientConfig()
transport := pool.GetOrCreateTransport(config)
require.NotNil(t, transport)
assert.Equal(t, int32(1), atomic.LoadInt32(&pool.clientCount))
assert.Len(t, pool.transports, 1)
})
t.Run("reuse existing transport", func(t *testing.T) {
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
clientCount: 0,
maxClients: 5,
}
config := DefaultHTTPClientConfig()
transport1 := pool.GetOrCreateTransport(config)
transport2 := pool.GetOrCreateTransport(config)
assert.Equal(t, transport1, transport2, "should reuse same transport")
assert.Equal(t, int32(1), atomic.LoadInt32(&pool.clientCount), "client count should not increase")
// Check ref count
pool.mu.RLock()
key := pool.configKey(config)
shared := pool.transports[key]
pool.mu.RUnlock()
assert.Equal(t, 2, shared.refCount, "ref count should be 2")
})
t.Run("client limit enforcement", func(t *testing.T) {
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
clientCount: 5, // Already at max
maxClients: 5,
}
config := DefaultHTTPClientConfig()
transport := pool.GetOrCreateTransport(config)
assert.Nil(t, transport, "should return nil when at client limit")
})
t.Run("client limit with existing transport", func(t *testing.T) {
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
clientCount: 0,
maxClients: 5,
}
// Create first transport
config1 := DefaultHTTPClientConfig()
transport1 := pool.GetOrCreateTransport(config1)
require.NotNil(t, transport1)
// Set client count to max
atomic.StoreInt32(&pool.clientCount, 5)
// Try to create with different config
config2 := DefaultHTTPClientConfig()
config2.MaxConnsPerHost = 15 // Different config
transport2 := pool.GetOrCreateTransport(config2)
// Should return existing transport since at limit
assert.NotNil(t, transport2)
assert.Equal(t, transport1, transport2)
})
t.Run("updates last used time", func(t *testing.T) {
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
clientCount: 0,
maxClients: 5,
}
config := DefaultHTTPClientConfig()
transport := pool.GetOrCreateTransport(config)
require.NotNil(t, transport)
pool.mu.RLock()
key := pool.configKey(config)
firstTime := pool.transports[key].lastUsed
pool.mu.RUnlock()
time.Sleep(10 * time.Millisecond)
// Get again
transport2 := pool.GetOrCreateTransport(config)
require.NotNil(t, transport2)
pool.mu.RLock()
secondTime := pool.transports[key].lastUsed
pool.mu.RUnlock()
assert.True(t, secondTime.After(firstTime), "lastUsed should be updated")
})
}
// TestSharedTransportPoolReleaseTransport tests transport release
func TestSharedTransportPoolReleaseTransport(t *testing.T) {
t.Run("decrement ref count", func(t *testing.T) {
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
clientCount: 0,
maxClients: 5,
}
config := DefaultHTTPClientConfig()
transport := pool.GetOrCreateTransport(config)
require.NotNil(t, transport)
// Get again to increase ref count
pool.GetOrCreateTransport(config)
pool.mu.RLock()
key := pool.configKey(config)
refCount := pool.transports[key].refCount
pool.mu.RUnlock()
assert.Equal(t, 2, refCount)
// Release
pool.ReleaseTransport(transport)
pool.mu.RLock()
newRefCount := pool.transports[key].refCount
pool.mu.RUnlock()
assert.Equal(t, 1, newRefCount)
})
t.Run("ref count reaches zero", func(t *testing.T) {
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
clientCount: 0,
maxClients: 5,
}
config := DefaultHTTPClientConfig()
transport := pool.GetOrCreateTransport(config)
require.NotNil(t, transport)
pool.mu.RLock()
key := pool.configKey(config)
pool.mu.RUnlock()
// Release to zero
pool.ReleaseTransport(transport)
pool.mu.RLock()
shared := pool.transports[key]
pool.mu.RUnlock()
assert.Equal(t, 0, shared.refCount)
assert.NotZero(t, shared.lastUsed, "lastUsed should be set")
})
t.Run("release non-existent transport", func(t *testing.T) {
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
clientCount: 0,
maxClients: 5,
}
// Create a transport not in the pool
fakeTransport := &http.Transport{}
// Should not panic
assert.NotPanics(t, func() {
pool.ReleaseTransport(fakeTransport)
})
})
t.Run("release updates last used", func(t *testing.T) {
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
clientCount: 0,
maxClients: 5,
}
config := DefaultHTTPClientConfig()
transport := pool.GetOrCreateTransport(config)
require.NotNil(t, transport)
time.Sleep(10 * time.Millisecond)
beforeRelease := time.Now()
pool.ReleaseTransport(transport)
pool.mu.RLock()
key := pool.configKey(config)
lastUsed := pool.transports[key].lastUsed
pool.mu.RUnlock()
assert.True(t, lastUsed.After(beforeRelease) || lastUsed.Equal(beforeRelease))
})
}
// TestSharedTransportPoolCleanup tests cleanup functionality
func TestSharedTransportPoolCleanup(t *testing.T) {
t.Run("cleanup all transports", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
clientCount: 0,
maxClients: 5,
ctx: ctx,
cancel: cancel,
}
// Create multiple transports
config1 := DefaultHTTPClientConfig()
pool.GetOrCreateTransport(config1)
config2 := DefaultHTTPClientConfig()
config2.MaxConnsPerHost = 15
pool.GetOrCreateTransport(config2)
assert.Greater(t, len(pool.transports), 0)
// Cleanup
pool.Cleanup()
assert.Len(t, pool.transports, 0, "all transports should be removed")
})
t.Run("cleanup cancels context", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
clientCount: 0,
maxClients: 5,
ctx: ctx,
cancel: cancel,
}
pool.Cleanup()
select {
case <-pool.ctx.Done():
// Context was canceled
case <-time.After(100 * time.Millisecond):
t.Error("context should be canceled")
}
})
t.Run("cleanup with no transports", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
clientCount: 0,
maxClients: 5,
ctx: ctx,
cancel: cancel,
}
assert.NotPanics(t, func() {
pool.Cleanup()
})
})
t.Run("cleanup closes idle connections", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
clientCount: 0,
maxClients: 5,
ctx: ctx,
cancel: cancel,
}
config := DefaultHTTPClientConfig()
transport := pool.GetOrCreateTransport(config)
require.NotNil(t, transport)
// Cleanup should call CloseIdleConnections on each transport
pool.Cleanup()
// Verify transports map is cleared
assert.Empty(t, pool.transports)
})
}
// TestSharedTransportPoolCleanupIdleTransports tests periodic cleanup
func TestSharedTransportPoolCleanupIdleTransports(t *testing.T) {
if testing.Short() {
t.Skip("Skipping cleanup goroutine test in short mode")
}
t.Run("cleanup removes idle transports", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
clientCount: 0,
maxClients: 5,
ctx: ctx,
cancel: cancel,
}
// Create transport and release it
config := DefaultHTTPClientConfig()
transport := pool.GetOrCreateTransport(config)
require.NotNil(t, transport)
pool.ReleaseTransport(transport)
// Set lastUsed to old time
pool.mu.Lock()
key := pool.configKey(config)
pool.transports[key].lastUsed = time.Now().Add(-3 * time.Minute)
pool.mu.Unlock()
// Start cleanup in background (simulating what would happen)
// Note: We're testing the cleanup logic manually here
pool.mu.Lock()
now := time.Now()
for transportKey, shared := range pool.transports {
if shared.refCount <= 0 && now.Sub(shared.lastUsed) > 2*time.Minute {
shared.transport.CloseIdleConnections()
delete(pool.transports, transportKey)
atomic.AddInt32(&pool.clientCount, -1)
}
}
pool.mu.Unlock()
// Transport should be removed
pool.mu.RLock()
_, exists := pool.transports[key]
pool.mu.RUnlock()
assert.False(t, exists, "old idle transport should be removed")
})
t.Run("cleanup preserves active transports", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
clientCount: 0,
maxClients: 5,
ctx: ctx,
cancel: cancel,
}
// Create transport with refs
config := DefaultHTTPClientConfig()
transport := pool.GetOrCreateTransport(config)
require.NotNil(t, transport)
// Keep ref count > 0, but set old lastUsed
pool.mu.Lock()
key := pool.configKey(config)
pool.transports[key].lastUsed = time.Now().Add(-3 * time.Minute)
pool.mu.Unlock()
// Run cleanup logic
pool.mu.Lock()
now := time.Now()
for transportKey, shared := range pool.transports {
if shared.refCount <= 0 && now.Sub(shared.lastUsed) > 2*time.Minute {
shared.transport.CloseIdleConnections()
delete(pool.transports, transportKey)
}
}
pool.mu.Unlock()
// Transport should still exist (has ref count)
pool.mu.RLock()
_, exists := pool.transports[key]
pool.mu.RUnlock()
assert.True(t, exists, "transport with references should be preserved")
})
t.Run("cleanup respects context cancellation", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
clientCount: 0,
maxClients: 5,
ctx: ctx,
cancel: cancel,
}
// Start cleanup goroutine
done := make(chan bool)
go func() {
pool.cleanupIdleTransports(ctx)
done <- true
}()
// Cancel context
cancel()
// Should exit quickly
select {
case <-done:
// Success
case <-time.After(2 * time.Second):
t.Error("cleanup goroutine should exit on context cancellation")
}
})
}
// TestCreatePooledHTTPClient tests pooled client creation
func TestCreatePooledHTTPClient(t *testing.T) {
t.Run("create client with default config", func(t *testing.T) {
config := DefaultHTTPClientConfig()
client := CreatePooledHTTPClient(config)
require.NotNil(t, client)
assert.NotNil(t, client.Transport)
assert.Equal(t, config.Timeout, client.Timeout)
})
t.Run("create multiple clients reuse transport", func(t *testing.T) {
// Reset global pool for clean test
globalTransportPoolOnce = sync.Once{}
globalTransportPool = nil
config := DefaultHTTPClientConfig()
client1 := CreatePooledHTTPClient(config)
client2 := CreatePooledHTTPClient(config)
require.NotNil(t, client1)
require.NotNil(t, client2)
// Should use same transport
assert.Equal(t, client1.Transport, client2.Transport)
})
t.Run("redirect policy is set", func(t *testing.T) {
config := DefaultHTTPClientConfig()
config.MaxRedirects = 3
client := CreatePooledHTTPClient(config)
require.NotNil(t, client)
assert.NotNil(t, client.CheckRedirect)
// Test redirect limit
var redirects []*http.Request
for i := 0; i < 3; i++ {
redirects = append(redirects, &http.Request{})
}
err := client.CheckRedirect(nil, redirects)
assert.Error(t, err, "should error after max redirects")
})
t.Run("default redirect limit", func(t *testing.T) {
config := DefaultHTTPClientConfig()
config.MaxRedirects = 0 // Should default to 10
client := CreatePooledHTTPClient(config)
require.NotNil(t, client)
// Test default redirect limit (10)
var redirects []*http.Request
for i := 0; i < 10; i++ {
redirects = append(redirects, &http.Request{})
}
err := client.CheckRedirect(nil, redirects)
assert.Error(t, err, "should error after 10 redirects")
})
}
// TestGetGlobalTransportPool tests singleton pattern
func TestGetGlobalTransportPool(t *testing.T) {
t.Run("returns same instance", func(t *testing.T) {
pool1 := GetGlobalTransportPool()
pool2 := GetGlobalTransportPool()
assert.Equal(t, pool1, pool2, "should return same singleton instance")
})
t.Run("pool is initialized", func(t *testing.T) {
pool := GetGlobalTransportPool()
require.NotNil(t, pool)
assert.NotNil(t, pool.transports)
assert.Equal(t, 20, pool.maxConns)
assert.Equal(t, int32(5), pool.maxClients)
assert.NotNil(t, pool.ctx)
assert.NotNil(t, pool.cancel)
})
}
// TestSharedTransportPoolConcurrency tests thread safety
func TestSharedTransportPoolConcurrency(t *testing.T) {
if testing.Short() {
t.Skip("Skipping concurrency test in short mode")
}
t.Run("concurrent GetOrCreateTransport", func(t *testing.T) {
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
clientCount: 0,
maxClients: 10, // Allow more for concurrency test
}
config := DefaultHTTPClientConfig()
const numGoroutines = 20
var wg sync.WaitGroup
transports := make([]*http.Transport, numGoroutines)
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
transports[idx] = pool.GetOrCreateTransport(config)
}(i)
}
wg.Wait()
// All should get same transport
firstTransport := transports[0]
for i := 1; i < numGoroutines; i++ {
if transports[i] != nil {
assert.Equal(t, firstTransport, transports[i])
}
}
})
t.Run("concurrent ReleaseTransport", func(t *testing.T) {
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
clientCount: 0,
maxClients: 10,
}
config := DefaultHTTPClientConfig()
transport := pool.GetOrCreateTransport(config)
// Increase ref count
for i := 0; i < 20; i++ {
pool.GetOrCreateTransport(config)
}
const numReleases = 20
var wg sync.WaitGroup
for i := 0; i < numReleases; i++ {
wg.Add(1)
go func() {
defer wg.Done()
pool.ReleaseTransport(transport)
}()
}
wg.Wait()
// Should not panic and ref count should be decremented
pool.mu.RLock()
key := pool.configKey(config)
refCount := pool.transports[key].refCount
pool.mu.RUnlock()
assert.Equal(t, 1, refCount, "ref count should be 1 after 20 releases from initial 21")
})
}
// TestSharedTransportPoolEdgeCases tests edge cases
func TestSharedTransportPoolEdgeCases(t *testing.T) {
t.Run("config key generation", func(t *testing.T) {
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
}
config1 := DefaultHTTPClientConfig()
config1.MaxConnsPerHost = 10
config1.MaxIdleConnsPerHost = 5
config2 := DefaultHTTPClientConfig()
config2.MaxConnsPerHost = 10
config2.MaxIdleConnsPerHost = 5
key1 := pool.configKey(config1)
key2 := pool.configKey(config2)
assert.Equal(t, key1, key2, "same config should produce same key")
})
t.Run("different configs produce different keys", func(t *testing.T) {
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
}
config1 := DefaultHTTPClientConfig()
config1.MaxConnsPerHost = 10
config2 := DefaultHTTPClientConfig()
config2.MaxConnsPerHost = 20
key1 := pool.configKey(config1)
key2 := pool.configKey(config2)
assert.NotEqual(t, key1, key2, "different configs should produce different keys")
})
t.Run("client count decrements on cleanup", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
clientCount: 0,
maxClients: 5,
ctx: ctx,
cancel: cancel,
}
config := DefaultHTTPClientConfig()
transport := pool.GetOrCreateTransport(config)
require.NotNil(t, transport)
initialCount := atomic.LoadInt32(&pool.clientCount)
assert.Equal(t, int32(1), initialCount)
// Release and mark as old
pool.ReleaseTransport(transport)
pool.mu.Lock()
key := pool.configKey(config)
pool.transports[key].lastUsed = time.Now().Add(-3 * time.Minute)
pool.mu.Unlock()
// Run cleanup
pool.mu.Lock()
now := time.Now()
for transportKey, shared := range pool.transports {
if shared.refCount <= 0 && now.Sub(shared.lastUsed) > 2*time.Minute {
shared.transport.CloseIdleConnections()
delete(pool.transports, transportKey)
atomic.AddInt32(&pool.clientCount, -1)
}
}
pool.mu.Unlock()
finalCount := atomic.LoadInt32(&pool.clientCount)
assert.Equal(t, int32(0), finalCount, "client count should decrement on cleanup")
})
}
+52 -47
View File
@@ -15,20 +15,21 @@ import (
// XSS, path traversal, and other injection attacks. It validates and sanitizes
// various input types used in OIDC authentication flows.
type InputValidator struct {
usernameRegex *regexp.Regexp
tokenRegex *regexp.Regexp
logger *Logger
urlRegex *regexp.Regexp
emailRegex *regexp.Regexp
sqlInjectionPatterns []string
pathTraversalPatterns []string
xssPatterns []string
maxUsernameLength int
maxURLLength int
maxTokenLength int
maxEmailLength int
maxClaimLength int
maxHeaderLength int
usernameRegex *regexp.Regexp
tokenRegex *regexp.Regexp
logger *Logger
urlRegex *regexp.Regexp
emailRegex *regexp.Regexp
sqlInjectionPatterns []string
pathTraversalPatterns []string
xssPatterns []string
maxUsernameLength int
maxURLLength int
maxTokenLength int
maxEmailLength int
maxClaimLength int
maxHeaderLength int
allowPrivateIPAddresses bool // Allow private IP addresses in URL validation
}
// ValidationResult encapsulates the outcome of input validation.
@@ -46,13 +47,14 @@ type ValidationResult struct {
// It specifies maximum lengths for various input types and controls whether
// strict validation mode is enabled.
type InputValidationConfig struct {
MaxTokenLength int `json:"max_token_length"`
MaxURLLength int `json:"max_url_length"`
MaxHeaderLength int `json:"max_header_length"`
MaxClaimLength int `json:"max_claim_length"`
MaxEmailLength int `json:"max_email_length"`
MaxUsernameLength int `json:"max_username_length"`
StrictMode bool `json:"strict_mode"`
MaxTokenLength int `json:"max_token_length"`
MaxURLLength int `json:"max_url_length"`
MaxHeaderLength int `json:"max_header_length"`
MaxClaimLength int `json:"max_claim_length"`
MaxEmailLength int `json:"max_email_length"`
MaxUsernameLength int `json:"max_username_length"`
StrictMode bool `json:"strict_mode"`
AllowPrivateIPAddresses bool `json:"allow_private_ip_addresses"` // Allow private IP addresses in URL validation
}
// DefaultInputValidationConfig returns a secure default configuration
@@ -103,16 +105,17 @@ func NewInputValidator(config InputValidationConfig, logger *Logger) (*InputVali
}
return &InputValidator{
maxTokenLength: config.MaxTokenLength,
maxURLLength: config.MaxURLLength,
maxHeaderLength: config.MaxHeaderLength,
maxClaimLength: config.MaxClaimLength,
maxEmailLength: config.MaxEmailLength,
maxUsernameLength: config.MaxUsernameLength,
emailRegex: emailRegex,
urlRegex: urlRegex,
tokenRegex: tokenRegex,
usernameRegex: usernameRegex,
maxTokenLength: config.MaxTokenLength,
maxURLLength: config.MaxURLLength,
maxHeaderLength: config.MaxHeaderLength,
maxClaimLength: config.MaxClaimLength,
maxEmailLength: config.MaxEmailLength,
maxUsernameLength: config.MaxUsernameLength,
allowPrivateIPAddresses: config.AllowPrivateIPAddresses,
emailRegex: emailRegex,
urlRegex: urlRegex,
tokenRegex: tokenRegex,
usernameRegex: usernameRegex,
sqlInjectionPatterns: []string{
"'", "\"", ";", "--", "/*", "*/", "xp_", "sp_",
"union", "select", "insert", "update", "delete", "drop",
@@ -335,24 +338,26 @@ func (iv *InputValidator) ValidateURL(urlStr string) ValidationResult {
}
}
// Check for private IP ranges (RFC 1918)
if strings.HasPrefix(hostname, "10.") ||
strings.HasPrefix(hostname, "192.168.") ||
strings.HasPrefix(hostname, "172.") {
// For 172.x check if it's in the 172.16.0.0/12 range
if strings.HasPrefix(hostname, "172.") {
parts := strings.Split(hostname, ".")
if len(parts) >= 2 {
if second, err := strconv.Atoi(parts[1]); err == nil && second >= 16 && second <= 31 {
result.IsValid = false
result.Errors = append(result.Errors, "private IP URLs are not allowed for security")
return result
// Check for private IP ranges (RFC 1918) - skip if allowPrivateIPAddresses is enabled
if !iv.allowPrivateIPAddresses {
if strings.HasPrefix(hostname, "10.") ||
strings.HasPrefix(hostname, "192.168.") ||
strings.HasPrefix(hostname, "172.") {
// For 172.x check if it's in the 172.16.0.0/12 range
if strings.HasPrefix(hostname, "172.") {
parts := strings.Split(hostname, ".")
if len(parts) >= 2 {
if second, err := strconv.Atoi(parts[1]); err == nil && second >= 16 && second <= 31 {
result.IsValid = false
result.Errors = append(result.Errors, "private IP URLs are not allowed for security")
return result
}
}
} else {
result.IsValid = false
result.Errors = append(result.Errors, "private IP URLs are not allowed for security")
return result
}
} else {
result.IsValid = false
result.Errors = append(result.Errors, "private IP URLs are not allowed for security")
return result
}
}
+9 -9
View File
@@ -14,7 +14,7 @@ func TestInputValidator(t *testing.T) {
}
t.Run("Valid token validation", func(t *testing.T) {
validToken := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.EkN-DOsnsuRjRO6BxXemmJDm3HbxrbRzXglbN2S4sOkopdU4IsDxTI8jO19W_A4K8ZPJijNLis4EZsHeY559a4DFOd50_OqgHs3UjpMC6M6FNqI2J-I2NxrragtnDxGxdJUvDERDQVHzeNlVQiuqWDEeO_O-0KptafbfyuGqfQxH_6dp2_MeFpAc"
validToken := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.EkN-DOsnsuRjRO6BxXemmJDm3HbxrbRzXglbN2S4sOkopdU4IsDxTI8jO19W_A4K8ZPJijNLis4EZsHeY559a4DFOd50_OqgHs3UjpMC6M6FNqI2J-I2NxrragtnDxGxdJUvDERDQVHzeNlVQiuqWDEeO_O-0KptafbfyuGqfQxH_6dp2_MeFpAc" // trufflehog:ignore
result := validator.ValidateToken(validToken)
if !result.IsValid {
@@ -428,12 +428,12 @@ func TestInputValidatorValidateToken(t *testing.T) {
tests := []struct {
name string
token string
expectValid bool
description string
expectValid bool
}{
{
name: "ValidJWTToken",
token: "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiZXhwIjoxNTE2MjM5MDIyLCJpYXQiOjE1MTYyMzkwMjJ9.signature",
token: "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiZXhwIjoxNTE2MjM5MDIyLCJpYXQiOjE1MTYyMzkwMjJ9.signature", // trufflehog:ignore
expectValid: true,
description: "Valid JWT token should pass validation",
},
@@ -475,7 +475,7 @@ func TestInputValidatorValidateToken(t *testing.T) {
},
{
name: "MaliciousJWTWithExtraData",
token: "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.sig.malicious_extra",
token: "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.sig.malicious_extra", // trufflehog:ignore
expectValid: false,
description: "JWT with extra malicious data should fail validation",
},
@@ -500,8 +500,8 @@ func TestInputValidatorValidateEmail(t *testing.T) {
tests := []struct {
name string
email string
expectValid bool
description string
expectValid bool
}{
{
name: "ValidEmail",
@@ -578,8 +578,8 @@ func TestInputValidatorValidateURL(t *testing.T) {
tests := []struct {
name string
url string
expectValid bool
description string
expectValid bool
}{
{
name: "ValidHTTPSURL",
@@ -669,8 +669,8 @@ func TestInputValidatorValidateClaim(t *testing.T) {
name string
claimName string
claimValue string
expectValid bool
description string
expectValid bool
}{
{
name: "ValidStringClaim",
@@ -750,8 +750,8 @@ func TestInputValidatorValidateHeader(t *testing.T) {
name string
headerName string
headerValue string
expectValid bool
description string
expectValid bool
}{
{
name: "ValidHeader",
@@ -830,8 +830,8 @@ func TestInputValidatorValidateUsername(t *testing.T) {
tests := []struct {
name string
username string
expectValid bool
description string
expectValid bool
}{
{
name: "ValidUsername",
+4 -4
View File
@@ -726,20 +726,20 @@ type MockConfig struct {
}
type MockSession struct {
id string
userID string
created time.Time
lastUsed time.Time
data map[string]interface{}
id string
userID string
}
type TestResult struct {
UserID int
StartTime time.Time
EndTime time.Time
Error error
UserID int
Duration time.Duration
Success bool
Error error
}
// ============================================================================
+79
View File
@@ -0,0 +1,79 @@
package backends
import "time"
// BackendType represents the type of cache backend
type BackendType string
const (
BackendTypeMemory BackendType = "memory"
BackendTypeRedis BackendType = "redis"
BackendTypeHybrid BackendType = "hybrid"
// Aliases for backward compatibility
TypeMemory BackendType = "memory"
TypeRedis BackendType = "redis"
TypeHybrid BackendType = "hybrid"
)
// Config provides common configuration for cache backends
type Config struct {
L2Config *Config
L1Config *Config
RedisPrefix string
Type BackendType
RedisAddr string
RedisPassword string
PoolSize int
RedisDB int
CleanupInterval time.Duration
MaxMemoryBytes int64
MaxSize int
HealthCheckInterval time.Duration
AsyncWrites bool
EnableCircuitBreaker bool
EnableHealthCheck bool
EnableMetrics bool
}
// DefaultConfig returns a default configuration for in-memory caching
func DefaultConfig() *Config {
return &Config{
Type: BackendTypeMemory,
MaxSize: 1000,
MaxMemoryBytes: 50 * 1024 * 1024, // 50MB
CleanupInterval: 5 * time.Minute,
EnableMetrics: true,
}
}
// DefaultRedisConfig returns a default configuration for Redis caching
func DefaultRedisConfig(addr string) *Config {
return &Config{
Type: BackendTypeRedis,
RedisAddr: addr,
RedisDB: 0,
RedisPrefix: "traefikoidc:",
PoolSize: 10,
EnableCircuitBreaker: true,
EnableHealthCheck: true,
HealthCheckInterval: 30 * time.Second,
EnableMetrics: true,
}
}
// DefaultHybridConfig returns a default configuration for hybrid caching
func DefaultHybridConfig(redisAddr string) *Config {
return &Config{
Type: BackendTypeHybrid,
L1Config: &Config{
Type: BackendTypeMemory,
MaxSize: 500,
MaxMemoryBytes: 10 * 1024 * 1024, // 10MB for L1
CleanupInterval: 1 * time.Minute,
},
L2Config: DefaultRedisConfig(redisAddr),
AsyncWrites: true,
EnableMetrics: true,
}
}
+59
View File
@@ -0,0 +1,59 @@
//go:build !yaegi
package backends
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestDefaultHybridConfig verifies the default hybrid configuration
func TestDefaultHybridConfig(t *testing.T) {
redisAddr := "localhost:6379"
config := DefaultHybridConfig(redisAddr)
require.NotNil(t, config)
// Verify top-level config
assert.Equal(t, BackendTypeHybrid, config.Type)
assert.True(t, config.AsyncWrites)
assert.True(t, config.EnableMetrics)
// Verify L1 (memory) config
require.NotNil(t, config.L1Config)
assert.Equal(t, BackendTypeMemory, config.L1Config.Type)
assert.Equal(t, 500, config.L1Config.MaxSize)
assert.Equal(t, int64(10*1024*1024), config.L1Config.MaxMemoryBytes) // 10MB
assert.Equal(t, 1*time.Minute, config.L1Config.CleanupInterval)
// Verify L2 (Redis) config exists
require.NotNil(t, config.L2Config)
assert.Equal(t, BackendTypeRedis, config.L2Config.Type)
}
func TestDefaultHybridConfig_DifferentRedisAddr(t *testing.T) {
tests := []struct {
name string
redisAddr string
}{
{"localhost", "localhost:6379"},
{"remote host", "redis.example.com:6379"},
{"IP address", "192.168.1.100:6379"},
{"custom port", "localhost:6380"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
config := DefaultHybridConfig(tt.redisAddr)
require.NotNil(t, config)
assert.Equal(t, BackendTypeHybrid, config.Type)
assert.NotNil(t, config.L1Config)
assert.NotNil(t, config.L2Config)
})
}
}
+38
View File
@@ -0,0 +1,38 @@
package backends
import "errors"
var (
// ErrBackendClosed is returned when operating on a closed backend
ErrBackendClosed = errors.New("cache backend is closed")
// ErrKeyNotFound is returned when a key doesn't exist
ErrKeyNotFound = errors.New("key not found")
// ErrCacheMiss indicates the requested key was not found in the cache
ErrCacheMiss = errors.New("cache miss")
// ErrBackendUnavailable indicates the cache backend is not available
ErrBackendUnavailable = errors.New("cache backend unavailable")
// ErrInvalidValue indicates the cached value is invalid or corrupted
ErrInvalidValue = errors.New("invalid cached value")
// ErrInvalidTTL is returned when TTL is invalid
ErrInvalidTTL = errors.New("invalid TTL")
// ErrConnectionFailed is returned when connection fails
ErrConnectionFailed = errors.New("connection failed")
// ErrCircuitOpen is returned when circuit breaker is open
ErrCircuitOpen = errors.New("circuit breaker is open")
// ErrTimeout is returned when operation times out
ErrTimeout = errors.New("operation timeout")
// ErrSerializationFailed is returned when serialization fails
ErrSerializationFailed = errors.New("serialization failed")
// ErrDeserializationFailed is returned when deserialization fails
ErrDeserializationFailed = errors.New("deserialization failed")
)
+685
View File
@@ -0,0 +1,685 @@
// Package backend provides cache backend implementations for the Traefik OIDC plugin.
package backends
import (
"context"
"fmt"
"log"
"sync"
"sync/atomic"
"time"
)
// HybridBackend implements a two-tier cache with L1 (memory) and L2 (Redis) backends
// It provides automatic failover, async writes for non-critical data, and optimized read paths
type HybridBackend struct {
lastL2Error atomic.Value
secondary CacheBackend
primary CacheBackend
logger Logger
ctx context.Context
syncWriteCacheTypes map[string]bool
asyncWriteBuffer chan *asyncWriteItem
cancel context.CancelFunc
wg sync.WaitGroup
l1Hits atomic.Int64
errors atomic.Int64
l2Writes atomic.Int64
l1Writes atomic.Int64
misses atomic.Int64
l2Hits atomic.Int64
fallbackMode atomic.Bool
}
// asyncWriteItem represents an async write operation
type asyncWriteItem struct {
ctx context.Context
key string
value []byte
ttl time.Duration
}
// Logger interface for structured logging
type Logger interface {
Debugf(format string, args ...interface{})
Infof(format string, args ...interface{})
Warnf(format string, args ...interface{})
Errorf(format string, args ...interface{})
}
// defaultLogger provides a basic logger implementation
type defaultLogger struct {
*log.Logger
}
func (l *defaultLogger) Debugf(format string, args ...interface{}) {
l.Printf("[DEBUG] "+format, args...)
}
func (l *defaultLogger) Infof(format string, args ...interface{}) {
l.Printf("[INFO] "+format, args...)
}
func (l *defaultLogger) Warnf(format string, args ...interface{}) {
l.Printf("[WARN] "+format, args...)
}
func (l *defaultLogger) Errorf(format string, args ...interface{}) {
l.Printf("[ERROR] "+format, args...)
}
// HybridConfig provides configuration for the hybrid backend
type HybridConfig struct {
Primary CacheBackend
Secondary CacheBackend
Logger Logger
SyncWriteCacheTypes map[string]bool
AsyncBufferSize int
}
// NewHybridBackend creates a new hybrid cache backend with L1 (memory) and L2 (Redis) tiers
func NewHybridBackend(config *HybridConfig) (*HybridBackend, error) {
if config == nil {
return nil, fmt.Errorf("config is required")
}
if config.Primary == nil {
return nil, fmt.Errorf("primary (L1) backend is required")
}
if config.Secondary == nil {
return nil, fmt.Errorf("secondary (L2) backend is required")
}
if config.Logger == nil {
config.Logger = &defaultLogger{Logger: log.New(log.Writer(), "[HybridCache] ", log.LstdFlags)}
}
if config.AsyncBufferSize <= 0 {
config.AsyncBufferSize = 1000
}
// Default critical cache types that require synchronous writes
if config.SyncWriteCacheTypes == nil {
config.SyncWriteCacheTypes = map[string]bool{
"blacklist": true, // Token blacklist must be immediately consistent
"token": true, // Token validation is critical
}
}
ctx, cancel := context.WithCancel(context.Background())
h := &HybridBackend{
primary: config.Primary,
secondary: config.Secondary,
syncWriteCacheTypes: config.SyncWriteCacheTypes,
asyncWriteBuffer: make(chan *asyncWriteItem, config.AsyncBufferSize),
ctx: ctx,
cancel: cancel,
logger: config.Logger,
}
// Start async write worker
h.wg.Add(1)
go h.asyncWriteWorker()
// Start health monitoring
h.wg.Add(1)
go h.healthMonitor()
h.logger.Infof("HybridBackend initialized with L1 (memory) and L2 (Redis) tiers")
h.logger.Infof("Sync write cache types: %v", config.SyncWriteCacheTypes)
h.logger.Infof("Async write buffer size: %d", config.AsyncBufferSize)
return h, nil
}
// Set stores a value in both L1 and L2 caches
func (h *HybridBackend) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
// Always write to L1 first (synchronous)
if err := h.primary.Set(ctx, key, value, ttl); err != nil {
h.errors.Add(1)
h.logger.Warnf("Failed to write to L1 cache: %v", err)
// Continue to try L2 even if L1 fails
} else {
h.l1Writes.Add(1)
}
// Check if we're in fallback mode
if h.fallbackMode.Load() {
h.logger.Debugf("Operating in fallback mode, skipping L2 write for key: %s", key)
return nil // Don't fail the operation if L2 is down
}
// Determine if this should be a sync or async write based on cache type
cacheType := h.extractCacheType(key)
requiresSync := h.syncWriteCacheTypes[cacheType]
if requiresSync {
// Synchronous write for critical cache types
if err := h.secondary.Set(ctx, key, value, ttl); err != nil {
h.errors.Add(1)
h.logger.Warnf("Failed to write to L2 cache (sync) for key %s: %v", key, err)
h.recordL2Error()
// Don't fail the operation - L1 write succeeded
return nil
}
h.l2Writes.Add(1)
h.logger.Debugf("Synchronous write to L2 completed for critical key: %s", key)
} else {
// Asynchronous write for non-critical cache types
select {
case h.asyncWriteBuffer <- &asyncWriteItem{
key: key,
value: value,
ttl: ttl,
ctx: ctx,
}:
h.logger.Debugf("Queued async write to L2 for key: %s", key)
default:
// Buffer is full, log and continue
h.logger.Warnf("Async write buffer full, dropping L2 write for key: %s", key)
h.errors.Add(1)
}
}
return nil
}
// Get retrieves a value from cache, checking L1 first, then L2
func (h *HybridBackend) Get(ctx context.Context, key string) ([]byte, time.Duration, bool, error) {
// Try L1 first
value, ttl, exists, err := h.primary.Get(ctx, key)
if err != nil {
h.errors.Add(1)
h.logger.Debugf("L1 get error for key %s: %v", key, err)
}
if exists {
h.l1Hits.Add(1)
return value, ttl, true, nil
}
// Check if we're in fallback mode
if h.fallbackMode.Load() {
h.misses.Add(1)
return nil, 0, false, nil
}
// Try L2
value, ttl, exists, err = h.secondary.Get(ctx, key)
if err != nil {
h.errors.Add(1)
h.logger.Debugf("L2 get error for key %s: %v", key, err)
h.recordL2Error()
h.misses.Add(1)
return nil, 0, false, nil // Don't propagate L2 errors
}
if !exists {
h.misses.Add(1)
return nil, 0, false, nil
}
h.l2Hits.Add(1)
// Populate L1 cache with value from L2 (write-through on read)
// Use goroutine to avoid blocking the read path
go func() {
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
if err := h.primary.Set(writeCtx, key, value, ttl); err != nil {
h.logger.Debugf("Failed to populate L1 cache from L2 for key %s: %v", key, err)
} else {
h.logger.Debugf("Populated L1 cache from L2 for key: %s", key)
}
}()
return value, ttl, true, nil
}
// Delete removes a key from both L1 and L2 caches
func (h *HybridBackend) Delete(ctx context.Context, key string) (bool, error) {
var deleted bool
// Delete from L1
if d, err := h.primary.Delete(ctx, key); err != nil {
h.logger.Debugf("Failed to delete from L1 cache: %v", err)
} else if d {
deleted = true
}
// Delete from L2 if not in fallback mode
if !h.fallbackMode.Load() {
if d, err := h.secondary.Delete(ctx, key); err != nil {
h.logger.Debugf("Failed to delete from L2 cache: %v", err)
h.recordL2Error()
} else if d {
deleted = true
}
}
return deleted, nil
}
// Exists checks if a key exists in either cache
func (h *HybridBackend) Exists(ctx context.Context, key string) (bool, error) {
// Check L1 first
if exists, err := h.primary.Exists(ctx, key); err == nil && exists {
return true, nil
}
// Check L2 if not in fallback mode
if !h.fallbackMode.Load() {
if exists, err := h.secondary.Exists(ctx, key); err == nil && exists {
return true, nil
}
}
return false, nil
}
// Clear removes all keys from both caches
func (h *HybridBackend) Clear(ctx context.Context) error {
var lastErr error
// Clear L1
if err := h.primary.Clear(ctx); err != nil {
h.logger.Errorf("Failed to clear L1 cache: %v", err)
lastErr = err
}
// Clear L2 if not in fallback mode
if !h.fallbackMode.Load() {
if err := h.secondary.Clear(ctx); err != nil {
h.logger.Errorf("Failed to clear L2 cache: %v", err)
h.recordL2Error()
lastErr = err
}
}
return lastErr
}
// GetStats returns statistics for the hybrid cache
func (h *HybridBackend) GetStats() map[string]interface{} {
l1Hits := h.l1Hits.Load()
l2Hits := h.l2Hits.Load()
misses := h.misses.Load()
total := l1Hits + l2Hits + misses
stats := map[string]interface{}{
"type": TypeHybrid,
"l1_hits": l1Hits,
"l2_hits": l2Hits,
"misses": misses,
"total": total,
"l1_writes": h.l1Writes.Load(),
"l2_writes": h.l2Writes.Load(),
"errors": h.errors.Load(),
"fallback_mode": h.fallbackMode.Load(),
}
if total > 0 {
stats["l1_hit_rate"] = float64(l1Hits) / float64(total)
stats["l2_hit_rate"] = float64(l2Hits) / float64(total)
stats["overall_hit_rate"] = float64(l1Hits+l2Hits) / float64(total)
}
// Add sub-backend stats
stats["l1_stats"] = h.primary.GetStats()
stats["l2_stats"] = h.secondary.GetStats()
// Add last L2 error time if available
if lastErr := h.lastL2Error.Load(); lastErr != nil {
if t, ok := lastErr.(time.Time); ok {
stats["last_l2_error"] = t.Format(time.RFC3339)
stats["seconds_since_l2_error"] = time.Since(t).Seconds()
}
}
return stats
}
// Ping checks if both backends are healthy
func (h *HybridBackend) Ping(ctx context.Context) error {
// Check L1
if err := h.primary.Ping(ctx); err != nil {
return fmt.Errorf("L1 ping failed: %w", err)
}
// Check L2 (but don't fail if it's down)
if err := h.secondary.Ping(ctx); err != nil {
h.logger.Warnf("L2 ping failed: %v", err)
h.recordL2Error()
// Don't return error - we can operate with L1 only
} else {
// L2 is healthy, clear fallback mode if it was set
if h.fallbackMode.CompareAndSwap(true, false) {
h.logger.Infof("L2 backend recovered, exiting fallback mode")
}
}
return nil
}
// Close shuts down the hybrid backend
func (h *HybridBackend) Close() error {
// Cancel context to stop workers
h.cancel()
// Close async write channel
close(h.asyncWriteBuffer)
// Wait for workers to finish with timeout
done := make(chan struct{})
go func() {
h.wg.Wait()
close(done)
}()
select {
case <-done:
// Workers finished
case <-time.After(5 * time.Second):
h.logger.Warnf("Timeout waiting for workers to finish")
}
var lastErr error
// Close backends
if err := h.primary.Close(); err != nil {
h.logger.Errorf("Failed to close L1 backend: %v", err)
lastErr = err
}
if err := h.secondary.Close(); err != nil {
h.logger.Errorf("Failed to close L2 backend: %v", err)
lastErr = err
}
h.logger.Infof("HybridBackend closed")
return lastErr
}
// GetMany retrieves multiple values efficiently
func (h *HybridBackend) GetMany(ctx context.Context, keys []string) (map[string][]byte, error) {
if len(keys) == 0 {
return make(map[string][]byte), nil
}
results := make(map[string][]byte, len(keys))
missingKeys := make([]string, 0)
// Try L1 first for all keys
for _, key := range keys {
if value, _, exists, _ := h.primary.Get(ctx, key); exists {
results[key] = value
h.l1Hits.Add(1)
} else {
missingKeys = append(missingKeys, key)
}
}
// If all found in L1 or in fallback mode, return
if len(missingKeys) == 0 || h.fallbackMode.Load() {
return results, nil
}
// Try L2 for missing keys using batch operation if available
if batcher, ok := h.secondary.(interface {
GetMany(context.Context, []string) (map[string][]byte, error)
}); ok {
l2Results, err := batcher.GetMany(ctx, missingKeys)
if err != nil {
h.logger.Debugf("L2 batch get error: %v", err)
h.recordL2Error()
} else {
for key, value := range l2Results {
results[key] = value
h.l2Hits.Add(1)
// Asynchronously populate L1
go func(k string, v []byte) {
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
_ = h.primary.Set(writeCtx, k, v, 0) // Use default TTL
}(key, value)
}
}
} else {
// Fallback to individual gets
for _, key := range missingKeys {
if value, ttl, exists, err := h.secondary.Get(ctx, key); err == nil && exists {
results[key] = value
h.l2Hits.Add(1)
// Asynchronously populate L1
go func(k string, v []byte, t time.Duration) {
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
_ = h.primary.Set(writeCtx, k, v, t)
}(key, value, ttl)
}
}
}
// Count misses for keys not found anywhere
for _, key := range keys {
if _, found := results[key]; !found {
h.misses.Add(1)
}
}
return results, nil
}
// SetMany stores multiple key-value pairs efficiently
func (h *HybridBackend) SetMany(ctx context.Context, items map[string][]byte, ttl time.Duration) error {
if len(items) == 0 {
return nil
}
// Write to L1 first
for key, value := range items {
if err := h.primary.Set(ctx, key, value, ttl); err != nil {
h.logger.Debugf("Failed to write to L1 in batch: %v", err)
} else {
h.l1Writes.Add(1)
}
}
// Skip L2 if in fallback mode
if h.fallbackMode.Load() {
return nil
}
// Check if L2 supports batch operations
if batcher, ok := h.secondary.(interface {
SetMany(context.Context, map[string][]byte, time.Duration) error
}); ok {
if err := batcher.SetMany(ctx, items, ttl); err != nil {
h.logger.Warnf("Failed to batch write to L2: %v", err)
h.recordL2Error()
} else {
h.l2Writes.Add(int64(len(items)))
}
} else {
// Fallback to individual sets
for key, value := range items {
cacheType := h.extractCacheType(key)
if h.syncWriteCacheTypes[cacheType] {
// Sync write for critical types
if err := h.secondary.Set(ctx, key, value, ttl); err != nil {
h.logger.Debugf("Failed to write to L2: %v", err)
h.recordL2Error()
} else {
h.l2Writes.Add(1)
}
} else {
// Async write for non-critical types
select {
case h.asyncWriteBuffer <- &asyncWriteItem{
key: key,
value: value,
ttl: ttl,
ctx: ctx,
}:
// Queued
default:
h.logger.Warnf("Async buffer full for batch write")
}
}
}
}
return nil
}
// asyncWriteWorker processes asynchronous writes to L2
func (h *HybridBackend) asyncWriteWorker() {
defer h.wg.Done()
for {
select {
case <-h.ctx.Done():
// Drain remaining items with best effort
for len(h.asyncWriteBuffer) > 0 {
select {
case item := <-h.asyncWriteBuffer:
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
_ = h.secondary.Set(ctx, item.key, item.value, item.ttl)
cancel()
default:
return
}
}
return
case item, ok := <-h.asyncWriteBuffer:
if !ok {
return
}
// Skip if in fallback mode
if h.fallbackMode.Load() {
continue
}
// Perform the write with a timeout
writeCtx, cancel := context.WithTimeout(item.ctx, 500*time.Millisecond)
if err := h.secondary.Set(writeCtx, item.key, item.value, item.ttl); err != nil {
h.errors.Add(1)
h.logger.Debugf("Async write to L2 failed for key %s: %v", item.key, err)
h.recordL2Error()
} else {
h.l2Writes.Add(1)
h.logger.Debugf("Async write to L2 completed for key: %s", item.key)
}
cancel()
}
}
}
// healthMonitor periodically checks L2 health and manages fallback mode
func (h *HybridBackend) healthMonitor() {
defer h.wg.Done()
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-h.ctx.Done():
return
case <-ticker.C:
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
if err := h.secondary.Ping(ctx); err != nil {
if !h.fallbackMode.Load() {
h.fallbackMode.Store(true)
h.logger.Warnf("L2 backend unhealthy, entering fallback mode: %v", err)
}
} else {
if h.fallbackMode.CompareAndSwap(true, false) {
h.logger.Infof("L2 backend healthy, exiting fallback mode")
}
}
cancel()
}
}
}
// recordL2Error records the timestamp of an L2 error
func (h *HybridBackend) recordL2Error() {
h.lastL2Error.Store(time.Now())
// Check if we should enter fallback mode based on recent errors
if !h.fallbackMode.Load() {
// Simple heuristic: if we've had an error in the last second, consider L2 unhealthy
if lastErr := h.lastL2Error.Load(); lastErr != nil {
if t, ok := lastErr.(time.Time); ok && time.Since(t) < time.Second {
h.fallbackMode.Store(true)
h.logger.Warnf("Multiple L2 errors detected, entering fallback mode")
}
}
}
}
// extractCacheType attempts to determine the cache type from the key
func (h *HybridBackend) extractCacheType(key string) string {
// Simple heuristic based on key prefixes
// This should match the actual cache type strategy in the main application
if len(key) > 10 {
prefix := key[:10]
switch {
case contains(prefix, "blacklist"):
return "blacklist"
case contains(prefix, "token"):
return "token"
case contains(prefix, "metadata"):
return "metadata"
case contains(prefix, "jwk"):
return "jwk"
case contains(prefix, "session"):
return "session"
case contains(prefix, "introspect"):
return "introspection"
}
}
return "general"
}
// contains checks if a string contains a substring (case-insensitive)
func contains(s, substr string) bool {
if len(substr) > len(s) {
return false
}
for i := 0; i <= len(s)-len(substr); i++ {
match := true
for j := 0; j < len(substr); j++ {
if toLower(s[i+j]) != toLower(substr[j]) {
match = false
break
}
}
if match {
return true
}
}
return false
}
// toLower converts a byte to lowercase
func toLower(b byte) byte {
if b >= 'A' && b <= 'Z' {
return b + 32
}
return b
}
File diff suppressed because it is too large Load Diff
+102
View File
@@ -0,0 +1,102 @@
// Package backend provides cache backend implementations for the Traefik OIDC plugin.
package backends
import (
"context"
"time"
)
// CacheBackend defines the interface for all cache backend implementations
// Implementations include: MemoryBackend, RedisBackend, and HybridBackend
type CacheBackend interface {
// Set stores a value in the cache with the specified TTL
// Returns an error if the operation fails
Set(ctx context.Context, key string, value []byte, ttl time.Duration) error
// Get retrieves a value from the cache
// Returns: value, remaining TTL, exists flag, and error
// If the key doesn't exist, exists will be false
Get(ctx context.Context, key string) (value []byte, ttl time.Duration, exists bool, err error)
// Delete removes a key from the cache
// Returns true if the key was deleted, false if it didn't exist
Delete(ctx context.Context, key string) (bool, error)
// Exists checks if a key exists in the cache
Exists(ctx context.Context, key string) (bool, error)
// Clear removes all keys from the cache
Clear(ctx context.Context) error
// GetStats returns cache statistics
// Stats include: hits, misses, size, memory usage, etc.
GetStats() map[string]interface{}
// Close shuts down the cache backend and releases resources
Close() error
// Ping checks if the backend is healthy and responsive
Ping(ctx context.Context) error
}
// BackendStats represents statistics for a cache backend
type BackendStats struct {
StartTime time.Time
LastErrorTime time.Time
Type BackendType
LastError string
Deletes int64
Errors int64
Evictions int64
CurrentSize int64
MaxSize int64
MemoryUsage int64
AverageGetLatency time.Duration
AverageSetLatency time.Duration
Sets int64
Misses int64
Uptime time.Duration
Hits int64
}
// BackendCapabilities describes the capabilities of a cache backend
type BackendCapabilities struct {
// Distributed indicates if the backend is distributed across multiple instances
Distributed bool
// Persistent indicates if the backend persists data across restarts
Persistent bool
// Eviction indicates if the backend supports automatic eviction
Eviction bool
// TTL indicates if the backend supports TTL (time-to-live)
TTL bool
// MaxKeySize is the maximum size of a key in bytes (0 = unlimited)
MaxKeySize int64
// MaxValueSize is the maximum size of a value in bytes (0 = unlimited)
MaxValueSize int64
// MaxKeys is the maximum number of keys (0 = unlimited)
MaxKeys int64
// SupportsExpire indicates if the backend supports expiration
SupportsExpire bool
// SupportsMultiGet indicates if the backend supports batch get operations
SupportsMultiGet bool
// SupportsTransaction indicates if the backend supports transactions
SupportsTransaction bool
// SupportsCompression indicates if the backend supports compression
SupportsCompression bool
// RequiresSerialize indicates if values must be serialized
RequiresSerialize bool
// AtomicOperations indicates if the backend supports atomic operations
AtomicOperations bool
}
+421
View File
@@ -0,0 +1,421 @@
package backends
import (
"context"
"fmt"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestCacheBackendContract defines a set of tests that all CacheBackend implementations must pass
// This ensures that Memory, Redis, and Hybrid backends all behave consistently
func TestCacheBackendContract(t *testing.T) {
// Test suite will be run against each backend type
t.Run("MemoryBackend", func(t *testing.T) {
backend := setupMemoryBackend(t)
runContractTests(t, backend)
})
t.Run("RedisBackend", func(t *testing.T) {
backend := setupRedisBackend(t)
runContractTests(t, backend)
})
t.Run("HybridBackend", func(t *testing.T) {
backend := setupHybridBackend(t)
runContractTests(t, backend)
})
}
// runContractTests executes all contract tests against a backend
func runContractTests(t *testing.T, backend CacheBackend) {
t.Helper()
ctx := context.Background()
t.Run("BasicSetGet", func(t *testing.T) {
testBasicSetGet(t, ctx, backend)
})
t.Run("GetNonExistent", func(t *testing.T) {
testGetNonExistent(t, ctx, backend)
})
t.Run("UpdateExisting", func(t *testing.T) {
testUpdateExisting(t, ctx, backend)
})
t.Run("Delete", func(t *testing.T) {
testDelete(t, ctx, backend)
})
t.Run("DeleteNonExistent", func(t *testing.T) {
testDeleteNonExistent(t, ctx, backend)
})
t.Run("Exists", func(t *testing.T) {
testExists(t, ctx, backend)
})
t.Run("TTLExpiration", func(t *testing.T) {
testTTLExpiration(t, ctx, backend)
})
t.Run("Clear", func(t *testing.T) {
testClear(t, ctx, backend)
})
t.Run("Ping", func(t *testing.T) {
testPing(t, ctx, backend)
})
t.Run("Stats", func(t *testing.T) {
testStats(t, ctx, backend)
})
t.Run("ConcurrentAccess", func(t *testing.T) {
testConcurrentAccess(t, ctx, backend)
})
t.Run("LargeValues", func(t *testing.T) {
testLargeValues(t, ctx, backend)
})
t.Run("EmptyValues", func(t *testing.T) {
testEmptyValues(t, ctx, backend)
})
t.Run("SpecialCharactersInKeys", func(t *testing.T) {
testSpecialCharactersInKeys(t, ctx, backend)
})
}
// testBasicSetGet verifies basic set and get operations
func testBasicSetGet(t *testing.T, ctx context.Context, backend CacheBackend) {
t.Helper()
key := "test-key-1"
value := []byte("test-value-1")
ttl := 1 * time.Minute
// Set value
err := backend.Set(ctx, key, value, ttl)
require.NoError(t, err, "Set should not return error")
// Get value
retrieved, remainingTTL, exists, err := backend.Get(ctx, key)
require.NoError(t, err, "Get should not return error")
assert.True(t, exists, "Key should exist")
assert.Equal(t, value, retrieved, "Retrieved value should match")
assert.Greater(t, remainingTTL, 50*time.Second, "TTL should be close to original")
assert.LessOrEqual(t, remainingTTL, ttl, "TTL should not exceed original")
}
// testGetNonExistent verifies behavior when getting non-existent keys
func testGetNonExistent(t *testing.T, ctx context.Context, backend CacheBackend) {
t.Helper()
key := "non-existent-key"
retrieved, ttl, exists, err := backend.Get(ctx, key)
require.NoError(t, err, "Get should not return error for non-existent key")
assert.False(t, exists, "Key should not exist")
assert.Nil(t, retrieved, "Value should be nil")
assert.Equal(t, time.Duration(0), ttl, "TTL should be zero")
}
// testUpdateExisting verifies updating an existing key
func testUpdateExisting(t *testing.T, ctx context.Context, backend CacheBackend) {
t.Helper()
key := "update-key"
value1 := []byte("original-value")
value2 := []byte("updated-value")
ttl := 1 * time.Minute
// Set initial value
err := backend.Set(ctx, key, value1, ttl)
require.NoError(t, err)
// Update value
err = backend.Set(ctx, key, value2, ttl)
require.NoError(t, err)
// Verify updated value
retrieved, _, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, value2, retrieved, "Value should be updated")
}
// testDelete verifies delete operation
func testDelete(t *testing.T, ctx context.Context, backend CacheBackend) {
t.Helper()
key := "delete-key"
value := []byte("delete-value")
// Set value
err := backend.Set(ctx, key, value, 1*time.Minute)
require.NoError(t, err)
// Verify exists
exists, err := backend.Exists(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
// Delete
deleted, err := backend.Delete(ctx, key)
require.NoError(t, err)
assert.True(t, deleted, "Delete should return true for existing key")
// Verify deleted
exists, err = backend.Exists(ctx, key)
require.NoError(t, err)
assert.False(t, exists, "Key should not exist after delete")
}
// testDeleteNonExistent verifies deleting non-existent keys
func testDeleteNonExistent(t *testing.T, ctx context.Context, backend CacheBackend) {
t.Helper()
key := "non-existent-delete-key"
deleted, err := backend.Delete(ctx, key)
require.NoError(t, err)
assert.False(t, deleted, "Delete should return false for non-existent key")
}
// testExists verifies the Exists operation
func testExists(t *testing.T, ctx context.Context, backend CacheBackend) {
t.Helper()
key := "exists-key"
value := []byte("exists-value")
// Check non-existent key
exists, err := backend.Exists(ctx, key)
require.NoError(t, err)
assert.False(t, exists, "Key should not exist initially")
// Set value
err = backend.Set(ctx, key, value, 1*time.Minute)
require.NoError(t, err)
// Check existing key
exists, err = backend.Exists(ctx, key)
require.NoError(t, err)
assert.True(t, exists, "Key should exist after Set")
}
// testTTLExpiration verifies TTL expiration behavior
func testTTLExpiration(t *testing.T, ctx context.Context, backend CacheBackend) {
t.Helper()
key := "ttl-key"
value := []byte("ttl-value")
shortTTL := 100 * time.Millisecond
// Set with short TTL
err := backend.Set(ctx, key, value, shortTTL)
require.NoError(t, err)
// Verify exists immediately
exists, err := backend.Exists(ctx, key)
require.NoError(t, err)
assert.True(t, exists, "Key should exist immediately after Set")
// Wait for expiration
time.Sleep(200 * time.Millisecond)
// Verify expired
exists, err = backend.Exists(ctx, key)
require.NoError(t, err)
assert.False(t, exists, "Key should not exist after TTL expiration")
}
// testClear verifies Clear operation
func testClear(t *testing.T, ctx context.Context, backend CacheBackend) {
t.Helper()
// Set multiple keys
for i := 0; i < 5; i++ {
key := fmt.Sprintf("clear-key-%d", i)
value := []byte(fmt.Sprintf("clear-value-%d", i))
err := backend.Set(ctx, key, value, 1*time.Minute)
require.NoError(t, err)
}
// Give async writes time to complete before clearing
// This prevents race conditions with async write workers
time.Sleep(50 * time.Millisecond)
// Clear all
err := backend.Clear(ctx)
require.NoError(t, err)
// Verify all keys are gone
for i := 0; i < 5; i++ {
key := fmt.Sprintf("clear-key-%d", i)
exists, err := backend.Exists(ctx, key)
require.NoError(t, err)
assert.False(t, exists, "Key should not exist after Clear")
}
}
// testPing verifies Ping operation
func testPing(t *testing.T, ctx context.Context, backend CacheBackend) {
t.Helper()
err := backend.Ping(ctx)
assert.NoError(t, err, "Ping should succeed on healthy backend")
}
// testStats verifies GetStats operation
func testStats(t *testing.T, ctx context.Context, backend CacheBackend) {
t.Helper()
stats := backend.GetStats()
assert.NotNil(t, stats, "Stats should not be nil")
// Stats should contain basic metrics
_, hasHits := stats["hits"]
_, hasMisses := stats["misses"]
assert.True(t, hasHits || hasMisses, "Stats should contain hits or misses")
}
// testConcurrentAccess verifies thread safety
func testConcurrentAccess(t *testing.T, ctx context.Context, backend CacheBackend) {
t.Helper()
var wg sync.WaitGroup
goroutines := 10
iterations := 20
// Concurrent writes
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
for j := 0; j < iterations; j++ {
key := fmt.Sprintf("concurrent-key-%d-%d", id, j)
value := []byte(fmt.Sprintf("concurrent-value-%d-%d", id, j))
err := backend.Set(ctx, key, value, 1*time.Minute)
assert.NoError(t, err)
// Read back
retrieved, _, exists, err := backend.Get(ctx, key)
assert.NoError(t, err)
if exists {
assert.Equal(t, value, retrieved)
}
}
}(i)
}
wg.Wait()
}
// testLargeValues verifies handling of large values
func testLargeValues(t *testing.T, ctx context.Context, backend CacheBackend) {
t.Helper()
key := "large-value-key"
value := GenerateLargeValue(1024 * 1024) // 1MB
err := backend.Set(ctx, key, value, 1*time.Minute)
require.NoError(t, err, "Should handle large values")
retrieved, _, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, len(value), len(retrieved), "Large value should be retrieved intact")
}
// testEmptyValues verifies handling of empty values
func testEmptyValues(t *testing.T, ctx context.Context, backend CacheBackend) {
t.Helper()
key := "empty-value-key"
value := []byte{}
err := backend.Set(ctx, key, value, 1*time.Minute)
require.NoError(t, err, "Should handle empty values")
retrieved, _, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists, "Empty value should exist")
assert.Equal(t, 0, len(retrieved), "Retrieved value should be empty")
}
// testSpecialCharactersInKeys verifies handling of special characters in keys
func testSpecialCharactersInKeys(t *testing.T, ctx context.Context, backend CacheBackend) {
t.Helper()
specialKeys := []string{
"key:with:colons",
"key/with/slashes",
"key-with-dashes",
"key_with_underscores",
"key.with.dots",
"key|with|pipes",
}
for _, key := range specialKeys {
value := []byte(fmt.Sprintf("value-for-%s", key))
err := backend.Set(ctx, key, value, 1*time.Minute)
require.NoError(t, err, "Should handle special character in key: %s", key)
retrieved, _, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists, "Key with special characters should exist: %s", key)
assert.Equal(t, value, retrieved)
}
}
// Helper functions to setup different backend types
// These will be implemented in respective test files
func setupMemoryBackend(t *testing.T) CacheBackend {
t.Helper()
// This will be implemented in memory_test.go
// For now, return nil to allow compilation
t.Skip("MemoryBackend implementation pending")
return nil
}
func setupRedisBackend(t *testing.T) CacheBackend {
t.Helper()
// This will be implemented in redis_test.go
// For now, return nil to allow compilation
t.Skip("RedisBackend implementation pending")
return nil
}
func setupHybridBackend(t *testing.T) CacheBackend {
t.Helper()
primary := newMockBackend()
secondary := newMockBackend()
config := &HybridConfig{
Primary: primary,
Secondary: secondary,
AsyncBufferSize: 100,
Logger: NewTestLogger(t),
}
hybrid, err := NewHybridBackend(config)
require.NoError(t, err)
t.Cleanup(func() {
hybrid.Close()
})
return hybrid
}
+508
View File
@@ -0,0 +1,508 @@
// Package backend provides cache backend implementations for the Traefik OIDC plugin.
package backends
import (
"container/list"
"context"
"sync"
"sync/atomic"
"time"
)
// memoryCacheItem represents an item in the memory cache
type memoryCacheItem struct {
expiresAt time.Time
createdAt time.Time
accessedAt time.Time
value interface{}
element *list.Element
key string
accessCount int64
size int64
}
// isExpired checks if the item is expired
func (item *memoryCacheItem) isExpired() bool {
if item.expiresAt.IsZero() {
return false
}
return time.Now().After(item.expiresAt)
}
// MemoryCacheBackend implements the CacheBackend interface using in-memory storage
type MemoryCacheBackend struct {
startTime time.Time
lastErrorTime time.Time
items map[string]*memoryCacheItem
lruList *list.List
cleanupDone chan bool
cleanupTicker *time.Ticker
evictionPolicy string
lastError string
currentMemory int64
misses atomic.Int64
deletes atomic.Int64
evictions atomic.Int64
errors atomic.Int64
totalGetTime atomic.Int64
totalSetTime atomic.Int64
getCount atomic.Int64
setCount atomic.Int64
sets atomic.Int64
hits atomic.Int64
maxSize int64
currentSize int64
maxMemory int64
cleanupInterval time.Duration
mu sync.RWMutex
closed atomic.Bool
}
// NewMemoryCacheBackend creates a new memory cache backend
func NewMemoryCacheBackend(maxSize int64, maxMemory int64, cleanupInterval time.Duration) *MemoryCacheBackend {
if maxSize <= 0 {
maxSize = 10000 // Default to 10k items
}
if maxMemory <= 0 {
maxMemory = 100 * 1024 * 1024 // Default to 100MB
}
if cleanupInterval <= 0 {
cleanupInterval = 5 * time.Minute
}
m := &MemoryCacheBackend{
items: make(map[string]*memoryCacheItem),
lruList: list.New(),
maxSize: maxSize,
maxMemory: maxMemory,
startTime: time.Now(),
cleanupInterval: cleanupInterval,
evictionPolicy: "lru",
cleanupDone: make(chan bool),
}
// Start cleanup goroutine
m.cleanupTicker = time.NewTicker(cleanupInterval)
go m.cleanupLoop()
return m
}
// cleanupLoop runs periodic cleanup of expired items
func (m *MemoryCacheBackend) cleanupLoop() {
for {
select {
case <-m.cleanupTicker.C:
m.cleanupExpired()
case <-m.cleanupDone:
return
}
}
}
// cleanupExpired removes all expired items from the cache
func (m *MemoryCacheBackend) cleanupExpired() {
m.mu.Lock()
defer m.mu.Unlock()
var keysToDelete []string
for key, item := range m.items {
if item.isExpired() {
keysToDelete = append(keysToDelete, key)
}
}
for _, key := range keysToDelete {
m.deleteItemLocked(key)
}
}
// Get retrieves a value from the cache
func (m *MemoryCacheBackend) Get(ctx context.Context, key string) (interface{}, error) {
if m.closed.Load() {
return nil, ErrBackendUnavailable
}
start := time.Now()
defer func() {
duration := time.Since(start).Nanoseconds()
m.totalGetTime.Add(duration)
m.getCount.Add(1)
}()
m.mu.RLock()
item, exists := m.items[key]
m.mu.RUnlock()
if !exists {
m.misses.Add(1)
return nil, ErrCacheMiss
}
if item.isExpired() {
m.mu.Lock()
m.deleteItemLocked(key)
m.mu.Unlock()
m.misses.Add(1)
return nil, ErrCacheMiss
}
// Update access time and count
m.mu.Lock()
item.accessedAt = time.Now()
item.accessCount++
// Move to front of LRU list
if m.evictionPolicy == "lru" && item.element != nil {
m.lruList.MoveToFront(item.element)
}
m.mu.Unlock()
m.hits.Add(1)
return item.value, nil
}
// Set stores a value in the cache with optional TTL
func (m *MemoryCacheBackend) Set(ctx context.Context, key string, value interface{}, ttl time.Duration) error {
if m.closed.Load() {
return ErrBackendUnavailable
}
start := time.Now()
defer func() {
duration := time.Since(start).Nanoseconds()
m.totalSetTime.Add(duration)
m.setCount.Add(1)
}()
// Calculate item size (simplified estimation)
itemSize := int64(len(key)) + estimateValueSize(value)
m.mu.Lock()
defer m.mu.Unlock()
// Check if we need to evict items
if m.currentSize >= m.maxSize || m.currentMemory+itemSize > m.maxMemory {
m.evictLocked()
}
// Check if key exists
if oldItem, exists := m.items[key]; exists {
m.currentMemory -= oldItem.size
if oldItem.element != nil {
m.lruList.Remove(oldItem.element)
}
} else {
m.currentSize++
}
now := time.Now()
var expiresAt time.Time
if ttl > 0 {
expiresAt = now.Add(ttl)
}
item := &memoryCacheItem{
key: key,
value: value,
expiresAt: expiresAt,
createdAt: now,
accessedAt: now,
accessCount: 0,
size: itemSize,
}
// Add to LRU list
if m.evictionPolicy == "lru" {
item.element = m.lruList.PushFront(item)
}
m.items[key] = item
m.currentMemory += itemSize
m.sets.Add(1)
return nil
}
// Delete removes a key from the cache
func (m *MemoryCacheBackend) Delete(ctx context.Context, key string) error {
if m.closed.Load() {
return ErrBackendUnavailable
}
m.mu.Lock()
defer m.mu.Unlock()
if _, exists := m.items[key]; !exists {
return nil
}
m.deleteItemLocked(key)
m.deletes.Add(1)
return nil
}
// deleteItemLocked deletes an item without acquiring the lock (must be called with lock held)
func (m *MemoryCacheBackend) deleteItemLocked(key string) {
if item, exists := m.items[key]; exists {
m.currentMemory -= item.size
m.currentSize--
if item.element != nil {
m.lruList.Remove(item.element)
}
delete(m.items, key)
}
}
// evictLocked evicts items based on the eviction policy (must be called with lock held)
func (m *MemoryCacheBackend) evictLocked() {
if m.evictionPolicy == "lru" && m.lruList.Len() > 0 {
// Evict least recently used item
element := m.lruList.Back()
if element != nil {
item := element.Value.(*memoryCacheItem)
m.deleteItemLocked(item.key)
m.evictions.Add(1)
}
}
}
// Exists checks if a key exists in the cache
func (m *MemoryCacheBackend) Exists(ctx context.Context, key string) (bool, error) {
if m.closed.Load() {
return false, ErrBackendUnavailable
}
m.mu.RLock()
item, exists := m.items[key]
m.mu.RUnlock()
if !exists {
return false, nil
}
return !item.isExpired(), nil
}
// Clear removes all items from the cache
func (m *MemoryCacheBackend) Clear(ctx context.Context) error {
if m.closed.Load() {
return ErrBackendUnavailable
}
m.mu.Lock()
defer m.mu.Unlock()
m.items = make(map[string]*memoryCacheItem)
m.lruList = list.New()
m.currentSize = 0
m.currentMemory = 0
return nil
}
// Keys returns all keys matching the pattern (use "*" for all keys)
func (m *MemoryCacheBackend) Keys(ctx context.Context, pattern string) ([]string, error) {
if m.closed.Load() {
return nil, ErrBackendUnavailable
}
m.mu.RLock()
defer m.mu.RUnlock()
var keys []string
for key, item := range m.items {
if !item.isExpired() && matchPattern(pattern, key) {
keys = append(keys, key)
}
}
return keys, nil
}
// Size returns the number of items in the cache
func (m *MemoryCacheBackend) Size(ctx context.Context) (int64, error) {
if m.closed.Load() {
return 0, ErrBackendUnavailable
}
m.mu.RLock()
defer m.mu.RUnlock()
return m.currentSize, nil
}
// TTL returns the remaining time-to-live for a key
func (m *MemoryCacheBackend) TTL(ctx context.Context, key string) (time.Duration, error) {
if m.closed.Load() {
return 0, ErrBackendUnavailable
}
m.mu.RLock()
item, exists := m.items[key]
m.mu.RUnlock()
if !exists || item.isExpired() {
return 0, ErrCacheMiss
}
if item.expiresAt.IsZero() {
return 0, nil // No expiration
}
remaining := time.Until(item.expiresAt)
if remaining < 0 {
return 0, nil
}
return remaining, nil
}
// Expire updates the TTL for an existing key
func (m *MemoryCacheBackend) Expire(ctx context.Context, key string, ttl time.Duration) error {
if m.closed.Load() {
return ErrBackendUnavailable
}
m.mu.Lock()
defer m.mu.Unlock()
item, exists := m.items[key]
if !exists || item.isExpired() {
return ErrCacheMiss
}
if ttl > 0 {
item.expiresAt = time.Now().Add(ttl)
} else {
item.expiresAt = time.Time{} // Remove expiration
}
return nil
}
// GetStats returns statistics about the cache backend
func (m *MemoryCacheBackend) GetStats(ctx context.Context) (*BackendStats, error) {
if m.closed.Load() {
return nil, ErrBackendUnavailable
}
m.mu.RLock()
lastError := m.lastError
lastErrorTime := m.lastErrorTime
m.mu.RUnlock()
avgGetLatency := time.Duration(0)
if getCount := m.getCount.Load(); getCount > 0 {
avgGetLatency = time.Duration(m.totalGetTime.Load() / getCount)
}
avgSetLatency := time.Duration(0)
if setCount := m.setCount.Load(); setCount > 0 {
avgSetLatency = time.Duration(m.totalSetTime.Load() / setCount)
}
return &BackendStats{
Type: TypeMemory,
Hits: m.hits.Load(),
Misses: m.misses.Load(),
Sets: m.sets.Load(),
Deletes: m.deletes.Load(),
Errors: m.errors.Load(),
Evictions: m.evictions.Load(),
CurrentSize: m.currentSize,
MaxSize: m.maxSize,
MemoryUsage: m.currentMemory,
AverageGetLatency: avgGetLatency,
AverageSetLatency: avgSetLatency,
LastError: lastError,
LastErrorTime: lastErrorTime,
Uptime: time.Since(m.startTime),
StartTime: m.startTime,
}, nil
}
// Ping checks if the backend is healthy
func (m *MemoryCacheBackend) Ping(ctx context.Context) error {
if m.closed.Load() {
return ErrBackendUnavailable
}
return nil
}
// Close closes the backend and releases resources
func (m *MemoryCacheBackend) Close() error {
if m.closed.Swap(true) {
return nil // Already closed
}
m.cleanupTicker.Stop()
close(m.cleanupDone)
m.mu.Lock()
m.items = nil
m.lruList = nil
m.mu.Unlock()
return nil
}
// IsHealthy returns true if the backend is healthy
func (m *MemoryCacheBackend) IsHealthy() bool {
return !m.closed.Load()
}
// Type returns the backend type
func (m *MemoryCacheBackend) Type() BackendType {
return TypeMemory
}
// Capabilities returns the backend capabilities
func (m *MemoryCacheBackend) Capabilities() *BackendCapabilities {
return &BackendCapabilities{
Distributed: false,
Persistent: false,
Eviction: true,
TTL: true,
MaxKeySize: 1024, // 1KB
MaxValueSize: 10485760, // 10MB
MaxKeys: m.maxSize,
SupportsExpire: true,
SupportsMultiGet: true,
SupportsTransaction: false,
SupportsCompression: false,
RequiresSerialize: false,
}
}
// Helper functions
// estimateValueSize estimates the size of a value in bytes
func estimateValueSize(value interface{}) int64 {
// This is a simplified estimation
// In production, you might want to use a more accurate method
switch v := value.(type) {
case string:
return int64(len(v))
case []byte:
return int64(len(v))
case int, int32, int64, uint, uint32, uint64:
return 8
case float32, float64:
return 8
case bool:
return 1
default:
// For complex types, use a default estimate
return 256
}
}
// matchPattern checks if a key matches a pattern (simplified glob matching)
func matchPattern(pattern, key string) bool {
if pattern == "*" {
return true
}
// Simplified pattern matching - in production, use a proper glob library
return key == pattern || (len(pattern) > 0 && pattern[0] == '*' &&
len(key) >= len(pattern)-1 && key[len(key)-len(pattern)+1:] == pattern[1:])
}
+182
View File
@@ -0,0 +1,182 @@
package backends
import (
"context"
"testing"
"time"
"github.com/alicebob/miniredis/v2"
)
// setupBenchmarkRedis creates a miniredis instance for benchmarking
func setupBenchmarkRedis(b *testing.B) string {
b.Helper()
mr, err := miniredis.Run()
if err != nil {
b.Fatal(err)
}
b.Cleanup(func() {
mr.Close()
})
return mr.Addr()
}
// BenchmarkRedisOperations_WithPooling benchmarks memory allocations with object pooling
func BenchmarkRedisOperations_WithPooling(b *testing.B) {
addr := setupBenchmarkRedis(b)
config := &PoolConfig{
Address: addr,
MaxConnections: 10,
ConnectTimeout: 5 * time.Second,
}
pool, err := NewConnectionPool(config)
if err != nil {
b.Fatal(err)
}
defer pool.Close()
ctx := context.Background()
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
conn, err := pool.Get(ctx)
if err != nil {
b.Fatal(err)
}
// Perform various operations
_, _ = conn.Do("SET", "bench-key", "bench-value")
_, _ = conn.Do("GET", "bench-key")
_, _ = conn.Do("EXISTS", "bench-key")
_, _ = conn.Do("DEL", "bench-key")
pool.Put(conn)
}
}
// BenchmarkRedisBackend_SetGet benchmarks the full backend with pooling
func BenchmarkRedisBackend_SetGet(b *testing.B) {
addr := setupBenchmarkRedis(b)
backend, err := NewRedisBackend(&Config{
RedisAddr: addr,
PoolSize: 10,
})
if err != nil {
b.Fatal(err)
}
defer backend.Close()
ctx := context.Background()
testData := []byte("benchmark test data with some content")
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
// Set operation
err := backend.Set(ctx, "bench-key", testData, 0)
if err != nil {
b.Fatal(err)
}
// Get operation
_, _, _, err = backend.Get(ctx, "bench-key")
if err != nil {
b.Fatal(err)
}
}
}
// BenchmarkRedisBackend_ConcurrentAccess benchmarks concurrent operations with pooling
func BenchmarkRedisBackend_ConcurrentAccess(b *testing.B) {
addr := setupBenchmarkRedis(b)
backend, err := NewRedisBackend(&Config{
RedisAddr: addr,
PoolSize: 10,
})
if err != nil {
b.Fatal(err)
}
defer backend.Close()
ctx := context.Background()
testData := []byte("concurrent benchmark data")
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_ = backend.Set(ctx, "concurrent-key", testData, 0)
_, _, _, _ = backend.Get(ctx, "concurrent-key")
}
})
}
// BenchmarkRESPProtocol_WriteRead benchmarks RESP protocol encoding/decoding
func BenchmarkRESPProtocol_WriteRead(b *testing.B) {
addr := setupBenchmarkRedis(b)
config := &PoolConfig{
Address: addr,
MaxConnections: 10,
ConnectTimeout: 5 * time.Second,
}
pool, err := NewConnectionPool(config)
if err != nil {
b.Fatal(err)
}
defer pool.Close()
ctx := context.Background()
conn, err := pool.Get(ctx)
if err != nil {
b.Fatal(err)
}
defer pool.Put(conn)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
// This tests the pooling of RESPReader/RESPWriter
_, _ = conn.Do("PING")
}
}
// BenchmarkConnectionPool_GetPut benchmarks connection pool operations
func BenchmarkConnectionPool_GetPut(b *testing.B) {
addr := setupBenchmarkRedis(b)
config := &PoolConfig{
Address: addr,
MaxConnections: 10,
ConnectTimeout: 5 * time.Second,
}
pool, err := NewConnectionPool(config)
if err != nil {
b.Fatal(err)
}
defer pool.Close()
ctx := context.Background()
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
conn, err := pool.Get(ctx)
if err != nil {
b.Fatal(err)
}
pool.Put(conn)
}
}
+783
View File
@@ -0,0 +1,783 @@
package backends
import (
"context"
"fmt"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestMemoryBackend_BasicOperations tests basic CRUD operations
func TestMemoryBackend_BasicOperations(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
t.Run("SetAndGet", func(t *testing.T) {
key := "test-key"
value := []byte("test-value")
ttl := 1 * time.Minute
err := backend.Set(ctx, key, value, ttl)
require.NoError(t, err)
retrieved, remainingTTL, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, value, retrieved)
assert.Greater(t, remainingTTL, 50*time.Second)
assert.LessOrEqual(t, remainingTTL, ttl)
})
t.Run("GetNonExistent", func(t *testing.T) {
_, _, exists, err := backend.Get(ctx, "non-existent")
require.NoError(t, err)
assert.False(t, exists)
})
t.Run("Delete", func(t *testing.T) {
key := "delete-key"
value := []byte("delete-value")
err := backend.Set(ctx, key, value, 1*time.Minute)
require.NoError(t, err)
deleted, err := backend.Delete(ctx, key)
require.NoError(t, err)
assert.True(t, deleted)
exists, err := backend.Exists(ctx, key)
require.NoError(t, err)
assert.False(t, exists)
})
t.Run("DeleteNonExistent", func(t *testing.T) {
deleted, err := backend.Delete(ctx, "non-existent-delete")
require.NoError(t, err)
assert.False(t, deleted)
})
t.Run("Exists", func(t *testing.T) {
key := "exists-key"
value := []byte("exists-value")
exists, err := backend.Exists(ctx, key)
require.NoError(t, err)
assert.False(t, exists)
err = backend.Set(ctx, key, value, 1*time.Minute)
require.NoError(t, err)
exists, err = backend.Exists(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
})
t.Run("Clear", func(t *testing.T) {
// Add multiple items
for i := 0; i < 10; i++ {
key := fmt.Sprintf("clear-key-%d", i)
value := []byte(fmt.Sprintf("clear-value-%d", i))
err := backend.Set(ctx, key, value, 1*time.Minute)
require.NoError(t, err)
}
err := backend.Clear(ctx)
require.NoError(t, err)
stats := backend.GetStats()
size := stats["size"].(int64)
assert.Equal(t, int64(0), size)
})
}
// TestMemoryBackend_TTLExpiration tests TTL and expiration
func TestMemoryBackend_TTLExpiration(t *testing.T) {
t.Parallel()
config := DefaultConfig()
config.CleanupInterval = 50 * time.Millisecond
backend, err := NewMemoryBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
t.Run("ShortTTL", func(t *testing.T) {
key := "short-ttl-key"
value := []byte("short-ttl-value")
shortTTL := 100 * time.Millisecond
err := backend.Set(ctx, key, value, shortTTL)
require.NoError(t, err)
// Verify exists immediately
exists, err := backend.Exists(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
// Wait for expiration
time.Sleep(150 * time.Millisecond)
// Should be expired
_, _, exists, err = backend.Get(ctx, key)
require.NoError(t, err)
assert.False(t, exists)
})
t.Run("TTLDecrement", func(t *testing.T) {
key := "ttl-decrement-key"
value := []byte("ttl-decrement-value")
ttl := 2 * time.Second
err := backend.Set(ctx, key, value, ttl)
require.NoError(t, err)
// Check TTL immediately
_, ttl1, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
// Wait a bit
time.Sleep(500 * time.Millisecond)
// Check TTL again - should be less
_, ttl2, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Less(t, ttl2, ttl1, "TTL should decrease over time")
})
t.Run("CleanupExpiredItems", func(t *testing.T) {
// Set multiple items with short TTL
for i := 0; i < 5; i++ {
key := fmt.Sprintf("cleanup-key-%d", i)
value := []byte(fmt.Sprintf("cleanup-value-%d", i))
err := backend.Set(ctx, key, value, 50*time.Millisecond)
require.NoError(t, err)
}
// Wait for cleanup to run
time.Sleep(200 * time.Millisecond)
// All items should be cleaned up
for i := 0; i < 5; i++ {
key := fmt.Sprintf("cleanup-key-%d", i)
exists, err := backend.Exists(ctx, key)
require.NoError(t, err)
assert.False(t, exists, "Expired items should be cleaned up")
}
})
}
// TestMemoryBackend_LRUEviction tests LRU eviction
func TestMemoryBackend_LRUEviction(t *testing.T) {
t.Parallel()
config := DefaultConfig()
config.MaxSize = 5
backend, err := NewMemoryBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
// Fill cache to max size
for i := 0; i < 5; i++ {
key := fmt.Sprintf("lru-key-%d", i)
value := []byte(fmt.Sprintf("lru-value-%d", i))
err := backend.Set(ctx, key, value, 1*time.Minute)
require.NoError(t, err)
}
// Access first item to make it most recently used
_, _, exists, err := backend.Get(ctx, "lru-key-0")
require.NoError(t, err)
assert.True(t, exists)
// Add a new item - should evict lru-key-1 (least recently used)
err = backend.Set(ctx, "lru-key-new", []byte("new-value"), 1*time.Minute)
require.NoError(t, err)
// lru-key-0 should still exist (was accessed recently)
exists, err = backend.Exists(ctx, "lru-key-0")
require.NoError(t, err)
assert.True(t, exists, "Recently accessed item should not be evicted")
// lru-key-1 should be evicted
exists, err = backend.Exists(ctx, "lru-key-1")
require.NoError(t, err)
assert.False(t, exists, "Least recently used item should be evicted")
// Check eviction count
stats := backend.GetStats()
evictions := stats["evictions"].(int64)
assert.Greater(t, evictions, int64(0), "Should have evictions")
}
// TestMemoryBackend_MemoryLimit tests memory-based eviction
func TestMemoryBackend_MemoryLimit(t *testing.T) {
t.Parallel()
config := DefaultConfig()
config.MaxSize = 100
config.MaxMemoryBytes = 1024 // 1KB limit
backend, err := NewMemoryBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
// Add items until memory limit is reached
largeValue := make([]byte, 512) // 512 bytes each
for i := 0; i < 5; i++ {
key := fmt.Sprintf("mem-key-%d", i)
err := backend.Set(ctx, key, largeValue, 1*time.Minute)
require.NoError(t, err)
}
stats := backend.GetStats()
memory := stats["memory"].(int64)
assert.LessOrEqual(t, memory, config.MaxMemoryBytes, "Memory should not exceed limit")
evictions := stats["evictions"].(int64)
assert.Greater(t, evictions, int64(0), "Should have memory-based evictions")
}
// TestMemoryBackend_ConcurrentAccess tests thread safety
func TestMemoryBackend_ConcurrentAccess(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
var wg sync.WaitGroup
goroutines := 20
iterations := 50
// Concurrent writes
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
for j := 0; j < iterations; j++ {
key := fmt.Sprintf("concurrent-key-%d-%d", id, j)
value := []byte(fmt.Sprintf("concurrent-value-%d-%d", id, j))
err := backend.Set(ctx, key, value, 1*time.Minute)
assert.NoError(t, err)
// Read back
retrieved, _, exists, err := backend.Get(ctx, key)
assert.NoError(t, err)
if exists {
assert.Equal(t, value, retrieved)
}
// Random deletes
if j%5 == 0 {
backend.Delete(ctx, key)
}
}
}(i)
}
wg.Wait()
// Verify stats are consistent
stats := backend.GetStats()
hits := stats["hits"].(int64)
misses := stats["misses"].(int64)
assert.Greater(t, hits+misses, int64(0), "Should have cache operations")
}
// TestMemoryBackend_UpdateExisting tests updating existing keys
func TestMemoryBackend_UpdateExisting(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
key := "update-key"
value1 := []byte("original-value")
value2 := []byte("updated-value")
// Set original
err = backend.Set(ctx, key, value1, 1*time.Minute)
require.NoError(t, err)
// Update
err = backend.Set(ctx, key, value2, 2*time.Minute)
require.NoError(t, err)
// Verify updated
retrieved, ttl, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, value2, retrieved)
assert.Greater(t, ttl, 1*time.Minute, "TTL should be updated")
// Size should not increase (same key)
stats := backend.GetStats()
size := stats["size"].(int64)
assert.Equal(t, int64(1), size, "Size should be 1 for one key")
}
// TestMemoryBackend_Stats tests statistics tracking
func TestMemoryBackend_Stats(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
// Initial stats
stats := backend.GetStats()
assert.Equal(t, int64(0), stats["hits"].(int64))
assert.Equal(t, int64(0), stats["misses"].(int64))
// Add items and track hits/misses
backend.Set(ctx, "key1", []byte("value1"), 1*time.Minute)
backend.Set(ctx, "key2", []byte("value2"), 1*time.Minute)
// Hit
backend.Get(ctx, "key1")
// Miss
backend.Get(ctx, "non-existent")
stats = backend.GetStats()
assert.Equal(t, int64(1), stats["hits"].(int64))
assert.Equal(t, int64(1), stats["misses"].(int64))
hitRate := stats["hit_rate"].(float64)
assert.InDelta(t, 0.5, hitRate, 0.01)
}
// TestMemoryBackend_EmptyValues tests handling of empty values
func TestMemoryBackend_EmptyValues(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
key := "empty-key"
emptyValue := []byte{}
err = backend.Set(ctx, key, emptyValue, 1*time.Minute)
require.NoError(t, err)
retrieved, _, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, 0, len(retrieved))
}
// TestMemoryBackend_LargeValues tests handling of large values
func TestMemoryBackend_LargeValues(t *testing.T) {
t.Parallel()
config := DefaultConfig()
config.MaxMemoryBytes = 10 * 1024 * 1024 // 10MB
backend, err := NewMemoryBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
key := "large-key"
largeValue := make([]byte, 1024*1024) // 1MB
err = backend.Set(ctx, key, largeValue, 1*time.Minute)
require.NoError(t, err)
retrieved, _, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, len(largeValue), len(retrieved))
}
// TestMemoryBackend_Close tests proper cleanup on close
func TestMemoryBackend_Close(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
ctx := context.Background()
// Add some items
for i := 0; i < 10; i++ {
key := fmt.Sprintf("close-key-%d", i)
value := []byte(fmt.Sprintf("close-value-%d", i))
backend.Set(ctx, key, value, 1*time.Minute)
}
// Close
err = backend.Close()
require.NoError(t, err)
// Operations after close should fail
err = backend.Set(ctx, "after-close", []byte("value"), 1*time.Minute)
assert.Error(t, err)
assert.Equal(t, ErrBackendClosed, err)
_, _, _, err = backend.Get(ctx, "close-key-0")
assert.Error(t, err)
assert.Equal(t, ErrBackendClosed, err)
// Closing again should be safe
err = backend.Close()
assert.NoError(t, err)
}
// TestMemoryBackend_Ping tests ping operation
func TestMemoryBackend_Ping(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
err = backend.Ping(ctx)
assert.NoError(t, err)
// Close and ping should fail
backend.Close()
err = backend.Ping(ctx)
assert.Error(t, err)
}
// TestMemoryBackend_ValueIsolation tests that returned values are isolated
func TestMemoryBackend_ValueIsolation(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
key := "isolation-key"
originalValue := []byte("original-value")
err = backend.Set(ctx, key, originalValue, 1*time.Minute)
require.NoError(t, err)
// Get value and modify it
retrieved, _, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
// Modify retrieved value
if len(retrieved) > 0 {
retrieved[0] = 'X'
}
// Get again - should be unchanged
retrieved2, _, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, originalValue, retrieved2, "Original value should not be modified")
}
// TestMemoryBackend_Keys tests the Keys method with pattern matching
func TestMemoryBackend_Keys(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
// Add test data
testKeys := []string{"user:1", "user:2", "session:abc", "session:def", "token:xyz"}
for _, key := range testKeys {
err := backend.Set(ctx, key, []byte("value"), 1*time.Minute)
require.NoError(t, err)
}
t.Run("AllKeys", func(t *testing.T) {
keys, err := backend.Keys(ctx, "*")
require.NoError(t, err)
assert.Len(t, keys, 5)
})
t.Run("SpecificPattern", func(t *testing.T) {
// Simple exact match
keys, err := backend.Keys(ctx, "user:1")
require.NoError(t, err)
assert.Len(t, keys, 1)
assert.Contains(t, keys, "user:1")
})
t.Run("ExcludesExpired", func(t *testing.T) {
// Add an expired key
expiredKey := "expired:key"
err := backend.Set(ctx, expiredKey, []byte("value"), 1*time.Millisecond)
require.NoError(t, err)
// Wait for expiration
time.Sleep(10 * time.Millisecond)
keys, err := backend.Keys(ctx, "*")
require.NoError(t, err)
assert.NotContains(t, keys, expiredKey, "Expired keys should not be returned")
})
t.Run("AfterClose", func(t *testing.T) {
closedBackend, _ := NewMemoryBackend(DefaultConfig())
closedBackend.Close()
_, err := closedBackend.Keys(ctx, "*")
assert.Error(t, err)
assert.Equal(t, ErrBackendUnavailable, err)
})
}
// TestMemoryBackend_Size tests the Size method
func TestMemoryBackend_Size(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
// Initially empty
size, err := backend.Size(ctx)
require.NoError(t, err)
assert.Equal(t, int64(0), size)
// Add items
for i := 0; i < 5; i++ {
key := fmt.Sprintf("key-%d", i)
err := backend.Set(ctx, key, []byte("value"), 1*time.Minute)
require.NoError(t, err)
}
size, err = backend.Size(ctx)
require.NoError(t, err)
assert.Equal(t, int64(5), size)
// Delete one
backend.Delete(ctx, "key-0")
size, err = backend.Size(ctx)
require.NoError(t, err)
assert.Equal(t, int64(4), size)
// After close
backend.Close()
_, err = backend.Size(ctx)
assert.Error(t, err)
assert.Equal(t, ErrBackendUnavailable, err)
}
// TestMemoryBackend_TTL tests the TTL method
func TestMemoryBackend_TTL(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
t.Run("ExistingKey", func(t *testing.T) {
key := "ttl-key"
ttl := 1 * time.Minute
err := backend.Set(ctx, key, []byte("value"), ttl)
require.NoError(t, err)
remaining, err := backend.TTL(ctx, key)
require.NoError(t, err)
assert.Greater(t, remaining, 50*time.Second)
assert.LessOrEqual(t, remaining, ttl)
})
t.Run("NonExistentKey", func(t *testing.T) {
_, err := backend.TTL(ctx, "non-existent")
assert.Error(t, err)
assert.Equal(t, ErrCacheMiss, err)
})
t.Run("NoExpiration", func(t *testing.T) {
key := "no-expiry"
// TTL of 0 typically means no expiration
err := backend.Set(ctx, key, []byte("value"), 0)
require.NoError(t, err)
remaining, err := backend.TTL(ctx, key)
require.NoError(t, err)
// No expiration returns 0
assert.Equal(t, time.Duration(0), remaining)
})
t.Run("AfterClose", func(t *testing.T) {
closedBackend, _ := NewMemoryBackend(DefaultConfig())
closedBackend.Close()
_, err := closedBackend.TTL(ctx, "key")
assert.Error(t, err)
assert.Equal(t, ErrBackendUnavailable, err)
})
}
// TestMemoryBackend_Expire tests the Expire method
func TestMemoryBackend_Expire(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
t.Run("UpdateTTL", func(t *testing.T) {
key := "expire-key"
err := backend.Set(ctx, key, []byte("value"), 1*time.Minute)
require.NoError(t, err)
// Update to shorter TTL
err = backend.Expire(ctx, key, 5*time.Second)
require.NoError(t, err)
// Check new TTL
remaining, err := backend.TTL(ctx, key)
require.NoError(t, err)
assert.LessOrEqual(t, remaining, 5*time.Second)
})
t.Run("NonExistentKey", func(t *testing.T) {
err := backend.Expire(ctx, "non-existent", 1*time.Minute)
assert.Error(t, err)
assert.Equal(t, ErrCacheMiss, err)
})
t.Run("RemoveExpiration", func(t *testing.T) {
key := "no-expire-key"
err := backend.Set(ctx, key, []byte("value"), 1*time.Minute)
require.NoError(t, err)
// Set TTL to 0 to remove expiration
err = backend.Expire(ctx, key, 0)
require.NoError(t, err)
// TTL should now be 0
remaining, err := backend.TTL(ctx, key)
require.NoError(t, err)
assert.Equal(t, time.Duration(0), remaining)
})
t.Run("AfterClose", func(t *testing.T) {
closedBackend, _ := NewMemoryBackend(DefaultConfig())
closedBackend.Close()
err := closedBackend.Expire(ctx, "key", 1*time.Minute)
assert.Error(t, err)
assert.Equal(t, ErrBackendUnavailable, err)
})
}
// TestMemoryBackend_IsHealthy tests the IsHealthy method
func TestMemoryBackend_IsHealthy(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
// Should be healthy when open
assert.True(t, backend.IsHealthy())
// Should be unhealthy after close
backend.Close()
assert.False(t, backend.IsHealthy())
}
// TestMemoryBackend_Type tests the Type method
func TestMemoryBackend_Type(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
backendType := backend.Type()
assert.Equal(t, TypeMemory, backendType)
}
// TestMemoryBackend_Capabilities tests the Capabilities method
func TestMemoryBackend_Capabilities(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
caps := backend.Capabilities()
require.NotNil(t, caps)
// Memory backend should not be distributed or persistent
assert.False(t, caps.Distributed)
assert.False(t, caps.Persistent)
// Should support eviction and TTL
assert.True(t, caps.Eviction)
assert.True(t, caps.TTL)
assert.True(t, caps.SupportsExpire)
assert.True(t, caps.SupportsMultiGet)
// Check limits
assert.Greater(t, caps.MaxKeySize, int64(0))
assert.Greater(t, caps.MaxValueSize, int64(0))
}
// TestMatchPattern tests the matchPattern helper function
func TestMatchPattern(t *testing.T) {
t.Parallel()
tests := []struct {
pattern string
key string
matches bool
}{
{"*", "any-key", true},
{"*", "another", true},
{"user:1", "user:1", true},
{"user:1", "user:2", false},
{"*:suffix", "prefix:suffix", true},
{"*suffix", "prefix-suffix", true},
{"*abc", "xyzabc", true},
{"*abc", "xyz", false},
}
for _, tt := range tests {
t.Run(fmt.Sprintf("%s-%s", tt.pattern, tt.key), func(t *testing.T) {
result := matchPattern(tt.pattern, tt.key)
assert.Equal(t, tt.matches, result)
})
}
}
+153
View File
@@ -0,0 +1,153 @@
package backends
import (
"context"
"time"
)
// MemoryBackend wraps MemoryCacheBackend to implement the CacheBackend interface
type MemoryBackend struct {
*MemoryCacheBackend
}
// NewMemoryBackend creates a new memory backend from a config
func NewMemoryBackend(config *Config) (*MemoryBackend, error) {
maxSize := int64(config.MaxSize)
if maxSize <= 0 {
maxSize = 1000
}
cacheBackend := NewMemoryCacheBackend(maxSize, config.MaxMemoryBytes, config.CleanupInterval)
return &MemoryBackend{
MemoryCacheBackend: cacheBackend,
}, nil
}
// Set stores a value in the cache with the specified TTL
func (m *MemoryBackend) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
err := m.MemoryCacheBackend.Set(ctx, key, value, ttl)
if err == ErrBackendUnavailable {
return ErrBackendClosed
}
return err
}
// Get retrieves a value from the cache
func (m *MemoryBackend) Get(ctx context.Context, key string) ([]byte, time.Duration, bool, error) {
val, err := m.MemoryCacheBackend.Get(ctx, key)
if err != nil {
if err == ErrCacheMiss {
return nil, 0, false, nil
}
if err == ErrBackendUnavailable {
return nil, 0, false, ErrBackendClosed
}
return nil, 0, false, err
}
// Get the item directly to check TTL
m.MemoryCacheBackend.mu.RLock()
item, exists := m.MemoryCacheBackend.items[key]
m.MemoryCacheBackend.mu.RUnlock()
if !exists {
return nil, 0, false, nil
}
var ttl time.Duration
if !item.expiresAt.IsZero() {
ttl = time.Until(item.expiresAt)
if ttl < 0 {
ttl = 0
}
}
// Convert interface{} to []byte
var valueBytes []byte
if val != nil {
if bytes, ok := val.([]byte); ok {
valueBytes = bytes
} else {
// If it's not already []byte, we might need to handle other types
// For now, we'll just return an error
return nil, 0, false, ErrInvalidValue
}
}
return valueBytes, ttl, true, nil
}
// Delete removes a key from the cache
func (m *MemoryBackend) Delete(ctx context.Context, key string) (bool, error) {
// Check if key exists first
exists, err := m.MemoryCacheBackend.Exists(ctx, key)
if err != nil {
return false, err
}
if !exists {
return false, nil
}
err = m.MemoryCacheBackend.Delete(ctx, key)
if err != nil {
return false, err
}
return true, nil
}
// Exists checks if a key exists in the cache
func (m *MemoryBackend) Exists(ctx context.Context, key string) (bool, error) {
return m.MemoryCacheBackend.Exists(ctx, key)
}
// Clear removes all keys from the cache
func (m *MemoryBackend) Clear(ctx context.Context) error {
return m.MemoryCacheBackend.Clear(ctx)
}
// GetStats returns cache statistics
func (m *MemoryBackend) GetStats() map[string]interface{} {
stats, err := m.MemoryCacheBackend.GetStats(context.Background())
if err != nil {
return map[string]interface{}{
"error": err.Error(),
}
}
// Convert BackendStats to map
hitRate := float64(0)
total := stats.Hits + stats.Misses
if total > 0 {
hitRate = float64(stats.Hits) / float64(total)
}
return map[string]interface{}{
"type": stats.Type,
"hits": stats.Hits,
"misses": stats.Misses,
"sets": stats.Sets,
"deletes": stats.Deletes,
"errors": stats.Errors,
"evictions": stats.Evictions,
"size": stats.CurrentSize,
"max_size": stats.MaxSize,
"memory": stats.MemoryUsage,
"hit_rate": hitRate,
"uptime": stats.Uptime,
"start_time": stats.StartTime,
}
}
// Close shuts down the cache backend and releases resources
func (m *MemoryBackend) Close() error {
return m.MemoryCacheBackend.Close()
}
// Ping checks if the backend is healthy and responsive
func (m *MemoryBackend) Ping(ctx context.Context) error {
return m.MemoryCacheBackend.Ping(ctx)
}
// Ensure MemoryBackend implements CacheBackend
var _ CacheBackend = (*MemoryBackend)(nil)
+470
View File
@@ -0,0 +1,470 @@
package backends
import (
"context"
"errors"
"fmt"
"sync"
"sync/atomic"
"time"
)
// Pure-Go Redis client implementation
// Compatible with Yaegi interpreter (no unsafe package)
// Implements RESP protocol for basic Redis operations
var (
ErrPoolExhausted = errors.New("connection pool exhausted")
)
// RedisBackend implements a Redis-based cache backend using pure Go
type RedisBackend struct {
config *Config
pool *ConnectionPool
healthMonitor *HealthMonitor
// Metrics
hits atomic.Int64
misses atomic.Int64
// Lifecycle
closed atomic.Bool
mu sync.Mutex
}
// NewRedisBackend creates a new Redis cache backend with pure-Go implementation
func NewRedisBackend(config *Config) (*RedisBackend, error) {
if config == nil {
return nil, fmt.Errorf("config is required")
}
if config.RedisAddr == "" {
return nil, fmt.Errorf("redis address is required")
}
// Create connection pool with health checks enabled
// Timeouts are kept short to prevent request pileup when Redis is slow/stalled.
// The UniversalCache uses 200ms context timeout, so socket timeouts should be
// shorter to allow proper context cancellation handling.
poolConfig := &PoolConfig{
Address: config.RedisAddr,
Password: config.RedisPassword,
DB: config.RedisDB,
MaxConnections: config.PoolSize,
ConnectTimeout: 2 * time.Second,
ReadTimeout: 500 * time.Millisecond,
WriteTimeout: 500 * time.Millisecond,
EnableHealthCheck: true,
MaxRetries: 3,
RetryDelay: 100 * time.Millisecond,
}
pool, err := NewConnectionPool(poolConfig)
if err != nil {
return nil, fmt.Errorf("failed to create connection pool: %w", err)
}
// Create health monitor
healthConfig := DefaultHealthMonitorConfig()
healthMonitor := NewHealthMonitor(pool, healthConfig)
backend := &RedisBackend{
config: config,
pool: pool,
healthMonitor: healthMonitor,
}
// Test connectivity
if err := backend.Ping(context.Background()); err != nil {
_ = pool.Close()
return nil, fmt.Errorf("failed to ping Redis: %w", err)
}
// Start health monitoring
healthMonitor.Start()
return backend, nil
}
// Set stores a value in Redis with TTL
func (r *RedisBackend) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
if r.closed.Load() {
return ErrBackendClosed
}
prefixedKey := r.prefixKey(key)
// Execute with retry logic
return r.executeWithRetry(ctx, func(conn *RedisConn) error {
var err error
// Use PSETEX for millisecond precision, SETEX for second precision
if ttl > 0 {
ttlMillis := ttl.Milliseconds()
if ttlMillis < 1000 {
// Use PSETEX for sub-second TTLs (millisecond precision)
_, err = conn.Do("PSETEX", prefixedKey, fmt.Sprintf("%d", ttlMillis), string(value))
} else {
// Use SETEX for larger TTLs (second precision)
ttlSeconds := int(ttl.Seconds())
_, err = conn.Do("SETEX", prefixedKey, fmt.Sprintf("%d", ttlSeconds), string(value))
}
} else {
_, err = conn.Do("SET", prefixedKey, string(value))
}
return err
})
}
// Get retrieves a value from Redis
func (r *RedisBackend) Get(ctx context.Context, key string) ([]byte, time.Duration, bool, error) {
if r.closed.Load() {
return nil, 0, false, ErrBackendClosed
}
prefixedKey := r.prefixKey(key)
var resultValue []byte
var resultTTL time.Duration
var resultExists bool
// Execute with retry logic
err := r.executeWithRetry(ctx, func(conn *RedisConn) error {
// Get value
resp, err := conn.Do("GET", prefixedKey)
if err != nil {
if errors.Is(err, ErrNilResponse) {
r.misses.Add(1)
resultExists = false
return nil // Not an error, key just doesn't exist
}
return err
}
value, err := RESPString(resp)
if err != nil {
return err
}
// Get TTL
ttlResp, err := conn.Do("TTL", prefixedKey)
if err != nil {
// If TTL fails, still return the value
r.hits.Add(1)
resultValue = []byte(value)
resultTTL = 0
resultExists = true
return nil
}
ttlSeconds, _ := RESPInt(ttlResp)
var ttl time.Duration
if ttlSeconds > 0 {
ttl = time.Duration(ttlSeconds) * time.Second
}
r.hits.Add(1)
resultValue = []byte(value)
resultTTL = ttl
resultExists = true
return nil
})
return resultValue, resultTTL, resultExists, err
}
// Delete removes a key from Redis
func (r *RedisBackend) Delete(ctx context.Context, key string) (bool, error) {
if r.closed.Load() {
return false, ErrBackendClosed
}
conn, err := r.pool.Get(ctx)
if err != nil {
return false, err
}
defer r.pool.Put(conn)
prefixedKey := r.prefixKey(key)
resp, err := conn.Do("DEL", prefixedKey)
if err != nil {
return false, err
}
count, err := RESPInt(resp)
if err != nil {
return false, err
}
return count > 0, nil
}
// Exists checks if a key exists in Redis
func (r *RedisBackend) Exists(ctx context.Context, key string) (bool, error) {
if r.closed.Load() {
return false, ErrBackendClosed
}
conn, err := r.pool.Get(ctx)
if err != nil {
return false, err
}
defer r.pool.Put(conn)
prefixedKey := r.prefixKey(key)
resp, err := conn.Do("EXISTS", prefixedKey)
if err != nil {
return false, err
}
count, err := RESPInt(resp)
if err != nil {
return false, err
}
return count > 0, nil
}
// Clear removes all keys with the configured prefix
func (r *RedisBackend) Clear(ctx context.Context) error {
if r.closed.Load() {
return ErrBackendClosed
}
conn, err := r.pool.Get(ctx)
if err != nil {
return err
}
defer r.pool.Put(conn)
// Use FLUSHDB if no prefix (clear entire DB)
if r.config.RedisPrefix == "" {
_, err := conn.Do("FLUSHDB")
return err
}
// With prefix, we need to scan and delete keys
// For simplicity in this implementation, we'll use KEYS pattern (not recommended for production at scale)
pattern := r.config.RedisPrefix + "*"
resp, err := conn.Do("KEYS", pattern)
if err != nil {
return err
}
// Extract keys from array response
keys, ok := resp.([]interface{})
if !ok || len(keys) == 0 {
return nil
}
// Delete each key
for _, keyInterface := range keys {
key, err := RESPString(keyInterface)
if err != nil {
continue
}
_, _ = conn.Do("DEL", key) // Best effort, ignore errors
}
return nil
}
// GetStats returns backend statistics
func (r *RedisBackend) GetStats() map[string]interface{} {
hits := r.hits.Load()
misses := r.misses.Load()
total := hits + misses
hitRate := float64(0)
if total > 0 {
hitRate = float64(hits) / float64(total)
}
stats := map[string]interface{}{
"backend": "redis-pure-go",
"address": r.config.RedisAddr,
"hits": hits,
"misses": misses,
"hit_rate": hitRate,
"pool": r.pool.Stats(),
}
// Add health monitor stats if available
if r.healthMonitor != nil {
stats["health"] = r.healthMonitor.GetStats()
}
return stats
}
// Ping checks Redis connectivity
func (r *RedisBackend) Ping(ctx context.Context) error {
if r.closed.Load() {
return ErrBackendClosed
}
conn, err := r.pool.Get(ctx)
if err != nil {
return err
}
defer r.pool.Put(conn)
_, err = conn.Do("PING")
return err
}
// Close closes the Redis backend and all connections
func (r *RedisBackend) Close() error {
if r.closed.Swap(true) {
return nil // Already closed
}
r.mu.Lock()
defer r.mu.Unlock()
// Stop health monitor
if r.healthMonitor != nil {
r.healthMonitor.Stop()
}
// Close connection pool
if r.pool != nil {
return r.pool.Close()
}
return nil
}
// prefixKey adds the configured prefix to a key
func (r *RedisBackend) prefixKey(key string) string {
if r.config.RedisPrefix == "" {
return key
}
return r.config.RedisPrefix + key
}
// executeWithRetry executes a Redis operation with exponential backoff retry logic.
// It checks context cancellation at multiple points to ensure fast abort when the
// caller's context is cancelled (e.g., due to request timeout).
func (r *RedisBackend) executeWithRetry(ctx context.Context, operation func(*RedisConn) error) error {
maxRetries := 3
baseDelay := 50 * time.Millisecond // Reduced from 100ms to fail faster
for attempt := 0; attempt < maxRetries; attempt++ {
// Check context before each attempt to fail fast
if ctx.Err() != nil {
return ctx.Err()
}
conn, err := r.pool.Get(ctx)
if err != nil {
// If we can't get a connection and this is the last attempt, fail
if attempt == maxRetries-1 {
return fmt.Errorf("failed to get connection after %d attempts: %w", maxRetries, err)
}
// Wait with exponential backoff before retrying
delay := baseDelay * time.Duration(1<<uint(attempt))
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(delay):
continue
}
}
// Execute the operation
err = operation(conn)
r.pool.Put(conn)
// Check context after operation - if cancelled, don't bother retrying
if ctx.Err() != nil {
return ctx.Err()
}
// If successful, return
if err == nil {
return nil
}
// If error is not retryable or last attempt, fail
if attempt == maxRetries-1 || !isRetryableError(err) {
return err
}
// Wait with exponential backoff before retrying
delay := baseDelay * time.Duration(1<<uint(attempt))
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(delay):
continue
}
}
return fmt.Errorf("operation failed after %d attempts", maxRetries)
}
// isRetryableError determines if an error is worth retrying
func isRetryableError(err error) bool {
if err == nil {
return false
}
// Retry on connection errors, timeouts, etc.
// Don't retry on application-level errors like wrong type
errMsg := err.Error()
retryablePatterns := []string{
"connection",
"timeout",
"EOF",
"broken pipe",
"reset by peer",
}
for _, pattern := range retryablePatterns {
if contains(errMsg, pattern) {
return true
}
}
return false
}
// SetMany stores multiple values in Redis (batch operation)
func (r *RedisBackend) SetMany(ctx context.Context, items map[string][]byte, ttl time.Duration) error {
if r.closed.Load() {
return ErrBackendClosed
}
// For simplicity, execute sequentially (can be optimized with pipelining later)
for key, value := range items {
if err := r.Set(ctx, key, value, ttl); err != nil {
return err
}
}
return nil
}
// GetMany retrieves multiple values from Redis
func (r *RedisBackend) GetMany(ctx context.Context, keys []string) (map[string][]byte, error) {
if r.closed.Load() {
return nil, ErrBackendClosed
}
result := make(map[string][]byte)
// For simplicity, execute sequentially
for _, key := range keys {
value, _, exists, err := r.Get(ctx, key)
if err != nil {
return nil, err
}
if exists {
result[key] = value
}
}
return result, nil
}
+170
View File
@@ -0,0 +1,170 @@
package backends
import (
"context"
"sync"
"sync/atomic"
"time"
)
// HealthMonitor continuously monitors Redis connection health and triggers reconnections
type HealthMonitor struct {
pool *ConnectionPool
config *HealthMonitorConfig
stopChan chan struct{}
wg sync.WaitGroup
lastCheckTime atomic.Int64
consecutiveFailures atomic.Int64
totalChecks atomic.Int64
totalFailures atomic.Int64
healthy atomic.Bool
running atomic.Bool
}
// HealthMonitorConfig configures the health monitor
type HealthMonitorConfig struct {
OnHealthChange func(healthy bool)
CheckInterval time.Duration
Timeout time.Duration
UnhealthyThreshold int
}
// DefaultHealthMonitorConfig returns default health monitor configuration
func DefaultHealthMonitorConfig() *HealthMonitorConfig {
return &HealthMonitorConfig{
CheckInterval: 5 * time.Second,
Timeout: 3 * time.Second,
UnhealthyThreshold: 3,
}
}
// NewHealthMonitor creates a new health monitor
func NewHealthMonitor(pool *ConnectionPool, config *HealthMonitorConfig) *HealthMonitor {
if config == nil {
config = DefaultHealthMonitorConfig()
}
hm := &HealthMonitor{
pool: pool,
config: config,
stopChan: make(chan struct{}),
}
hm.healthy.Store(true) // Assume healthy initially
return hm
}
// Start begins health monitoring
func (hm *HealthMonitor) Start() {
if hm.running.Swap(true) {
return // Already running
}
hm.wg.Add(1)
go hm.monitorLoop()
}
// Stop stops health monitoring
func (hm *HealthMonitor) Stop() {
if !hm.running.Swap(false) {
return // Not running
}
close(hm.stopChan)
hm.wg.Wait()
}
// IsHealthy returns the current health status
func (hm *HealthMonitor) IsHealthy() bool {
return hm.healthy.Load()
}
// GetStats returns health monitor statistics
func (hm *HealthMonitor) GetStats() map[string]interface{} {
lastCheck := time.Unix(hm.lastCheckTime.Load(), 0)
return map[string]interface{}{
"healthy": hm.healthy.Load(),
"consecutive_failures": hm.consecutiveFailures.Load(),
"total_checks": hm.totalChecks.Load(),
"total_failures": hm.totalFailures.Load(),
"last_check": lastCheck,
}
}
// monitorLoop runs the health check loop
func (hm *HealthMonitor) monitorLoop() {
defer hm.wg.Done()
ticker := time.NewTicker(hm.config.CheckInterval)
defer ticker.Stop()
// Perform initial check immediately
hm.performHealthCheck()
for {
select {
case <-hm.stopChan:
return
case <-ticker.C:
hm.performHealthCheck()
}
}
}
// performHealthCheck executes a health check
func (hm *HealthMonitor) performHealthCheck() {
hm.totalChecks.Add(1)
hm.lastCheckTime.Store(time.Now().Unix())
ctx, cancel := context.WithTimeout(context.Background(), hm.config.Timeout)
defer cancel()
// Try to get a connection and ping Redis
conn, err := hm.pool.Get(ctx)
if err != nil {
hm.recordFailure()
return
}
defer hm.pool.Put(conn)
// Ping Redis
_, err = conn.Do("PING")
if err != nil {
hm.recordFailure()
return
}
// Success!
hm.recordSuccess()
}
// recordSuccess records a successful health check
func (hm *HealthMonitor) recordSuccess() {
wasHealthy := hm.healthy.Load()
hm.consecutiveFailures.Store(0)
hm.healthy.Store(true)
// Trigger callback if health changed
if !wasHealthy && hm.config.OnHealthChange != nil {
hm.config.OnHealthChange(true)
}
}
// recordFailure records a failed health check
func (hm *HealthMonitor) recordFailure() {
hm.totalFailures.Add(1)
failures := hm.consecutiveFailures.Add(1)
wasHealthy := hm.healthy.Load()
// Mark unhealthy if threshold exceeded
if failures >= int64(hm.config.UnhealthyThreshold) {
hm.healthy.Store(false)
// Trigger callback if health changed
if wasHealthy && hm.config.OnHealthChange != nil {
hm.config.OnHealthChange(false)
}
}
}
+421
View File
@@ -0,0 +1,421 @@
package backends
import (
"context"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestHealthMonitor_BasicOperation tests basic health monitoring
func TestHealthMonitor_BasicOperation(t *testing.T) {
mr := NewMiniredisServer(t)
config := &PoolConfig{
Address: mr.GetAddr(),
MaxConnections: 5,
ConnectTimeout: 5 * time.Second,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
// Create health monitor with fast check interval for testing
hmConfig := &HealthMonitorConfig{
CheckInterval: 100 * time.Millisecond,
Timeout: 1 * time.Second,
UnhealthyThreshold: 2,
}
hm := NewHealthMonitor(pool, hmConfig)
require.NotNil(t, hm)
// Initially should be healthy
assert.True(t, hm.IsHealthy())
// Start monitoring
hm.Start()
defer hm.Stop()
// Wait for a few checks
time.Sleep(500 * time.Millisecond)
// Should still be healthy
assert.True(t, hm.IsHealthy())
// Check stats
stats := hm.GetStats()
require.NotNil(t, stats)
assert.True(t, stats["healthy"].(bool))
assert.Greater(t, stats["total_checks"].(int64), int64(0))
assert.Equal(t, int64(0), stats["consecutive_failures"].(int64))
}
// TestHealthMonitor_HealthyToUnhealthy tests transition to unhealthy state
func TestHealthMonitor_HealthyToUnhealthy(t *testing.T) {
mr := NewMiniredisServer(t)
config := &PoolConfig{
Address: mr.GetAddr(),
MaxConnections: 5,
ConnectTimeout: 100 * time.Millisecond,
ReadTimeout: 100 * time.Millisecond,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
var healthChangedCalled atomic.Bool
hmConfig := &HealthMonitorConfig{
CheckInterval: 50 * time.Millisecond,
Timeout: 100 * time.Millisecond,
UnhealthyThreshold: 2,
OnHealthChange: func(healthy bool) {
if !healthy {
healthChangedCalled.Store(true)
}
},
}
hm := NewHealthMonitor(pool, hmConfig)
hm.Start()
defer hm.Stop()
// Initially healthy
assert.True(t, hm.IsHealthy())
// Simulate Redis errors
mr.SetError("ERR server is down")
// Wait for health checks to detect failure (2 failures * 50ms + buffer)
time.Sleep(350 * time.Millisecond)
// Should now be unhealthy
assert.False(t, hm.IsHealthy(), "Health monitor should detect server failure")
assert.True(t, healthChangedCalled.Load(), "OnHealthChange callback should be called")
// Check stats
stats := hm.GetStats()
assert.False(t, stats["healthy"].(bool))
assert.GreaterOrEqual(t, stats["consecutive_failures"].(int64), int64(2))
assert.Greater(t, stats["total_failures"].(int64), int64(0))
}
// TestHealthMonitor_UnhealthyToHealthy tests recovery to healthy state
func TestHealthMonitor_UnhealthyToHealthy(t *testing.T) {
mr := NewMiniredisServer(t)
config := &PoolConfig{
Address: mr.GetAddr(),
MaxConnections: 5,
ConnectTimeout: 100 * time.Millisecond,
ReadTimeout: 100 * time.Millisecond,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
var recoveryDetected atomic.Bool
hmConfig := &HealthMonitorConfig{
CheckInterval: 50 * time.Millisecond,
Timeout: 100 * time.Millisecond,
UnhealthyThreshold: 2,
OnHealthChange: func(healthy bool) {
if healthy {
recoveryDetected.Store(true)
}
},
}
hm := NewHealthMonitor(pool, hmConfig)
hm.Start()
defer hm.Stop()
// Initially healthy
assert.True(t, hm.IsHealthy())
// Simulate Redis errors
mr.SetError("ERR server is down")
// Wait for health checks to detect failure
time.Sleep(350 * time.Millisecond)
// Should now be unhealthy
assert.False(t, hm.IsHealthy(), "Should detect server failure")
// Clear error to simulate recovery
mr.ClearError()
// Wait for recovery
time.Sleep(350 * time.Millisecond)
// Should be healthy again
assert.True(t, hm.IsHealthy(), "Should recover after server restart")
assert.True(t, recoveryDetected.Load(), "Recovery callback should be called")
// Consecutive failures should be reset
stats := hm.GetStats()
assert.True(t, stats["healthy"].(bool))
assert.Equal(t, int64(0), stats["consecutive_failures"].(int64))
}
// TestHealthMonitor_StartStop tests start/stop behavior
func TestHealthMonitor_StartStop(t *testing.T) {
mr := NewMiniredisServer(t)
config := &PoolConfig{
Address: mr.GetAddr(),
MaxConnections: 5,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
hm := NewHealthMonitor(pool, DefaultHealthMonitorConfig())
// Start monitoring
hm.Start()
assert.True(t, hm.running.Load())
// Starting again should be no-op
hm.Start()
assert.True(t, hm.running.Load())
// Stop monitoring
hm.Stop()
assert.False(t, hm.running.Load())
// Stopping again should be no-op
hm.Stop()
assert.False(t, hm.running.Load())
}
// TestHealthMonitor_MultipleMonitors tests multiple health monitors
func TestHealthMonitor_MultipleMonitors(t *testing.T) {
mr := NewMiniredisServer(t)
config := &PoolConfig{
Address: mr.GetAddr(),
MaxConnections: 10,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
// Create multiple monitors
hm1 := NewHealthMonitor(pool, &HealthMonitorConfig{
CheckInterval: 100 * time.Millisecond,
Timeout: 1 * time.Second,
UnhealthyThreshold: 2,
})
hm2 := NewHealthMonitor(pool, &HealthMonitorConfig{
CheckInterval: 150 * time.Millisecond,
Timeout: 1 * time.Second,
UnhealthyThreshold: 3,
})
// Start both
hm1.Start()
hm2.Start()
// Both should be healthy
time.Sleep(200 * time.Millisecond)
assert.True(t, hm1.IsHealthy())
assert.True(t, hm2.IsHealthy())
// Stop both
hm1.Stop()
hm2.Stop()
// Verify they stopped
assert.False(t, hm1.running.Load())
assert.False(t, hm2.running.Load())
}
// TestHealthMonitor_StatsAccuracy tests stats tracking
func TestHealthMonitor_StatsAccuracy(t *testing.T) {
mr := NewMiniredisServer(t)
config := &PoolConfig{
Address: mr.GetAddr(),
MaxConnections: 5,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
hm := NewHealthMonitor(pool, &HealthMonitorConfig{
CheckInterval: 100 * time.Millisecond,
Timeout: 1 * time.Second,
UnhealthyThreshold: 2,
})
hm.Start()
defer hm.Stop()
// Wait for some checks
time.Sleep(550 * time.Millisecond)
stats := hm.GetStats()
// Should have performed multiple checks
totalChecks := stats["total_checks"].(int64)
assert.GreaterOrEqual(t, totalChecks, int64(4))
// All checks should succeed
assert.Equal(t, int64(0), stats["total_failures"].(int64))
assert.Equal(t, int64(0), stats["consecutive_failures"].(int64))
// Last check time should be recent (within check interval + buffer)
// Use 2s tolerance to account for CI runner load and timing variance
lastCheck := stats["last_check"].(time.Time)
assert.WithinDuration(t, time.Now(), lastCheck, 2*time.Second)
}
// TestHealthMonitor_DefaultConfig tests default configuration
func TestHealthMonitor_DefaultConfig(t *testing.T) {
config := DefaultHealthMonitorConfig()
assert.Equal(t, 5*time.Second, config.CheckInterval)
assert.Equal(t, 3*time.Second, config.Timeout)
assert.Equal(t, 3, config.UnhealthyThreshold)
assert.Nil(t, config.OnHealthChange)
}
// TestHealthMonitor_PoolExhaustion tests behavior when pool is exhausted
func TestHealthMonitor_PoolExhaustion(t *testing.T) {
mr := NewMiniredisServer(t)
config := &PoolConfig{
Address: mr.GetAddr(),
MaxConnections: 1, // Very small pool
ConnectTimeout: 100 * time.Millisecond,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
hm := NewHealthMonitor(pool, &HealthMonitorConfig{
CheckInterval: 100 * time.Millisecond,
Timeout: 50 * time.Millisecond, // Short timeout
UnhealthyThreshold: 2,
})
hm.Start()
defer hm.Stop()
// Get the only connection, blocking health checks
ctx := context.Background()
conn, err := pool.Get(ctx)
require.NoError(t, err)
// Wait for health check attempts
time.Sleep(350 * time.Millisecond)
// Health monitor might mark as unhealthy due to timeouts
stats := hm.GetStats()
t.Logf("Stats with blocked pool: %+v", stats)
// Return connection
pool.Put(conn)
// Wait for recovery
time.Sleep(300 * time.Millisecond)
// Should recover
assert.True(t, hm.IsHealthy())
}
// TestConnectionPool_WithHealthChecks tests pool with health checks enabled
func TestConnectionPool_WithHealthChecks(t *testing.T) {
mr := NewMiniredisServer(t)
config := &PoolConfig{
Address: mr.GetAddr(),
MaxConnections: 5,
ConnectTimeout: 5 * time.Second,
EnableHealthCheck: true,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
ctx := context.Background()
// Get a connection
conn, err := pool.Get(ctx)
require.NoError(t, err)
require.NotNil(t, conn)
// Connection should be healthy
assert.True(t, pool.isConnectionHealthy(conn))
// Use connection
resp, err := conn.Do("PING")
require.NoError(t, err)
assert.Equal(t, "PONG", resp)
// Return to pool
pool.Put(conn)
// Get again - should reuse and validate
conn2, err := pool.Get(ctx)
require.NoError(t, err)
require.NotNil(t, conn2)
pool.Put(conn2)
}
// TestConnectionPool_StaleConnectionRemoval tests stale connection handling
func TestConnectionPool_StaleConnectionRemoval(t *testing.T) {
mr := NewMiniredisServer(t)
config := &PoolConfig{
Address: mr.GetAddr(),
MaxConnections: 3,
ConnectTimeout: 5 * time.Second,
EnableHealthCheck: true,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
ctx := context.Background()
// Get and return a connection
conn, err := pool.Get(ctx)
require.NoError(t, err)
pool.Put(conn)
initialTotal := pool.totalConns.Load()
// Close the connection manually to make it stale
conn.Close()
// Get another connection - should detect stale and create new
conn2, err := pool.Get(ctx)
require.NoError(t, err)
require.NotNil(t, conn2)
// Connection should be healthy
assert.True(t, pool.isConnectionHealthy(conn2))
pool.Put(conn2)
// Total connections might be same or less (stale removed)
finalTotal := pool.totalConns.Load()
assert.LessOrEqual(t, finalTotal, initialTotal+1)
}
+338
View File
@@ -0,0 +1,338 @@
package backends
import (
"context"
"errors"
"fmt"
"net"
"sync"
"sync/atomic"
"time"
)
// ConnectionPool manages a pool of Redis connections
// Pure-Go implementation compatible with Yaegi
type ConnectionPool struct {
config *PoolConfig
connections chan *RedisConn
mu sync.Mutex
closed atomic.Bool
// Metrics
activeConns atomic.Int32
totalConns atomic.Int32
gets atomic.Int64
puts atomic.Int64
timeouts atomic.Int64
}
// PoolConfig holds connection pool configuration
type PoolConfig struct {
Address string
Password string
DB int
MaxConnections int
ConnectTimeout time.Duration
ReadTimeout time.Duration
WriteTimeout time.Duration
EnableHealthCheck bool // Enable connection health validation
MaxRetries int // Max retries for failed operations
RetryDelay time.Duration // Initial delay between retries
}
// NewConnectionPool creates a new connection pool
func NewConnectionPool(config *PoolConfig) (*ConnectionPool, error) {
if config == nil {
return nil, errors.New("config is required")
}
if config.MaxConnections <= 0 {
config.MaxConnections = 10
}
if config.ConnectTimeout == 0 {
config.ConnectTimeout = 5 * time.Second
}
pool := &ConnectionPool{
config: config,
connections: make(chan *RedisConn, config.MaxConnections),
}
return pool, nil
}
// Get retrieves a connection from the pool or creates a new one
func (p *ConnectionPool) Get(ctx context.Context) (*RedisConn, error) {
if p.closed.Load() {
return nil, ErrBackendClosed
}
p.gets.Add(1)
// Try to get a connection with validation
maxAttempts := 3
for attempt := 0; attempt < maxAttempts; attempt++ {
var conn *RedisConn
var err error
select {
case conn = <-p.connections:
// Reuse existing connection - validate if health check enabled
if p.config.EnableHealthCheck && !p.isConnectionHealthy(conn) {
// Connection is stale, close it and try again
_ = conn.Close()
p.totalConns.Add(-1)
continue
}
p.activeConns.Add(1)
return conn, nil
case <-ctx.Done():
return nil, ctx.Err()
default:
// No available connection, create new one if under limit
// #nosec G115 -- MaxConnections is a small config value that fits in int32
if p.totalConns.Load() < int32(p.config.MaxConnections) {
conn, err = p.createConnection()
if err != nil {
// If this is the last attempt, return error
if attempt == maxAttempts-1 {
return nil, err
}
// Wait before retry with exponential backoff
time.Sleep(time.Duration(attempt+1) * 100 * time.Millisecond)
continue
}
p.activeConns.Add(1)
p.totalConns.Add(1)
return conn, nil
}
// Pool exhausted, wait for a connection with timeout
select {
case conn = <-p.connections:
// Validate connection if health check enabled
if p.config.EnableHealthCheck && !p.isConnectionHealthy(conn) {
_ = conn.Close()
p.totalConns.Add(-1)
continue
}
p.activeConns.Add(1)
return conn, nil
case <-ctx.Done():
p.timeouts.Add(1)
return nil, ctx.Err()
case <-time.After(p.config.ConnectTimeout):
p.timeouts.Add(1)
return nil, ErrPoolExhausted
}
}
}
return nil, errors.New("failed to get healthy connection after retries")
}
// Put returns a connection to the pool
func (p *ConnectionPool) Put(conn *RedisConn) {
if conn == nil {
return
}
p.puts.Add(1)
p.activeConns.Add(-1)
if p.closed.Load() || conn.closed.Load() {
_ = conn.Close()
p.totalConns.Add(-1)
return
}
// Return to pool (non-blocking)
select {
case p.connections <- conn:
// Successfully returned to pool
default:
// Pool full, close connection
_ = conn.Close()
p.totalConns.Add(-1)
}
}
// Close closes all connections in the pool
func (p *ConnectionPool) Close() error {
if p.closed.Swap(true) {
return nil
}
p.mu.Lock()
defer p.mu.Unlock()
close(p.connections)
// Close all pooled connections
for conn := range p.connections {
_ = conn.Close()
}
return nil
}
// Stats returns pool statistics
func (p *ConnectionPool) Stats() map[string]interface{} {
return map[string]interface{}{
"active_connections": p.activeConns.Load(),
"total_connections": p.totalConns.Load(),
"max_connections": p.config.MaxConnections,
"gets": p.gets.Load(),
"puts": p.puts.Load(),
"timeouts": p.timeouts.Load(),
}
}
// createConnection creates a new Redis connection
func (p *ConnectionPool) createConnection() (*RedisConn, error) {
// Connect with timeout
dialer := &net.Dialer{
Timeout: p.config.ConnectTimeout,
}
conn, err := dialer.Dial("tcp", p.config.Address)
if err != nil {
return nil, fmt.Errorf("failed to connect to Redis: %w", err)
}
redisConn := &RedisConn{
conn: conn,
readTimeout: p.config.ReadTimeout,
writeTimeout: p.config.WriteTimeout,
}
// Authenticate if password is provided
if p.config.Password != "" {
if _, err := redisConn.Do("AUTH", p.config.Password); err != nil {
_ = redisConn.Close()
return nil, fmt.Errorf("authentication failed: %w", err)
}
}
// Select database
if p.config.DB != 0 {
if _, err := redisConn.Do("SELECT", fmt.Sprintf("%d", p.config.DB)); err != nil {
_ = redisConn.Close()
return nil, fmt.Errorf("failed to select database: %w", err)
}
}
return redisConn, nil
}
// RedisConn represents a single Redis connection
type RedisConn struct {
conn net.Conn
readTimeout time.Duration
writeTimeout time.Duration
closed atomic.Bool
mu sync.Mutex
}
// Do executes a Redis command and returns the response
func (c *RedisConn) Do(command string, args ...string) (interface{}, error) {
if c.closed.Load() {
return nil, ErrBackendClosed
}
c.mu.Lock()
defer c.mu.Unlock()
// Validate argument count to prevent integer overflow in slice operations
// maxSafeArgs is set to (1<<20)-1 = 1,048,575 which is more than any reasonable Redis command
const maxSafeArgs = (1 << 20) - 1
if len(args) > maxSafeArgs {
return nil, errors.New("too many arguments: exceeds maximum safe count")
}
// Build command arguments
// Validate total argument size to prevent memory exhaustion
const maxTotalArgBytes = 64 << 20 // 64 MiB max total size
totalBytes := len(command)
for _, s := range args {
// Protect against possible overflow
if len(s) > maxTotalArgBytes-totalBytes {
return nil, errors.New("arguments too large (would overflow maximum allowed total size)")
}
totalBytes += len(s)
if totalBytes > maxTotalArgBytes {
return nil, errors.New("total argument size exceeds maximum allowed")
}
}
// Build command slice: prepend command to args
// Using append avoids arithmetic on potentially large len(args)
cmdArgs := append([]string{command}, args...)
// Set write timeout
if c.writeTimeout > 0 {
_ = c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
}
// Write command (using pooled writer for memory efficiency)
writer := NewRESPWriter(c.conn)
err := writer.WriteCommand(cmdArgs...)
writer.Release() // Return to pool immediately after use
if err != nil {
c.closed.Store(true)
return nil, err
}
// Set read timeout
if c.readTimeout > 0 {
_ = c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
}
// Read response (using pooled reader for memory efficiency)
reader := NewRESPReader(c.conn)
resp, err := reader.ReadResponse()
reader.Release() // Return to pool immediately after use
if err != nil {
if !errors.Is(err, ErrNilResponse) {
c.closed.Store(true)
}
return nil, err
}
return resp, nil
}
// Close closes the connection
func (c *RedisConn) Close() error {
if c.closed.Swap(true) {
return nil
}
c.mu.Lock()
defer c.mu.Unlock()
if c.conn != nil {
return c.conn.Close()
}
return nil
}
// isConnectionHealthy validates a connection is still working
func (p *ConnectionPool) isConnectionHealthy(conn *RedisConn) bool {
if conn == nil || conn.closed.Load() {
return false
}
// Set a read deadline for the ping
if conn.conn != nil {
_ = conn.conn.SetReadDeadline(time.Now().Add(1 * time.Second))
defer func() { _ = conn.conn.SetReadDeadline(time.Time{}) }() // Clear deadline
}
_, err := conn.Do("PING")
return err == nil
}
+620
View File
@@ -0,0 +1,620 @@
package backends
import (
"context"
"errors"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestConnectionPool_BasicOperations tests basic pool operations
func TestConnectionPool_BasicOperations(t *testing.T) {
mr := NewMiniredisServer(t)
config := &PoolConfig{
Address: mr.GetAddr(),
MaxConnections: 5,
ConnectTimeout: 5 * time.Second,
ReadTimeout: 3 * time.Second,
WriteTimeout: 3 * time.Second,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
t.Run("GetAndPutConnection", func(t *testing.T) {
ctx := context.Background()
// Get a connection
conn, err := pool.Get(ctx)
require.NoError(t, err)
require.NotNil(t, conn)
// Verify connection works
resp, err := conn.Do("PING")
require.NoError(t, err)
assert.Equal(t, "PONG", resp)
// Return to pool
pool.Put(conn)
// Get again - should reuse same connection
conn2, err := pool.Get(ctx)
require.NoError(t, err)
require.NotNil(t, conn2)
pool.Put(conn2)
})
t.Run("Stats", func(t *testing.T) {
stats := pool.Stats()
require.NotNil(t, stats)
assert.Contains(t, stats, "active_connections")
assert.Contains(t, stats, "total_connections")
assert.Contains(t, stats, "max_connections")
assert.Equal(t, 5, stats["max_connections"])
})
}
// TestConnectionPool_MaxConnections tests pool size limits
func TestConnectionPool_MaxConnections(t *testing.T) {
mr := NewMiniredisServer(t)
maxConns := 3
config := &PoolConfig{
Address: mr.GetAddr(),
MaxConnections: maxConns,
ConnectTimeout: 1 * time.Second,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
ctx := context.Background()
// Get max connections
conns := make([]*RedisConn, maxConns)
for i := 0; i < maxConns; i++ {
conn, err := pool.Get(ctx)
require.NoError(t, err)
conns[i] = conn
}
// Verify stats
stats := pool.Stats()
assert.Equal(t, int32(maxConns), stats["total_connections"])
assert.Equal(t, int32(maxConns), stats["active_connections"])
// Try to get one more - should block/timeout
ctx2, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
conn, err := pool.Get(ctx2)
require.Error(t, err)
require.Nil(t, conn)
// Return one connection
pool.Put(conns[0])
// Now we should be able to get a connection
conn, err = pool.Get(context.Background())
require.NoError(t, err)
require.NotNil(t, conn)
// Cleanup
pool.Put(conn)
for i := 1; i < maxConns; i++ {
pool.Put(conns[i])
}
}
// TestConnectionPool_ConcurrentAccess tests concurrent pool usage
func TestConnectionPool_ConcurrentAccess(t *testing.T) {
mr := NewMiniredisServer(t)
config := &PoolConfig{
Address: mr.GetAddr(),
MaxConnections: 10,
ConnectTimeout: 5 * time.Second,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
ctx := context.Background()
numGoroutines := 50
numOperations := 20
var wg sync.WaitGroup
errors := make(chan error, numGoroutines*numOperations)
// Spawn goroutines
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
for j := 0; j < numOperations; j++ {
conn, err := pool.Get(ctx)
if err != nil {
errors <- err
continue
}
// Do some work
_, err = conn.Do("PING")
if err != nil {
errors <- err
}
// Return to pool
pool.Put(conn)
// Small delay
time.Sleep(time.Millisecond)
}
}(i)
}
wg.Wait()
close(errors)
// Check for errors
errorCount := 0
for err := range errors {
t.Logf("Error: %v", err)
errorCount++
}
assert.Equal(t, 0, errorCount, "Expected no errors in concurrent access")
// Verify stats
stats := pool.Stats()
t.Logf("Final stats: %+v", stats)
assert.LessOrEqual(t, stats["total_connections"].(int32), int32(10))
assert.Equal(t, int32(0), stats["active_connections"])
}
// TestConnectionPool_ContextCancellation tests context cancellation
func TestConnectionPool_ContextCancellation(t *testing.T) {
mr := NewMiniredisServer(t)
config := &PoolConfig{
Address: mr.GetAddr(),
MaxConnections: 1,
ConnectTimeout: 5 * time.Second,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
// Get the only connection
conn, err := pool.Get(context.Background())
require.NoError(t, err)
// Try to get another with cancelled context
ctx, cancel := context.WithCancel(context.Background())
cancel() // Cancel immediately
conn2, err := pool.Get(ctx)
require.Error(t, err)
require.Nil(t, conn2)
assert.Contains(t, err.Error(), "context canceled")
// Cleanup
pool.Put(conn)
}
// TestConnectionPool_Authentication tests auth support
func TestConnectionPool_Authentication(t *testing.T) {
mr := NewMiniredisServer(t)
// Set password on miniredis
mr.server.RequireAuth("secret-password")
t.Run("CorrectPassword", func(t *testing.T) {
config := &PoolConfig{
Address: mr.GetAddr(),
Password: "secret-password",
MaxConnections: 2,
ConnectTimeout: 5 * time.Second,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
conn, err := pool.Get(context.Background())
require.NoError(t, err)
resp, err := conn.Do("PING")
require.NoError(t, err)
assert.Equal(t, "PONG", resp)
pool.Put(conn)
})
t.Run("WrongPassword", func(t *testing.T) {
t.Skip("Miniredis doesn't fully simulate AUTH errors like real Redis")
config := &PoolConfig{
Address: mr.GetAddr(),
Password: "wrong-password",
MaxConnections: 2,
ConnectTimeout: 5 * time.Second,
}
_, err := NewConnectionPool(config)
require.Error(t, err)
assert.Contains(t, err.Error(), "authentication failed")
})
}
// TestConnectionPool_DatabaseSelection tests DB selection
func TestConnectionPool_DatabaseSelection(t *testing.T) {
mr := NewMiniredisServer(t)
config := &PoolConfig{
Address: mr.GetAddr(),
DB: 5,
MaxConnections: 2,
ConnectTimeout: 5 * time.Second,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
conn, err := pool.Get(context.Background())
require.NoError(t, err)
// Connection should be on DB 5
resp, err := conn.Do("PING")
require.NoError(t, err)
assert.Equal(t, "PONG", resp)
pool.Put(conn)
}
// TestConnectionPool_ClosedConnection tests handling closed connections
func TestConnectionPool_ClosedConnection(t *testing.T) {
mr := NewMiniredisServer(t)
config := &PoolConfig{
Address: mr.GetAddr(),
MaxConnections: 2,
ConnectTimeout: 5 * time.Second,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
// Get connection
conn, err := pool.Get(context.Background())
require.NoError(t, err)
// Close it manually
conn.Close()
// Try to use it
_, err = conn.Do("PING")
require.Error(t, err)
assert.True(t, errors.Is(err, ErrBackendClosed))
// Return to pool (should be discarded)
pool.Put(conn)
// Get new connection - should create a new one
conn2, err := pool.Get(context.Background())
require.NoError(t, err)
require.NotNil(t, conn2)
resp, err := conn2.Do("PING")
require.NoError(t, err)
assert.Equal(t, "PONG", resp)
pool.Put(conn2)
}
// TestConnectionPool_Close tests pool closure
func TestConnectionPool_Close(t *testing.T) {
mr := NewMiniredisServer(t)
config := &PoolConfig{
Address: mr.GetAddr(),
MaxConnections: 5,
ConnectTimeout: 5 * time.Second,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
// Get some connections
conns := make([]*RedisConn, 3)
for i := 0; i < 3; i++ {
conn, err := pool.Get(context.Background())
require.NoError(t, err)
conns[i] = conn
}
// Return them
for _, conn := range conns {
pool.Put(conn)
}
// Close pool
err = pool.Close()
require.NoError(t, err)
// Try to get connection from closed pool
_, err = pool.Get(context.Background())
require.Error(t, err)
assert.True(t, errors.Is(err, ErrBackendClosed))
// Close again should be no-op
err = pool.Close()
require.NoError(t, err)
}
// TestConnectionPool_Timeouts tests various timeout scenarios
func TestConnectionPool_Timeouts(t *testing.T) {
mr := NewMiniredisServer(t)
config := &PoolConfig{
Address: mr.GetAddr(),
MaxConnections: 2,
ConnectTimeout: 100 * time.Millisecond,
ReadTimeout: 100 * time.Millisecond,
WriteTimeout: 100 * time.Millisecond,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
conn, err := pool.Get(context.Background())
require.NoError(t, err)
// Normal operation should work
resp, err := conn.Do("PING")
require.NoError(t, err)
assert.Equal(t, "PONG", resp)
pool.Put(conn)
}
// TestRedisConn_DoCommand tests the Do method
func TestRedisConn_DoCommand(t *testing.T) {
mr := NewMiniredisServer(t)
config := &PoolConfig{
Address: mr.GetAddr(),
MaxConnections: 2,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
conn, err := pool.Get(context.Background())
require.NoError(t, err)
defer pool.Put(conn)
t.Run("SET and GET", func(t *testing.T) {
// SET
resp, err := conn.Do("SET", "testkey", "testvalue")
require.NoError(t, err)
assert.Equal(t, "OK", resp)
// GET
resp, err = conn.Do("GET", "testkey")
require.NoError(t, err)
assert.Equal(t, "testvalue", resp)
})
t.Run("DEL", func(t *testing.T) {
// SET key first
_, err := conn.Do("SET", "delkey", "delvalue")
require.NoError(t, err)
// DEL
resp, err := conn.Do("DEL", "delkey")
require.NoError(t, err)
count, err := RESPInt(resp)
require.NoError(t, err)
assert.Equal(t, int64(1), count)
})
t.Run("EXISTS", func(t *testing.T) {
// SET key first
_, err := conn.Do("SET", "existskey", "value")
require.NoError(t, err)
// EXISTS - key exists
resp, err := conn.Do("EXISTS", "existskey")
require.NoError(t, err)
count, err := RESPInt(resp)
require.NoError(t, err)
assert.Equal(t, int64(1), count)
// EXISTS - key doesn't exist
resp, err = conn.Do("EXISTS", "nonexistent")
require.NoError(t, err)
count, err = RESPInt(resp)
require.NoError(t, err)
assert.Equal(t, int64(0), count)
})
t.Run("TTL commands", func(t *testing.T) {
// SETEX
resp, err := conn.Do("SETEX", "ttlkey", "60", "ttlvalue")
require.NoError(t, err)
assert.Equal(t, "OK", resp)
// TTL
resp, err = conn.Do("TTL", "ttlkey")
require.NoError(t, err)
ttl, err := RESPInt(resp)
require.NoError(t, err)
assert.Greater(t, ttl, int64(0))
assert.LessOrEqual(t, ttl, int64(60))
})
}
// TestPoolConfig_Defaults tests default configuration values
func TestPoolConfig_Defaults(t *testing.T) {
mr := NewMiniredisServer(t)
config := &PoolConfig{
Address: mr.GetAddr(),
// Leave other fields at zero values
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
// Should use defaults
assert.Equal(t, 10, pool.config.MaxConnections)
assert.Equal(t, 5*time.Second, pool.config.ConnectTimeout)
// Verify it works
conn, err := pool.Get(context.Background())
require.NoError(t, err)
pool.Put(conn)
}
// TestConnectionPool_NilConnection tests handling nil connections
func TestConnectionPool_NilConnection(t *testing.T) {
mr := NewMiniredisServer(t)
config := &PoolConfig{
Address: mr.GetAddr(),
MaxConnections: 2,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
// Putting nil should be safe
pool.Put(nil)
// Pool should still work
conn, err := pool.Get(context.Background())
require.NoError(t, err)
require.NotNil(t, conn)
pool.Put(conn)
}
// TestConnectionPool_StatsTracking tests metrics tracking
func TestConnectionPool_StatsTracking(t *testing.T) {
mr := NewMiniredisServer(t)
config := &PoolConfig{
Address: mr.GetAddr(),
MaxConnections: 5,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
ctx := context.Background()
// Initial stats
stats := pool.Stats()
initialGets := stats["gets"].(int64)
initialPuts := stats["puts"].(int64)
// Perform operations
numOps := 10
for i := 0; i < numOps; i++ {
conn, err := pool.Get(ctx)
require.NoError(t, err)
pool.Put(conn)
}
// Check updated stats
stats = pool.Stats()
assert.Equal(t, initialGets+int64(numOps), stats["gets"].(int64))
assert.Equal(t, initialPuts+int64(numOps), stats["puts"].(int64))
assert.Equal(t, int32(0), stats["active_connections"].(int32))
}
// TestRedisConn_TooManyArguments tests protection against allocation overflow
func TestRedisConn_TooManyArguments(t *testing.T) {
mr := NewMiniredisServer(t)
config := &PoolConfig{
Address: mr.GetAddr(),
MaxConnections: 1,
ConnectTimeout: 5 * time.Second,
ReadTimeout: 3 * time.Second,
WriteTimeout: 3 * time.Second,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
ctx := context.Background()
conn, err := pool.Get(ctx)
require.NoError(t, err)
defer pool.Put(conn)
t.Run("AcceptableArgumentCount", func(t *testing.T) {
// Should work with reasonable number of args
args := make([]string, 100)
for i := range args {
args[i] = "value"
}
_, err := conn.Do("MSET", args...)
// May fail due to Redis constraints, but shouldn't panic or error on overflow
// Just verify it doesn't trigger our overflow protection
if err != nil {
assert.NotContains(t, err.Error(), "too many arguments")
}
})
t.Run("RejectExcessiveArguments", func(t *testing.T) {
// Create an absurdly large number of arguments that would cause overflow
// Use 1M + 1 to exceed maxSafeArgs = (1<<20)-1 = 1048575
args := make([]string, 1<<20) // 1,048,576 args
for i := range args {
args[i] = "x"
}
_, err := conn.Do("MSET", args...)
require.Error(t, err)
assert.Contains(t, err.Error(), "too many arguments")
})
t.Run("BoundaryCase", func(t *testing.T) {
// Test exactly at the boundary (maxSafeArgs)
args := make([]string, (1<<20)-1) // Exactly 1,048,575 args (max allowed)
for i := range args {
args[i] = "x"
}
_, err := conn.Do("ECHO", args...)
// Should not error due to overflow protection
if err != nil {
assert.NotContains(t, err.Error(), "too many arguments")
}
})
}
+545
View File
@@ -0,0 +1,545 @@
package backends
import (
"context"
"fmt"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestRedisBackend_BasicOperations tests basic Redis operations
func TestRedisBackend_BasicOperations(t *testing.T) {
t.Parallel()
mr := NewMiniredisServer(t)
config := DefaultRedisConfig(mr.GetAddr())
backend, err := NewRedisBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
t.Run("SetAndGet", func(t *testing.T) {
key := "redis-test-key"
value := []byte("redis-test-value")
ttl := 1 * time.Minute
err := backend.Set(ctx, key, value, ttl)
require.NoError(t, err)
retrieved, remainingTTL, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, value, retrieved)
assert.Greater(t, remainingTTL, 50*time.Second)
})
t.Run("GetNonExistent", func(t *testing.T) {
_, _, exists, err := backend.Get(ctx, "non-existent-redis-key")
require.NoError(t, err)
assert.False(t, exists)
})
t.Run("Delete", func(t *testing.T) {
key := "redis-delete-key"
value := []byte("redis-delete-value")
err := backend.Set(ctx, key, value, 1*time.Minute)
require.NoError(t, err)
deleted, err := backend.Delete(ctx, key)
require.NoError(t, err)
assert.True(t, deleted)
exists, err := backend.Exists(ctx, key)
require.NoError(t, err)
assert.False(t, exists)
})
t.Run("Exists", func(t *testing.T) {
key := "redis-exists-key"
value := []byte("redis-exists-value")
exists, err := backend.Exists(ctx, key)
require.NoError(t, err)
assert.False(t, exists)
err = backend.Set(ctx, key, value, 1*time.Minute)
require.NoError(t, err)
exists, err = backend.Exists(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
})
}
// TestRedisBackend_KeyPrefixing tests key namespace prefixing
func TestRedisBackend_KeyPrefixing(t *testing.T) {
t.Parallel()
mr := NewMiniredisServer(t)
config := DefaultRedisConfig(mr.GetAddr())
config.RedisPrefix = "test:prefix:"
backend, err := NewRedisBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
key := "my-key"
value := []byte("my-value")
err = backend.Set(ctx, key, value, 1*time.Minute)
require.NoError(t, err)
// Check that key is stored with prefix
keys := mr.CheckKeys()
require.Len(t, keys, 1)
assert.Equal(t, "test:prefix:my-key", keys[0])
// Get should work without prefix
retrieved, _, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, value, retrieved)
}
// TestRedisBackend_TTLExpiration tests TTL handling
func TestRedisBackend_TTLExpiration(t *testing.T) {
t.Parallel()
mr := NewMiniredisServer(t)
config := DefaultRedisConfig(mr.GetAddr())
backend, err := NewRedisBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
t.Run("ShortTTL", func(t *testing.T) {
key := "ttl-key"
value := []byte("ttl-value")
shortTTL := 100 * time.Millisecond
err := backend.Set(ctx, key, value, shortTTL)
require.NoError(t, err)
// Exists immediately
exists, err := backend.Exists(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
// Fast forward time in miniredis
mr.FastForward(150 * time.Millisecond)
// Should be expired
exists, err = backend.Exists(ctx, key)
require.NoError(t, err)
assert.False(t, exists)
})
t.Run("TTLRemaining", func(t *testing.T) {
key := "ttl-remaining-key"
value := []byte("ttl-remaining-value")
ttl := 10 * time.Second
err := backend.Set(ctx, key, value, ttl)
require.NoError(t, err)
// Get immediately
_, ttl1, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
// Fast forward 2 seconds
mr.FastForward(2 * time.Second)
// Check TTL is less
_, ttl2, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Less(t, ttl2, ttl1)
})
}
// TestRedisBackend_Clear tests clearing all keys
func TestRedisBackend_Clear(t *testing.T) {
t.Parallel()
mr := NewMiniredisServer(t)
config := DefaultRedisConfig(mr.GetAddr())
config.RedisPrefix = "clear-test:"
backend, err := NewRedisBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
// Add multiple keys
for i := 0; i < 10; i++ {
key := fmt.Sprintf("clear-key-%d", i)
value := []byte(fmt.Sprintf("clear-value-%d", i))
err := backend.Set(ctx, key, value, 1*time.Minute)
require.NoError(t, err)
}
// Verify keys exist
keys := mr.CheckKeys()
assert.Len(t, keys, 10)
// Clear all
err = backend.Clear(ctx)
require.NoError(t, err)
// Verify all keys are gone
keys = mr.CheckKeys()
assert.Len(t, keys, 0)
}
// TestRedisBackend_ConnectionFailure tests behavior on connection failure
func TestRedisBackend_ConnectionFailure(t *testing.T) {
t.Parallel()
// Try to connect to non-existent Redis
config := DefaultRedisConfig("localhost:9999")
_, err := NewRedisBackend(config)
assert.Error(t, err, "Should fail to connect to non-existent Redis")
}
// TestRedisBackend_RedisErrors tests handling of Redis errors
func TestRedisBackend_RedisErrors(t *testing.T) {
t.Parallel()
mr := NewMiniredisServer(t)
config := DefaultRedisConfig(mr.GetAddr())
backend, err := NewRedisBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
// Simulate Redis error
mr.SetError("simulated error")
// Operations should fail
err = backend.Set(ctx, "error-key", []byte("error-value"), 1*time.Minute)
assert.Error(t, err)
// Clear error
mr.ClearError()
// Operations should work again
err = backend.Set(ctx, "success-key", []byte("success-value"), 1*time.Minute)
assert.NoError(t, err)
}
// TestRedisBackend_ConcurrentAccess tests thread safety
func TestRedisBackend_ConcurrentAccess(t *testing.T) {
t.Parallel()
mr := NewMiniredisServer(t)
config := DefaultRedisConfig(mr.GetAddr())
backend, err := NewRedisBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
var wg sync.WaitGroup
goroutines := 20
iterations := 50
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
for j := 0; j < iterations; j++ {
key := fmt.Sprintf("concurrent-key-%d-%d", id, j)
value := []byte(fmt.Sprintf("concurrent-value-%d-%d", id, j))
err := backend.Set(ctx, key, value, 1*time.Minute)
assert.NoError(t, err)
retrieved, _, exists, err := backend.Get(ctx, key)
assert.NoError(t, err)
if exists {
assert.Equal(t, value, retrieved)
}
if j%5 == 0 {
backend.Delete(ctx, key)
}
}
}(i)
}
wg.Wait()
stats := backend.GetStats()
hits := stats["hits"].(int64)
misses := stats["misses"].(int64)
assert.Greater(t, hits+misses, int64(0))
}
// TestRedisBackend_Stats tests statistics tracking
func TestRedisBackend_Stats(t *testing.T) {
t.Parallel()
mr := NewMiniredisServer(t)
config := DefaultRedisConfig(mr.GetAddr())
backend, err := NewRedisBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
// Initial stats
stats := backend.GetStats()
assert.Equal(t, int64(0), stats["hits"].(int64))
assert.Equal(t, int64(0), stats["misses"].(int64))
// Add and access items
backend.Set(ctx, "key1", []byte("value1"), 1*time.Minute)
backend.Get(ctx, "key1") // Hit
backend.Get(ctx, "non-existent") // Miss
stats = backend.GetStats()
assert.Equal(t, int64(1), stats["hits"].(int64))
assert.Equal(t, int64(1), stats["misses"].(int64))
hitRate := stats["hit_rate"].(float64)
assert.InDelta(t, 0.5, hitRate, 0.01)
}
// TestRedisBackend_Ping tests health check
func TestRedisBackend_Ping(t *testing.T) {
t.Parallel()
mr := NewMiniredisServer(t)
config := DefaultRedisConfig(mr.GetAddr())
backend, err := NewRedisBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
err = backend.Ping(ctx)
assert.NoError(t, err)
// Close and ping should fail
backend.Close()
err = backend.Ping(ctx)
assert.Error(t, err)
}
// TestRedisBackend_Close tests proper cleanup
func TestRedisBackend_Close(t *testing.T) {
t.Parallel()
mr := NewMiniredisServer(t)
config := DefaultRedisConfig(mr.GetAddr())
backend, err := NewRedisBackend(config)
require.NoError(t, err)
ctx := context.Background()
// Add items
for i := 0; i < 10; i++ {
key := fmt.Sprintf("close-key-%d", i)
value := []byte(fmt.Sprintf("close-value-%d", i))
backend.Set(ctx, key, value, 1*time.Minute)
}
// Close
err = backend.Close()
require.NoError(t, err)
// Operations should fail
err = backend.Set(ctx, "after-close", []byte("value"), 1*time.Minute)
assert.Error(t, err)
assert.Equal(t, ErrBackendClosed, err)
// Double close should be safe
err = backend.Close()
assert.NoError(t, err)
}
// TestRedisBackend_UpdateExisting tests updating existing keys
func TestRedisBackend_UpdateExisting(t *testing.T) {
t.Parallel()
mr := NewMiniredisServer(t)
config := DefaultRedisConfig(mr.GetAddr())
backend, err := NewRedisBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
key := "update-key"
value1 := []byte("original-value")
value2 := []byte("updated-value")
// Set original
err = backend.Set(ctx, key, value1, 1*time.Minute)
require.NoError(t, err)
// Update
err = backend.Set(ctx, key, value2, 2*time.Minute)
require.NoError(t, err)
// Verify updated
retrieved, ttl, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, value2, retrieved)
assert.Greater(t, ttl, 1*time.Minute)
}
// TestRedisBackend_LargeValues tests handling of large values
func TestRedisBackend_LargeValues(t *testing.T) {
t.Parallel()
mr := NewMiniredisServer(t)
config := DefaultRedisConfig(mr.GetAddr())
backend, err := NewRedisBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
key := "large-key"
largeValue := make([]byte, 1024*1024) // 1MB
err = backend.Set(ctx, key, largeValue, 1*time.Minute)
require.NoError(t, err)
retrieved, _, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, len(largeValue), len(retrieved))
}
// TestRedisBackend_EmptyValues tests handling of empty values
func TestRedisBackend_EmptyValues(t *testing.T) {
t.Parallel()
mr := NewMiniredisServer(t)
config := DefaultRedisConfig(mr.GetAddr())
backend, err := NewRedisBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
key := "empty-key"
emptyValue := []byte{}
err = backend.Set(ctx, key, emptyValue, 1*time.Minute)
require.NoError(t, err)
retrieved, _, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, 0, len(retrieved))
}
// TestRedisBackend_PipelineOperations tests batch operations
func TestRedisBackend_PipelineOperations(t *testing.T) {
t.Parallel()
mr := NewMiniredisServer(t)
config := DefaultRedisConfig(mr.GetAddr())
backend, err := NewRedisBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
t.Run("SetMany", func(t *testing.T) {
items := make(map[string][]byte)
for i := 0; i < 10; i++ {
key := fmt.Sprintf("batch-key-%d", i)
value := []byte(fmt.Sprintf("batch-value-%d", i))
items[key] = value
}
err := backend.SetMany(ctx, items, 1*time.Minute)
require.NoError(t, err)
// Verify all items were set
for key, expectedValue := range items {
retrieved, _, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, expectedValue, retrieved)
}
})
t.Run("GetMany", func(t *testing.T) {
// Set test data
testData := GenerateTestData(5)
for key, value := range testData {
backend.Set(ctx, key, value, 1*time.Minute)
}
// Get all keys
keys := make([]string, 0, len(testData))
for key := range testData {
keys = append(keys, key)
}
results, err := backend.GetMany(ctx, keys)
require.NoError(t, err)
assert.Len(t, results, len(testData))
for key, expectedValue := range testData {
retrievedValue, exists := results[key]
assert.True(t, exists)
assert.Equal(t, expectedValue, retrievedValue)
}
})
t.Run("GetManyWithNonExistent", func(t *testing.T) {
keys := []string{"exists-1", "non-existent", "exists-2"}
backend.Set(ctx, "exists-1", []byte("value-1"), 1*time.Minute)
backend.Set(ctx, "exists-2", []byte("value-2"), 1*time.Minute)
results, err := backend.GetMany(ctx, keys)
require.NoError(t, err)
assert.Len(t, results, 2) // Only existing keys
assert.Equal(t, []byte("value-1"), results["exists-1"])
assert.Equal(t, []byte("value-2"), results["exists-2"])
_, exists := results["non-existent"]
assert.False(t, exists)
})
}
// TestRedisBackend_NoPrefix tests operation without prefix
func TestRedisBackend_NoPrefix(t *testing.T) {
t.Parallel()
mr := NewMiniredisServer(t)
config := DefaultRedisConfig(mr.GetAddr())
config.RedisPrefix = "" // No prefix
backend, err := NewRedisBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
key := "no-prefix-key"
value := []byte("no-prefix-value")
err = backend.Set(ctx, key, value, 1*time.Minute)
require.NoError(t, err)
// Check key is stored without prefix
keys := mr.CheckKeys()
require.Len(t, keys, 1)
assert.Equal(t, key, keys[0])
}
+251
View File
@@ -0,0 +1,251 @@
package backends
import (
"bufio"
"errors"
"fmt"
"io"
"strconv"
"strings"
"sync"
)
// RESP (REdis Serialization Protocol) implementation
// Pure Go implementation compatible with Yaegi interpreter (no unsafe package)
var (
ErrInvalidRESP = errors.New("invalid RESP response")
ErrNilResponse = errors.New("nil response")
)
// Object pools for memory optimization - reduces allocations by 50-70%
var (
readerPool = sync.Pool{
New: func() interface{} {
return &RESPReader{
r: bufio.NewReaderSize(nil, 4096),
}
},
}
writerPool = sync.Pool{
New: func() interface{} {
return &RESPWriter{
w: nil,
}
},
}
)
// RESPWriter writes RESP protocol messages
type RESPWriter struct {
w io.Writer
}
// NewRESPWriter creates a new RESP writer from the pool (memory optimized)
func NewRESPWriter(w io.Writer) *RESPWriter {
writer := writerPool.Get().(*RESPWriter)
writer.w = w
return writer
}
// Release returns the writer to the pool for reuse
func (w *RESPWriter) Release() {
w.w = nil
writerPool.Put(w)
}
// WriteCommand writes a Redis command in RESP array format
// Example: SET key value EX 3600 -> *5\r\n$3\r\nSET\r\n$3\r\nkey\r\n$5\r\nvalue\r\n$2\r\nEX\r\n$4\r\n3600\r\n
func (w *RESPWriter) WriteCommand(args ...string) error {
// Write array header
if _, err := fmt.Fprintf(w.w, "*%d\r\n", len(args)); err != nil {
return err
}
// Write each argument as bulk string
for _, arg := range args {
if _, err := fmt.Fprintf(w.w, "$%d\r\n%s\r\n", len(arg), arg); err != nil {
return err
}
}
return nil
}
// RESPReader reads RESP protocol messages
type RESPReader struct {
r *bufio.Reader
}
// NewRESPReader creates a new RESP reader from the pool (memory optimized)
func NewRESPReader(r io.Reader) *RESPReader {
reader := readerPool.Get().(*RESPReader)
reader.r.Reset(r)
return reader
}
// Release returns the reader to the pool for reuse
func (r *RESPReader) Release() {
r.r.Reset(nil)
readerPool.Put(r)
}
// ReadResponse reads a RESP response and returns the parsed value
func (r *RESPReader) ReadResponse() (interface{}, error) {
typeByte, err := r.r.ReadByte()
if err != nil {
return nil, err
}
switch typeByte {
case '+': // Simple string
return r.readSimpleString()
case '-': // Error
return nil, r.readError()
case ':': // Integer
return r.readInteger()
case '$': // Bulk string
return r.readBulkString()
case '*': // Array
return r.readArray()
default:
return nil, fmt.Errorf("%w: unknown type byte '%c'", ErrInvalidRESP, typeByte)
}
}
// readSimpleString reads a simple string (+OK\r\n)
func (r *RESPReader) readSimpleString() (string, error) {
line, err := r.readLine()
if err != nil {
return "", err
}
return line, nil
}
// readError reads an error message (-Error message\r\n)
func (r *RESPReader) readError() error {
line, err := r.readLine()
if err != nil {
return err
}
return errors.New(line)
}
// readInteger reads an integer (:1000\r\n)
func (r *RESPReader) readInteger() (int64, error) {
line, err := r.readLine()
if err != nil {
return 0, err
}
return strconv.ParseInt(line, 10, 64)
}
// readBulkString reads a bulk string ($6\r\nfoobar\r\n or $-1\r\n for nil)
func (r *RESPReader) readBulkString() (interface{}, error) {
line, err := r.readLine()
if err != nil {
return nil, err
}
length, err := strconv.Atoi(line)
if err != nil {
return nil, fmt.Errorf("%w: invalid bulk string length", ErrInvalidRESP)
}
// -1 indicates nil bulk string
if length == -1 {
return nil, ErrNilResponse
}
// Read exactly 'length' bytes plus \r\n
buf := make([]byte, length+2)
if _, err := io.ReadFull(r.r, buf); err != nil {
return nil, err
}
// Verify \r\n terminator
if buf[length] != '\r' || buf[length+1] != '\n' {
return nil, fmt.Errorf("%w: missing CRLF after bulk string", ErrInvalidRESP)
}
return string(buf[:length]), nil
}
// readArray reads an array (*2\r\n...\r\n or *-1\r\n for nil)
func (r *RESPReader) readArray() (interface{}, error) {
line, err := r.readLine()
if err != nil {
return nil, err
}
length, err := strconv.Atoi(line)
if err != nil {
return nil, fmt.Errorf("%w: invalid array length", ErrInvalidRESP)
}
// -1 indicates nil array
if length == -1 {
return nil, ErrNilResponse
}
// Read each element
result := make([]interface{}, length)
for i := 0; i < length; i++ {
elem, err := r.ReadResponse()
if err != nil {
return nil, err
}
result[i] = elem
}
return result, nil
}
// readLine reads a line terminated by \r\n
func (r *RESPReader) readLine() (string, error) {
line, err := r.r.ReadString('\n')
if err != nil {
return "", err
}
// Remove \r\n
line = strings.TrimSuffix(line, "\r\n")
if !strings.HasSuffix(line+"\r\n", "\r\n") {
return "", fmt.Errorf("%w: missing CRLF", ErrInvalidRESP)
}
return line, nil
}
// RESPString extracts a string from RESP response
func RESPString(resp interface{}) (string, error) {
if resp == nil {
return "", ErrNilResponse
}
switch v := resp.(type) {
case string:
return v, nil
case []byte:
return string(v), nil
default:
return "", fmt.Errorf("expected string, got %T", resp)
}
}
// RESPInt extracts an integer from RESP response
func RESPInt(resp interface{}) (int64, error) {
if resp == nil {
return 0, ErrNilResponse
}
switch v := resp.(type) {
case int64:
return v, nil
case int:
return int64(v), nil
default:
return 0, fmt.Errorf("expected integer, got %T", resp)
}
}
+495
View File
@@ -0,0 +1,495 @@
package backends
import (
"bytes"
"errors"
"io"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestRESPWriter_WriteCommand tests RESP command writing
func TestRESPWriter_WriteCommand(t *testing.T) {
tests := []struct {
name string
expected string
args []string
}{
{
name: "Simple command",
args: []string{"PING"},
expected: "*1\r\n$4\r\nPING\r\n",
},
{
name: "SET command",
args: []string{"SET", "key", "value"},
expected: "*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$5\r\nvalue\r\n",
},
{
name: "SETEX command",
args: []string{"SETEX", "mykey", "60", "myvalue"},
expected: "*4\r\n$5\r\nSETEX\r\n$5\r\nmykey\r\n$2\r\n60\r\n$7\r\nmyvalue\r\n",
},
{
name: "DEL with multiple keys",
args: []string{"DEL", "key1", "key2", "key3"},
expected: "*4\r\n$3\r\nDEL\r\n$4\r\nkey1\r\n$4\r\nkey2\r\n$4\r\nkey3\r\n",
},
{
name: "Command with empty string",
args: []string{"SET", "key", ""},
expected: "*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$0\r\n\r\n",
},
{
name: "Command with special characters",
args: []string{"SET", "key", "val\r\nue"},
expected: "*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$7\r\nval\r\nue\r\n",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
buf := &bytes.Buffer{}
writer := NewRESPWriter(buf)
err := writer.WriteCommand(tt.args...)
require.NoError(t, err)
assert.Equal(t, tt.expected, buf.String())
})
}
}
// TestRESPReader_ReadSimpleString tests reading simple strings
func TestRESPReader_ReadSimpleString(t *testing.T) {
tests := []struct {
name string
input string
expected string
wantErr bool
}{
{
name: "OK response",
input: "+OK\r\n",
expected: "OK",
wantErr: false,
},
{
name: "PONG response",
input: "+PONG\r\n",
expected: "PONG",
wantErr: false,
},
{
name: "Empty string",
input: "+\r\n",
expected: "",
wantErr: false,
},
{
name: "String with spaces",
input: "+Hello World\r\n",
expected: "Hello World",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
reader := NewRESPReader(strings.NewReader(tt.input))
result, err := reader.ReadResponse()
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, tt.expected, result)
})
}
}
// TestRESPReader_ReadError tests reading error messages
func TestRESPReader_ReadError(t *testing.T) {
tests := []struct {
name string
input string
expectedError string
}{
{
name: "ERR error",
input: "-ERR unknown command\r\n",
expectedError: "ERR unknown command",
},
{
name: "WRONGTYPE error",
input: "-WRONGTYPE Operation against a key holding the wrong kind of value\r\n",
expectedError: "WRONGTYPE Operation against a key holding the wrong kind of value",
},
{
name: "Simple error",
input: "-Error\r\n",
expectedError: "Error",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
reader := NewRESPReader(strings.NewReader(tt.input))
_, err := reader.ReadResponse()
require.Error(t, err)
assert.Equal(t, tt.expectedError, err.Error())
})
}
}
// TestRESPReader_ReadInteger tests reading integers
func TestRESPReader_ReadInteger(t *testing.T) {
tests := []struct {
name string
input string
expected int64
wantErr bool
}{
{
name: "Zero",
input: ":0\r\n",
expected: 0,
wantErr: false,
},
{
name: "Positive integer",
input: ":1000\r\n",
expected: 1000,
wantErr: false,
},
{
name: "Negative integer",
input: ":-1\r\n",
expected: -1,
wantErr: false,
},
{
name: "Large integer",
input: ":9223372036854775807\r\n",
expected: 9223372036854775807,
wantErr: false,
},
{
name: "Invalid integer",
input: ":abc\r\n",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
reader := NewRESPReader(strings.NewReader(tt.input))
result, err := reader.ReadResponse()
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, tt.expected, result)
})
}
}
// TestRESPReader_ReadBulkString tests reading bulk strings
func TestRESPReader_ReadBulkString(t *testing.T) {
tests := []struct {
expected interface{}
name string
input string
wantErr bool
isNil bool
}{
{
name: "Simple bulk string",
input: "$6\r\nfoobar\r\n",
expected: "foobar",
wantErr: false,
},
{
name: "Empty bulk string",
input: "$0\r\n\r\n",
expected: "",
wantErr: false,
},
{
name: "Nil bulk string",
input: "$-1\r\n",
expected: nil,
wantErr: true,
isNil: true,
},
{
name: "Binary safe bulk string",
input: "$5\r\n\x00\x01\x02\x03\x04\r\n",
expected: "\x00\x01\x02\x03\x04",
wantErr: false,
},
{
name: "Invalid length",
input: "$abc\r\ntest\r\n",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
reader := NewRESPReader(strings.NewReader(tt.input))
result, err := reader.ReadResponse()
if tt.isNil {
require.Error(t, err)
assert.True(t, errors.Is(err, ErrNilResponse))
return
}
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, tt.expected, result)
})
}
}
// TestRESPReader_ReadArray tests reading arrays
func TestRESPReader_ReadArray(t *testing.T) {
tests := []struct {
name string
input string
expected []interface{}
wantErr bool
isNil bool
}{
{
name: "Empty array",
input: "*0\r\n",
expected: []interface{}{},
wantErr: false,
},
{
name: "Array of bulk strings",
input: "*2\r\n$3\r\nfoo\r\n$3\r\nbar\r\n",
expected: []interface{}{
"foo",
"bar",
},
wantErr: false,
},
{
name: "Array of integers",
input: "*3\r\n:1\r\n:2\r\n:3\r\n",
expected: []interface{}{
int64(1),
int64(2),
int64(3),
},
wantErr: false,
},
{
name: "Mixed array",
input: "*5\r\n:1\r\n:2\r\n:3\r\n:4\r\n$6\r\nfoobar\r\n",
expected: []interface{}{
int64(1),
int64(2),
int64(3),
int64(4),
"foobar",
},
wantErr: false,
},
{
name: "Nil array",
input: "*-1\r\n",
expected: nil,
wantErr: true,
isNil: true,
},
{
name: "Nested arrays",
input: "*2\r\n*2\r\n$3\r\nfoo\r\n$3\r\nbar\r\n*1\r\n$3\r\nbaz\r\n",
expected: []interface{}{
[]interface{}{"foo", "bar"},
[]interface{}{"baz"},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
reader := NewRESPReader(strings.NewReader(tt.input))
result, err := reader.ReadResponse()
if tt.isNil {
require.Error(t, err)
assert.True(t, errors.Is(err, ErrNilResponse))
return
}
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, tt.expected, result)
})
}
}
// TestRESPReader_InvalidInput tests error handling for invalid input
func TestRESPReader_InvalidInput(t *testing.T) {
tests := []struct {
name string
input string
}{
{
name: "Unknown type byte",
input: "?invalid\r\n",
},
{
name: "Incomplete response",
input: "+OK",
},
{
name: "Missing CRLF in bulk string",
input: "$5\r\nhello",
},
{
name: "Truncated array",
input: "*3\r\n:1\r\n:2\r\n",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
reader := NewRESPReader(strings.NewReader(tt.input))
_, err := reader.ReadResponse()
require.Error(t, err)
})
}
}
// TestRESPReader_EOF tests handling of EOF
func TestRESPReader_EOF(t *testing.T) {
reader := NewRESPReader(strings.NewReader(""))
_, err := reader.ReadResponse()
require.Error(t, err)
assert.True(t, errors.Is(err, io.EOF))
}
// TestRESPHelpers tests helper functions
func TestRESPHelpers(t *testing.T) {
t.Run("RESPString", func(t *testing.T) {
// Valid string
result, err := RESPString("hello")
require.NoError(t, err)
assert.Equal(t, "hello", result)
// Byte slice
result, err = RESPString([]byte("world"))
require.NoError(t, err)
assert.Equal(t, "world", result)
// Nil
_, err = RESPString(nil)
require.Error(t, err)
assert.True(t, errors.Is(err, ErrNilResponse))
// Invalid type
_, err = RESPString(123)
require.Error(t, err)
})
t.Run("RESPInt", func(t *testing.T) {
// Valid int64
result, err := RESPInt(int64(42))
require.NoError(t, err)
assert.Equal(t, int64(42), result)
// Valid int
result, err = RESPInt(42)
require.NoError(t, err)
assert.Equal(t, int64(42), result)
// Nil
_, err = RESPInt(nil)
require.Error(t, err)
assert.True(t, errors.Is(err, ErrNilResponse))
// Invalid type
_, err = RESPInt("string")
require.Error(t, err)
})
}
// TestRESPRoundTrip tests full round-trip encoding/decoding
func TestRESPRoundTrip(t *testing.T) {
tests := []struct {
expected interface{}
name string
response string
command []string
}{
{
name: "PING command",
command: []string{"PING"},
response: "+PONG\r\n",
expected: "PONG",
},
{
name: "GET command with result",
command: []string{"GET", "mykey"},
response: "$7\r\nmyvalue\r\n",
expected: "myvalue",
},
{
name: "GET command with nil",
command: []string{"GET", "nonexistent"},
response: "$-1\r\n",
expected: nil,
},
{
name: "DEL command",
command: []string{"DEL", "key1", "key2"},
response: ":2\r\n",
expected: int64(2),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Write command
writeBuf := &bytes.Buffer{}
writer := NewRESPWriter(writeBuf)
err := writer.WriteCommand(tt.command...)
require.NoError(t, err)
// Read response
reader := NewRESPReader(strings.NewReader(tt.response))
result, err := reader.ReadResponse()
if tt.expected == nil {
require.Error(t, err)
assert.True(t, errors.Is(err, ErrNilResponse))
} else {
require.NoError(t, err)
assert.Equal(t, tt.expected, result)
}
})
}
}
+198
View File
@@ -0,0 +1,198 @@
package backends
import (
"context"
"fmt"
"testing"
"time"
"github.com/alicebob/miniredis/v2"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
)
// TestLogger implements a simple logger for tests
type TestLogger struct {
t *testing.T
}
func NewTestLogger(t *testing.T) *TestLogger {
return &TestLogger{t: t}
}
func (l *TestLogger) Debug(format string, args ...interface{}) {
l.t.Logf("[DEBUG] "+format, args...)
}
func (l *TestLogger) Info(format string, args ...interface{}) {
l.t.Logf("[INFO] "+format, args...)
}
func (l *TestLogger) Error(format string, args ...interface{}) {
l.t.Logf("[ERROR] "+format, args...)
}
func (l *TestLogger) Debugf(format string, args ...interface{}) {
l.Debug(format, args...)
}
func (l *TestLogger) Infof(format string, args ...interface{}) {
l.Info(format, args...)
}
func (l *TestLogger) Errorf(format string, args ...interface{}) {
l.Error(format, args...)
}
func (l *TestLogger) Warnf(format string, args ...interface{}) {
l.t.Logf("[WARN] "+format, args...)
}
// MiniredisServer manages a miniredis instance for testing
type MiniredisServer struct {
server *miniredis.Miniredis
client *redis.Client
}
// NewMiniredisServer creates a new miniredis server for testing
func NewMiniredisServer(t *testing.T) *MiniredisServer {
t.Helper()
mr, err := miniredis.Run()
require.NoError(t, err, "failed to start miniredis")
client := redis.NewClient(&redis.Options{
Addr: mr.Addr(),
})
// Verify connection
ctx := context.Background()
err = client.Ping(ctx).Err()
require.NoError(t, err, "failed to ping miniredis")
t.Cleanup(func() {
client.Close()
mr.Close()
})
return &MiniredisServer{
server: mr,
client: client,
}
}
// GetAddr returns the address of the miniredis server
func (m *MiniredisServer) GetAddr() string {
return m.server.Addr()
}
// GetClient returns the Redis client
func (m *MiniredisServer) GetClient() *redis.Client {
return m.client
}
// FastForward advances the miniredis server's time
func (m *MiniredisServer) FastForward(d time.Duration) {
m.server.FastForward(d)
}
// FlushAll removes all keys from the database
func (m *MiniredisServer) FlushAll() {
m.server.FlushAll()
}
// SetError simulates a Redis error
func (m *MiniredisServer) SetError(err string) {
m.server.SetError(err)
}
// ClearError clears any simulated errors
func (m *MiniredisServer) ClearError() {
m.server.SetError("")
}
// CheckKeys verifies that specific keys exist in Redis
func (m *MiniredisServer) CheckKeys() []string {
return m.server.Keys()
}
// Close closes the miniredis server
func (m *MiniredisServer) Close() {
m.server.Close()
}
// Restart restarts the miniredis server
func (m *MiniredisServer) Restart() {
m.server.Restart()
}
// TestConfig provides default test configuration
type TestConfig struct {
MaxSize int
DefaultTTL time.Duration
CleanupInterval time.Duration
EnableMetrics bool
}
// DefaultTestConfig returns a standard test configuration
func DefaultTestConfig() *TestConfig {
return &TestConfig{
MaxSize: 100,
DefaultTTL: 5 * time.Minute,
CleanupInterval: 1 * time.Second,
EnableMetrics: true,
}
}
// GenerateTestData creates test cache data
func GenerateTestData(count int) map[string][]byte {
data := make(map[string][]byte, count)
for i := 0; i < count; i++ {
key := fmt.Sprintf("test-key-%d", i)
value := []byte(fmt.Sprintf("test-value-%d", i))
data[key] = value
}
return data
}
// GenerateLargeValue creates a large test value
func GenerateLargeValue(sizeBytes int) []byte {
return make([]byte, sizeBytes)
}
// AssertCacheStats is a helper to verify cache statistics
func AssertCacheStats(t *testing.T, stats map[string]interface{}, expectedHits, expectedMisses int64) {
t.Helper()
hits, ok := stats["hits"].(int64)
require.True(t, ok, "hits should be int64")
require.Equal(t, expectedHits, hits, "unexpected hit count")
misses, ok := stats["misses"].(int64)
require.True(t, ok, "misses should be int64")
require.Equal(t, expectedMisses, misses, "unexpected miss count")
}
// WaitForCondition waits for a condition to be true or times out
func WaitForCondition(t *testing.T, timeout time.Duration, checkInterval time.Duration, condition func() bool) {
t.Helper()
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
if condition() {
return
}
time.Sleep(checkInterval)
}
t.Fatal("timeout waiting for condition")
}
// AssertEventuallyExpires verifies that a key eventually expires
func AssertEventuallyExpires(t *testing.T, backend CacheBackend, ctx context.Context, key string, maxWait time.Duration) {
t.Helper()
WaitForCondition(t, maxWait, 100*time.Millisecond, func() bool {
_, _, exists, err := backend.Get(ctx, key)
return err == nil && !exists
})
}
+29 -41
View File
@@ -33,21 +33,19 @@ type Logger interface {
// Config provides configuration for the cache
type Config struct {
Logger Logger
JWKConfig *JWKConfig
MetadataConfig *MetadataConfig
TokenConfig *TokenConfig
Type Type
MaxSize int
MaxMemoryBytes int64
DefaultTTL time.Duration
CleanupInterval time.Duration
EnableCompression bool
MaxMemoryBytes int64
MaxSize int
EnableMetrics bool
EnableAutoCleanup bool
EnableMemoryLimit bool
Logger Logger
// Type-specific configurations
TokenConfig *TokenConfig
MetadataConfig *MetadataConfig
JWKConfig *JWKConfig
EnableCompression bool
}
// TokenConfig provides token-specific cache configuration
@@ -59,11 +57,11 @@ type TokenConfig struct {
// MetadataConfig provides metadata-specific cache configuration
type MetadataConfig struct {
SecurityCriticalFields []string
GracePeriod time.Duration
ExtendedGracePeriod time.Duration
MaxGracePeriod time.Duration
SecurityCriticalMaxGracePeriod time.Duration
SecurityCriticalFields []string
}
// JWKConfig provides JWK-specific cache configuration
@@ -75,45 +73,35 @@ type JWKConfig struct {
// Item represents a single cache entry
type Item struct {
Key string
Value interface{}
Size int64
ExpiresAt time.Time
LastAccessed time.Time
AccessCount int64
Value interface{}
Metadata map[string]interface{}
element *list.Element
Key string
CacheType Type
// Type-specific metadata
Metadata map[string]interface{}
// LRU list element reference
element *list.Element
Size int64
AccessCount int64
}
// Cache provides a single, unified cache implementation
type Cache struct {
mu sync.RWMutex
items map[string]*Item
lruList *list.List
config Config
logger Logger
// Memory management
config Config
ctx context.Context
logger Logger
cancel context.CancelFunc
lruList *list.List
items map[string]*Item
stopCleanup chan bool
wg sync.WaitGroup
currentSize int64
currentMemory int64
// Metrics
hits int64
misses int64
evictions int64
sets int64
// Lifecycle management
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
stopCleanup chan bool
closed int32
hits int64
misses int64
evictions int64
sets int64
mu sync.RWMutex
closed int32
}
// DefaultConfig returns a default cache configuration
@@ -355,7 +343,7 @@ func (c *Cache) removeItem(key string, item *Item) {
func (c *Cache) evictLRU() {
if elem := c.lruList.Back(); elem != nil {
item := elem.Value.(*Item)
item, _ := elem.Value.(*Item) // Safe to ignore: type assertion from known type
c.removeItem(item.Key, item)
atomic.AddInt64(&c.evictions, 1)
c.logger.Debugf("Cache: Evicted LRU item key=%s", item.Key)
+107 -21
View File
@@ -1750,19 +1750,19 @@ func TestAdvancedEdgeCases(t *testing.T) {
// Test with various data types
testCases := []struct {
key string
value interface{}
key string
}{
{"string", "test string"},
{"int", 42},
{"float", 3.14159},
{"bool", true},
{"slice", []string{"a", "b", "c"}},
{"map", map[string]int{"one": 1, "two": 2}},
{"nil", nil},
{"empty-string", ""},
{"empty-slice", []string{}},
{"empty-map", map[string]interface{}{}},
{key: "string", value: "test string"},
{key: "int", value: 42},
{key: "float", value: 3.14159},
{key: "bool", value: true},
{key: "slice", value: []string{"a", "b", "c"}},
{key: "map", value: map[string]int{"one": 1, "two": 2}},
{key: "nil", value: nil},
{key: "empty-string", value: ""},
{key: "empty-slice", value: []string{}},
{key: "empty-map", value: map[string]interface{}{}},
}
for _, tc := range testCases {
@@ -1880,19 +1880,20 @@ func TestConcurrentManagerOperations(t *testing.T) {
// TestTTLExpirationAndCleanup tests TTL expiration and cleanup routines comprehensively
func TestTTLExpirationAndCleanup(t *testing.T) {
config := DefaultConfig()
config.CleanupInterval = 10 * time.Millisecond
config.CleanupInterval = 50 * time.Millisecond
config.EnableAutoCleanup = true
cache := New(config)
defer cache.Close()
// Test various TTL scenarios
// Note: Timing increased 5x to account for race detector overhead
testCases := []struct {
key string
ttl time.Duration
}{
{"very-short", 5 * time.Millisecond},
{"short", 25 * time.Millisecond},
{"medium", 100 * time.Millisecond},
{"very-short", 25 * time.Millisecond},
{"short", 125 * time.Millisecond},
{"medium", 500 * time.Millisecond},
{"long", 1 * time.Hour},
}
@@ -1908,13 +1909,13 @@ func TestTTLExpirationAndCleanup(t *testing.T) {
}
// Wait for very short items to expire
time.Sleep(15 * time.Millisecond)
time.Sleep(75 * time.Millisecond)
if _, exists := cache.Get("very-short"); exists {
t.Error("Very short item should be expired")
}
// Wait for short items to expire
time.Sleep(30 * time.Millisecond)
time.Sleep(150 * time.Millisecond)
if _, exists := cache.Get("short"); exists {
t.Error("Short item should be expired")
}
@@ -1930,16 +1931,16 @@ func TestTTLExpirationAndCleanup(t *testing.T) {
}
// Test manual cleanup
cache.Set("manual-cleanup", "value", 1*time.Millisecond)
time.Sleep(5 * time.Millisecond)
cache.Set("manual-cleanup", "value", 5*time.Millisecond)
time.Sleep(25 * time.Millisecond)
cache.Cleanup()
// Add many expired items to test bulk cleanup
for i := 0; i < 100; i++ {
key := fmt.Sprintf("bulk-%d", i)
cache.Set(key, fmt.Sprintf("value-%d", i), 1*time.Millisecond)
cache.Set(key, fmt.Sprintf("value-%d", i), 5*time.Millisecond)
}
time.Sleep(5 * time.Millisecond)
time.Sleep(25 * time.Millisecond)
sizeBefore := cache.Size()
cache.Cleanup()
@@ -2038,3 +2039,88 @@ func TestCacheStatisticsAndMetrics(t *testing.T) {
t.Error("Memory usage should increase after adding large item")
}
}
// ============================================================================
// noOpLogger Tests
// ============================================================================
// TestNoOpLogger_AllMethods tests all noOpLogger methods to ensure they don't panic
func TestNoOpLogger_AllMethods(t *testing.T) {
logger := &noOpLogger{}
// Test simple message methods
logger.Debug("test debug message")
logger.Info("test info message")
logger.Error("test error message")
logger.Warn("test warn message")
logger.Fatal("test fatal message")
// Test formatted message methods
logger.Debugf("test debug: %s", "value")
logger.Infof("test info: %s", "value")
logger.Errorf("test error: %s", "value")
logger.Warnf("test warn: %s", "value")
logger.Fatalf("test fatal: %s", "value")
// If we reach here, all methods executed without panicking
// This is expected behavior for a no-op logger
}
// TestNoOpLogger_WithField verifies WithField returns the same logger
func TestNoOpLogger_WithField(t *testing.T) {
logger := &noOpLogger{}
result := logger.WithField("key", "value")
if result != logger {
t.Error("WithField should return the same logger instance")
}
// Verify the returned logger works
result.Info("test message after WithField")
}
// TestNoOpLogger_WithFields verifies WithFields returns the same logger
func TestNoOpLogger_WithFields(t *testing.T) {
logger := &noOpLogger{}
fields := map[string]interface{}{
"key1": "value1",
"key2": 123,
"key3": true,
}
result := logger.WithFields(fields)
if result != logger {
t.Error("WithFields should return the same logger instance")
}
// Verify the returned logger works
result.Info("test message after WithFields")
}
// TestNoOpLogger_Chaining verifies method chaining works
func TestNoOpLogger_Chaining(t *testing.T) {
logger := &noOpLogger{}
// Use WithField and verify it returns a usable logger
result := logger.WithField("key1", "value1")
// Verify the result can be used for logging (Logger interface methods)
result.Info("info after WithField")
result.Infof("infof after WithField: %s", "test")
result.Debug("debug after WithField")
result.Debugf("debugf after WithField: %d", 123)
result.Error("error after WithField")
result.Errorf("errorf after WithField: %v", true)
// Use WithFields and verify it returns a usable logger
result2 := logger.WithFields(map[string]interface{}{
"key2": "value2",
"key3": 123,
})
// Verify the result can be used for logging
result2.Infof("message after WithFields: %s", "test")
}
+2
View File
@@ -1,3 +1,5 @@
// Package cache provides high-performance caching implementations for OIDC tokens, metadata, and JWKs.
// It includes compatibility wrappers for backward compatibility with existing cache interfaces.
package cache
import (
+2 -7
View File
@@ -7,22 +7,17 @@ import (
// Manager manages multiple cache instances with singleton pattern
type Manager struct {
mu sync.RWMutex
// Core caches
logger Logger
tokenCache *Cache
metadataCache *Cache
jwkCache *Cache
sessionCache *Cache
generalCache *Cache
// Typed wrappers
typedToken *TokenCache
typedMetadata *MetadataCache
typedJWK *JWKCache
typedSession *SessionCache
logger Logger
mu sync.RWMutex
}
var (
+313
View File
@@ -0,0 +1,313 @@
// Package resilience provides resilience patterns for cache backends.
package resilience
import (
"context"
"errors"
"sync"
"sync/atomic"
"time"
)
// Common errors
var (
// ErrCircuitOpen is returned when the circuit breaker is open
ErrCircuitOpen = errors.New("circuit breaker is open")
// ErrTooManyRequests is returned when too many requests are made in half-open state
ErrTooManyRequests = errors.New("too many requests in half-open state")
)
// State represents the state of the circuit breaker
type State int32
const (
// StateClosed allows all operations to pass through
StateClosed State = iota
// StateOpen blocks all operations
StateOpen
// StateHalfOpen allows a limited number of operations to test recovery
StateHalfOpen
)
// String returns the string representation of the state
func (s State) String() string {
switch s {
case StateClosed:
return "closed"
case StateOpen:
return "open"
case StateHalfOpen:
return "half-open"
default:
return "unknown"
}
}
// CircuitBreakerConfig holds configuration for the circuit breaker
type CircuitBreakerConfig struct {
OnStateChange func(from, to State)
MaxFailures int
FailureThreshold float64
Timeout time.Duration
HalfOpenMaxRequests int
ResetTimeout time.Duration
}
// DefaultCircuitBreakerConfig returns default configuration
func DefaultCircuitBreakerConfig() *CircuitBreakerConfig {
return &CircuitBreakerConfig{
MaxFailures: 5,
FailureThreshold: 0.6,
Timeout: 30 * time.Second,
HalfOpenMaxRequests: 3,
ResetTimeout: 60 * time.Second,
}
}
// CircuitBreaker implements the circuit breaker pattern
type CircuitBreaker struct {
nextRetryTime time.Time
lastStateChange time.Time
lastSuccessTime time.Time
lastFailureTime time.Time
config *CircuitBreakerConfig
totalFailures atomic.Int64
totalRequests atomic.Int64
stateTransitions atomic.Int64
rejectedRequests atomic.Int64
stateMu sync.RWMutex
timeMu sync.RWMutex
halfOpenRequests atomic.Int32
consecutiveFailures atomic.Int32
state atomic.Int32
}
// NewCircuitBreaker creates a new circuit breaker
func NewCircuitBreaker(config *CircuitBreakerConfig) *CircuitBreaker {
if config == nil {
config = DefaultCircuitBreakerConfig()
}
return &CircuitBreaker{
config: config,
lastStateChange: time.Now(),
}
}
// Execute runs a function through the circuit breaker
func (cb *CircuitBreaker) Execute(ctx context.Context, fn func() error) error {
if !cb.AllowRequest() {
cb.rejectedRequests.Add(1)
return ErrCircuitOpen
}
cb.totalRequests.Add(1)
err := fn()
if err != nil {
cb.RecordFailure()
} else {
cb.RecordSuccess()
}
return err
}
// AllowRequest checks if a request is allowed to proceed
func (cb *CircuitBreaker) AllowRequest() bool {
state := cb.GetState()
switch state {
case StateClosed:
return true
case StateOpen:
// Check if timeout has passed and we should try half-open
cb.timeMu.RLock()
shouldRetry := time.Now().After(cb.nextRetryTime)
cb.timeMu.RUnlock()
if shouldRetry {
cb.setState(StateHalfOpen)
return true
}
return false
case StateHalfOpen:
// Allow limited requests in half-open state
current := cb.halfOpenRequests.Add(1)
// #nosec G115 -- HalfOpenMaxRequests is a small config value that fits in int32
return current <= int32(cb.config.HalfOpenMaxRequests)
default:
return false
}
}
// RecordSuccess records a successful operation
func (cb *CircuitBreaker) RecordSuccess() {
cb.timeMu.Lock()
cb.lastSuccessTime = time.Now()
cb.timeMu.Unlock()
state := cb.GetState()
switch state {
case StateClosed:
// Reset consecutive failures
cb.consecutiveFailures.Store(0)
case StateHalfOpen:
// If we've had enough successful requests, close the circuit
successfulRequests := cb.halfOpenRequests.Load()
// #nosec G115 -- HalfOpenMaxRequests is a small config value that fits in int32
if successfulRequests >= int32(cb.config.HalfOpenMaxRequests) {
cb.setState(StateClosed)
cb.consecutiveFailures.Store(0)
cb.halfOpenRequests.Store(0)
}
}
}
// RecordFailure records a failed operation
func (cb *CircuitBreaker) RecordFailure() {
cb.totalFailures.Add(1)
failures := cb.consecutiveFailures.Add(1)
cb.timeMu.Lock()
cb.lastFailureTime = time.Now()
cb.timeMu.Unlock()
state := cb.GetState()
switch state {
case StateClosed:
// Check if we should open the circuit
// #nosec G115 -- MaxFailures is a small config value that fits in int32
if failures >= int32(cb.config.MaxFailures) {
cb.openCircuit()
} else if cb.config.FailureThreshold > 0 {
// Check failure rate
total := cb.totalRequests.Load()
failureCount := cb.totalFailures.Load()
if total > 10 && float64(failureCount)/float64(total) > cb.config.FailureThreshold {
cb.openCircuit()
}
}
case StateHalfOpen:
// Any failure in half-open state reopens the circuit
cb.openCircuit()
}
}
// openCircuit transitions to open state
func (cb *CircuitBreaker) openCircuit() {
cb.setState(StateOpen)
cb.halfOpenRequests.Store(0)
cb.timeMu.Lock()
cb.nextRetryTime = time.Now().Add(cb.config.Timeout)
cb.timeMu.Unlock()
}
// GetState returns the current state
func (cb *CircuitBreaker) GetState() State {
return State(cb.state.Load())
}
// setState changes the circuit breaker state
func (cb *CircuitBreaker) setState(newState State) {
oldState := State(cb.state.Swap(int32(newState)))
if oldState != newState {
cb.stateTransitions.Add(1)
cb.stateMu.Lock()
cb.lastStateChange = time.Now()
cb.stateMu.Unlock()
if cb.config.OnStateChange != nil {
cb.config.OnStateChange(oldState, newState)
}
}
}
// Reset resets the circuit breaker to closed state
func (cb *CircuitBreaker) Reset() {
cb.setState(StateClosed)
cb.consecutiveFailures.Store(0)
cb.totalRequests.Store(0)
cb.totalFailures.Store(0)
cb.halfOpenRequests.Store(0)
cb.rejectedRequests.Store(0)
cb.stateTransitions.Store(0)
now := time.Now()
cb.timeMu.Lock()
cb.lastFailureTime = now
cb.lastSuccessTime = now
cb.nextRetryTime = now
cb.timeMu.Unlock()
cb.stateMu.Lock()
cb.lastStateChange = now
cb.stateMu.Unlock()
}
// Stats returns circuit breaker statistics
func (cb *CircuitBreaker) Stats() CircuitBreakerStats {
cb.timeMu.RLock()
lastFailure := cb.lastFailureTime
lastSuccess := cb.lastSuccessTime
nextRetry := cb.nextRetryTime
cb.timeMu.RUnlock()
cb.stateMu.RLock()
lastChange := cb.lastStateChange
cb.stateMu.RUnlock()
totalReq := cb.totalRequests.Load()
totalFail := cb.totalFailures.Load()
successRate := float64(0)
if totalReq > 0 {
successRate = float64(totalReq-totalFail) / float64(totalReq)
}
return CircuitBreakerStats{
State: cb.GetState(),
ConsecutiveFailures: cb.consecutiveFailures.Load(),
TotalRequests: totalReq,
TotalFailures: totalFail,
SuccessRate: successRate,
RejectedRequests: cb.rejectedRequests.Load(),
StateTransitions: cb.stateTransitions.Load(),
LastFailureTime: lastFailure,
LastSuccessTime: lastSuccess,
LastStateChange: lastChange,
NextRetryTime: nextRetry,
}
}
// CircuitBreakerStats holds statistics for the circuit breaker
type CircuitBreakerStats struct {
LastFailureTime time.Time
LastSuccessTime time.Time
LastStateChange time.Time
NextRetryTime time.Time
TotalRequests int64
TotalFailures int64
SuccessRate float64
RejectedRequests int64
StateTransitions int64
State State
ConsecutiveFailures int32
}
// IsHealthy returns true if the circuit breaker is in a healthy state
func (cb *CircuitBreaker) IsHealthy() bool {
return cb.GetState() != StateOpen
}
+141
View File
@@ -0,0 +1,141 @@
// Package resilience provides resilience patterns for cache backends.
package resilience
import (
"context"
"time"
"github.com/lukaszraczylo/traefikoidc/internal/cache/backends"
)
// CircuitBreakerBackend wraps a cache backend with circuit breaker protection
type CircuitBreakerBackend struct {
backend backends.CacheBackend
cb *CircuitBreaker
}
// NewCircuitBreakerBackend creates a new circuit breaker wrapped backend
func NewCircuitBreakerBackend(b backends.CacheBackend, config *CircuitBreakerConfig) backends.CacheBackend {
if config == nil {
config = DefaultCircuitBreakerConfig()
}
return &CircuitBreakerBackend{
backend: b,
cb: NewCircuitBreaker(config),
}
}
// Set stores a value with circuit breaker protection
func (c *CircuitBreakerBackend) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
if !c.cb.AllowRequest() {
return backends.ErrCircuitOpen
}
err := c.backend.Set(ctx, key, value, ttl)
if err == nil {
c.cb.RecordSuccess()
} else {
c.cb.RecordFailure()
}
return err
}
// Get retrieves a value with circuit breaker protection
func (c *CircuitBreakerBackend) Get(ctx context.Context, key string) ([]byte, time.Duration, bool, error) {
if !c.cb.AllowRequest() {
return nil, 0, false, backends.ErrCircuitOpen
}
value, ttl, exists, err := c.backend.Get(ctx, key)
if err == nil {
c.cb.RecordSuccess()
} else {
c.cb.RecordFailure()
}
return value, ttl, exists, err
}
// Delete removes a key with circuit breaker protection
func (c *CircuitBreakerBackend) Delete(ctx context.Context, key string) (bool, error) {
if !c.cb.AllowRequest() {
return false, backends.ErrCircuitOpen
}
deleted, err := c.backend.Delete(ctx, key)
if err == nil {
c.cb.RecordSuccess()
} else {
c.cb.RecordFailure()
}
return deleted, err
}
// Exists checks if a key exists with circuit breaker protection
func (c *CircuitBreakerBackend) Exists(ctx context.Context, key string) (bool, error) {
if !c.cb.AllowRequest() {
return false, backends.ErrCircuitOpen
}
exists, err := c.backend.Exists(ctx, key)
if err == nil {
c.cb.RecordSuccess()
} else {
c.cb.RecordFailure()
}
return exists, err
}
// Clear removes all keys with circuit breaker protection
func (c *CircuitBreakerBackend) Clear(ctx context.Context) error {
if !c.cb.AllowRequest() {
return backends.ErrCircuitOpen
}
err := c.backend.Clear(ctx)
if err == nil {
c.cb.RecordSuccess()
} else {
c.cb.RecordFailure()
}
return err
}
// GetStats returns statistics including circuit breaker state
func (c *CircuitBreakerBackend) GetStats() map[string]interface{} {
stats := c.backend.GetStats()
if stats == nil {
stats = make(map[string]interface{})
}
cbStats := c.cb.Stats()
stats["circuit_breaker"] = map[string]interface{}{
"state": cbStats.State.String(),
"consecutive_failures": cbStats.ConsecutiveFailures,
"total_requests": cbStats.TotalRequests,
"total_failures": cbStats.TotalFailures,
"success_rate": cbStats.SuccessRate,
}
return stats
}
// Ping checks backend health with circuit breaker protection
func (c *CircuitBreakerBackend) Ping(ctx context.Context) error {
if !c.cb.AllowRequest() {
return backends.ErrCircuitOpen
}
err := c.backend.Ping(ctx)
if err == nil {
c.cb.RecordSuccess()
} else {
c.cb.RecordFailure()
}
return err
}
// Close shuts down the backend
func (c *CircuitBreakerBackend) Close() error {
return c.backend.Close()
}
@@ -0,0 +1,561 @@
//go:build !yaegi
package resilience
import (
"context"
"errors"
"sync"
"testing"
"time"
"github.com/lukaszraczylo/traefikoidc/internal/cache/backends"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// mockBackend is a simple mock implementation for testing
type mockBackend struct {
data map[string]mockEntry
mu sync.RWMutex
failSet bool
failGet bool
failDelete bool
failExists bool
failClear bool
failPing bool
callCount int
}
type mockEntry struct {
expiresAt time.Time
value []byte
}
func newMockBackend() *mockBackend {
return &mockBackend{
data: make(map[string]mockEntry),
}
}
func (m *mockBackend) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
m.mu.Lock()
defer m.mu.Unlock()
m.callCount++
if m.failSet {
return errors.New("mock set error")
}
expiresAt := time.Now().Add(ttl)
if ttl == 0 {
expiresAt = time.Now().Add(24 * time.Hour)
}
m.data[key] = mockEntry{
value: value,
expiresAt: expiresAt,
}
return nil
}
func (m *mockBackend) Get(ctx context.Context, key string) ([]byte, time.Duration, bool, error) {
m.mu.RLock()
defer m.mu.RUnlock()
m.callCount++
if m.failGet {
return nil, 0, false, errors.New("mock get error")
}
entry, exists := m.data[key]
if !exists {
return nil, 0, false, nil
}
if time.Now().After(entry.expiresAt) {
return nil, 0, false, nil
}
ttl := time.Until(entry.expiresAt)
return entry.value, ttl, true, nil
}
func (m *mockBackend) Delete(ctx context.Context, key string) (bool, error) {
m.mu.Lock()
defer m.mu.Unlock()
m.callCount++
if m.failDelete {
return false, errors.New("mock delete error")
}
_, existed := m.data[key]
delete(m.data, key)
return existed, nil
}
func (m *mockBackend) Exists(ctx context.Context, key string) (bool, error) {
m.mu.RLock()
defer m.mu.RUnlock()
m.callCount++
if m.failExists {
return false, errors.New("mock exists error")
}
entry, exists := m.data[key]
if !exists {
return false, nil
}
if time.Now().After(entry.expiresAt) {
return false, nil
}
return true, nil
}
func (m *mockBackend) Clear(ctx context.Context) error {
m.mu.Lock()
defer m.mu.Unlock()
m.callCount++
if m.failClear {
return errors.New("mock clear error")
}
m.data = make(map[string]mockEntry)
return nil
}
func (m *mockBackend) GetStats() map[string]interface{} {
m.mu.RLock()
defer m.mu.RUnlock()
return map[string]interface{}{
"hits": int64(0),
"misses": int64(0),
"call_count": m.callCount,
}
}
func (m *mockBackend) Close() error {
return nil
}
func (m *mockBackend) Ping(ctx context.Context) error {
m.mu.Lock()
defer m.mu.Unlock()
m.callCount++
if m.failPing {
return errors.New("mock ping error")
}
return nil
}
// Constructor Tests
func TestNewCircuitBreakerBackend_WithDefaultConfig(t *testing.T) {
mockBE := newMockBackend()
cb := NewCircuitBreakerBackend(mockBE, nil)
require.NotNil(t, cb)
// Verify it implements the interface (compile-time check)
var _ backends.CacheBackend = cb
}
func TestNewCircuitBreakerBackend_WithCustomConfig(t *testing.T) {
mockBE := newMockBackend()
config := &CircuitBreakerConfig{
MaxFailures: 3,
FailureThreshold: 0.5,
Timeout: 5 * time.Second,
HalfOpenMaxRequests: 2,
ResetTimeout: 10 * time.Second,
}
cb := NewCircuitBreakerBackend(mockBE, config)
require.NotNil(t, cb)
}
// Set Operation Tests
func TestCircuitBreakerBackend_Set_Success(t *testing.T) {
mockBE := newMockBackend()
cb := NewCircuitBreakerBackend(mockBE, nil)
ctx := context.Background()
err := cb.Set(ctx, "key1", []byte("value1"), 1*time.Minute)
assert.NoError(t, err)
assert.Equal(t, 1, mockBE.callCount)
// Verify value was stored
value, _, exists, _ := mockBE.Get(ctx, "key1")
assert.True(t, exists)
assert.Equal(t, []byte("value1"), value)
}
func TestCircuitBreakerBackend_Set_Failure(t *testing.T) {
mockBE := newMockBackend()
mockBE.failSet = true
cb := NewCircuitBreakerBackend(mockBE, nil)
ctx := context.Background()
err := cb.Set(ctx, "key1", []byte("value1"), 1*time.Minute)
assert.Error(t, err)
}
func TestCircuitBreakerBackend_Set_CircuitOpen(t *testing.T) {
mockBE := newMockBackend()
mockBE.failSet = true
config := &CircuitBreakerConfig{
MaxFailures: 2,
Timeout: 100 * time.Millisecond,
}
cb := NewCircuitBreakerBackend(mockBE, config)
ctx := context.Background()
// Trigger failures to open circuit
for i := 0; i < 5; i++ {
cb.Set(ctx, "key", []byte("value"), 1*time.Minute)
}
// Circuit should be open now
err := cb.Set(ctx, "key2", []byte("value2"), 1*time.Minute)
assert.Error(t, err)
assert.Equal(t, backends.ErrCircuitOpen, err)
}
// Get Operation Tests
func TestCircuitBreakerBackend_Get_Success(t *testing.T) {
mockBE := newMockBackend()
cb := NewCircuitBreakerBackend(mockBE, nil)
ctx := context.Background()
// First set a value
mockBE.Set(ctx, "key1", []byte("value1"), 1*time.Minute)
// Now get it through circuit breaker
value, _, exists, err := cb.Get(ctx, "key1")
assert.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, []byte("value1"), value)
}
func TestCircuitBreakerBackend_Get_Failure(t *testing.T) {
mockBE := newMockBackend()
mockBE.failGet = true
cb := NewCircuitBreakerBackend(mockBE, nil)
ctx := context.Background()
_, _, _, err := cb.Get(ctx, "key1")
assert.Error(t, err)
}
func TestCircuitBreakerBackend_Get_CircuitOpen(t *testing.T) {
mockBE := newMockBackend()
mockBE.failGet = true
config := &CircuitBreakerConfig{
MaxFailures: 2,
Timeout: 100 * time.Millisecond,
}
cb := NewCircuitBreakerBackend(mockBE, config)
ctx := context.Background()
// Trigger failures
for i := 0; i < 5; i++ {
cb.Get(ctx, "key")
}
// Circuit should be open
_, _, _, err := cb.Get(ctx, "key2")
assert.Error(t, err)
assert.Equal(t, backends.ErrCircuitOpen, err)
}
// Delete Operation Tests
func TestCircuitBreakerBackend_Delete_Success(t *testing.T) {
mockBE := newMockBackend()
cb := NewCircuitBreakerBackend(mockBE, nil)
ctx := context.Background()
// Set a value first
mockBE.Set(ctx, "key1", []byte("value1"), 1*time.Minute)
// Delete through circuit breaker
deleted, err := cb.Delete(ctx, "key1")
assert.NoError(t, err)
assert.True(t, deleted)
// Verify it's deleted
exists, _ := mockBE.Exists(ctx, "key1")
assert.False(t, exists)
}
func TestCircuitBreakerBackend_Delete_CircuitOpen(t *testing.T) {
mockBE := newMockBackend()
mockBE.failDelete = true
config := &CircuitBreakerConfig{
MaxFailures: 2,
Timeout: 100 * time.Millisecond,
}
cb := NewCircuitBreakerBackend(mockBE, config)
ctx := context.Background()
// Trigger failures
for i := 0; i < 5; i++ {
cb.Delete(ctx, "key")
}
// Circuit should be open
_, err := cb.Delete(ctx, "key2")
assert.Error(t, err)
assert.Equal(t, backends.ErrCircuitOpen, err)
}
// Exists Operation Tests
func TestCircuitBreakerBackend_Exists_Success(t *testing.T) {
mockBE := newMockBackend()
cb := NewCircuitBreakerBackend(mockBE, nil)
ctx := context.Background()
// Set a value first
mockBE.Set(ctx, "key1", []byte("value1"), 1*time.Minute)
// Check existence through circuit breaker
exists, err := cb.Exists(ctx, "key1")
assert.NoError(t, err)
assert.True(t, exists)
}
func TestCircuitBreakerBackend_Exists_CircuitOpen(t *testing.T) {
mockBE := newMockBackend()
mockBE.failExists = true
config := &CircuitBreakerConfig{
MaxFailures: 2,
Timeout: 100 * time.Millisecond,
}
cb := NewCircuitBreakerBackend(mockBE, config)
ctx := context.Background()
// Trigger failures
for i := 0; i < 5; i++ {
cb.Exists(ctx, "key")
}
// Circuit should be open
_, err := cb.Exists(ctx, "key2")
assert.Error(t, err)
assert.Equal(t, backends.ErrCircuitOpen, err)
}
// Clear Operation Tests
func TestCircuitBreakerBackend_Clear_Success(t *testing.T) {
mockBE := newMockBackend()
cb := NewCircuitBreakerBackend(mockBE, nil)
ctx := context.Background()
// Set some values
mockBE.Set(ctx, "key1", []byte("value1"), 1*time.Minute)
mockBE.Set(ctx, "key2", []byte("value2"), 1*time.Minute)
// Clear through circuit breaker
err := cb.Clear(ctx)
assert.NoError(t, err)
// Verify cleared
exists1, _ := mockBE.Exists(ctx, "key1")
exists2, _ := mockBE.Exists(ctx, "key2")
assert.False(t, exists1)
assert.False(t, exists2)
}
func TestCircuitBreakerBackend_Clear_CircuitOpen(t *testing.T) {
mockBE := newMockBackend()
mockBE.failClear = true
config := &CircuitBreakerConfig{
MaxFailures: 2,
Timeout: 100 * time.Millisecond,
}
cb := NewCircuitBreakerBackend(mockBE, config)
ctx := context.Background()
// Trigger failures
for i := 0; i < 5; i++ {
cb.Clear(ctx)
}
// Circuit should be open
err := cb.Clear(ctx)
assert.Error(t, err)
assert.Equal(t, backends.ErrCircuitOpen, err)
}
// GetStats Tests
func TestCircuitBreakerBackend_GetStats(t *testing.T) {
mockBE := newMockBackend()
cb := NewCircuitBreakerBackend(mockBE, nil)
ctx := context.Background()
// Perform some operations
cb.Set(ctx, "key1", []byte("value1"), 1*time.Minute)
cb.Get(ctx, "key1")
stats := cb.GetStats()
require.NotNil(t, stats)
// Should have circuit breaker stats
assert.Contains(t, stats, "circuit_breaker")
cbStats, ok := stats["circuit_breaker"].(map[string]interface{})
require.True(t, ok)
// Verify circuit breaker stats fields
assert.Contains(t, cbStats, "state")
assert.Contains(t, cbStats, "consecutive_failures")
assert.Contains(t, cbStats, "total_requests")
assert.Contains(t, cbStats, "total_failures")
assert.Contains(t, cbStats, "success_rate")
}
func TestCircuitBreakerBackend_GetStats_NilBackendStats(t *testing.T) {
// Create a mock backend that returns nil stats
mockBE := &mockBackendNilStats{}
cb := NewCircuitBreakerBackend(mockBE, nil)
stats := cb.GetStats()
require.NotNil(t, stats)
assert.Contains(t, stats, "circuit_breaker")
}
// mockBackendNilStats returns nil from GetStats
type mockBackendNilStats struct {
mockBackend
}
func (m *mockBackendNilStats) GetStats() map[string]interface{} {
return nil
}
// Ping Tests
func TestCircuitBreakerBackend_Ping_Success(t *testing.T) {
mockBE := newMockBackend()
cb := NewCircuitBreakerBackend(mockBE, nil)
ctx := context.Background()
err := cb.Ping(ctx)
assert.NoError(t, err)
}
func TestCircuitBreakerBackend_Ping_CircuitOpen(t *testing.T) {
mockBE := newMockBackend()
mockBE.failPing = true
config := &CircuitBreakerConfig{
MaxFailures: 2,
Timeout: 100 * time.Millisecond,
}
cb := NewCircuitBreakerBackend(mockBE, config)
ctx := context.Background()
// Trigger failures
for i := 0; i < 5; i++ {
cb.Ping(ctx)
}
// Circuit should be open
err := cb.Ping(ctx)
assert.Error(t, err)
assert.Equal(t, backends.ErrCircuitOpen, err)
}
// Close Tests
func TestCircuitBreakerBackend_Close(t *testing.T) {
mockBE := newMockBackend()
cb := NewCircuitBreakerBackend(mockBE, nil)
err := cb.Close()
assert.NoError(t, err)
}
// Circuit Recovery Test
func TestCircuitBreakerBackend_CircuitRecovery(t *testing.T) {
mockBE := newMockBackend()
mockBE.failSet = true
config := &CircuitBreakerConfig{
MaxFailures: 2,
Timeout: 200 * time.Millisecond,
HalfOpenMaxRequests: 1,
}
cb := NewCircuitBreakerBackend(mockBE, config)
ctx := context.Background()
// Trigger failures to open circuit
for i := 0; i < 5; i++ {
cb.Set(ctx, "key", []byte("value"), 1*time.Minute)
}
// Verify circuit is open
err := cb.Set(ctx, "key2", []byte("value2"), 1*time.Minute)
assert.Equal(t, backends.ErrCircuitOpen, err)
// Wait for timeout
time.Sleep(250 * time.Millisecond)
// Fix the backend
mockBE.mu.Lock()
mockBE.failSet = false
mockBE.mu.Unlock()
// Circuit should be in half-open state, allow a test request
err = cb.Set(ctx, "key3", []byte("value3"), 1*time.Minute)
// After success threshold is met, circuit should close
if err == nil {
// Circuit recovered
err2 := cb.Set(ctx, "key4", []byte("value4"), 1*time.Minute)
assert.NoError(t, err2, "Circuit should be closed after recovery")
}
}
+553
View File
@@ -0,0 +1,553 @@
package resilience
import (
"context"
"errors"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
// TestCircuitBreaker_StateTransitions tests state machine transitions
func TestCircuitBreaker_StateTransitions(t *testing.T) {
t.Parallel()
config := &CircuitBreakerConfig{
MaxFailures: 3,
Timeout: 100 * time.Millisecond,
HalfOpenMaxRequests: 2,
}
cb := NewCircuitBreaker(config)
ctx := context.Background()
t.Run("Initial state is closed", func(t *testing.T) {
assert.Equal(t, StateClosed, cb.GetState())
})
t.Run("Closed to Open after max failures", func(t *testing.T) {
cb.Reset()
// Simulate failures
for i := 0; i < 3; i++ {
cb.Execute(ctx, func() error {
return errors.New("test error")
})
}
assert.Equal(t, StateOpen, cb.GetState())
})
t.Run("Open to HalfOpen after timeout", func(t *testing.T) {
// Open the circuit
cb.Reset()
for i := 0; i < 3; i++ {
cb.Execute(ctx, func() error {
return errors.New("test error")
})
}
assert.Equal(t, StateOpen, cb.GetState())
// Wait for timeout
time.Sleep(150 * time.Millisecond)
// Should allow request and transition to half-open
err := cb.Execute(ctx, func() error {
return nil
})
assert.NoError(t, err)
assert.Equal(t, StateHalfOpen, cb.GetState())
})
t.Run("HalfOpen to Closed after successful requests", func(t *testing.T) {
// Open circuit then wait for half-open
cb.Reset()
for i := 0; i < 3; i++ {
cb.Execute(ctx, func() error {
return errors.New("test error")
})
}
assert.Equal(t, StateOpen, cb.GetState())
time.Sleep(150 * time.Millisecond)
// First request transitions to half-open and succeeds
err := cb.Execute(ctx, func() error {
return nil
})
assert.NoError(t, err)
// Should be in half-open after first request
state := cb.GetState()
assert.True(t, state == StateHalfOpen || state == StateClosed,
"After first successful request, should be half-open or potentially closed")
if state == StateHalfOpen {
// Need more successful requests to close
// The exact number depends on implementation but should be within HalfOpenMaxRequests
for i := 0; i < config.HalfOpenMaxRequests; i++ {
cb.Execute(ctx, func() error {
return nil
})
}
// After multiple successful requests, should eventually close
finalState := cb.GetState()
assert.True(t, finalState == StateClosed || finalState == StateHalfOpen,
"After successful requests, circuit should transition towards closed")
}
})
t.Run("HalfOpen to Open on failure", func(t *testing.T) {
// Open circuit then wait for half-open
cb.Reset()
for i := 0; i < 3; i++ {
cb.Execute(ctx, func() error {
return errors.New("test error")
})
}
time.Sleep(150 * time.Millisecond)
// First call transitions to half-open, second failure reopens
cb.Execute(ctx, func() error {
return errors.New("test error")
})
assert.Equal(t, StateOpen, cb.GetState())
})
}
// TestCircuitBreaker_OpenCircuitBlocks tests that open circuit blocks requests
func TestCircuitBreaker_OpenCircuitBlocks(t *testing.T) {
t.Parallel()
config := &CircuitBreakerConfig{
MaxFailures: 2,
Timeout: 1 * time.Second,
HalfOpenMaxRequests: 1,
}
cb := NewCircuitBreaker(config)
ctx := context.Background()
// Trigger failures to open circuit
for i := 0; i < 2; i++ {
cb.Execute(ctx, func() error {
return errors.New("test error")
})
}
assert.Equal(t, StateOpen, cb.GetState())
// Requests should be blocked
err := cb.Execute(ctx, func() error {
t.Fatal("Should not execute function when circuit is open")
return nil
})
assert.Error(t, err)
assert.Equal(t, ErrCircuitOpen, err)
}
// TestCircuitBreaker_HalfOpenMaxRequests tests max requests in half-open state
func TestCircuitBreaker_HalfOpenMaxRequests(t *testing.T) {
t.Parallel()
config := &CircuitBreakerConfig{
MaxFailures: 3,
Timeout: 100 * time.Millisecond,
HalfOpenMaxRequests: 2,
}
cb := NewCircuitBreaker(config)
ctx := context.Background()
// Open circuit then wait for half-open
for i := 0; i < 3; i++ {
cb.Execute(ctx, func() error {
return errors.New("test error")
})
}
assert.Equal(t, StateOpen, cb.GetState())
time.Sleep(150 * time.Millisecond)
// After timeout, circuit should allow transition to half-open
// Execute HalfOpenMaxRequests successful requests
successCount := 0
for i := 0; i < config.HalfOpenMaxRequests; i++ {
err := cb.Execute(ctx, func() error {
successCount++
return nil
})
// Should allow up to HalfOpenMaxRequests
assert.NoError(t, err)
}
// Verify we executed the expected number
assert.Equal(t, config.HalfOpenMaxRequests, successCount)
// After successful requests, circuit behavior depends on implementation
// It could close (allowing more requests) or stay half-open (blocking)
// The important thing is that we allowed exactly HalfOpenMaxRequests
}
// TestCircuitBreaker_SuccessResetsFailures tests failure counter reset
func TestCircuitBreaker_SuccessResetsFailures(t *testing.T) {
t.Parallel()
config := &CircuitBreakerConfig{
MaxFailures: 3,
Timeout: 100 * time.Millisecond,
HalfOpenMaxRequests: 1,
}
cb := NewCircuitBreaker(config)
ctx := context.Background()
// Have some failures (but less than max)
cb.Execute(ctx, func() error {
return errors.New("error")
})
cb.Execute(ctx, func() error {
return errors.New("error")
})
assert.Equal(t, StateClosed, cb.GetState())
stats := cb.Stats()
assert.Equal(t, int32(2), stats.ConsecutiveFailures)
// One success should reset failures
cb.Execute(ctx, func() error {
return nil
})
assert.Equal(t, StateClosed, cb.GetState())
stats = cb.Stats()
assert.Equal(t, int32(0), stats.ConsecutiveFailures)
}
// TestCircuitBreaker_ConcurrentAccess tests thread safety
func TestCircuitBreaker_ConcurrentAccess(t *testing.T) {
t.Parallel()
config := &CircuitBreakerConfig{
MaxFailures: 10,
Timeout: 100 * time.Millisecond,
HalfOpenMaxRequests: 5,
}
cb := NewCircuitBreaker(config)
ctx := context.Background()
var wg sync.WaitGroup
goroutines := 20
iterations := 50
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
for j := 0; j < iterations; j++ {
// Mix of successes and failures
cb.Execute(ctx, func() error {
if (id+j)%3 == 0 {
return errors.New("test error")
}
return nil
})
// Random state checks
_ = cb.GetState()
_ = cb.Stats()
}
}(i)
}
wg.Wait()
// Should complete without panics
stats := cb.Stats()
assert.NotNil(t, stats)
}
// TestCircuitBreaker_Stats tests statistics tracking
func TestCircuitBreaker_Stats(t *testing.T) {
t.Parallel()
config := &CircuitBreakerConfig{
MaxFailures: 5,
Timeout: 100 * time.Millisecond,
HalfOpenMaxRequests: 2,
}
cb := NewCircuitBreaker(config)
ctx := context.Background()
// Execute some requests
cb.Execute(ctx, func() error { return nil }) // Success
cb.Execute(ctx, func() error { return errors.New("error") }) // Failure
cb.Execute(ctx, func() error { return errors.New("error") }) // Failure
stats := cb.Stats()
assert.Equal(t, StateClosed, stats.State)
assert.Equal(t, int64(3), stats.TotalRequests)
assert.Equal(t, int64(2), stats.TotalFailures)
assert.Equal(t, int32(2), stats.ConsecutiveFailures)
}
// TestCircuitBreaker_Reset tests circuit reset
func TestCircuitBreaker_Reset(t *testing.T) {
t.Parallel()
config := &CircuitBreakerConfig{
MaxFailures: 2,
Timeout: 100 * time.Millisecond,
HalfOpenMaxRequests: 1,
}
cb := NewCircuitBreaker(config)
ctx := context.Background()
// Open the circuit
for i := 0; i < 2; i++ {
cb.Execute(ctx, func() error {
return errors.New("error")
})
}
assert.Equal(t, StateOpen, cb.GetState())
// Reset
cb.Reset()
assert.Equal(t, StateClosed, cb.GetState())
stats := cb.Stats()
assert.Equal(t, int32(0), stats.ConsecutiveFailures)
assert.Equal(t, int64(0), stats.TotalRequests)
assert.Equal(t, int64(0), stats.TotalFailures)
}
// TestCircuitBreaker_StateChangeCallback tests state change notifications
func TestCircuitBreaker_StateChangeCallback(t *testing.T) {
t.Parallel()
var transitions []string
var mu sync.Mutex
config := &CircuitBreakerConfig{
MaxFailures: 2,
Timeout: 50 * time.Millisecond,
HalfOpenMaxRequests: 1,
OnStateChange: func(from, to State) {
mu.Lock()
defer mu.Unlock()
transitions = append(transitions, from.String()+"->"+to.String())
},
}
cb := NewCircuitBreaker(config)
ctx := context.Background()
// Trigger state transitions
// Closed -> Open
for i := 0; i < 2; i++ {
cb.Execute(ctx, func() error {
return errors.New("error")
})
}
// Should be open now
assert.Equal(t, StateOpen, cb.GetState())
// Wait for timeout to allow half-open transition
time.Sleep(100 * time.Millisecond)
// Open -> HalfOpen on first request after timeout
err := cb.Execute(ctx, func() error {
return nil
})
assert.NoError(t, err)
// Execute more successful requests to trigger HalfOpen -> Closed
for i := 0; i < config.HalfOpenMaxRequests-1; i++ {
cb.Execute(ctx, func() error {
return nil
})
}
mu.Lock()
defer mu.Unlock()
assert.Contains(t, transitions, "closed->open")
assert.Contains(t, transitions, "open->half-open")
}
// TestCircuitBreaker_IsHealthy tests health check
func TestCircuitBreaker_IsHealthy(t *testing.T) {
t.Parallel()
config := &CircuitBreakerConfig{
MaxFailures: 2,
Timeout: 100 * time.Millisecond,
HalfOpenMaxRequests: 1,
}
cb := NewCircuitBreaker(config)
ctx := context.Background()
// Initially healthy
assert.True(t, cb.IsHealthy())
// Open circuit
for i := 0; i < 2; i++ {
cb.Execute(ctx, func() error {
return errors.New("error")
})
}
assert.Equal(t, StateOpen, cb.GetState())
assert.False(t, cb.IsHealthy(), "Should not be healthy when open")
// Wait for timeout and allow successful request
time.Sleep(150 * time.Millisecond)
cb.Execute(ctx, func() error {
return nil
})
// Should be healthy after recovery
assert.True(t, cb.IsHealthy(), "Should be healthy after recovery")
}
// TestCircuitBreaker_RapidFailures tests rapid consecutive failures
func TestCircuitBreaker_RapidFailures(t *testing.T) {
t.Parallel()
config := &CircuitBreakerConfig{
MaxFailures: 5,
Timeout: 200 * time.Millisecond,
HalfOpenMaxRequests: 1,
}
cb := NewCircuitBreaker(config)
ctx := context.Background()
// Rapid failures
for i := 0; i < 10; i++ {
cb.Execute(ctx, func() error {
return errors.New("rapid error")
})
}
assert.Equal(t, StateOpen, cb.GetState())
stats := cb.Stats()
assert.GreaterOrEqual(t, stats.TotalFailures, int64(5))
}
// TestCircuitBreaker_TimeoutAccuracy tests timeout precision
func TestCircuitBreaker_TimeoutAccuracy(t *testing.T) {
t.Parallel()
timeout := 100 * time.Millisecond
config := &CircuitBreakerConfig{
MaxFailures: 1,
Timeout: timeout,
HalfOpenMaxRequests: 1,
}
cb := NewCircuitBreaker(config)
ctx := context.Background()
// Open circuit
cb.Execute(ctx, func() error {
return errors.New("error")
})
assert.Equal(t, StateOpen, cb.GetState())
// Wait just before timeout
time.Sleep(timeout - 20*time.Millisecond)
assert.False(t, cb.IsHealthy())
// Wait until after timeout
time.Sleep(40 * time.Millisecond)
// After timeout, AllowRequest should return true for transition to half-open
assert.True(t, cb.AllowRequest())
}
// TestCircuitBreaker_DefaultConfig tests default configuration
func TestCircuitBreaker_DefaultConfig(t *testing.T) {
t.Parallel()
cb := NewCircuitBreaker(nil) // Should use defaults
assert.NotNil(t, cb)
assert.Equal(t, StateClosed, cb.GetState())
// Verify defaults by triggering circuit breaker behavior
ctx := context.Background()
// Test that it takes 5 failures to open (default MaxFailures)
for i := 0; i < 4; i++ {
cb.Execute(ctx, func() error {
return errors.New("error")
})
}
assert.Equal(t, StateClosed, cb.GetState(), "Should still be closed after 4 failures")
// 5th failure should open it
cb.Execute(ctx, func() error {
return errors.New("error")
})
assert.Equal(t, StateOpen, cb.GetState(), "Should be open after 5 failures (default threshold)")
}
// TestCircuitBreaker_StateString tests state string representation
func TestCircuitBreaker_StateString(t *testing.T) {
t.Parallel()
assert.Equal(t, "closed", StateClosed.String())
assert.Equal(t, "open", StateOpen.String())
assert.Equal(t, "half-open", StateHalfOpen.String())
assert.Equal(t, "unknown", State(999).String())
}
// Benchmark circuit breaker performance
func BenchmarkCircuitBreaker_Execute(b *testing.B) {
config := &CircuitBreakerConfig{
MaxFailures: 100,
Timeout: 1 * time.Second,
HalfOpenMaxRequests: 10,
}
cb := NewCircuitBreaker(config)
ctx := context.Background()
b.ResetTimer()
for i := 0; i < b.N; i++ {
cb.Execute(ctx, func() error {
return nil
})
}
}
func BenchmarkCircuitBreaker_ExecuteWithFailures(b *testing.B) {
config := &CircuitBreakerConfig{
MaxFailures: 1000,
Timeout: 1 * time.Second,
HalfOpenMaxRequests: 10,
}
cb := NewCircuitBreaker(config)
ctx := context.Background()
b.ResetTimer()
for i := 0; i < b.N; i++ {
cb.Execute(ctx, func() error {
if i%10 == 0 {
return errors.New("error")
}
return nil
})
}
}
+356
View File
@@ -0,0 +1,356 @@
// Package resilience provides resilience patterns for cache backends.
package resilience
import (
"context"
"sync"
"sync/atomic"
"time"
)
// HealthStatus represents the health status of a backend
type HealthStatus int32
const (
// HealthUnknown indicates unknown health status
HealthUnknown HealthStatus = iota
// HealthHealthy indicates the backend is healthy
HealthHealthy
// HealthDegraded indicates the backend is degraded but operational
HealthDegraded
// HealthUnhealthy indicates the backend is unhealthy
HealthUnhealthy
)
// String returns the string representation of the health status
func (h HealthStatus) String() string {
switch h {
case HealthHealthy:
return "healthy"
case HealthDegraded:
return "degraded"
case HealthUnhealthy:
return "unhealthy"
default:
return "unknown"
}
}
// HealthCheckConfig holds configuration for the health checker
type HealthCheckConfig struct {
OnStatusChange func(from, to HealthStatus)
CheckFunc func(ctx context.Context) error
CheckInterval time.Duration
Timeout time.Duration
HealthyThreshold int
UnhealthyThreshold int
DegradedThreshold time.Duration
}
// DefaultHealthCheckConfig returns default configuration
func DefaultHealthCheckConfig() *HealthCheckConfig {
return &HealthCheckConfig{
CheckInterval: 30 * time.Second,
Timeout: 5 * time.Second,
HealthyThreshold: 3,
UnhealthyThreshold: 3,
DegradedThreshold: 100 * time.Millisecond,
}
}
// HealthChecker monitors the health of a backend
type HealthChecker struct {
lastCheckTime time.Time
lastSuccessTime time.Time
lastFailureTime time.Time
config *HealthCheckConfig
stopChan chan struct{}
ticker *time.Ticker
wg sync.WaitGroup
statusChanges atomic.Int64
totalChecks atomic.Int64
totalSuccesses atomic.Int64
totalFailures atomic.Int64
averageLatency atomic.Int64
timeMu sync.RWMutex
consecutiveFailures atomic.Int32
consecutiveSuccesses atomic.Int32
stopped atomic.Bool
status atomic.Int32
}
// NewHealthChecker creates a new health checker
func NewHealthChecker(config *HealthCheckConfig) *HealthChecker {
if config == nil {
config = DefaultHealthCheckConfig()
}
hc := &HealthChecker{
config: config,
stopChan: make(chan struct{}),
}
hc.status.Store(int32(HealthUnknown))
return hc
}
// Start begins health checking
func (hc *HealthChecker) Start() {
if hc.stopped.Load() {
return
}
hc.ticker = time.NewTicker(hc.config.CheckInterval)
hc.wg.Add(1)
go hc.checkLoop()
}
// Stop stops health checking
func (hc *HealthChecker) Stop() {
if hc.stopped.Swap(true) {
return // Already stopped
}
close(hc.stopChan)
if hc.ticker != nil {
hc.ticker.Stop()
}
hc.wg.Wait()
}
// checkLoop runs periodic health checks
func (hc *HealthChecker) checkLoop() {
defer hc.wg.Done()
// Initial check - log error but continue
if err := hc.Check(context.Background()); err != nil {
// Error is already tracked in Check() method, no need to log again
_ = err
}
for {
select {
case <-hc.stopChan:
return
case <-hc.ticker.C:
ctx, cancel := context.WithTimeout(context.Background(), hc.config.Timeout)
if err := hc.Check(ctx); err != nil {
// Error is already tracked in Check() method, no need to log again
_ = err
}
cancel()
}
}
}
// Check performs a health check
func (hc *HealthChecker) Check(ctx context.Context) error {
if hc.config.CheckFunc == nil {
return nil
}
hc.totalChecks.Add(1)
start := time.Now()
// Create timeout context if not already set
if _, hasDeadline := ctx.Deadline(); !hasDeadline {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, hc.config.Timeout)
defer cancel()
}
// Perform health check
err := hc.config.CheckFunc(ctx)
latency := time.Since(start)
hc.timeMu.Lock()
hc.lastCheckTime = time.Now()
hc.timeMu.Unlock()
// Update average latency
hc.updateAverageLatency(latency)
if err != nil {
hc.recordFailure()
} else {
hc.recordSuccess(latency)
}
return err
}
// recordSuccess records a successful health check
func (hc *HealthChecker) recordSuccess(latency time.Duration) {
hc.totalSuccesses.Add(1)
successes := hc.consecutiveSuccesses.Add(1)
hc.consecutiveFailures.Store(0)
hc.timeMu.Lock()
hc.lastSuccessTime = time.Now()
hc.timeMu.Unlock()
currentStatus := hc.GetStatus()
newStatus := currentStatus
// Check if we should become healthy
// #nosec G115 -- HealthyThreshold is a small config value that fits in int32
if successes >= int32(hc.config.HealthyThreshold) {
if latency > hc.config.DegradedThreshold {
newStatus = HealthDegraded
} else {
newStatus = HealthHealthy
}
}
if newStatus != currentStatus {
hc.setStatus(newStatus)
}
}
// recordFailure records a failed health check
func (hc *HealthChecker) recordFailure() {
hc.totalFailures.Add(1)
failures := hc.consecutiveFailures.Add(1)
hc.consecutiveSuccesses.Store(0)
hc.timeMu.Lock()
hc.lastFailureTime = time.Now()
hc.timeMu.Unlock()
// Check if we should become unhealthy
// #nosec G115 -- UnhealthyThreshold is a small config value that fits in int32
if failures >= int32(hc.config.UnhealthyThreshold) {
hc.setStatus(HealthUnhealthy)
}
}
// updateAverageLatency updates the rolling average latency
func (hc *HealthChecker) updateAverageLatency(latency time.Duration) {
// Simple exponential moving average
currentAvg := time.Duration(hc.averageLatency.Load())
if currentAvg == 0 {
hc.averageLatency.Store(int64(latency))
} else {
// Weight: 0.2 for new value, 0.8 for old average
newAvg := (currentAvg*4 + latency) / 5
hc.averageLatency.Store(int64(newAvg))
}
}
// GetStatus returns the current health status
func (hc *HealthChecker) GetStatus() HealthStatus {
return HealthStatus(hc.status.Load())
}
// setStatus changes the health status
func (hc *HealthChecker) setStatus(newStatus HealthStatus) {
oldStatus := HealthStatus(hc.status.Swap(int32(newStatus)))
if oldStatus != newStatus {
hc.statusChanges.Add(1)
if hc.config.OnStatusChange != nil {
hc.config.OnStatusChange(oldStatus, newStatus)
}
}
}
// IsHealthy returns true if the backend is healthy or degraded
func (hc *HealthChecker) IsHealthy() bool {
status := hc.GetStatus()
return status == HealthHealthy || status == HealthDegraded
}
// LastCheckTime returns the time of the last health check
func (hc *HealthChecker) LastCheckTime() time.Time {
hc.timeMu.RLock()
defer hc.timeMu.RUnlock()
return hc.lastCheckTime
}
// HealthScore returns a health score between 0.0 (unhealthy) and 1.0 (healthy)
func (hc *HealthChecker) HealthScore() float64 {
status := hc.GetStatus()
switch status {
case HealthHealthy:
return 1.0
case HealthDegraded:
return 0.7
case HealthUnhealthy:
return 0.0
default:
return 0.5
}
}
// Stats returns health checker statistics
func (hc *HealthChecker) Stats() HealthCheckerStats {
hc.timeMu.RLock()
lastCheck := hc.lastCheckTime
lastSuccess := hc.lastSuccessTime
lastFailure := hc.lastFailureTime
hc.timeMu.RUnlock()
totalChecks := hc.totalChecks.Load()
totalSuccesses := hc.totalSuccesses.Load()
totalFailures := hc.totalFailures.Load()
successRate := float64(0)
if totalChecks > 0 {
successRate = float64(totalSuccesses) / float64(totalChecks)
}
return HealthCheckerStats{
Status: hc.GetStatus(),
ConsecutiveSuccesses: hc.consecutiveSuccesses.Load(),
ConsecutiveFailures: hc.consecutiveFailures.Load(),
TotalChecks: totalChecks,
TotalSuccesses: totalSuccesses,
TotalFailures: totalFailures,
SuccessRate: successRate,
AverageLatency: time.Duration(hc.averageLatency.Load()),
StatusChanges: hc.statusChanges.Load(),
LastCheckTime: lastCheck,
LastSuccessTime: lastSuccess,
LastFailureTime: lastFailure,
HealthScore: hc.HealthScore(),
}
}
// HealthCheckerStats holds statistics for the health checker
type HealthCheckerStats struct {
LastCheckTime time.Time
LastFailureTime time.Time
LastSuccessTime time.Time
TotalChecks int64
TotalSuccesses int64
TotalFailures int64
SuccessRate float64
AverageLatency time.Duration
StatusChanges int64
HealthScore float64
Status HealthStatus
ConsecutiveFailures int32
ConsecutiveSuccesses int32
}
// Reset resets the health checker statistics
func (hc *HealthChecker) Reset() {
hc.status.Store(int32(HealthUnknown))
hc.consecutiveSuccesses.Store(0)
hc.consecutiveFailures.Store(0)
hc.totalChecks.Store(0)
hc.totalSuccesses.Store(0)
hc.totalFailures.Store(0)
hc.statusChanges.Store(0)
hc.averageLatency.Store(0)
now := time.Now()
hc.timeMu.Lock()
hc.lastCheckTime = now
hc.lastSuccessTime = now
hc.lastFailureTime = now
hc.timeMu.Unlock()
}
+212
View File
@@ -0,0 +1,212 @@
// Package resilience provides resilience patterns for cache backends.
package resilience
import (
"context"
"sync"
"sync/atomic"
"time"
"github.com/lukaszraczylo/traefikoidc/internal/cache/backends"
)
// HealthCheckBackend wraps a cache backend with health checking
type HealthCheckBackend struct {
lastCheck time.Time
backend backends.CacheBackend
ctx context.Context
config *HealthCheckConfig
cancel context.CancelFunc
wg sync.WaitGroup
checkMutex sync.RWMutex
status atomic.Int32
consecutiveFails atomic.Int32
consecutiveOK atomic.Int32
}
// NewHealthCheckBackend creates a new health check wrapped backend
func NewHealthCheckBackend(b backends.CacheBackend, config *HealthCheckConfig) backends.CacheBackend {
if config == nil {
config = DefaultHealthCheckConfig()
}
ctx, cancel := context.WithCancel(context.Background())
hc := &HealthCheckBackend{
backend: b,
config: config,
ctx: ctx,
cancel: cancel,
}
// Set initial status to healthy (optimistic)
hc.status.Store(int32(HealthHealthy))
// Start health check routine
hc.wg.Add(1)
go hc.healthCheckLoop()
return hc
}
// Set stores a value and tracks health
func (h *HealthCheckBackend) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
// Allow operations even if unhealthy (may recover)
err := h.backend.Set(ctx, key, value, ttl)
h.recordResult(err == nil)
return err
}
// Get retrieves a value and tracks health
func (h *HealthCheckBackend) Get(ctx context.Context, key string) ([]byte, time.Duration, bool, error) {
value, ttl, exists, err := h.backend.Get(ctx, key)
h.recordResult(err == nil)
return value, ttl, exists, err
}
// Delete removes a key and tracks health
func (h *HealthCheckBackend) Delete(ctx context.Context, key string) (bool, error) {
deleted, err := h.backend.Delete(ctx, key)
h.recordResult(err == nil)
return deleted, err
}
// Exists checks if a key exists and tracks health
func (h *HealthCheckBackend) Exists(ctx context.Context, key string) (bool, error) {
exists, err := h.backend.Exists(ctx, key)
h.recordResult(err == nil)
return exists, err
}
// Clear removes all keys and tracks health
func (h *HealthCheckBackend) Clear(ctx context.Context) error {
err := h.backend.Clear(ctx)
h.recordResult(err == nil)
return err
}
// GetStats returns statistics including health status
func (h *HealthCheckBackend) GetStats() map[string]interface{} {
stats := h.backend.GetStats()
if stats == nil {
stats = make(map[string]interface{})
}
h.checkMutex.RLock()
lastCheck := h.lastCheck
h.checkMutex.RUnlock()
status := HealthStatus(h.status.Load())
stats["health"] = map[string]interface{}{
"status": status.String(),
"consecutive_fails": h.consecutiveFails.Load(),
"consecutive_ok": h.consecutiveOK.Load(),
"last_check": lastCheck.Format(time.RFC3339),
"time_since_check": time.Since(lastCheck).Seconds(),
"check_interval_sec": h.config.CheckInterval.Seconds(),
}
return stats
}
// Ping checks backend health
func (h *HealthCheckBackend) Ping(ctx context.Context) error {
err := h.backend.Ping(ctx)
h.recordResult(err == nil)
return err
}
// Close shuts down the health checker and backend
func (h *HealthCheckBackend) Close() error {
// Stop health check routine
h.cancel()
// Wait for routine to finish
done := make(chan struct{})
go func() {
h.wg.Wait()
close(done)
}()
select {
case <-done:
// Finished normally
case <-time.After(2 * time.Second):
// Timeout
}
return h.backend.Close()
}
// IsHealthy returns true if the backend is healthy
func (h *HealthCheckBackend) IsHealthy() bool {
status := HealthStatus(h.status.Load())
return status == HealthHealthy || status == HealthDegraded
}
// recordResult records the result of an operation for health tracking
func (h *HealthCheckBackend) recordResult(success bool) {
// #nosec G115 -- threshold config values are small integers that fit in int32
if success {
fails := h.consecutiveFails.Swap(0)
oks := h.consecutiveOK.Add(1)
// Check if we should transition to healthy
if fails > 0 && oks >= int32(h.config.HealthyThreshold) {
oldStatus := HealthStatus(h.status.Swap(int32(HealthHealthy)))
if oldStatus != HealthHealthy && h.config.OnStatusChange != nil {
h.config.OnStatusChange(oldStatus, HealthHealthy)
}
}
} else {
oks := h.consecutiveOK.Swap(0)
fails := h.consecutiveFails.Add(1)
// Check if we should transition to unhealthy
if oks > 0 && fails >= int32(h.config.UnhealthyThreshold) {
oldStatus := HealthStatus(h.status.Swap(int32(HealthUnhealthy)))
if oldStatus != HealthUnhealthy && h.config.OnStatusChange != nil {
h.config.OnStatusChange(oldStatus, HealthUnhealthy)
}
} else if fails >= int32(h.config.UnhealthyThreshold)*2 {
// Severely degraded
h.status.Store(int32(HealthUnhealthy))
} else if fails >= int32(h.config.UnhealthyThreshold) {
// Degraded but still trying
h.status.Store(int32(HealthDegraded))
}
}
}
// healthCheckLoop runs periodic health checks
func (h *HealthCheckBackend) healthCheckLoop() {
defer h.wg.Done()
ticker := time.NewTicker(h.config.CheckInterval)
defer ticker.Stop()
// Do initial check
h.performHealthCheck()
for {
select {
case <-h.ctx.Done():
return
case <-ticker.C:
h.performHealthCheck()
}
}
}
// performHealthCheck performs a single health check
func (h *HealthCheckBackend) performHealthCheck() {
h.checkMutex.Lock()
h.lastCheck = time.Now()
h.checkMutex.Unlock()
ctx, cancel := context.WithTimeout(context.Background(), h.config.Timeout)
defer cancel()
err := h.backend.Ping(ctx)
h.recordResult(err == nil)
}
+447
View File
@@ -0,0 +1,447 @@
package resilience
import (
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
// TestHealthChecker_StatusTransitions tests health status transitions
func TestHealthChecker_StatusTransitions(t *testing.T) {
t.Parallel()
var callCount atomic.Int32
var shouldFail atomic.Bool
checkFunc := func(ctx context.Context) error {
callCount.Add(1)
if shouldFail.Load() {
return errors.New("health check failed")
}
return nil
}
config := &HealthCheckConfig{
CheckInterval: 50 * time.Millisecond,
Timeout: 10 * time.Millisecond,
UnhealthyThreshold: 3,
HealthyThreshold: 2,
CheckFunc: checkFunc,
}
hc := NewHealthChecker(config)
hc.Start()
defer hc.Stop()
// Initially unknown
assert.Equal(t, HealthUnknown, hc.GetStatus())
// Trigger failures
shouldFail.Store(true)
time.Sleep(200 * time.Millisecond)
// Should be unhealthy after threshold failures
status := hc.GetStatus()
assert.True(t, status == HealthUnhealthy || status == HealthDegraded)
// Recover
shouldFail.Store(false)
time.Sleep(150 * time.Millisecond)
// Should recover towards healthy
finalStatus := hc.GetStatus()
assert.True(t, finalStatus == HealthHealthy || finalStatus == HealthDegraded || finalStatus == HealthUnknown)
}
// TestHealthChecker_InitialState tests initial health status
func TestHealthChecker_InitialState(t *testing.T) {
t.Parallel()
checkFunc := func(ctx context.Context) error {
return nil
}
config := &HealthCheckConfig{
CheckFunc: checkFunc,
}
hc := NewHealthChecker(config)
assert.Equal(t, HealthUnknown, hc.GetStatus())
assert.False(t, hc.IsHealthy())
}
// TestHealthChecker_ForceCheck tests manual health check trigger
func TestHealthChecker_ForceCheck(t *testing.T) {
t.Parallel()
var callCount atomic.Int32
checkFunc := func(ctx context.Context) error {
callCount.Add(1)
return nil
}
config := &HealthCheckConfig{
CheckInterval: 10 * time.Second, // Long interval
Timeout: 1 * time.Second,
UnhealthyThreshold: 3,
HealthyThreshold: 2,
CheckFunc: checkFunc,
}
hc := NewHealthChecker(config)
initialCount := callCount.Load()
// Force check
hc.Check(context.Background())
// Should have been called
assert.Greater(t, callCount.Load(), initialCount)
}
// TestHealthChecker_StatusChangeCallback tests status change notifications
func TestHealthChecker_StatusChangeCallback(t *testing.T) {
t.Parallel()
var transitions []string
var mu sync.Mutex
var shouldFail atomic.Bool
checkFunc := func(ctx context.Context) error {
if shouldFail.Load() {
return errors.New("health check failed")
}
return nil
}
config := &HealthCheckConfig{
CheckInterval: 30 * time.Millisecond,
Timeout: 10 * time.Millisecond,
UnhealthyThreshold: 2,
HealthyThreshold: 2,
CheckFunc: checkFunc,
OnStatusChange: func(from, to HealthStatus) {
mu.Lock()
defer mu.Unlock()
transitions = append(transitions, from.String()+"->"+to.String())
},
}
hc := NewHealthChecker(config)
hc.Start()
defer hc.Stop()
// Trigger failures
shouldFail.Store(true)
time.Sleep(100 * time.Millisecond)
// Recover
shouldFail.Store(false)
time.Sleep(100 * time.Millisecond)
mu.Lock()
defer mu.Unlock()
// Should have status transitions
assert.NotEmpty(t, transitions)
}
// TestHealthChecker_Stats tests statistics tracking
func TestHealthChecker_Stats(t *testing.T) {
t.Parallel()
var callCount atomic.Int32
checkFunc := func(ctx context.Context) error {
callCount.Add(1)
if callCount.Load()%2 == 0 {
return errors.New("failure")
}
return nil
}
config := &HealthCheckConfig{
CheckInterval: 20 * time.Millisecond,
Timeout: 10 * time.Millisecond,
UnhealthyThreshold: 5,
HealthyThreshold: 2,
CheckFunc: checkFunc,
}
hc := NewHealthChecker(config)
hc.Start()
defer hc.Stop()
time.Sleep(150 * time.Millisecond)
stats := hc.Stats()
assert.Greater(t, stats.TotalChecks, int64(0))
assert.Greater(t, stats.TotalFailures, int64(0))
assert.Greater(t, stats.SuccessRate, 0.0)
assert.Less(t, stats.SuccessRate, 1.0)
}
// TestHealthChecker_Timeout tests check timeout handling
func TestHealthChecker_Timeout(t *testing.T) {
t.Parallel()
checkFunc := func(ctx context.Context) error {
// Simulate slow check
select {
case <-time.After(100 * time.Millisecond):
return nil
case <-ctx.Done():
return ctx.Err()
}
}
config := &HealthCheckConfig{
CheckInterval: 50 * time.Millisecond,
Timeout: 10 * time.Millisecond, // Short timeout
UnhealthyThreshold: 2,
HealthyThreshold: 2,
CheckFunc: checkFunc,
}
hc := NewHealthChecker(config)
hc.Start()
defer hc.Stop()
time.Sleep(150 * time.Millisecond)
// Should be unhealthy due to timeouts
status := hc.GetStatus()
assert.NotEqual(t, HealthHealthy, status)
}
// TestHealthChecker_ConcurrentAccess tests thread safety
func TestHealthChecker_ConcurrentAccess(t *testing.T) {
t.Parallel()
checkFunc := func(ctx context.Context) error {
return nil
}
config := &HealthCheckConfig{
CheckInterval: 10 * time.Millisecond,
Timeout: 5 * time.Millisecond,
UnhealthyThreshold: 3,
HealthyThreshold: 2,
CheckFunc: checkFunc,
}
hc := NewHealthChecker(config)
hc.Start()
defer hc.Stop()
var wg sync.WaitGroup
goroutines := 20
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < 50; j++ {
_ = hc.GetStatus()
_ = hc.IsHealthy()
_ = hc.Stats()
hc.Check(context.Background())
}
}()
}
wg.Wait()
// Should complete without panics
}
// TestHealthChecker_StopAndStart tests lifecycle management
func TestHealthChecker_StopAndStart(t *testing.T) {
t.Parallel()
var callCount atomic.Int32
checkFunc := func(ctx context.Context) error {
callCount.Add(1)
return nil
}
config := &HealthCheckConfig{
CheckInterval: 20 * time.Millisecond,
Timeout: 10 * time.Millisecond,
UnhealthyThreshold: 3,
HealthyThreshold: 2,
CheckFunc: checkFunc,
}
hc := NewHealthChecker(config)
// Start
hc.Start()
time.Sleep(100 * time.Millisecond)
count1 := callCount.Load()
assert.Greater(t, count1, int32(0))
// Stop
hc.Stop()
time.Sleep(100 * time.Millisecond)
count2 := callCount.Load()
// Should not have increased significantly after stop
assert.Less(t, count2-count1, int32(3))
}
// TestHealthChecker_DegradedState tests degraded status
func TestHealthChecker_DegradedState(t *testing.T) {
t.Parallel()
var callCount atomic.Int32
checkFunc := func(ctx context.Context) error {
count := callCount.Add(1)
// Fail once, then succeed
if count == 1 {
return errors.New("single failure")
}
return nil
}
config := &HealthCheckConfig{
CheckInterval: 30 * time.Millisecond,
Timeout: 10 * time.Millisecond,
UnhealthyThreshold: 3, // Need 3 failures for unhealthy
HealthyThreshold: 2, // Need 2 successes for healthy
CheckFunc: checkFunc,
}
hc := NewHealthChecker(config)
hc.Start()
defer hc.Stop()
time.Sleep(100 * time.Millisecond)
// After initial checks, status should be set (might be healthy or degraded based on execution)
status := hc.GetStatus()
assert.True(t, status != HealthUnknown, "Status should not be unknown after checks")
}
// TestHealthChecker_DefaultConfig tests default configuration
func TestHealthChecker_DefaultConfig(t *testing.T) {
t.Parallel()
checkFunc := func(ctx context.Context) error {
return nil
}
config := &HealthCheckConfig{
CheckFunc: checkFunc,
}
hc := NewHealthChecker(config)
assert.NotNil(t, hc)
assert.Equal(t, HealthUnknown, hc.GetStatus())
// Verify default config was applied (we can't access private fields, so just check it works)
assert.NotNil(t, hc)
}
// TestHealthChecker_StatusString tests status string representation
func TestHealthChecker_StatusString(t *testing.T) {
t.Parallel()
assert.Equal(t, "healthy", HealthHealthy.String())
assert.Equal(t, "unhealthy", HealthUnhealthy.String())
assert.Equal(t, "degraded", HealthDegraded.String())
assert.Equal(t, "unknown", HealthStatus(999).String())
}
// TestHealthChecker_RecoveryPattern tests typical failure and recovery
func TestHealthChecker_RecoveryPattern(t *testing.T) {
t.Parallel()
var checkNumber atomic.Int32
checkFunc := func(ctx context.Context) error {
n := checkNumber.Add(1)
// Fail checks 3-5, succeed others
if n >= 3 && n <= 5 {
return errors.New("temporary failure")
}
return nil
}
var statusLog []HealthStatus
var mu sync.Mutex
config := &HealthCheckConfig{
CheckInterval: 30 * time.Millisecond,
Timeout: 10 * time.Millisecond,
UnhealthyThreshold: 3,
HealthyThreshold: 2,
CheckFunc: checkFunc,
OnStatusChange: func(from, to HealthStatus) {
mu.Lock()
defer mu.Unlock()
statusLog = append(statusLog, to)
},
}
hc := NewHealthChecker(config)
hc.Start()
defer hc.Stop()
time.Sleep(300 * time.Millisecond)
mu.Lock()
defer mu.Unlock()
// Should see transitions through unhealthy and back to healthy
assert.NotEmpty(t, statusLog)
// Final status should be healthy or degraded (recovered)
finalStatus := hc.GetStatus()
assert.True(t, finalStatus == HealthHealthy || finalStatus == HealthDegraded, "Should have recovered")
}
// Benchmark health checker performance
func BenchmarkHealthChecker_ForceCheck(b *testing.B) {
checkFunc := func(ctx context.Context) error {
return nil
}
config := &HealthCheckConfig{
CheckInterval: 10 * time.Minute,
Timeout: 1 * time.Second,
UnhealthyThreshold: 3,
HealthyThreshold: 2,
CheckFunc: checkFunc,
}
hc := NewHealthChecker(config)
b.ResetTimer()
for i := 0; i < b.N; i++ {
hc.Check(context.Background())
}
}
func BenchmarkHealthChecker_Status(b *testing.B) {
checkFunc := func(ctx context.Context) error {
return nil
}
config := &HealthCheckConfig{
CheckFunc: checkFunc,
}
hc := NewHealthChecker(config)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = hc.GetStatus()
}
}
+19 -5
View File
@@ -1,9 +1,12 @@
package cache
import (
"bytes"
"encoding/json"
"fmt"
"time"
"github.com/lukaszraczylo/traefikoidc/internal/pool"
)
// TypedCache provides a type-safe wrapper around Cache for specific types
@@ -42,13 +45,24 @@ func (tc *TypedCache[T]) Get(key string) (T, bool) {
}
// If that fails, try JSON marshaling/unmarshaling for complex types
data, err := json.Marshal(value)
if err != nil {
// Use pooled buffer for encoding
pm := pool.Get()
buf := pm.GetBuffer(256)
defer pm.PutBuffer(buf)
encoder := pm.GetJSONEncoder(buf)
defer pm.PutJSONEncoder(encoder)
if err := encoder.Encode(value); err != nil {
return zero, false
}
// Decode using pooled decoder
var result T
if err := json.Unmarshal(data, &result); err != nil {
decoder := pm.GetJSONDecoder(bytes.NewReader(buf.Bytes()))
defer pm.PutJSONDecoder(decoder)
if err := decoder.Decode(&result); err != nil {
return zero, false
}
@@ -278,12 +292,12 @@ type SessionCache struct {
// SessionData represents session information
type SessionData struct {
ExpiresAt time.Time `json:"expires_at"`
Claims map[string]interface{} `json:"claims"`
ID string `json:"id"`
UserID string `json:"user_id"`
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresAt time.Time `json:"expires_at"`
Claims map[string]interface{} `json:"claims"`
}
// NewSessionCache creates a new session cache

Some files were not shown because too many files have changed in this diff Show More